Yeuoly 7 months ago
parent
commit
435e71eb60

+ 1 - 1
api/core/plugin/entities/plugin_daemon.py

@@ -3,7 +3,7 @@ from typing import Generic, Optional, TypeVar
 
 from pydantic import BaseModel
 
-T = TypeVar("T", bound=(BaseModel | dict | bool))
+T = TypeVar("T", bound=(BaseModel | dict | list | bool))
 
 
 class PluginDaemonBasicResponse(BaseModel, Generic[T]):

+ 38 - 11
api/core/plugin/manager/base.py

@@ -12,7 +12,7 @@ from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
 plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL
 plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
 
-T = TypeVar("T", bound=(BaseModel | dict | bool))
+T = TypeVar("T", bound=(BaseModel | dict | list | bool))
 
 
 class BasePluginManager:
@@ -22,6 +22,7 @@ class BasePluginManager:
         path: str,
         headers: dict | None = None,
         data: bytes | dict | None = None,
+        params: dict | None = None,
         stream: bool = False,
     ) -> requests.Response:
         """
@@ -30,16 +31,23 @@ class BasePluginManager:
         url = URL(str(plugin_daemon_inner_api_baseurl)) / path
         headers = headers or {}
         headers["X-Api-Key"] = plugin_daemon_inner_api_key
-        response = requests.request(method=method, url=str(url), headers=headers, data=data, stream=stream)
+        response = requests.request(
+            method=method, url=str(url), headers=headers, data=data, params=params, stream=stream
+        )
         return response
 
     def _stream_request(
-        self, method: str, path: str, headers: dict | None = None, data: bytes | dict | None = None
+        self,
+        method: str,
+        path: str,
+        params: dict | None = None,
+        headers: dict | None = None,
+        data: bytes | dict | None = None,
     ) -> Generator[bytes, None, None]:
         """
         Make a stream request to the plugin daemon inner API
         """
-        response = self._request(method, path, headers, data, stream=True)
+        response = self._request(method, path, headers, data, params, stream=True)
         yield from response.iter_lines()
 
     def _stream_request_with_model(
@@ -49,29 +57,42 @@ class BasePluginManager:
         type: type[T],
         headers: dict | None = None,
         data: bytes | dict | None = None,
+        params: dict | None = None,
     ) -> Generator[T, None, None]:
         """
         Make a stream request to the plugin daemon inner API and yield the response as a model.
         """
-        for line in self._stream_request(method, path, headers, data):
+        for line in self._stream_request(method, path, params, headers, data):
             yield type(**json.loads(line))
 
     def _request_with_model(
-        self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | None = None
+        self,
+        method: str,
+        path: str,
+        type: type[T],
+        headers: dict | None = None,
+        data: bytes | None = None,
+        params: dict | None = None,
     ) -> T:
         """
         Make a request to the plugin daemon inner API and return the response as a model.
         """
-        response = self._request(method, path, headers, data)
+        response = self._request(method, path, headers, data, params)
         return type(**response.json())
 
     def _request_with_plugin_daemon_response(
-        self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | dict | None = None
+        self,
+        method: str,
+        path: str,
+        type: type[T],
+        headers: dict | None = None,
+        data: bytes | dict | None = None,
+        params: dict | None = None,
     ) -> T:
         """
         Make a request to the plugin daemon inner API and return the response as a model.
         """
-        response = self._request(method, path, headers, data)
+        response = self._request(method, path, headers, data, params)
         rep = PluginDaemonBasicResponse[type](**response.json())
         if rep.code != 0:
             raise ValueError(f"got error from plugin daemon: {rep.message}, code: {rep.code}")
@@ -81,12 +102,18 @@ class BasePluginManager:
         return rep.data
 
     def _request_with_plugin_daemon_response_stream(
-        self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | dict | None = None
+        self,
+        method: str,
+        path: str,
+        type: type[T],
+        headers: dict | None = None,
+        data: bytes | dict | None = None,
+        params: dict | None = None,
     ) -> Generator[T, None, None]:
         """
         Make a stream request to the plugin daemon inner API and yield the response as a model.
         """
-        for line in self._stream_request(method, path, headers, data):
+        for line in self._stream_request(method, path, params, headers, data):
             line_data = json.loads(line)
             rep = PluginDaemonBasicResponse[type](**line_data)
             if rep.code != 0:

+ 9 - 1
api/core/plugin/manager/model.py

@@ -1,5 +1,13 @@
+from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.plugin.manager.base import BasePluginManager
 
 
 class PluginModelManager(BasePluginManager):
-    pass
+    def fetch_model_providers(self, tenant_id: str) -> list[ProviderEntity]:
+        """
+        Fetch model providers for the given tenant.
+        """
+        response = self._request_with_plugin_daemon_response(
+            "GET", f"plugin/{tenant_id}/models", list[ProviderEntity], params={"page": 1, "page_size": 256}
+        )
+        return response

+ 6 - 9
api/core/plugin/manager/plugin.py

@@ -1,5 +1,4 @@
 from collections.abc import Generator
-from urllib.parse import quote
 
 from core.plugin.entities.plugin_daemon import InstallPluginMessage
 from core.plugin.manager.base import BasePluginManager
@@ -9,9 +8,8 @@ class PluginInstallationManager(BasePluginManager):
     def fetch_plugin_by_identifier(self, tenant_id: str, identifier: str) -> bool:
         # urlencode the identifier
 
-        identifier = quote(identifier)
         return self._request_with_plugin_daemon_response(
-            "GET", f"/plugin/{tenant_id}/fetch/identifier?plugin_unique_identifier={identifier}", bool
+            "GET", f"plugin/{tenant_id}/fetch/identifier", bool, params={"plugin_unique_identifier": identifier}
         )
 
     def install_from_pkg(self, tenant_id: str, pkg: bytes) -> Generator[InstallPluginMessage, None, None]:
@@ -22,21 +20,20 @@ class PluginInstallationManager(BasePluginManager):
         body = {"dify_pkg": ("dify_pkg", pkg, "application/octet-stream")}
 
         return self._request_with_plugin_daemon_response_stream(
-            "POST", f"/plugin/{tenant_id}/install/pkg", InstallPluginMessage, data=body
+            "POST", f"plugin/{tenant_id}/install/pkg", InstallPluginMessage, data=body
         )
 
     def install_from_identifier(self, tenant_id: str, identifier: str) -> bool:
         """
         Install a plugin from an identifier.
         """
-        identifier = quote(identifier)
         # exception will be raised if the request failed
         return self._request_with_plugin_daemon_response(
             "POST",
-            f"/plugin/{tenant_id}/install/identifier",
+            f"plugin/{tenant_id}/install/identifier",
             bool,
-            headers={
-                "Content-Type": "application/json",
+            params={
+                "plugin_unique_identifier": identifier,
             },
             data={
                 "plugin_unique_identifier": identifier,
@@ -48,5 +45,5 @@ class PluginInstallationManager(BasePluginManager):
         Uninstall a plugin.
         """
         return self._request_with_plugin_daemon_response(
-            "DELETE", f"/plugin/{tenant_id}/uninstall?plugin_unique_identifier={identifier}", bool
+            "DELETE", f"plugin/{tenant_id}/uninstall", bool, params={"plugin_unique_identifier": identifier}
         )

+ 6 - 2
api/core/plugin/manager/tool.py

@@ -1,9 +1,13 @@
 from core.plugin.manager.base import BasePluginManager
+from core.tools.entities.tool_entities import ToolProviderEntity
 
 
 class PluginToolManager(BasePluginManager):
-    def fetch_tool_providers(self, asset_id: str) -> list[str]:
+    def fetch_tool_providers(self, tenant_id: str) -> list[ToolProviderEntity]:
         """
         Fetch tool providers for the given asset.
         """
-        response = self._request('GET', f'/plugin/asset/{asset_id}')
+        response = self._request_with_plugin_daemon_response(
+            "GET", f"plugin/{tenant_id}/tools", list[ToolProviderEntity], params={"page": 1, "page_size": 256}
+        )
+        return response

+ 24 - 20
api/core/tools/entities/tool_entities.py

@@ -274,9 +274,12 @@ class ToolProviderIdentity(BaseModel):
     )
 
 
-class ToolProviderEntity(BaseModel):
-    identity: ToolProviderIdentity
-    credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
+class ToolIdentity(BaseModel):
+    author: str = Field(..., description="The author of the tool")
+    name: str = Field(..., description="The name of the tool")
+    label: I18nObject = Field(..., description="The label of the tool")
+    provider: str = Field(..., description="The provider of the tool")
+    icon: Optional[str] = None
 
 
 class ToolDescription(BaseModel):
@@ -284,12 +287,24 @@ class ToolDescription(BaseModel):
     llm: str = Field(..., description="The description presented to the LLM")
 
 
-class ToolIdentity(BaseModel):
-    author: str = Field(..., description="The author of the tool")
-    name: str = Field(..., description="The name of the tool")
-    label: I18nObject = Field(..., description="The label of the tool")
-    provider: str = Field(..., description="The provider of the tool")
-    icon: Optional[str] = None
+class ToolEntity(BaseModel):
+    identity: ToolIdentity
+    parameters: list[ToolParameter] = Field(default_factory=list)
+    description: Optional[ToolDescription] = None
+
+    # pydantic configs
+    model_config = ConfigDict(protected_namespaces=())
+
+    @field_validator("parameters", mode="before")
+    @classmethod
+    def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
+        return v or []
+
+
+class ToolProviderEntity(BaseModel):
+    identity: ToolProviderIdentity
+    credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
+    tools: list[ToolEntity] = Field(default_factory=list)
 
 
 class WorkflowToolParameterConfiguration(BaseModel):
@@ -352,15 +367,4 @@ class ToolInvokeFrom(Enum):
     AGENT = "agent"
 
 
-class ToolEntity(BaseModel):
-    identity: ToolIdentity
-    parameters: list[ToolParameter] = Field(default_factory=list)
-    description: Optional[ToolDescription] = None
-
-    # pydantic configs
-    model_config = ConfigDict(protected_namespaces=())
 
-    @field_validator("parameters", mode="before")
-    @classmethod
-    def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
-        return v or []

+ 5 - 0
api/tests/integration_tests/.env.example

@@ -83,3 +83,8 @@ VOLC_EMBEDDING_ENDPOINT_ID=
 
 # 360 AI Credentials
 ZHINAO_API_KEY=
+
+# Plugin configuration
+PLUGIN_API_KEY=
+PLUGIN_API_URL=
+INNER_API_KEY=

+ 66 - 0
api/tests/integration_tests/plugin/__mock/http.py

@@ -0,0 +1,66 @@
+import os
+from typing import Literal
+
+import pytest
+import requests
+from _pytest.monkeypatch import MonkeyPatch
+
+from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolProviderEntity, ToolProviderIdentity
+
+
+class MockedHttp:
+    @classmethod
+    def list_tools(cls) -> list[ToolProviderEntity]:
+        return [
+            ToolProviderEntity(
+                identity=ToolProviderIdentity(
+                    author="Yeuoly",
+                    name="Yeuoly",
+                    description=I18nObject(en_US="Yeuoly"),
+                    icon="ssss.svg",
+                    label=I18nObject(en_US="Yeuoly"),
+                )
+            )
+        ]
+
+    @classmethod
+    def requests_request(
+        cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
+    ) -> requests.Response:
+        """
+        Mocked requests.request
+        """
+        request = requests.PreparedRequest()
+        request.method = method
+        request.url = url
+        if url.endswith("/tools"):
+            content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
+                code=0, message="success", data=cls.list_tools()
+            ).model_dump_json()
+        else:
+            raise ValueError("")
+
+        response = requests.Response()
+        response.status_code = 200
+        response.request = request
+        response._content = content.encode("utf-8")
+        return response
+
+
+MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
+
+@pytest.fixture
+def setup_http_mock(request, monkeypatch: MonkeyPatch):
+    if MOCK_SWITCH:
+        monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
+
+        def unpatch():
+            monkeypatch.undo()
+
+    yield
+
+    if MOCK_SWITCH:
+        unpatch()

+ 9 - 0
api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py

@@ -0,0 +1,9 @@
+from core.plugin.manager.tool import PluginToolManager
+from tests.integration_tests.plugin.__mock.http import setup_http_mock
+
+
+def test_fetch_all_plugin_tools(setup_http_mock):
+    manager = PluginToolManager()
+    tools = manager.fetch_tool_providers(tenant_id="test-tenant")
+    assert len(tools) >= 1
+

+ 0 - 23
api/tests/integration_tests/tools/test_all_provider.py

@@ -1,23 +0,0 @@
-import pytest
-
-from core.tools.tool_manager import ToolManager
-
-provider_generator = ToolManager.list_builtin_providers()
-provider_names = [provider.identity.name for provider in provider_generator]
-ToolManager.clear_builtin_providers_cache()
-provider_generator = ToolManager.list_builtin_providers()
-
-
-@pytest.mark.parametrize("name", provider_names)
-def test_tool_providers(benchmark, name):
-    """
-    Test that all tool providers can be loaded
-    """
-
-    def test(generator):
-        try:
-            return next(generator)
-        except StopIteration:
-            return None
-
-    benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)