Ver código fonte

optimize: migrate speed

Yeuoly 5 meses atrás
pai
commit
ac5e3caebc
2 arquivos alterados com 47 adições e 21 exclusões
  1. 4 3
      api/commands.py
  2. 43 18
      api/services/plugin/plugin_migration.py

+ 4 - 3
api/commands.py

@@ -661,13 +661,14 @@ def migrate_data_for_plugin():
 
 
 @click.command("extract-plugins", help="Extract plugins.")
-@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.")
-def extract_plugins(output_file: str):
+@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
+@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
+def extract_plugins(output_file: str, workers: int):
     """
     Extract plugins.
     """
     click.echo(click.style("Starting extract plugins.", fg="white"))
 
-    PluginMigration.extract_plugins(output_file)
+    PluginMigration.extract_plugins(output_file, workers)
 
     click.echo(click.style("Extract plugins completed.", fg="green"))

+ 43 - 18
api/services/plugin/plugin_migration.py

@@ -14,6 +14,7 @@ from models.engine import db
 from models.model import App, AppMode, AppModelConfig
 from models.tools import BuiltinToolProvider
 from models.workflow import Workflow
+from flask import Flask, current_app
 
 logger = logging.getLogger(__name__)
 
@@ -22,10 +23,13 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
 
 class PluginMigration:
     @classmethod
-    def extract_plugins(cls, filepath: str) -> None:
+    def extract_plugins(cls, filepath: str, workers: int) -> None:
         """
         Migrate plugin.
         """
+        import concurrent.futures
+        from threading import Lock
+
         click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
         ended_at = datetime.datetime.now()
         started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
@@ -34,9 +38,42 @@ class PluginMigration:
         with Session(db.engine) as session:
             total_tenant_count = session.query(Tenant.id).count()
 
+        click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
+
         handled_tenant_count = 0
+        file_lock = Lock()
+        counter_lock = Lock()
+
+        thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
+
+        def process_tenant(flask_app: Flask, tenant_id: str) -> None:
+            with flask_app.app_context():
+                nonlocal handled_tenant_count
+                try:
+                    plugins = cls.extract_installed_plugin_ids(tenant_id)
+                    # Use lock when writing to file
+                    with file_lock:
+                        with open(filepath, "a") as f:
+                            f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
+
+                    # Use lock when updating counter
+                    with counter_lock:
+                        nonlocal handled_tenant_count
+                        handled_tenant_count += 1
+                        click.echo(
+                            click.style(
+                                f"[{datetime.datetime.now()}] "
+                                f"Processed {handled_tenant_count} tenants "
+                                f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
+                                f"{handled_tenant_count}/{total_tenant_count}",
+                                fg="green",
+                            )
+                        )
+                except Exception:
+                    logger.exception(f"Failed to process tenant {tenant_id}")
 
         while current_time < ended_at:
+            click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
             # Initial interval of 1 day, will be dynamically adjusted based on tenant count
             interval = datetime.timedelta(days=1)
             # Process tenants in this batch
@@ -84,7 +121,6 @@ class PluginMigration:
                 )
 
                 tenants = []
-
                 for row in rs:
                     tenant_id = str(row.id)
                     try:
@@ -93,25 +129,14 @@ class PluginMigration:
                         logger.exception(f"Failed to process tenant {tenant_id}")
                         continue
 
-            for tenant_id in tenants:
-                plugins = cls.extract_installed_plugin_ids(tenant_id)
-                # append to file, it's a jsonl file
-                with open(filepath, "a") as f:
-                    f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
-
-            handled_tenant_count += len(tenants)
-
-            click.echo(
-                click.style(
-                    f"Processed {handled_tenant_count} tenants "
-                    f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
-                    f"{handled_tenant_count}/{total_tenant_count}",
-                    fg="green",
-                )
-            )
+            # Process batch with thread pool
+            thread_pool.map(lambda tenant_id: process_tenant(current_app, tenant_id), tenants)
 
             current_time = batch_end
 
+        # wait for all threads to finish
+        thread_pool.shutdown(wait=True)
+
     @classmethod
     def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
         """