api_tool_provider.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from typing import Optional
  2. from core.tools.entities.common_entities import I18nObject
  3. from core.tools.entities.tool_bundle import ApiToolBundle
  4. from core.tools.entities.tool_entities import (
  5. ApiProviderAuthType,
  6. ToolCredentialsOption,
  7. ToolDescription,
  8. ToolIdentity,
  9. ToolProviderCredentials,
  10. ToolProviderIdentity,
  11. ToolProviderType,
  12. )
  13. from core.tools.provider.tool_provider import ToolProviderController
  14. from core.tools.tool.api_tool import ApiTool
  15. from core.tools.tool.tool import Tool
  16. from extensions.ext_database import db
  17. from models.tools import ApiToolProvider
  18. class ApiToolProviderController(ToolProviderController):
  19. provider_id: str
  20. @staticmethod
  21. def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
  22. credentials_schema = {
  23. "auth_type": ToolProviderCredentials(
  24. name="auth_type",
  25. required=True,
  26. type=ToolProviderCredentials.CredentialsType.SELECT,
  27. options=[
  28. ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  29. ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
  30. ],
  31. default="none",
  32. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  33. )
  34. }
  35. if auth_type == ApiProviderAuthType.API_KEY:
  36. credentials_schema = {
  37. **credentials_schema,
  38. "api_key_header": ToolProviderCredentials(
  39. name="api_key_header",
  40. required=False,
  41. default="api_key",
  42. type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
  43. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  44. ),
  45. "api_key_value": ToolProviderCredentials(
  46. name="api_key_value",
  47. required=True,
  48. type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
  49. help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
  50. ),
  51. "api_key_header_prefix": ToolProviderCredentials(
  52. name="api_key_header_prefix",
  53. required=False,
  54. default="basic",
  55. type=ToolProviderCredentials.CredentialsType.SELECT,
  56. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  57. options=[
  58. ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  59. ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  60. ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  61. ],
  62. ),
  63. }
  64. elif auth_type == ApiProviderAuthType.NONE:
  65. pass
  66. else:
  67. raise ValueError(f"invalid auth type {auth_type}")
  68. user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
  69. return ApiToolProviderController(
  70. identity=ToolProviderIdentity(
  71. author=user_name,
  72. name=db_provider.name,
  73. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  74. description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
  75. icon=db_provider.icon,
  76. ),
  77. credentials_schema=credentials_schema,
  78. provider_id=db_provider.id or "",
  79. tools=None,
  80. )
  81. @property
  82. def provider_type(self) -> ToolProviderType:
  83. return ToolProviderType.API
  84. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  85. """
  86. parse tool bundle to tool
  87. :param tool_bundle: the tool bundle
  88. :return: the tool
  89. """
  90. return ApiTool(
  91. api_bundle=tool_bundle,
  92. identity=ToolIdentity(
  93. author=tool_bundle.author,
  94. name=tool_bundle.operation_id or "",
  95. label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
  96. icon=self.identity.icon if self.identity else None,
  97. provider=self.provider_id,
  98. ),
  99. description=ToolDescription(
  100. human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
  101. llm=tool_bundle.summary or "",
  102. ),
  103. parameters=tool_bundle.parameters or [],
  104. )
  105. def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
  106. """
  107. load bundled tools
  108. :param tools: the bundled tools
  109. :return: the tools
  110. """
  111. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  112. return self.tools
  113. def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
  114. """
  115. fetch tools from database
  116. :param user_id: the user id
  117. :param tenant_id: the tenant id
  118. :return: the tools
  119. """
  120. if self.tools is not None:
  121. return self.tools
  122. if self.identity is None:
  123. return None
  124. tools: list[Tool] = []
  125. # get tenant api providers
  126. db_providers: list[ApiToolProvider] = (
  127. db.session.query(ApiToolProvider)
  128. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
  129. .all()
  130. )
  131. if db_providers and len(db_providers) != 0:
  132. for db_provider in db_providers:
  133. for tool in db_provider.tools:
  134. assistant_tool = self._parse_tool_bundle(tool)
  135. assistant_tool.is_team_authorization = True
  136. tools.append(assistant_tool)
  137. self.tools = tools
  138. return tools
  139. def get_tool(self, tool_name: str) -> Tool:
  140. """
  141. get tool by name
  142. :param tool_name: the name of the tool
  143. :return: the tool
  144. """
  145. if self.tools is None:
  146. self.get_tools()
  147. for tool in self.tools or []:
  148. if tool.identity is None:
  149. continue
  150. if tool.identity.name == tool_name:
  151. return tool
  152. raise ValueError(f"tool {tool_name} not found")