plugin_migration.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. import datetime
  2. import json
  3. import logging
  4. import sys
  5. import time
  6. from collections.abc import Mapping, Sequence
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pathlib import Path
  9. from typing import Any, Optional
  10. from uuid import uuid4
  11. import click
  12. import tqdm
  13. from flask import Flask, current_app
  14. from sqlalchemy.orm import Session
  15. from core.agent.entities import AgentToolEntity
  16. from core.entities import DEFAULT_PLUGIN_ID
  17. from core.helper import marketplace
  18. from core.plugin.entities.plugin import PluginInstallationSource
  19. from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
  20. from core.plugin.manager.plugin import PluginInstallationManager
  21. from core.tools.entities.tool_entities import ToolProviderType
  22. from models.account import Tenant
  23. from models.engine import db
  24. from models.model import App, AppMode, AppModelConfig
  25. from models.tools import BuiltinToolProvider
  26. from models.workflow import Workflow
  27. logger = logging.getLogger(__name__)
  28. excluded_providers = ["time", "audio", "code", "webscraper"]
  29. class PluginMigration:
  30. @classmethod
  31. def extract_plugins(cls, filepath: str, workers: int) -> None:
  32. """
  33. Migrate plugin.
  34. """
  35. import concurrent.futures
  36. from threading import Lock
  37. click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
  38. ended_at = datetime.datetime.now()
  39. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  40. current_time = started_at
  41. with Session(db.engine) as session:
  42. total_tenant_count = session.query(Tenant.id).count()
  43. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  44. handled_tenant_count = 0
  45. file_lock = Lock()
  46. counter_lock = Lock()
  47. thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
  48. def process_tenant(flask_app: Flask, tenant_id: str) -> None:
  49. with flask_app.app_context():
  50. nonlocal handled_tenant_count
  51. try:
  52. plugins = cls.extract_installed_plugin_ids(tenant_id)
  53. # Use lock when writing to file
  54. with file_lock:
  55. with open(filepath, "a") as f:
  56. f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
  57. # Use lock when updating counter
  58. with counter_lock:
  59. nonlocal handled_tenant_count
  60. handled_tenant_count += 1
  61. click.echo(
  62. click.style(
  63. f"[{datetime.datetime.now()}] "
  64. f"Processed {handled_tenant_count} tenants "
  65. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  66. f"{handled_tenant_count}/{total_tenant_count}",
  67. fg="green",
  68. )
  69. )
  70. except Exception:
  71. logger.exception(f"Failed to process tenant {tenant_id}")
  72. futures = []
  73. while current_time < ended_at:
  74. click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
  75. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  76. interval = datetime.timedelta(days=1)
  77. # Process tenants in this batch
  78. with Session(db.engine) as session:
  79. # Calculate tenant count in next batch with current interval
  80. # Try different intervals until we find one with a reasonable tenant count
  81. test_intervals = [
  82. datetime.timedelta(days=1),
  83. datetime.timedelta(hours=12),
  84. datetime.timedelta(hours=6),
  85. datetime.timedelta(hours=3),
  86. datetime.timedelta(hours=1),
  87. ]
  88. for test_interval in test_intervals:
  89. tenant_count = (
  90. session.query(Tenant.id)
  91. .filter(Tenant.created_at.between(current_time, current_time + test_interval))
  92. .count()
  93. )
  94. if tenant_count <= 100:
  95. interval = test_interval
  96. break
  97. else:
  98. # If all intervals have too many tenants, use minimum interval
  99. interval = datetime.timedelta(hours=1)
  100. # Adjust interval to target ~100 tenants per batch
  101. if tenant_count > 0:
  102. # Scale interval based on ratio to target count
  103. interval = min(
  104. datetime.timedelta(days=1), # Max 1 day
  105. max(
  106. datetime.timedelta(hours=1), # Min 1 hour
  107. interval * (100 / tenant_count), # Scale to target 100
  108. ),
  109. )
  110. batch_end = min(current_time + interval, ended_at)
  111. rs = (
  112. session.query(Tenant.id)
  113. .filter(Tenant.created_at.between(current_time, batch_end))
  114. .order_by(Tenant.created_at)
  115. )
  116. tenants = []
  117. for row in rs:
  118. tenant_id = str(row.id)
  119. try:
  120. tenants.append(tenant_id)
  121. except Exception:
  122. logger.exception(f"Failed to process tenant {tenant_id}")
  123. continue
  124. futures.append(
  125. thread_pool.submit(
  126. process_tenant,
  127. current_app._get_current_object(), # type: ignore[attr-defined]
  128. tenant_id,
  129. )
  130. )
  131. current_time = batch_end
  132. # wait for all threads to finish
  133. for future in futures:
  134. future.result()
  135. @classmethod
  136. def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
  137. """
  138. Extract installed plugin ids.
  139. """
  140. tools = cls.extract_tool_tables(tenant_id)
  141. models = cls.extract_model_tables(tenant_id)
  142. workflows = cls.extract_workflow_tables(tenant_id)
  143. apps = cls.extract_app_tables(tenant_id)
  144. return list({*tools, *models, *workflows, *apps})
  145. @classmethod
  146. def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
  147. """
  148. Extract model tables.
  149. NOTE: rename google to gemini
  150. """
  151. models = []
  152. table_pairs = [
  153. ("providers", "provider_name"),
  154. ("provider_models", "provider_name"),
  155. ("provider_orders", "provider_name"),
  156. ("tenant_default_models", "provider_name"),
  157. ("tenant_preferred_model_providers", "provider_name"),
  158. ("provider_model_settings", "provider_name"),
  159. ("load_balancing_model_configs", "provider_name"),
  160. ]
  161. for table, column in table_pairs:
  162. models.extend(cls.extract_model_table(tenant_id, table, column))
  163. # duplicate models
  164. models = list(set(models))
  165. return models
  166. @classmethod
  167. def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
  168. """
  169. Extract model table.
  170. """
  171. with Session(db.engine) as session:
  172. rs = session.execute(
  173. db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
  174. )
  175. result = []
  176. for row in rs:
  177. provider_name = str(row[0])
  178. if provider_name and "/" not in provider_name:
  179. if provider_name == "google":
  180. provider_name = "gemini"
  181. result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
  182. elif provider_name:
  183. result.append(provider_name)
  184. return result
  185. @classmethod
  186. def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
  187. """
  188. Extract tool tables.
  189. """
  190. with Session(db.engine) as session:
  191. rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  192. result = []
  193. for row in rs:
  194. if "/" not in row.provider:
  195. result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
  196. else:
  197. result.append(row.provider)
  198. return result
  199. @classmethod
  200. def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
  201. """
  202. Handle builtin tool provider.
  203. """
  204. if provider_name == "jina":
  205. provider_name = "jina_tool"
  206. elif provider_name == "siliconflow":
  207. provider_name = "siliconflow_tool"
  208. elif provider_name == "stepfun":
  209. provider_name = "stepfun_tool"
  210. if "/" not in provider_name:
  211. return DEFAULT_PLUGIN_ID + "/" + provider_name
  212. else:
  213. return provider_name
  214. @classmethod
  215. def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
  216. """
  217. Extract workflow tables, only ToolNode is required.
  218. """
  219. with Session(db.engine) as session:
  220. rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
  221. result = []
  222. for row in rs:
  223. graph = row.graph_dict
  224. # get nodes
  225. nodes = graph.get("nodes", [])
  226. for node in nodes:
  227. data = node.get("data", {})
  228. if data.get("type") == "tool":
  229. provider_name = data.get("provider_name")
  230. provider_type = data.get("provider_type")
  231. if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
  232. provider_name = cls._handle_builtin_tool_provider(provider_name)
  233. result.append(provider_name)
  234. return result
  235. @classmethod
  236. def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
  237. """
  238. Extract app tables.
  239. """
  240. with Session(db.engine) as session:
  241. apps = session.query(App).filter(App.tenant_id == tenant_id).all()
  242. if not apps:
  243. return []
  244. agent_app_model_config_ids = [
  245. app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
  246. ]
  247. rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
  248. result = []
  249. for row in rs:
  250. agent_config = row.agent_mode_dict
  251. if "tools" in agent_config and isinstance(agent_config["tools"], list):
  252. for tool in agent_config["tools"]:
  253. if isinstance(tool, dict):
  254. try:
  255. tool_entity = AgentToolEntity(**tool)
  256. if (
  257. tool_entity.provider_type == ToolProviderType.BUILT_IN.value
  258. and tool_entity.provider_id not in excluded_providers
  259. ):
  260. result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
  261. except Exception:
  262. logger.exception(f"Failed to process tool {tool}")
  263. continue
  264. return result
  265. @classmethod
  266. def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
  267. """
  268. Fetch plugin unique identifier using plugin id.
  269. """
  270. plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
  271. if not plugin_manifest:
  272. return None
  273. return plugin_manifest[0].latest_package_identifier
  274. @classmethod
  275. def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
  276. """
  277. Extract unique plugins.
  278. """
  279. Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
  280. @classmethod
  281. def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
  282. plugins: dict[str, str] = {}
  283. plugin_ids = []
  284. plugin_not_exist = []
  285. logger.info(f"Extracting unique plugins from {extracted_plugins}")
  286. with open(extracted_plugins) as f:
  287. for line in f:
  288. data = json.loads(line)
  289. new_plugin_ids = data.get("plugins", [])
  290. for plugin_id in new_plugin_ids:
  291. if plugin_id not in plugin_ids:
  292. plugin_ids.append(plugin_id)
  293. def fetch_plugin(plugin_id):
  294. unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
  295. if unique_identifier:
  296. plugins[plugin_id] = unique_identifier
  297. else:
  298. plugin_not_exist.append(plugin_id)
  299. with ThreadPoolExecutor(max_workers=10) as executor:
  300. list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
  301. return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
  302. @classmethod
  303. def install_plugins(cls, extracted_plugins: str, output_file: str) -> None:
  304. """
  305. Install plugins.
  306. """
  307. manager = PluginInstallationManager()
  308. plugins = cls.extract_unique_plugins(extracted_plugins)
  309. not_installed = []
  310. plugin_install_failed = []
  311. # use a fake tenant id to install all the plugins
  312. fake_tenant_id = uuid4().hex
  313. logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
  314. thread_pool = ThreadPoolExecutor(max_workers=40)
  315. response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
  316. if response.get("failed"):
  317. plugin_install_failed.extend(response.get("failed", []))
  318. def install(tenant_id: str, plugin_ids: list[str]) -> None:
  319. logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
  320. # at most 64 plugins one batch
  321. for i in range(0, len(plugin_ids), 64):
  322. batch_plugin_ids = plugin_ids[i : i + 64]
  323. batch_plugin_identifiers = [plugins["plugins"][plugin_id] for plugin_id in batch_plugin_ids]
  324. manager.install_from_identifiers(
  325. tenant_id,
  326. batch_plugin_identifiers,
  327. PluginInstallationSource.Marketplace,
  328. metas=[
  329. {
  330. "plugin_unique_identifier": identifier,
  331. }
  332. for identifier in batch_plugin_identifiers
  333. ],
  334. )
  335. with open(extracted_plugins) as f:
  336. """
  337. Read line by line, and install plugins for each tenant.
  338. """
  339. for line in f:
  340. data = json.loads(line)
  341. tenant_id = data.get("tenant_id")
  342. plugin_ids = data.get("plugins", [])
  343. current_not_installed = {
  344. "tenant_id": tenant_id,
  345. "plugin_not_exist": [],
  346. }
  347. # get plugin unique identifier
  348. for plugin_id in plugin_ids:
  349. unique_identifier = plugins.get(plugin_id)
  350. if unique_identifier:
  351. current_not_installed["plugin_not_exist"].append(plugin_id)
  352. if current_not_installed["plugin_not_exist"]:
  353. not_installed.append(current_not_installed)
  354. thread_pool.submit(install, tenant_id, plugin_ids)
  355. thread_pool.shutdown(wait=True)
  356. logger.info("Uninstall plugins")
  357. sys.exit(-1)
  358. # get installation
  359. try:
  360. installation = manager.list_plugins(fake_tenant_id)
  361. while installation:
  362. for plugin in installation:
  363. manager.uninstall(fake_tenant_id, plugin.installation_id)
  364. installation = manager.list_plugins(fake_tenant_id)
  365. except Exception:
  366. logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
  367. Path(output_file).write_text(
  368. json.dumps(
  369. {
  370. "not_installed": not_installed,
  371. "plugin_install_failed": plugin_install_failed,
  372. }
  373. )
  374. )
  375. @classmethod
  376. def handle_plugin_instance_install(
  377. cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
  378. ) -> Mapping[str, Any]:
  379. """
  380. Install plugins for a tenant.
  381. """
  382. manager = PluginInstallationManager()
  383. # download all the plugins and upload
  384. thread_pool = ThreadPoolExecutor(max_workers=10)
  385. futures = []
  386. for plugin_id, plugin_identifier in plugin_identifiers_map.items():
  387. def download_and_upload(tenant_id, plugin_id, plugin_identifier):
  388. plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
  389. if not plugin_package:
  390. raise Exception(f"Failed to download plugin {plugin_identifier}")
  391. # upload
  392. manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
  393. futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
  394. # Wait for all downloads to complete
  395. for future in futures:
  396. future.result() # This will raise any exceptions that occurred
  397. thread_pool.shutdown(wait=True)
  398. success = []
  399. failed = []
  400. reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
  401. # at most 64 plugins one batch
  402. for i in range(0, len(plugin_identifiers_map), 64):
  403. batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 64]
  404. batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
  405. try:
  406. response = manager.install_from_identifiers(
  407. tenant_id=tenant_id,
  408. identifiers=batch_plugin_identifiers,
  409. source=PluginInstallationSource.Marketplace,
  410. metas=[
  411. {
  412. "plugin_unique_identifier": identifier,
  413. }
  414. for identifier in batch_plugin_identifiers
  415. ],
  416. )
  417. except Exception:
  418. # add to failed
  419. failed.extend(batch_plugin_identifiers)
  420. continue
  421. if response.all_installed:
  422. success.extend(batch_plugin_identifiers)
  423. continue
  424. task_id = response.task_id
  425. done = False
  426. while not done:
  427. status = manager.fetch_plugin_installation_task(tenant_id, task_id)
  428. if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
  429. for plugin in status.plugins:
  430. if plugin.status == PluginInstallTaskStatus.Success:
  431. success.append(reverse_map[plugin.plugin_unique_identifier])
  432. else:
  433. failed.append(reverse_map[plugin.plugin_unique_identifier])
  434. logger.error(
  435. f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
  436. )
  437. done = True
  438. else:
  439. time.sleep(1)
  440. return {"success": success, "failed": failed}