Browse Source

fix: transform plugin icon incorrect

Yeuoly 7 months ago
parent
commit
459cb9dd72

+ 1 - 6
api/controllers/console/workspace/tool_providers.py

@@ -128,12 +128,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
 class ToolBuiltinProviderIconApi(Resource):
     @setup_required
     def get(self, provider):
-        user = current_user
-
-        user_id = user.id
-        tenant_id = user.current_tenant_id
-
-        icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider, tenant_id)
+        icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
         icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
         return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
 

+ 14 - 3
api/core/tools/tool_manager.py

@@ -57,6 +57,17 @@ class ToolManager:
     _builtin_tools_labels = {}
 
     @classmethod
+    def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
+        """
+        get the hardcoded provider
+        """
+        if len(cls._hardcoded_providers) == 0:
+            # init the builtin providers
+            cls.load_hardcoded_providers_cache()
+
+        return cls._hardcoded_providers[provider]
+
+    @classmethod
     def get_builtin_provider(
         cls, provider: str, tenant_id: str
     ) -> BuiltinToolProviderController | PluginToolProviderController:
@@ -407,9 +418,9 @@ class ToolManager:
         return tool_entity
 
     @classmethod
-    def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
+    def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
         """
-        get the absolute path of the icon of the builtin provider
+        get the absolute path of the icon of the hardcoded provider
 
         :param provider: the name of the provider
         :param tenant_id: the id of the tenant
@@ -417,7 +428,7 @@ class ToolManager:
         :return: the absolute path of the icon, the mime type of the icon
         """
         # get provider
-        provider_controller = cls.get_builtin_provider(provider, tenant_id)
+        provider_controller = cls.get_hardcoded_provider(provider)
 
         absolute_path = path.join(
             path.dirname(path.realpath(__file__)),

+ 1 - 0
api/core/tools/workflow_as_tool/provider.py

@@ -60,6 +60,7 @@ class WorkflowToolProviderController(ToolProviderController):
                     icon=db_provider.icon,
                 ),
                 credentials_schema=[],
+                plugin_id=None,
             ),
             provider_id=db_provider.id,
         )

+ 3 - 3
api/services/tools/builtin_tools_manage_service.py

@@ -178,11 +178,11 @@ class BuiltinToolManageService:
         return {"result": "success"}
 
     @staticmethod
-    def get_builtin_tool_provider_icon(provider: str, tenant_id: str):
+    def get_builtin_tool_provider_icon(provider: str):
         """
         get tool provider icon and it's mimetype
         """
-        icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider, tenant_id)
+        icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
         icon_bytes = Path(icon_path).read_bytes()
 
         return icon_bytes, mime_type
@@ -233,7 +233,7 @@ class BuiltinToolManageService:
                 )
 
                 # add icon
-                ToolTransformService.repack_provider(user_builtin_provider)
+                ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
 
                 tools = provider_controller.get_tools()
                 for tool in tools:

+ 1 - 1
api/services/tools/tools_manage_service.py

@@ -19,7 +19,7 @@ class ToolCommonService:
 
         # add icon
         for provider in providers:
-            ToolTransformService.repack_provider(provider)
+            ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
 
         result = [provider.to_dict() for provider in providers]
 

+ 16 - 6
api/services/tools/tools_transform_service.py

@@ -2,6 +2,8 @@ import json
 import logging
 from typing import Optional, Union
 
+from yarl import URL
+
 from configs import dify_config
 from core.tools.__base.tool import Tool
 from core.tools.__base.tool_runtime import ToolRuntime
@@ -26,14 +28,19 @@ logger = logging.getLogger(__name__)
 
 class ToolTransformService:
     @classmethod
+    def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
+        url_prefix = URL(dify_config.CONSOLE_API_URL) / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
+        return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
+
+    @classmethod
     def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
         """
         get tool provider icon url
         """
-        url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
+        url_prefix = URL(dify_config.CONSOLE_API_URL) / "console" / "api" / "workspaces" / "current" / "tool-provider"
 
         if provider_type == ToolProviderType.BUILT_IN.value:
-            return url_prefix + "builtin/" + provider_name + "/icon"
+            return str(url_prefix / "builtin" / provider_name / "icon")
         elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
             try:
                 if isinstance(icon, str):
@@ -45,7 +52,7 @@ class ToolTransformService:
         return ""
 
     @staticmethod
-    def repack_provider(provider: Union[dict, ToolProviderApiEntity]):
+    def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
         """
         repack provider
 
@@ -56,9 +63,12 @@ class ToolTransformService:
                 provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
             )
         elif isinstance(provider, ToolProviderApiEntity):
-            provider.icon = ToolTransformService.get_tool_provider_icon_url(
-                provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
-            )
+            if provider.plugin_id:
+                provider.icon = ToolTransformService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
+            else:
+                provider.icon = ToolTransformService.get_tool_provider_icon_url(
+                    provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
+                )
 
     @classmethod
     def builtin_provider_to_user_provider(