plugin_migration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import datetime
  2. import json
  3. import logging
  4. from collections.abc import Sequence
  5. import click
  6. from flask import Flask, current_app
  7. from sqlalchemy.orm import Session
  8. from core.agent.entities import AgentToolEntity
  9. from core.entities import DEFAULT_PLUGIN_ID
  10. from core.tools.entities.tool_entities import ToolProviderType
  11. from models.account import Tenant
  12. from models.engine import db
  13. from models.model import App, AppMode, AppModelConfig
  14. from models.tools import BuiltinToolProvider
  15. from models.workflow import Workflow
  16. logger = logging.getLogger(__name__)
  17. excluded_providers = ["time", "audio", "code", "webscraper"]
  18. class PluginMigration:
  19. @classmethod
  20. def extract_plugins(cls, filepath: str, workers: int) -> None:
  21. """
  22. Migrate plugin.
  23. """
  24. import concurrent.futures
  25. from threading import Lock
  26. click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
  27. ended_at = datetime.datetime.now()
  28. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  29. current_time = started_at
  30. with Session(db.engine) as session:
  31. total_tenant_count = session.query(Tenant.id).count()
  32. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  33. handled_tenant_count = 0
  34. file_lock = Lock()
  35. counter_lock = Lock()
  36. thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
  37. def process_tenant(flask_app: Flask, tenant_id: str) -> None:
  38. with flask_app.app_context():
  39. nonlocal handled_tenant_count
  40. try:
  41. plugins = cls.extract_installed_plugin_ids(tenant_id)
  42. # Use lock when writing to file
  43. with file_lock:
  44. with open(filepath, "a") as f:
  45. f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
  46. # Use lock when updating counter
  47. with counter_lock:
  48. nonlocal handled_tenant_count
  49. handled_tenant_count += 1
  50. click.echo(
  51. click.style(
  52. f"[{datetime.datetime.now()}] "
  53. f"Processed {handled_tenant_count} tenants "
  54. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  55. f"{handled_tenant_count}/{total_tenant_count}",
  56. fg="green",
  57. )
  58. )
  59. except Exception:
  60. logger.exception(f"Failed to process tenant {tenant_id}")
  61. while current_time < ended_at:
  62. click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
  63. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  64. interval = datetime.timedelta(days=1)
  65. # Process tenants in this batch
  66. with Session(db.engine) as session:
  67. # Calculate tenant count in next batch with current interval
  68. # Try different intervals until we find one with a reasonable tenant count
  69. test_intervals = [
  70. datetime.timedelta(days=1),
  71. datetime.timedelta(hours=12),
  72. datetime.timedelta(hours=6),
  73. datetime.timedelta(hours=3),
  74. datetime.timedelta(hours=1),
  75. ]
  76. for test_interval in test_intervals:
  77. tenant_count = (
  78. session.query(Tenant.id)
  79. .filter(Tenant.created_at.between(current_time, current_time + test_interval))
  80. .count()
  81. )
  82. if tenant_count <= 100:
  83. interval = test_interval
  84. break
  85. else:
  86. # If all intervals have too many tenants, use minimum interval
  87. interval = datetime.timedelta(hours=1)
  88. # Adjust interval to target ~100 tenants per batch
  89. if tenant_count > 0:
  90. # Scale interval based on ratio to target count
  91. interval = min(
  92. datetime.timedelta(days=1), # Max 1 day
  93. max(
  94. datetime.timedelta(hours=1), # Min 1 hour
  95. interval * (100 / tenant_count), # Scale to target 100
  96. ),
  97. )
  98. batch_end = min(current_time + interval, ended_at)
  99. rs = (
  100. session.query(Tenant.id)
  101. .filter(Tenant.created_at.between(current_time, batch_end))
  102. .order_by(Tenant.created_at)
  103. )
  104. tenants = []
  105. for row in rs:
  106. tenant_id = str(row.id)
  107. try:
  108. tenants.append(tenant_id)
  109. except Exception:
  110. logger.exception(f"Failed to process tenant {tenant_id}")
  111. continue
  112. # Process batch with thread pool
  113. thread_pool.map(lambda tenant_id: process_tenant(current_app, tenant_id), tenants)
  114. current_time = batch_end
  115. # wait for all threads to finish
  116. thread_pool.shutdown(wait=True)
  117. @classmethod
  118. def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
  119. """
  120. Extract installed plugin ids.
  121. """
  122. tools = cls.extract_tool_tables(tenant_id)
  123. models = cls.extract_model_tables(tenant_id)
  124. workflows = cls.extract_workflow_tables(tenant_id)
  125. apps = cls.extract_app_tables(tenant_id)
  126. return list({*tools, *models, *workflows, *apps})
  127. @classmethod
  128. def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
  129. """
  130. Extract model tables.
  131. NOTE: rename google to gemini
  132. """
  133. models = []
  134. table_pairs = [
  135. ("providers", "provider_name"),
  136. ("provider_models", "provider_name"),
  137. ("provider_orders", "provider_name"),
  138. ("tenant_default_models", "provider_name"),
  139. ("tenant_preferred_model_providers", "provider_name"),
  140. ("provider_model_settings", "provider_name"),
  141. ("load_balancing_model_configs", "provider_name"),
  142. ]
  143. for table, column in table_pairs:
  144. models.extend(cls.extract_model_table(tenant_id, table, column))
  145. # duplicate models
  146. models = list(set(models))
  147. return models
  148. @classmethod
  149. def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
  150. """
  151. Extract model table.
  152. """
  153. with Session(db.engine) as session:
  154. rs = session.execute(
  155. db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
  156. )
  157. result = []
  158. for row in rs:
  159. provider_name = str(row[0])
  160. if provider_name and "/" not in provider_name:
  161. if provider_name == "google":
  162. provider_name = "gemini"
  163. result.append(DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name)
  164. elif provider_name:
  165. result.append(provider_name)
  166. return result
  167. @classmethod
  168. def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
  169. """
  170. Extract tool tables.
  171. """
  172. with Session(db.engine) as session:
  173. rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  174. result = []
  175. for row in rs:
  176. if "/" not in row.provider:
  177. result.append(DEFAULT_PLUGIN_ID + "/" + row.provider + "/" + row.provider)
  178. else:
  179. result.append(row.provider)
  180. return result
  181. @classmethod
  182. def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
  183. """
  184. Handle builtin tool provider.
  185. """
  186. if provider_name == "jina":
  187. provider_name = "jina_tool"
  188. elif provider_name == "siliconflow":
  189. provider_name = "siliconflow_tool"
  190. elif provider_name == "stepfun":
  191. provider_name = "stepfun_tool"
  192. if "/" not in provider_name:
  193. return DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name
  194. else:
  195. return provider_name
  196. @classmethod
  197. def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
  198. """
  199. Extract workflow tables, only ToolNode is required.
  200. """
  201. with Session(db.engine) as session:
  202. rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
  203. result = []
  204. for row in rs:
  205. graph = row.graph_dict
  206. # get nodes
  207. nodes = graph.get("nodes", [])
  208. for node in nodes:
  209. data = node.get("data", {})
  210. if data.get("type") == "tool":
  211. provider_name = data.get("provider_name")
  212. provider_type = data.get("provider_type")
  213. if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
  214. provider_name = cls._handle_builtin_tool_provider(provider_name)
  215. result.append(provider_name)
  216. return result
  217. @classmethod
  218. def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
  219. """
  220. Extract app tables.
  221. """
  222. with Session(db.engine) as session:
  223. apps = session.query(App).filter(App.tenant_id == tenant_id).all()
  224. if not apps:
  225. return []
  226. agent_app_model_config_ids = [
  227. app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
  228. ]
  229. rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
  230. result = []
  231. for row in rs:
  232. agent_config = row.agent_mode_dict
  233. if "tools" in agent_config and isinstance(agent_config["tools"], list):
  234. for tool in agent_config["tools"]:
  235. if isinstance(tool, dict):
  236. try:
  237. tool_entity = AgentToolEntity(**tool)
  238. if (
  239. tool_entity.provider_type == ToolProviderType.BUILT_IN.value
  240. and tool_entity.provider_id not in excluded_providers
  241. ):
  242. result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
  243. except Exception:
  244. logger.exception(f"Failed to process tool {tool}")
  245. continue
  246. return result