123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- from core.tools.entities.common_entities import I18nObject
- from core.tools.entities.tool_bundle import ApiToolBundle
- from core.tools.entities.tool_entities import (
- ApiProviderAuthType,
- ToolCredentialsOption,
- ToolProviderCredentials,
- ToolProviderType,
- )
- from core.tools.provider.tool_provider import ToolProviderController
- from core.tools.tool.api_tool import ApiTool
- from core.tools.tool.tool import Tool
- from extensions.ext_database import db
- from models.tools import ApiToolProvider
- class ApiToolProviderController(ToolProviderController):
- provider_id: str
- @staticmethod
- def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
- credentials_schema = {
- "auth_type": ToolProviderCredentials(
- name="auth_type",
- required=True,
- type=ToolProviderCredentials.CredentialsType.SELECT,
- options=[
- ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
- ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
- ],
- default="none",
- help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
- )
- }
- if auth_type == ApiProviderAuthType.API_KEY:
- credentials_schema = {
- **credentials_schema,
- "api_key_header": ToolProviderCredentials(
- name="api_key_header",
- required=False,
- default="api_key",
- type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
- help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
- ),
- "api_key_value": ToolProviderCredentials(
- name="api_key_value",
- required=True,
- type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
- help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
- ),
- "api_key_header_prefix": ToolProviderCredentials(
- name="api_key_header_prefix",
- required=False,
- default="basic",
- type=ToolProviderCredentials.CredentialsType.SELECT,
- help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
- options=[
- ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
- ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
- ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
- ],
- ),
- }
- elif auth_type == ApiProviderAuthType.NONE:
- pass
- else:
- raise ValueError(f"invalid auth type {auth_type}")
- user_name = db_provider.user.name if db_provider.user_id else ""
- return ApiToolProviderController(
- **{
- "identity": {
- "author": user_name,
- "name": db_provider.name,
- "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
- "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
- "icon": db_provider.icon,
- },
- "credentials_schema": credentials_schema,
- "provider_id": db_provider.id or "",
- }
- )
- @property
- def provider_type(self) -> ToolProviderType:
- return ToolProviderType.API
- def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
- """
- parse tool bundle to tool
- :param tool_bundle: the tool bundle
- :return: the tool
- """
- return ApiTool(
- **{
- "api_bundle": tool_bundle,
- "identity": {
- "author": tool_bundle.author,
- "name": tool_bundle.operation_id,
- "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
- "icon": self.identity.icon,
- "provider": self.provider_id,
- },
- "description": {
- "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
- "llm": tool_bundle.summary or "",
- },
- "parameters": tool_bundle.parameters or [],
- }
- )
- def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
- """
- load bundled tools
- :param tools: the bundled tools
- :return: the tools
- """
- self.tools = [self._parse_tool_bundle(tool) for tool in tools]
- return self.tools
- def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
- """
- fetch tools from database
- :param user_id: the user id
- :param tenant_id: the tenant id
- :return: the tools
- """
- if self.tools is not None:
- return self.tools
- tools: list[Tool] = []
- # get tenant api providers
- db_providers: list[ApiToolProvider] = (
- db.session.query(ApiToolProvider)
- .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
- .all()
- )
- if db_providers and len(db_providers) != 0:
- for db_provider in db_providers:
- for tool in db_provider.tools:
- assistant_tool = self._parse_tool_bundle(tool)
- assistant_tool.is_team_authorization = True
- tools.append(assistant_tool)
- self.tools = tools
- return tools
- def get_tool(self, tool_name: str) -> ApiTool:
- """
- get tool by name
- :param tool_name: the name of the tool
- :return: the tool
- """
- if self.tools is None:
- self.get_tools()
- for tool in self.tools:
- if tool.identity.name == tool_name:
- return tool
- raise ValueError(f"tool {tool_name} not found")
|