|
@@ -6,7 +6,10 @@ from os import listdir, path
|
|
|
from threading import Lock
|
|
|
from typing import TYPE_CHECKING, Any, Union, cast
|
|
|
|
|
|
+from core.plugin.manager.tool import PluginToolManager
|
|
|
from core.tools.__base.tool_runtime import ToolRuntime
|
|
|
+from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
|
+from core.tools.plugin_tool.tool import PluginTool
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from core.workflow.nodes.tool.entities import ToolEntity
|
|
@@ -24,7 +27,7 @@ from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
|
|
from core.tools.builtin_tool.tool import BuiltinTool
|
|
|
from core.tools.custom_tool.provider import ApiToolProviderController
|
|
|
from core.tools.custom_tool.tool import ApiTool
|
|
|
-from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
|
|
+from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
|
|
from core.tools.entities.common_entities import I18nObject
|
|
|
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
|
|
|
from core.tools.errors import ToolProviderNotFoundError
|
|
@@ -41,38 +44,61 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ToolManager:
|
|
|
_builtin_provider_lock = Lock()
|
|
|
- _builtin_providers = {}
|
|
|
+ _hardcoded_providers = {}
|
|
|
_builtin_providers_loaded = False
|
|
|
_builtin_tools_labels = {}
|
|
|
|
|
|
@classmethod
|
|
|
- def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
|
|
|
+ def get_builtin_provider(
|
|
|
+ cls, provider: str, tenant_id: str
|
|
|
+ ) -> BuiltinToolProviderController | PluginToolProviderController:
|
|
|
"""
|
|
|
get the builtin provider
|
|
|
|
|
|
:param provider: the name of the provider
|
|
|
+ :param tenant_id: the id of the tenant
|
|
|
:return: the provider
|
|
|
"""
|
|
|
- if len(cls._builtin_providers) == 0:
|
|
|
+ if len(cls._hardcoded_providers) == 0:
|
|
|
# init the builtin providers
|
|
|
- cls.load_builtin_providers_cache()
|
|
|
+ cls.load_hardcoded_providers_cache()
|
|
|
|
|
|
- if provider not in cls._builtin_providers:
|
|
|
- raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
|
|
|
+ if provider not in cls._hardcoded_providers:
|
|
|
+ # get plugin provider
|
|
|
+ plugin_provider = cls.get_plugin_provider(provider, tenant_id)
|
|
|
+ if plugin_provider:
|
|
|
+ return plugin_provider
|
|
|
|
|
|
- return cls._builtin_providers[provider]
|
|
|
+ return cls._hardcoded_providers[provider]
|
|
|
|
|
|
@classmethod
|
|
|
- def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
|
|
|
+ def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
|
|
|
+ """
|
|
|
+ get the plugin provider
|
|
|
+ """
|
|
|
+ manager = PluginToolManager()
|
|
|
+ providers = manager.fetch_tool_providers(tenant_id)
|
|
|
+ provider_entity = next((x for x in providers if x.declaration.identity.name == provider), None)
|
|
|
+ if not provider_entity:
|
|
|
+ raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
|
|
|
+
|
|
|
+ return PluginToolProviderController(
|
|
|
+ entity=provider_entity.declaration,
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
|
|
"""
|
|
|
get the builtin tool
|
|
|
|
|
|
:param provider: the name of the provider
|
|
|
:param tool_name: the name of the tool
|
|
|
-
|
|
|
+ :param tenant_id: the id of the tenant
|
|
|
:return: the provider, the tool
|
|
|
"""
|
|
|
- provider_controller = cls.get_builtin_provider(provider)
|
|
|
+ provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
|
|
tool = provider_controller.get_tool(tool_name)
|
|
|
|
|
|
return tool
|
|
@@ -97,12 +123,12 @@ class ToolManager:
|
|
|
:return: the tool
|
|
|
"""
|
|
|
if provider_type == ToolProviderType.BUILT_IN:
|
|
|
- builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
|
|
+ builtin_tool = cls.get_builtin_tool(provider_id, tool_name, tenant_id)
|
|
|
if not builtin_tool:
|
|
|
raise ValueError(f"tool {tool_name} not found")
|
|
|
|
|
|
# check if the builtin tool need credentials
|
|
|
- provider_controller = cls.get_builtin_provider(provider_id)
|
|
|
+ provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
|
|
if not provider_controller.need_credentials:
|
|
|
return cast(
|
|
|
BuiltinTool,
|
|
@@ -131,7 +157,7 @@ class ToolManager:
|
|
|
|
|
|
# decrypt the credentials
|
|
|
credentials = builtin_provider.credentials
|
|
|
- controller = cls.get_builtin_provider(provider_id)
|
|
|
+ controller = cls.get_builtin_provider(provider_id, tenant_id)
|
|
|
tool_configuration = ProviderConfigEncrypter(
|
|
|
tenant_id=tenant_id,
|
|
|
config=controller.get_credentials_schema(),
|
|
@@ -246,7 +272,7 @@ class ToolManager:
|
|
|
tool_invoke_from=ToolInvokeFrom.AGENT,
|
|
|
)
|
|
|
runtime_parameters = {}
|
|
|
- parameters = tool_entity.get_all_runtime_parameters()
|
|
|
+ parameters = tool_entity.get_merged_runtime_parameters()
|
|
|
for parameter in parameters:
|
|
|
# check file types
|
|
|
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
|
@@ -294,7 +320,7 @@ class ToolManager:
|
|
|
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
|
|
)
|
|
|
runtime_parameters = {}
|
|
|
- parameters = tool_entity.get_all_runtime_parameters()
|
|
|
+ parameters = tool_entity.get_merged_runtime_parameters()
|
|
|
|
|
|
for parameter in parameters:
|
|
|
# save tool parameter to tool entity memory
|
|
@@ -321,16 +347,17 @@ class ToolManager:
|
|
|
return tool_entity
|
|
|
|
|
|
@classmethod
|
|
|
- def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
|
|
|
+ def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
|
|
|
"""
|
|
|
get the absolute path of the icon of the builtin provider
|
|
|
|
|
|
:param provider: the name of the provider
|
|
|
+ :param tenant_id: the id of the tenant
|
|
|
|
|
|
:return: the absolute path of the icon, the mime type of the icon
|
|
|
"""
|
|
|
# get provider
|
|
|
- provider_controller = cls.get_builtin_provider(provider)
|
|
|
+ provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
|
absolute_path = path.join(
|
|
|
path.dirname(path.realpath(__file__)),
|
|
@@ -351,21 +378,48 @@ class ToolManager:
|
|
|
return absolute_path, mime_type
|
|
|
|
|
|
@classmethod
|
|
|
- def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
|
|
+ def list_hardcoded_providers(cls):
|
|
|
# use cache first
|
|
|
if cls._builtin_providers_loaded:
|
|
|
- yield from list(cls._builtin_providers.values())
|
|
|
+ yield from list(cls._hardcoded_providers.values())
|
|
|
return
|
|
|
|
|
|
with cls._builtin_provider_lock:
|
|
|
if cls._builtin_providers_loaded:
|
|
|
- yield from list(cls._builtin_providers.values())
|
|
|
+ yield from list(cls._hardcoded_providers.values())
|
|
|
return
|
|
|
|
|
|
- yield from cls._list_builtin_providers()
|
|
|
+ yield from cls._list_hardcoded_providers()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
|
|
|
+ """
|
|
|
+ list all the plugin providers
|
|
|
+ """
|
|
|
+ manager = PluginToolManager()
|
|
|
+ provider_entities = manager.fetch_tool_providers(tenant_id)
|
|
|
+ return [
|
|
|
+ PluginToolProviderController(
|
|
|
+ entity=provider.declaration,
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ plugin_unique_identifier=provider.plugin_unique_identifier,
|
|
|
+ )
|
|
|
+ for provider in provider_entities
|
|
|
+ ]
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def list_builtin_providers(
|
|
|
+ cls, tenant_id: str
|
|
|
+ ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
|
|
|
+ """
|
|
|
+ list all the builtin providers
|
|
|
+ """
|
|
|
+ yield from cls.list_hardcoded_providers()
|
|
|
+ # get plugin providers
|
|
|
+ yield from cls.list_plugin_providers(tenant_id)
|
|
|
|
|
|
@classmethod
|
|
|
- def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
|
|
+ def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
|
|
"""
|
|
|
list all the builtin providers
|
|
|
"""
|
|
@@ -391,7 +445,7 @@ class ToolManager:
|
|
|
parent_type=BuiltinToolProviderController,
|
|
|
)
|
|
|
provider: BuiltinToolProviderController = provider_class()
|
|
|
- cls._builtin_providers[provider.entity.identity.name] = provider
|
|
|
+ cls._hardcoded_providers[provider.entity.identity.name] = provider
|
|
|
for tool in provider.get_tools():
|
|
|
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
|
|
|
yield provider
|
|
@@ -403,13 +457,13 @@ class ToolManager:
|
|
|
cls._builtin_providers_loaded = True
|
|
|
|
|
|
@classmethod
|
|
|
- def load_builtin_providers_cache(cls):
|
|
|
- for _ in cls.list_builtin_providers():
|
|
|
+ def load_hardcoded_providers_cache(cls):
|
|
|
+ for _ in cls.list_hardcoded_providers():
|
|
|
pass
|
|
|
|
|
|
@classmethod
|
|
|
- def clear_builtin_providers_cache(cls):
|
|
|
- cls._builtin_providers = {}
|
|
|
+ def clear_hardcoded_providers_cache(cls):
|
|
|
+ cls._hardcoded_providers = {}
|
|
|
cls._builtin_providers_loaded = False
|
|
|
|
|
|
@classmethod
|
|
@@ -423,7 +477,7 @@ class ToolManager:
|
|
|
"""
|
|
|
if len(cls._builtin_tools_labels) == 0:
|
|
|
# init the builtin providers
|
|
|
- cls.load_builtin_providers_cache()
|
|
|
+ cls.load_hardcoded_providers_cache()
|
|
|
|
|
|
if tool_name not in cls._builtin_tools_labels:
|
|
|
return None
|
|
@@ -432,9 +486,9 @@ class ToolManager:
|
|
|
|
|
|
@classmethod
|
|
|
def user_list_providers(
|
|
|
- cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral
|
|
|
- ) -> list[UserToolProvider]:
|
|
|
- result_providers: dict[str, UserToolProvider] = {}
|
|
|
+ cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
|
|
+ ) -> list[ToolProviderApiEntity]:
|
|
|
+ result_providers: dict[str, ToolProviderApiEntity] = {}
|
|
|
|
|
|
filters = []
|
|
|
if not typ:
|
|
@@ -444,7 +498,7 @@ class ToolManager:
|
|
|
|
|
|
if "builtin" in filters:
|
|
|
# get builtin providers
|
|
|
- builtin_providers = cls.list_builtin_providers()
|
|
|
+ builtin_providers = cls.list_builtin_providers(tenant_id)
|
|
|
|
|
|
# get db builtin providers
|
|
|
db_builtin_providers: list[BuiltinToolProvider] = (
|
|
@@ -666,4 +720,4 @@ class ToolManager:
|
|
|
raise ValueError(f"provider type {provider_type} not found")
|
|
|
|
|
|
|
|
|
-ToolManager.load_builtin_providers_cache()
|
|
|
+ToolManager.load_hardcoded_providers_cache()
|