Quellcode durchsuchen

fix: deleted_tools

Yeuoly vor 6 Monaten
Ursprung
Commit
c2ce8e638e

+ 3 - 1
api/core/plugin/entities/plugin.py

@@ -139,6 +139,7 @@ class GenericProviderID:
     organization: str
     plugin_name: str
     provider_name: str
+    is_hardcoded: bool
 
     def to_string(self) -> str:
         return str(self)
@@ -146,7 +147,7 @@ class GenericProviderID:
     def __str__(self) -> str:
         return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
 
-    def __init__(self, value: str) -> None:
+    def __init__(self, value: str, is_hardcoded: bool = False) -> None:
         # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
         if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
             # check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
@@ -156,6 +157,7 @@ class GenericProviderID:
                 raise ValueError("Invalid plugin id")
 
         self.organization, self.plugin_name, self.provider_name = value.split("/")
+        self.is_hardcoded = is_hardcoded
 
     @property
     def plugin_id(self) -> str:

+ 21 - 0
api/core/plugin/manager/plugin.py

@@ -4,6 +4,7 @@ from pydantic import BaseModel
 
 from core.plugin.entities.bundle import PluginBundleDependency
 from core.plugin.entities.plugin import (
+    GenericProviderID,
     PluginDeclaration,
     PluginEntity,
     PluginInstallation,
@@ -224,3 +225,23 @@ class PluginInstallationManager(BasePluginManager):
             },
             headers={"Content-Type": "application/json"},
         )
+
+    def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
+        """
+        Check if the tools exist
+        """
+        return self._request_with_plugin_daemon_response(
+            "POST",
+            f"plugin/{tenant_id}/management/tools/check_existence",
+            list[bool],
+            data={
+                "provider_ids": [
+                    {
+                        "plugin_id": provider_id.plugin_id,
+                        "provider_name": provider_id.provider_name,
+                    }
+                    for provider_id in provider_ids
+                ]
+            },
+            headers={"Content-Type": "application/json"},
+        )

+ 8 - 1
api/fields/app_fields.py

@@ -149,6 +149,12 @@ site_fields = {
     "updated_at": TimestampField,
 }
 
+deleted_tool_fields = {
+    "type": fields.String,
+    "tool_name": fields.String,
+    "provider_id": fields.String,
+}
+
 app_detail_fields_with_site = {
     "id": fields.String,
     "name": fields.String,
@@ -169,9 +175,10 @@ app_detail_fields_with_site = {
     "created_at": TimestampField,
     "updated_by": fields.String,
     "updated_at": TimestampField,
-    "deleted_tools": fields.List(fields.String),
+    "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
 }
 
+
 app_site_fields = {
     "app_id": fields.String,
     "access_token": fields.String(attribute="code"),

+ 86 - 12
api/models/model.py

@@ -1,4 +1,5 @@
 import json
+import logging
 import re
 import uuid
 from collections.abc import Mapping
@@ -6,6 +7,10 @@ from datetime import datetime
 from enum import Enum
 from typing import TYPE_CHECKING, Optional
 
+from core.plugin.entities.plugin import GenericProviderID
+from core.tools.entities.tool_entities import ToolProviderType
+from services.plugin.plugin_service import PluginService
+
 if TYPE_CHECKING:
     from models.workflow import Workflow
 
@@ -16,7 +21,7 @@ import sqlalchemy as sa
 from flask import request
 from flask_login import UserMixin
 from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
-from sqlalchemy.orm import Mapped, mapped_column
+from sqlalchemy.orm import Mapped, Session, mapped_column
 
 from configs import dify_config
 from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
@@ -30,6 +35,8 @@ from models.enums import CreatedByRole
 from .account import Account, Tenant
 from .types import StringUUID
 
+logger = logging.getLogger(__name__)
+
 
 class DifySetup(Base):
     __tablename__ = "dify_setups"
@@ -162,47 +169,114 @@ class App(Base):
 
     @property
     def deleted_tools(self) -> list:
+        from core.tools.tool_manager import ToolManager
+
         # get agent mode tools
         app_model_config = self.app_model_config
         if not app_model_config:
             return []
+
         if not app_model_config.agent_mode:
             return []
+
         agent_mode = app_model_config.agent_mode_dict
         tools = agent_mode.get("tools", [])
 
-        provider_ids = []
+        api_provider_ids: list[str] = []
+        builtin_provider_ids: list[GenericProviderID] = []
 
         for tool in tools:
             keys = list(tool.keys())
             if len(keys) >= 4:
                 provider_type = tool.get("provider_type", "")
                 provider_id = tool.get("provider_id", "")
-                if provider_type == "api":
-                    # check if provider id is a uuid string, if not, skip
+                if provider_type == ToolProviderType.API.value:
                     try:
                         uuid.UUID(provider_id)
                     except Exception:
                         continue
-                    provider_ids.append(provider_id)
+                    api_provider_ids.append(provider_id)
+                if provider_type == ToolProviderType.BUILT_IN.value:
+                    try:
+                        # check if it's hardcoded
+                        try:
+                            ToolManager.get_hardcoded_provider(provider_id)
+                            is_hardcoded = True
+                        except Exception:
+                            is_hardcoded = False
+
+                        provider_id = GenericProviderID(provider_id, is_hardcoded)
+                    except Exception:
+                        logger.exception(f"Invalid builtin provider id: {provider_id}")
+                        continue
 
-        if not provider_ids:
+                    builtin_provider_ids.append(provider_id)
+
+        if not api_provider_ids and not builtin_provider_ids:
             return []
 
-        api_providers = db.session.execute(
-            text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)}
-        ).fetchall()
+        with Session(db.engine) as session:
+            if api_provider_ids:
+                existing_api_providers = [
+                    api_provider.id
+                    for api_provider in session.execute(
+                        text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"),
+                        {"provider_ids": tuple(api_provider_ids)},
+                    ).fetchall()
+                ]
+            else:
+                existing_api_providers = []
+
+        if builtin_provider_ids:
+            # get the non-hardcoded builtin providers
+            non_hardcoded_builtin_providers = [
+                provider_id for provider_id in builtin_provider_ids if not provider_id.is_hardcoded
+            ]
+            if non_hardcoded_builtin_providers:
+                existence = list(PluginService.check_tools_existence(self.tenant_id, non_hardcoded_builtin_providers))
+            else:
+                existence = []
+            # add the hardcoded builtin providers
+            existence.extend([True] * (len(builtin_provider_ids) - len(non_hardcoded_builtin_providers)))
+            builtin_provider_ids = non_hardcoded_builtin_providers + [
+                provider_id for provider_id in builtin_provider_ids if provider_id.is_hardcoded
+            ]
+        else:
+            existence = []
+
+        existing_builtin_providers = {
+            provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
+        }
 
         deleted_tools = []
-        current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers]
 
         for tool in tools:
             keys = list(tool.keys())
             if len(keys) >= 4:
                 provider_type = tool.get("provider_type", "")
                 provider_id = tool.get("provider_id", "")
-                if provider_type == "api" and provider_id not in current_api_provider_ids:
-                    deleted_tools.append(tool["tool_name"])
+
+                if provider_type == ToolProviderType.API.value:
+                    if provider_id not in existing_api_providers:
+                        deleted_tools.append(
+                            {
+                                "type": ToolProviderType.API.value,
+                                "tool_name": tool["tool_name"],
+                                "provider_id": provider_id,
+                            }
+                        )
+
+                if provider_type == ToolProviderType.BUILT_IN.value:
+                    generic_provider_id = GenericProviderID(provider_id)
+
+                    if not existing_builtin_providers[generic_provider_id.provider_name]:
+                        deleted_tools.append(
+                            {
+                                "type": ToolProviderType.BUILT_IN.value,
+                                "tool_name": tool["tool_name"],
+                                "provider_id": provider_id,  # use the original one
+                            }
+                        )
 
         return deleted_tools
 

+ 15 - 1
api/services/plugin/plugin_service.py

@@ -7,7 +7,13 @@ from core.helper import marketplace
 from core.helper.download import download_with_size_limit
 from core.helper.marketplace import download_plugin_pkg
 from core.plugin.entities.bundle import PluginBundleDependency
-from core.plugin.entities.plugin import PluginDeclaration, PluginEntity, PluginInstallation, PluginInstallationSource
+from core.plugin.entities.plugin import (
+    GenericProviderID,
+    PluginDeclaration,
+    PluginEntity,
+    PluginInstallation,
+    PluginInstallationSource,
+)
 from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
 from core.plugin.manager.asset import PluginAssetManager
 from core.plugin.manager.debugging import PluginDebuggingManager
@@ -279,3 +285,11 @@ class PluginService:
     def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
         manager = PluginInstallationManager()
         return manager.uninstall(tenant_id, plugin_installation_id)
+
+    @staticmethod
+    def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
+        """
+        Check if the tools exist
+        """
+        manager = PluginInstallationManager()
+        return manager.check_tools_existence(tenant_id, provider_ids)