api_tool_provider.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from core.tools.entities.common_entities import I18nObject
  2. from core.tools.entities.tool_bundle import ApiToolBundle
  3. from core.tools.entities.tool_entities import (
  4. ApiProviderAuthType,
  5. ToolCredentialsOption,
  6. ToolProviderCredentials,
  7. ToolProviderType,
  8. )
  9. from core.tools.provider.tool_provider import ToolProviderController
  10. from core.tools.tool.api_tool import ApiTool
  11. from core.tools.tool.tool import Tool
  12. from extensions.ext_database import db
  13. from models.tools import ApiToolProvider
  14. class ApiToolProviderController(ToolProviderController):
  15. provider_id: str
  16. @staticmethod
  17. def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
  18. credentials_schema = {
  19. "auth_type": ToolProviderCredentials(
  20. name="auth_type",
  21. required=True,
  22. type=ToolProviderCredentials.CredentialsType.SELECT,
  23. options=[
  24. ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  25. ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
  26. ],
  27. default="none",
  28. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  29. )
  30. }
  31. if auth_type == ApiProviderAuthType.API_KEY:
  32. credentials_schema = {
  33. **credentials_schema,
  34. "api_key_header": ToolProviderCredentials(
  35. name="api_key_header",
  36. required=False,
  37. default="api_key",
  38. type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
  39. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  40. ),
  41. "api_key_value": ToolProviderCredentials(
  42. name="api_key_value",
  43. required=True,
  44. type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
  45. help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
  46. ),
  47. "api_key_header_prefix": ToolProviderCredentials(
  48. name="api_key_header_prefix",
  49. required=False,
  50. default="basic",
  51. type=ToolProviderCredentials.CredentialsType.SELECT,
  52. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  53. options=[
  54. ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  55. ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  56. ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  57. ],
  58. ),
  59. }
  60. elif auth_type == ApiProviderAuthType.NONE:
  61. pass
  62. else:
  63. raise ValueError(f"invalid auth type {auth_type}")
  64. user_name = db_provider.user.name if db_provider.user_id else ""
  65. return ApiToolProviderController(
  66. **{
  67. "identity": {
  68. "author": user_name,
  69. "name": db_provider.name,
  70. "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
  71. "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
  72. "icon": db_provider.icon,
  73. },
  74. "credentials_schema": credentials_schema,
  75. "provider_id": db_provider.id or "",
  76. }
  77. )
  78. @property
  79. def provider_type(self) -> ToolProviderType:
  80. return ToolProviderType.API
  81. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  82. """
  83. parse tool bundle to tool
  84. :param tool_bundle: the tool bundle
  85. :return: the tool
  86. """
  87. return ApiTool(
  88. **{
  89. "api_bundle": tool_bundle,
  90. "identity": {
  91. "author": tool_bundle.author,
  92. "name": tool_bundle.operation_id,
  93. "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
  94. "icon": self.identity.icon,
  95. "provider": self.provider_id,
  96. },
  97. "description": {
  98. "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
  99. "llm": tool_bundle.summary or "",
  100. },
  101. "parameters": tool_bundle.parameters or [],
  102. }
  103. )
  104. def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
  105. """
  106. load bundled tools
  107. :param tools: the bundled tools
  108. :return: the tools
  109. """
  110. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  111. return self.tools
  112. def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
  113. """
  114. fetch tools from database
  115. :param user_id: the user id
  116. :param tenant_id: the tenant id
  117. :return: the tools
  118. """
  119. if self.tools is not None:
  120. return self.tools
  121. tools: list[Tool] = []
  122. # get tenant api providers
  123. db_providers: list[ApiToolProvider] = (
  124. db.session.query(ApiToolProvider)
  125. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
  126. .all()
  127. )
  128. if db_providers and len(db_providers) != 0:
  129. for db_provider in db_providers:
  130. for tool in db_provider.tools:
  131. assistant_tool = self._parse_tool_bundle(tool)
  132. assistant_tool.is_team_authorization = True
  133. tools.append(assistant_tool)
  134. self.tools = tools
  135. return tools
  136. def get_tool(self, tool_name: str) -> ApiTool:
  137. """
  138. get tool by name
  139. :param tool_name: the name of the tool
  140. :return: the tool
  141. """
  142. if self.tools is None:
  143. self.get_tools()
  144. for tool in self.tools:
  145. if tool.identity.name == tool_name:
  146. return tool
  147. raise ValueError(f"tool {tool_name} not found")