plugin_migration.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import datetime
  2. import logging
  3. from collections.abc import Sequence
  4. import click
  5. from sqlalchemy.orm import Session
  6. from core.agent.entities import AgentToolEntity
  7. from core.entities import DEFAULT_PLUGIN_ID
  8. from core.tools.entities.tool_entities import ToolProviderType
  9. from models.account import Tenant
  10. from models.engine import db
  11. from models.model import App, AppMode, AppModelConfig
  12. from models.tools import BuiltinToolProvider
  13. from models.workflow import Workflow
  14. logger = logging.getLogger(__name__)
  15. excluded_providers = ["time", "audio", "code", "webscraper"]
  16. class PluginMigration:
  17. @classmethod
  18. def extract_plugins(cls) -> None:
  19. """
  20. Migrate plugin.
  21. """
  22. click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
  23. ended_at = datetime.datetime.now()
  24. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  25. current_time = started_at
  26. with Session(db.engine) as session:
  27. total_tenant_count = session.query(Tenant.id).count()
  28. handled_tenant_count = 0
  29. while current_time < ended_at:
  30. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  31. interval = datetime.timedelta(days=1)
  32. # Process tenants in this batch
  33. with Session(db.engine) as session:
  34. # Calculate tenant count in next batch with current interval
  35. # Try different intervals until we find one with a reasonable tenant count
  36. test_intervals = [
  37. datetime.timedelta(days=1),
  38. datetime.timedelta(hours=12),
  39. datetime.timedelta(hours=6),
  40. datetime.timedelta(hours=3),
  41. datetime.timedelta(hours=1),
  42. ]
  43. for test_interval in test_intervals:
  44. tenant_count = (
  45. session.query(Tenant.id)
  46. .filter(Tenant.created_at.between(current_time, current_time + test_interval))
  47. .count()
  48. )
  49. if tenant_count <= 100:
  50. interval = test_interval
  51. break
  52. else:
  53. # If all intervals have too many tenants, use minimum interval
  54. interval = datetime.timedelta(hours=1)
  55. # Adjust interval to target ~100 tenants per batch
  56. if tenant_count > 0:
  57. # Scale interval based on ratio to target count
  58. interval = min(
  59. datetime.timedelta(days=1), # Max 1 day
  60. max(
  61. datetime.timedelta(hours=1), # Min 1 hour
  62. interval * (100 / tenant_count), # Scale to target 100
  63. ),
  64. )
  65. batch_end = min(current_time + interval, ended_at)
  66. rs = (
  67. session.query(Tenant.id)
  68. .filter(Tenant.created_at.between(current_time, batch_end))
  69. .order_by(Tenant.created_at)
  70. )
  71. tenants = []
  72. for row in rs:
  73. tenant_id = str(row.id)
  74. try:
  75. tenants.append(tenant_id)
  76. except Exception:
  77. logger.exception(f"Failed to process tenant {tenant_id}")
  78. continue
  79. for tenant_id in tenants:
  80. plugins = cls.extract_installed_plugin_ids(tenant_id)
  81. print(plugins)
  82. handled_tenant_count += len(tenants)
  83. click.echo(
  84. click.style(
  85. f"Processed {handled_tenant_count} tenants ({(handled_tenant_count/total_tenant_count)*100:.1f}%), {handled_tenant_count}/{total_tenant_count}",
  86. fg="green",
  87. )
  88. )
  89. current_time = batch_end
  90. @classmethod
  91. def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
  92. """
  93. Extract installed plugin ids.
  94. """
  95. tools = cls.extract_tool_tables(tenant_id)
  96. models = cls.extract_model_tables(tenant_id)
  97. workflows = cls.extract_workflow_tables(tenant_id)
  98. apps = cls.extract_app_tables(tenant_id)
  99. return list({*tools, *models, *workflows, *apps})
  100. @classmethod
  101. def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
  102. """
  103. Extract model tables.
  104. NOTE: rename google to gemini
  105. """
  106. models = []
  107. table_pairs = [
  108. ("providers", "provider_name"),
  109. ("provider_models", "provider_name"),
  110. ("provider_orders", "provider_name"),
  111. ("tenant_default_models", "provider_name"),
  112. ("tenant_preferred_model_providers", "provider_name"),
  113. ("provider_model_settings", "provider_name"),
  114. ("load_balancing_model_configs", "provider_name"),
  115. ]
  116. for table, column in table_pairs:
  117. models.extend(cls.extract_model_table(tenant_id, table, column))
  118. # duplicate models
  119. models = list(set(models))
  120. return models
  121. @classmethod
  122. def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
  123. """
  124. Extract model table.
  125. """
  126. with Session(db.engine) as session:
  127. rs = session.execute(
  128. db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
  129. )
  130. result = []
  131. for row in rs:
  132. provider_name = str(row[0])
  133. if provider_name and "/" not in provider_name:
  134. if provider_name == "google":
  135. provider_name = "gemini"
  136. result.append(DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name)
  137. elif provider_name:
  138. result.append(provider_name)
  139. return result
  140. @classmethod
  141. def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
  142. """
  143. Extract tool tables.
  144. """
  145. with Session(db.engine) as session:
  146. rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  147. result = []
  148. for row in rs:
  149. if "/" not in row.provider:
  150. result.append(DEFAULT_PLUGIN_ID + "/" + row.provider + "/" + row.provider)
  151. else:
  152. result.append(row.provider)
  153. return result
  154. @classmethod
  155. def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
  156. """
  157. Handle builtin tool provider.
  158. """
  159. if provider_name == "jina":
  160. provider_name = "jina_tool"
  161. elif provider_name == "siliconflow":
  162. provider_name = "siliconflow_tool"
  163. elif provider_name == "stepfun":
  164. provider_name = "stepfun_tool"
  165. if "/" not in provider_name:
  166. return DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name
  167. else:
  168. return provider_name
  169. @classmethod
  170. def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
  171. """
  172. Extract workflow tables, only ToolNode is required.
  173. """
  174. with Session(db.engine) as session:
  175. rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
  176. result = []
  177. for row in rs:
  178. graph = row.graph_dict
  179. # get nodes
  180. nodes = graph.get("nodes", [])
  181. for node in nodes:
  182. data = node.get("data", {})
  183. if data.get("type") == "tool":
  184. provider_name = data.get("provider_name")
  185. provider_type = data.get("provider_type")
  186. if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
  187. provider_name = cls._handle_builtin_tool_provider(provider_name)
  188. result.append(provider_name)
  189. return result
  190. @classmethod
  191. def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
  192. """
  193. Extract app tables.
  194. """
  195. with Session(db.engine) as session:
  196. apps = session.query(App).filter(App.tenant_id == tenant_id).all()
  197. if not apps:
  198. return []
  199. agent_app_model_config_ids = [
  200. app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
  201. ]
  202. rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
  203. result = []
  204. for row in rs:
  205. agent_config = row.agent_mode_dict
  206. if "tools" in agent_config and isinstance(agent_config["tools"], list):
  207. for tool in agent_config["tools"]:
  208. if isinstance(tool, dict):
  209. try:
  210. tool_entity = AgentToolEntity(**tool)
  211. if (
  212. tool_entity.provider_type == ToolProviderType.BUILT_IN.value
  213. and tool_entity.provider_id not in excluded_providers
  214. ):
  215. result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
  216. except Exception:
  217. logger.exception(f"Failed to process tool {tool}")
  218. continue
  219. return result