import datetime import json import logging from collections.abc import Sequence import click from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.entities import DEFAULT_PLUGIN_ID from core.tools.entities.tool_entities import ToolProviderType from models.account import Tenant from models.engine import db from models.model import App, AppMode, AppModelConfig from models.tools import BuiltinToolProvider from models.workflow import Workflow logger = logging.getLogger(__name__) excluded_providers = ["time", "audio", "code", "webscraper"] class PluginMigration: @classmethod def extract_plugins(cls, filepath: str) -> None: """ Migrate plugin. """ click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white")) ended_at = datetime.datetime.now() started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) current_time = started_at with Session(db.engine) as session: total_tenant_count = session.query(Tenant.id).count() handled_tenant_count = 0 while current_time < ended_at: # Initial interval of 1 day, will be dynamically adjusted based on tenant count interval = datetime.timedelta(days=1) # Process tenants in this batch with Session(db.engine) as session: # Calculate tenant count in next batch with current interval # Try different intervals until we find one with a reasonable tenant count test_intervals = [ datetime.timedelta(days=1), datetime.timedelta(hours=12), datetime.timedelta(hours=6), datetime.timedelta(hours=3), datetime.timedelta(hours=1), ] for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) .filter(Tenant.created_at.between(current_time, current_time + test_interval)) .count() ) if tenant_count <= 100: interval = test_interval break else: # If all intervals have too many tenants, use minimum interval interval = datetime.timedelta(hours=1) # Adjust interval to target ~100 tenants per batch if tenant_count > 0: # Scale interval based on ratio to target count interval = min( datetime.timedelta(days=1), # Max 1 day max( datetime.timedelta(hours=1), # Min 1 hour interval * (100 / tenant_count), # Scale to target 100 ), ) batch_end = min(current_time + interval, ended_at) rs = ( session.query(Tenant.id) .filter(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) tenants = [] for row in rs: tenant_id = str(row.id) try: tenants.append(tenant_id) except Exception: logger.exception(f"Failed to process tenant {tenant_id}") continue for tenant_id in tenants: plugins = cls.extract_installed_plugin_ids(tenant_id) # append to file, it's a jsonl file with open(filepath, "a") as f: f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n") handled_tenant_count += len(tenants) click.echo( click.style( f"Processed {handled_tenant_count} tenants " f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), " f"{handled_tenant_count}/{total_tenant_count}", fg="green", ) ) current_time = batch_end @classmethod def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]: """ Extract installed plugin ids. """ tools = cls.extract_tool_tables(tenant_id) models = cls.extract_model_tables(tenant_id) workflows = cls.extract_workflow_tables(tenant_id) apps = cls.extract_app_tables(tenant_id) return list({*tools, *models, *workflows, *apps}) @classmethod def extract_model_tables(cls, tenant_id: str) -> Sequence[str]: """ Extract model tables. NOTE: rename google to gemini """ models = [] table_pairs = [ ("providers", "provider_name"), ("provider_models", "provider_name"), ("provider_orders", "provider_name"), ("tenant_default_models", "provider_name"), ("tenant_preferred_model_providers", "provider_name"), ("provider_model_settings", "provider_name"), ("load_balancing_model_configs", "provider_name"), ] for table, column in table_pairs: models.extend(cls.extract_model_table(tenant_id, table, column)) # duplicate models models = list(set(models)) return models @classmethod def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]: """ Extract model table. """ with Session(db.engine) as session: rs = session.execute( db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} ) result = [] for row in rs: provider_name = str(row[0]) if provider_name and "/" not in provider_name: if provider_name == "google": provider_name = "gemini" result.append(DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name) elif provider_name: result.append(provider_name) return result @classmethod def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]: """ Extract tool tables. """ with Session(db.engine) as session: rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() result = [] for row in rs: if "/" not in row.provider: result.append(DEFAULT_PLUGIN_ID + "/" + row.provider + "/" + row.provider) else: result.append(row.provider) return result @classmethod def _handle_builtin_tool_provider(cls, provider_name: str) -> str: """ Handle builtin tool provider. """ if provider_name == "jina": provider_name = "jina_tool" elif provider_name == "siliconflow": provider_name = "siliconflow_tool" elif provider_name == "stepfun": provider_name = "stepfun_tool" if "/" not in provider_name: return DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name else: return provider_name @classmethod def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]: """ Extract workflow tables, only ToolNode is required. """ with Session(db.engine) as session: rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all() result = [] for row in rs: graph = row.graph_dict # get nodes nodes = graph.get("nodes", []) for node in nodes: data = node.get("data", {}) if data.get("type") == "tool": provider_name = data.get("provider_name") provider_type = data.get("provider_type") if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value: provider_name = cls._handle_builtin_tool_provider(provider_name) result.append(provider_name) return result @classmethod def extract_app_tables(cls, tenant_id: str) -> Sequence[str]: """ Extract app tables. """ with Session(db.engine) as session: apps = session.query(App).filter(App.tenant_id == tenant_id).all() if not apps: return [] agent_app_model_config_ids = [ app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value ] rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all() result = [] for row in rs: agent_config = row.agent_mode_dict if "tools" in agent_config and isinstance(agent_config["tools"], list): for tool in agent_config["tools"]: if isinstance(tool, dict): try: tool_entity = AgentToolEntity(**tool) if ( tool_entity.provider_type == ToolProviderType.BUILT_IN.value and tool_entity.provider_id not in excluded_providers ): result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id)) except Exception: logger.exception(f"Failed to process tool {tool}") continue return result