Browse Source

feat: compat tool provider credentials to updated data

Yeuoly 8 months ago
parent
commit
56b7853afe

+ 17 - 0
api/core/tools/entities/tool_entities.py

@@ -1,4 +1,5 @@
 import base64
+import re
 from enum import Enum
 from typing import Any, Optional, Union
 
@@ -377,3 +378,19 @@ class ToolInvokeFrom(Enum):
 
     WORKFLOW = "workflow"
     AGENT = "agent"
+
+
+class ToolProviderID:
+    organization: str
+    plugin_name: str
+    provider_name: str
+
+    def __str__(self) -> str:
+        return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
+
+    def __init__(self, value: str) -> None:
+        # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
+        if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
+            raise ValueError("Invalid plugin id")
+
+        self.organization, self.plugin_name, self.provider_name = value.split("/")

+ 36 - 11
api/core/tools/tool_manager.py

@@ -29,7 +29,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController
 from core.tools.custom_tool.tool import ApiTool
 from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
 from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
+from core.tools.entities.tool_entities import (
+    ApiProviderAuthType,
+    ToolInvokeFrom,
+    ToolParameter,
+    ToolProviderID,
+    ToolProviderType,
+)
 from core.tools.errors import ToolProviderNotFoundError
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
@@ -143,18 +149,30 @@ class ToolManager:
                     ),
                 )
 
-            # get credentials
-            builtin_provider: BuiltinToolProvider | None = (
-                db.session.query(BuiltinToolProvider)
-                .filter(
-                    BuiltinToolProvider.tenant_id == tenant_id,
-                    BuiltinToolProvider.provider == provider_id,
+            if isinstance(provider_controller, PluginToolProviderController):
+                provider_id_entity = ToolProviderID(provider_id)
+                # get credentials
+                builtin_provider: BuiltinToolProvider | None = (
+                    db.session.query(BuiltinToolProvider)
+                    .filter(
+                        BuiltinToolProvider.tenant_id == tenant_id,
+                        (BuiltinToolProvider.provider == provider_id)
+                        | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
+                    )
+                    .first()
                 )
-                .first()
-            )
 
-            if builtin_provider is None:
-                raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
+                if builtin_provider is None:
+                    raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
+            else:
+                builtin_provider: BuiltinToolProvider | None = (
+                    db.session.query(BuiltinToolProvider)
+                    .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
+                    .first()
+                )
+
+                if builtin_provider is None:
+                    raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
 
             # decrypt the credentials
             credentials = builtin_provider.credentials
@@ -505,6 +523,13 @@ class ToolManager:
                 db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
             )
 
+            # rewrite db_builtin_providers
+            for db_provider in db_builtin_providers:
+                try:
+                    ToolProviderID(db_provider.provider)
+                except Exception:
+                    db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
+
             find_db_builtin_provider = lambda provider: next(
                 (x for x in db_builtin_providers if x.provider == provider), None
             )

+ 1 - 0
api/models/tools.py

@@ -222,6 +222,7 @@ class ToolModelInvoke(Base):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
+@deprecated
 class ToolConversationVariables(Base):
     """
     store the conversation variables from tool invoke

+ 2 - 2
api/services/tools/api_tools_manage_service.py

@@ -221,7 +221,7 @@ class ApiToolManageService:
         labels = ToolLabelManager.get_tool_labels(controller)
 
         return [
-            ToolTransformService.tool_to_user_tool(
+            ToolTransformService.convert_tool_entity_to_api_entity(
                 tool_bundle,
                 tenant_id=tenant_id,
                 labels=labels,
@@ -465,7 +465,7 @@ class ApiToolManageService:
 
             for tool in tools:
                 user_provider.tools.append(
-                    ToolTransformService.tool_to_user_tool(
+                    ToolTransformService.convert_tool_entity_to_api_entity(
                         tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
                     )
                 )

+ 51 - 34
api/services/tools/builtin_tools_manage_service.py

@@ -7,6 +7,7 @@ from core.helper.position_helper import is_filtered
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
 from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
+from core.tools.entities.tool_entities import ToolProviderID
 from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.tool_manager import ToolManager
@@ -40,14 +41,7 @@ class BuiltinToolManageService:
             provider_identity=provider_controller.entity.identity.name,
         )
         # check if user has added the provider
-        builtin_provider: BuiltinToolProvider | None = (
-            db.session.query(BuiltinToolProvider)
-            .filter(
-                BuiltinToolProvider.tenant_id == tenant_id,
-                BuiltinToolProvider.provider == provider,
-            )
-            .first()
-        )
+        builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
 
         credentials = {}
         if builtin_provider is not None:
@@ -58,7 +52,7 @@ class BuiltinToolManageService:
         result = []
         for tool in tools:
             result.append(
-                ToolTransformService.tool_to_user_tool(
+                ToolTransformService.convert_tool_entity_to_api_entity(
                     tool=tool,
                     credentials=credentials,
                     tenant_id=tenant_id,
@@ -86,14 +80,7 @@ class BuiltinToolManageService:
         update builtin tool provider
         """
         # get if the provider exists
-        provider: BuiltinToolProvider | None = (
-            db.session.query(BuiltinToolProvider)
-            .filter(
-                BuiltinToolProvider.tenant_id == tenant_id,
-                BuiltinToolProvider.provider == provider_name,
-            )
-            .first()
-        )
+        provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
 
         try:
             # get provider
@@ -149,14 +136,7 @@ class BuiltinToolManageService:
         """
         get builtin tool provider credentials
         """
-        provider_obj: BuiltinToolProvider | None = (
-            db.session.query(BuiltinToolProvider)
-            .filter(
-                BuiltinToolProvider.tenant_id == tenant_id,
-                BuiltinToolProvider.provider == provider,
-            )
-            .first()
-        )
+        provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
 
         if provider_obj is None:
             return {}
@@ -177,14 +157,7 @@ class BuiltinToolManageService:
         """
         delete tool provider
         """
-        provider_obj: BuiltinToolProvider | None = (
-            db.session.query(BuiltinToolProvider)
-            .filter(
-                BuiltinToolProvider.tenant_id == tenant_id,
-                BuiltinToolProvider.provider == provider_name,
-            )
-            .first()
-        )
+        provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
 
         if provider_obj is None:
             raise ValueError(f"you have not added provider {provider_name}")
@@ -227,6 +200,13 @@ class BuiltinToolManageService:
             db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
         )
 
+        # rewrite db_providers
+        for db_provider in db_providers:
+            try:
+                ToolProviderID(db_provider.provider)
+            except Exception:
+                db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
+
         # find provider
         find_provider = lambda provider: next(
             filter(lambda db_provider: db_provider.provider == provider, db_providers), None
@@ -258,7 +238,7 @@ class BuiltinToolManageService:
                 tools = provider_controller.get_tools()
                 for tool in tools:
                     user_builtin_provider.tools.append(
-                        ToolTransformService.tool_to_user_tool(
+                        ToolTransformService.convert_tool_entity_to_api_entity(
                             tenant_id=tenant_id,
                             tool=tool,
                             credentials=user_builtin_provider.original_credentials,
@@ -271,3 +251,40 @@ class BuiltinToolManageService:
                 raise e
 
         return BuiltinToolProviderSort.sort(result)
+
+    @staticmethod
+    def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
+        try:
+            provider_id_entity = ToolProviderID(provider_name)
+            provider_name = provider_id_entity.provider_name
+            if provider_id_entity.organization != "langgenius":
+                return None
+
+            provider_obj = (
+                db.session.query(BuiltinToolProvider)
+                .filter(
+                    BuiltinToolProvider.tenant_id == tenant_id,
+                    (BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == provider_name),
+                )
+                .first()
+            )
+
+            if provider_obj is None:
+                return None
+
+            try:
+                ToolProviderID(provider_obj.provider)
+            except Exception:
+                provider_obj.provider = f"langgenius/{provider_obj.provider}/{provider_obj.provider}"
+
+            return provider_obj
+        except Exception:
+            # it's an old provider without organization
+            return (
+                db.session.query(BuiltinToolProvider)
+                .filter(
+                    BuiltinToolProvider.tenant_id == tenant_id,
+                    (BuiltinToolProvider.provider == provider_name),
+                )
+                .first()
+            )

+ 1 - 1
api/services/tools/tools_transform_service.py

@@ -223,7 +223,7 @@ class ToolTransformService:
         return result
 
     @staticmethod
-    def tool_to_user_tool(
+    def convert_tool_entity_to_api_entity(
         tool: Union[ApiToolBundle, WorkflowTool, Tool],
         tenant_id: str,
         credentials: dict | None = None,

+ 3 - 3
api/services/tools/workflow_tools_manage_service.py

@@ -210,7 +210,7 @@ class WorkflowToolManageService:
             )
             ToolTransformService.repack_provider(user_tool_provider)
             user_tool_provider.tools = [
-                ToolTransformService.tool_to_user_tool(
+                ToolTransformService.convert_tool_entity_to_api_entity(
                     tool=tool.get_tools(user_id, tenant_id)[0],
                     labels=labels.get(tool.provider_id, []),
                     tenant_id=tenant_id,
@@ -299,7 +299,7 @@ class WorkflowToolManageService:
             "icon": json.loads(db_tool.icon),
             "description": db_tool.description,
             "parameters": jsonable_encoder(db_tool.parameter_configurations),
-            "tool": ToolTransformService.tool_to_user_tool(
+            "tool": ToolTransformService.convert_tool_entity_to_api_entity(
                 tool=tool.get_tools(db_tool.tenant_id)[0],
                 labels=ToolLabelManager.get_tool_labels(tool),
                 tenant_id=tenant_id,
@@ -329,7 +329,7 @@ class WorkflowToolManageService:
         tool = ToolTransformService.workflow_provider_to_controller(db_tool)
 
         return [
-            ToolTransformService.tool_to_user_tool(
+            ToolTransformService.convert_tool_entity_to_api_entity(
                 tool=tool.get_tools(db_tool.tenant_id)[0],
                 labels=ToolLabelManager.get_tool_labels(tool),
                 tenant_id=tenant_id,