Browse Source

feat: plugin migrations

Yeuoly 4 months ago
parent
commit
6e73ad2fc6

+ 39 - 0
api/commands.py

@@ -677,3 +677,42 @@ def extract_plugins(output_file: str, workers: int):
     PluginMigration.extract_plugins(output_file, workers)
 
     click.echo(click.style("Extract plugins completed.", fg="green"))
+
+
+@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
+@click.option(
+    "--output_file",
+    prompt=True,
+    help="The file to store the extracted unique identifiers.",
+    default="unique_identifiers.json",
+)
+@click.option(
+    "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
+)
+def extract_unique_plugins(output_file: str, input_file: str):
+    """
+    Extract unique plugins.
+    """
+    click.echo(click.style("Starting extract unique plugins.", fg="white"))
+
+    PluginMigration.extract_unique_plugins(input_file, output_file)
+
+    click.echo(click.style("Extract unique plugins completed.", fg="green"))
+
+
+@click.command("install-plugins", help="Install plugins.")
+@click.option(
+    "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
+)
+@click.option(
+    "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
+)
+def install_plugins(input_file: str, output_file: str):
+    """
+    Install plugins.
+    """
+    click.echo(click.style("Starting install plugins.", fg="white"))
+
+    PluginMigration.install_plugins(input_file, output_file)
+
+    click.echo(click.style("Install plugins completed.", fg="green"))

+ 9 - 1
api/core/model_runtime/entities/provider_entities.py

@@ -2,7 +2,7 @@ from collections.abc import Sequence
 from enum import Enum
 from typing import Optional
 
-from pydantic import BaseModel, ConfigDict, Field
+from pydantic import BaseModel, ConfigDict, Field, field_validator
 
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -134,6 +134,14 @@ class ProviderEntity(BaseModel):
     # pydantic configs
     model_config = ConfigDict(protected_namespaces=())
 
+    @field_validator("models", mode="before")
+    @classmethod
+    def validate_models(cls, v):
+        # returns EmptyList if v is empty
+        if not v:
+            return []
+        return v
+
     def to_simple_provider(self) -> SimpleProviderEntity:
         """
         Convert to simple provider.

+ 6 - 2
api/core/plugin/manager/plugin.py

@@ -76,7 +76,11 @@ class PluginInstallationManager(BasePluginManager):
         )
 
     def install_from_identifiers(
-        self, tenant_id: str, identifiers: Sequence[str], source: PluginInstallationSource, meta: dict
+        self,
+        tenant_id: str,
+        identifiers: Sequence[str],
+        source: PluginInstallationSource,
+        metas: list[dict],
     ) -> PluginInstallTaskStartResponse:
         """
         Install a plugin from an identifier.
@@ -89,7 +93,7 @@ class PluginInstallationManager(BasePluginManager):
             data={
                 "plugin_unique_identifiers": identifiers,
                 "source": source,
-                "meta": meta,
+                "metas": metas,
             },
             headers={"Content-Type": "application/json"},
         )

+ 4 - 0
api/extensions/ext_commands.py

@@ -7,6 +7,7 @@ def init_app(app: DifyApp):
         convert_to_agent_apps,
         create_tenant,
         extract_plugins,
+        extract_unique_plugins,
         fix_app_site_missing,
         migrate_data_for_plugin,
         reset_email,
@@ -14,6 +15,7 @@ def init_app(app: DifyApp):
         reset_password,
         upgrade_db,
         vdb_migrate,
+        install_plugins,
     )
 
     cmds_to_register = [
@@ -28,6 +30,8 @@ def init_app(app: DifyApp):
         fix_app_site_missing,
         migrate_data_for_plugin,
         extract_plugins,
+        extract_unique_plugins,
+        install_plugins,
     ]
     for cmd in cmds_to_register:
         app.cli.add_command(cmd)

+ 227 - 3
api/services/plugin/plugin_migration.py

@@ -1,14 +1,25 @@
+from concurrent.futures import ThreadPoolExecutor
 import datetime
 import json
 import logging
 from collections.abc import Sequence
+from pathlib import Path
+import sys
+import time
+from typing import Any, Mapping, Optional
+from uuid import uuid4
 
 import click
+import tqdm
 from flask import Flask, current_app
 from sqlalchemy.orm import Session
 
 from core.agent.entities import AgentToolEntity
 from core.entities import DEFAULT_PLUGIN_ID
+from core.helper import marketplace
+from core.plugin.entities.plugin import PluginInstallationSource
+from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
+from core.plugin.manager.plugin import PluginInstallationManager
 from core.tools.entities.tool_entities import ToolProviderType
 from models.account import Tenant
 from models.engine import db
@@ -199,7 +210,7 @@ class PluginMigration:
                     if provider_name == "google":
                         provider_name = "gemini"
 
-                    result.append(DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name)
+                    result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
                 elif provider_name:
                     result.append(provider_name)
 
@@ -215,7 +226,7 @@ class PluginMigration:
             result = []
             for row in rs:
                 if "/" not in row.provider:
-                    result.append(DEFAULT_PLUGIN_ID + "/" + row.provider + "/" + row.provider)
+                    result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
                 else:
                     result.append(row.provider)
 
@@ -234,7 +245,7 @@ class PluginMigration:
             provider_name = "stepfun_tool"
 
         if "/" not in provider_name:
-            return DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name
+            return DEFAULT_PLUGIN_ID + "/" + provider_name
         else:
             return provider_name
 
@@ -297,3 +308,216 @@ class PluginMigration:
                                 continue
 
             return result
+
+    @classmethod
+    def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
+        """
+        Fetch plugin unique identifier using plugin id.
+        """
+        plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
+        if not plugin_manifest:
+            return None
+
+        return plugin_manifest[0].latest_package_identifier
+
+    @classmethod
+    def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
+        """
+        Extract unique plugins.
+        """
+        Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
+
+    @classmethod
+    def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
+        plugins: dict[str, str] = {}
+        plugin_ids = []
+        plugin_not_exist = []
+        logger.info(f"Extracting unique plugins from {extracted_plugins}")
+        with open(extracted_plugins) as f:
+            for line in f:
+                data = json.loads(line)
+                new_plugin_ids = data.get("plugins", [])
+                for plugin_id in new_plugin_ids:
+                    if plugin_id not in plugin_ids:
+                        plugin_ids.append(plugin_id)
+
+        def fetch_plugin(plugin_id):
+            unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
+            if unique_identifier:
+                plugins[plugin_id] = unique_identifier
+            else:
+                plugin_not_exist.append(plugin_id)
+
+        with ThreadPoolExecutor(max_workers=10) as executor:
+            list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
+
+        return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
+
+    @classmethod
+    def install_plugins(cls, extracted_plugins: str, output_file: str) -> None:
+        """
+        Install plugins.
+        """
+        manager = PluginInstallationManager()
+
+        plugins = cls.extract_unique_plugins(extracted_plugins)
+        not_installed = []
+        plugin_install_failed = []
+
+        # use a fake tenant id to install all the plugins
+        fake_tenant_id = uuid4().hex
+        logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
+
+        thread_pool = ThreadPoolExecutor(max_workers=40)
+
+        response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
+        if response.get("failed"):
+            plugin_install_failed.extend(response.get("failed", []))
+
+        def install(tenant_id: str, plugin_ids: list[str]) -> None:
+            logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
+            # at most 64 plugins one batch
+            for i in range(0, len(plugin_ids), 64):
+                batch_plugin_ids = plugin_ids[i : i + 64]
+                batch_plugin_identifiers = [plugins["plugins"][plugin_id] for plugin_id in batch_plugin_ids]
+                manager.install_from_identifiers(
+                    tenant_id,
+                    batch_plugin_identifiers,
+                    PluginInstallationSource.Marketplace,
+                    metas=[
+                        {
+                            "plugin_unique_identifier": identifier,
+                        }
+                        for identifier in batch_plugin_identifiers
+                    ],
+                )
+
+        with open(extracted_plugins, "r") as f:
+            """
+            Read line by line, and install plugins for each tenant.
+            """
+            for line in f:
+                data = json.loads(line)
+                tenant_id = data.get("tenant_id")
+                plugin_ids = data.get("plugins", [])
+                current_not_installed = {
+                    "tenant_id": tenant_id,
+                    "plugin_not_exist": [],
+                }
+                # get plugin unique identifier
+                for plugin_id in plugin_ids:
+                    unique_identifier = plugins.get(plugin_id)
+                    if unique_identifier:
+                        current_not_installed["plugin_not_exist"].append(plugin_id)
+
+                if current_not_installed["plugin_not_exist"]:
+                    not_installed.append(current_not_installed)
+
+                thread_pool.submit(install, tenant_id, plugin_ids)
+
+        thread_pool.shutdown(wait=True)
+
+        logger.info("Uninstall plugins")
+
+        sys.exit(-1)
+
+        # get installation
+        try:
+            installation = manager.list_plugins(fake_tenant_id)
+            while installation:
+                for plugin in installation:
+                    manager.uninstall(fake_tenant_id, plugin.installation_id)
+
+                installation = manager.list_plugins(fake_tenant_id)
+        except Exception:
+            logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
+
+        Path(output_file).write_text(
+            json.dumps(
+                {
+                    "not_installed": not_installed,
+                    "plugin_install_failed": plugin_install_failed,
+                }
+            )
+        )
+
+    @classmethod
+    def handle_plugin_instance_install(
+        cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
+    ) -> Mapping[str, Any]:
+        """
+        Install plugins for a tenant.
+        """
+        manager = PluginInstallationManager()
+
+        # download all the plugins and upload
+        thread_pool = ThreadPoolExecutor(max_workers=10)
+        futures = []
+
+        for plugin_id, plugin_identifier in plugin_identifiers_map.items():
+
+            def download_and_upload(tenant_id, plugin_id, plugin_identifier):
+                plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
+                if not plugin_package:
+                    raise Exception(f"Failed to download plugin {plugin_identifier}")
+
+                # upload
+                manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
+
+            futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
+
+        # Wait for all downloads to complete
+        for future in futures:
+            future.result()  # This will raise any exceptions that occurred
+
+        thread_pool.shutdown(wait=True)
+        success = []
+        failed = []
+
+        reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
+
+        # at most 64 plugins one batch
+        for i in range(0, len(plugin_identifiers_map), 64):
+            batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 64]
+            batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
+
+            try:
+                response = manager.install_from_identifiers(
+                    tenant_id=tenant_id,
+                    identifiers=batch_plugin_identifiers,
+                    source=PluginInstallationSource.Marketplace,
+                    metas=[
+                        {
+                            "plugin_unique_identifier": identifier,
+                        }
+                        for identifier in batch_plugin_identifiers
+                    ],
+                )
+            except Exception:
+                # add to failed
+                failed.extend(batch_plugin_identifiers)
+                continue
+
+            if response.all_installed:
+                success.extend(batch_plugin_identifiers)
+                continue
+
+            task_id = response.task_id
+            done = False
+            while not done:
+                status = manager.fetch_plugin_installation_task(tenant_id, task_id)
+                if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
+                    for plugin in status.plugins:
+                        if plugin.status == PluginInstallTaskStatus.Success:
+                            success.append(reverse_map[plugin.plugin_unique_identifier])
+                        else:
+                            failed.append(reverse_map[plugin.plugin_unique_identifier])
+                            logger.error(
+                                f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
+                            )
+
+                    done = True
+                else:
+                    time.sleep(1)
+
+        return {"success": success, "failed": failed}

+ 14 - 9
api/services/plugin/plugin_service.py

@@ -232,7 +232,7 @@ class PluginService:
             tenant_id,
             plugin_unique_identifiers,
             PluginInstallationSource.Package,
-            {},
+            [{}],
         )
 
     @staticmethod
@@ -246,11 +246,13 @@ class PluginService:
             tenant_id,
             [plugin_unique_identifier],
             PluginInstallationSource.Github,
-            {
-                "repo": repo,
-                "version": version,
-                "package": package,
-            },
+            [
+                {
+                    "repo": repo,
+                    "version": version,
+                    "package": package,
+                }
+            ],
         )
 
     @staticmethod
@@ -277,9 +279,12 @@ class PluginService:
             tenant_id,
             plugin_unique_identifiers,
             PluginInstallationSource.Marketplace,
-            {
-                "plugin_unique_identifier": plugin_unique_identifier,
-            },
+            [
+                {
+                    "plugin_unique_identifier": plugin_unique_identifier,
+                }
+                for plugin_unique_identifier in plugin_unique_identifiers
+            ],
         )
 
     @staticmethod