api_tool_provider.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from pydantic import Field
  2. from core.entities.provider_entities import ProviderConfig
  3. from core.tools.entities.common_entities import I18nObject
  4. from core.tools.entities.tool_bundle import ApiToolBundle
  5. from core.tools.entities.tool_entities import (
  6. ApiProviderAuthType,
  7. ToolProviderType,
  8. )
  9. from core.tools.provider.tool_provider import ToolProviderController
  10. from core.tools.tool.api_tool import ApiTool
  11. from extensions.ext_database import db
  12. from models.tools import ApiToolProvider
  13. class ApiToolProviderController(ToolProviderController):
  14. provider_id: str
  15. tenant_id: str
  16. tools: list[ApiTool] = Field(default_factory=list)
  17. @staticmethod
  18. def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
  19. credentials_schema = {
  20. "auth_type": ProviderConfig(
  21. name="auth_type",
  22. required=True,
  23. type=ProviderConfig.Type.SELECT,
  24. options=[
  25. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  26. ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
  27. ],
  28. default="none",
  29. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  30. )
  31. }
  32. if auth_type == ApiProviderAuthType.API_KEY:
  33. credentials_schema = {
  34. **credentials_schema,
  35. "api_key_header": ProviderConfig(
  36. name="api_key_header",
  37. required=False,
  38. default="api_key",
  39. type=ProviderConfig.Type.TEXT_INPUT,
  40. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  41. ),
  42. "api_key_value": ProviderConfig(
  43. name="api_key_value",
  44. required=True,
  45. type=ProviderConfig.Type.SECRET_INPUT,
  46. help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
  47. ),
  48. "api_key_header_prefix": ProviderConfig(
  49. name="api_key_header_prefix",
  50. required=False,
  51. default="basic",
  52. type=ProviderConfig.Type.SELECT,
  53. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  54. options=[
  55. ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  56. ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  57. ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  58. ],
  59. ),
  60. }
  61. elif auth_type == ApiProviderAuthType.NONE:
  62. pass
  63. else:
  64. raise ValueError(f"invalid auth type {auth_type}")
  65. user_name = db_provider.user.name if db_provider.user_id else ""
  66. return ApiToolProviderController(
  67. **{
  68. "identity": {
  69. "author": user_name,
  70. "name": db_provider.name,
  71. "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
  72. "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
  73. "icon": db_provider.icon,
  74. },
  75. "credentials_schema": credentials_schema,
  76. "provider_id": db_provider.id or "",
  77. "tenant_id": db_provider.tenant_id or "",
  78. },
  79. )
  80. @property
  81. def provider_type(self) -> ToolProviderType:
  82. return ToolProviderType.API
  83. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  84. """
  85. parse tool bundle to tool
  86. :param tool_bundle: the tool bundle
  87. :return: the tool
  88. """
  89. return ApiTool(
  90. **{
  91. "api_bundle": tool_bundle,
  92. "identity": {
  93. "author": tool_bundle.author,
  94. "name": tool_bundle.operation_id,
  95. "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
  96. "icon": self.identity.icon,
  97. "provider": self.provider_id,
  98. },
  99. "description": {
  100. "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
  101. "llm": tool_bundle.summary or "",
  102. },
  103. "parameters": tool_bundle.parameters or [],
  104. }
  105. )
  106. def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
  107. """
  108. load bundled tools
  109. :param tools: the bundled tools
  110. :return: the tools
  111. """
  112. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  113. return self.tools
  114. def get_tools(self, tenant_id: str) -> list[ApiTool]:
  115. """
  116. fetch tools from database
  117. :param user_id: the user id
  118. :param tenant_id: the tenant id
  119. :return: the tools
  120. """
  121. if self.tools is not None:
  122. return self.tools
  123. tools: list[ApiTool] = []
  124. # get tenant api providers
  125. db_providers: list[ApiToolProvider] = (
  126. db.session.query(ApiToolProvider)
  127. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
  128. .all()
  129. )
  130. if db_providers and len(db_providers) != 0:
  131. for db_provider in db_providers:
  132. for tool in db_provider.tools:
  133. assistant_tool = self._parse_tool_bundle(tool)
  134. assistant_tool.is_team_authorization = True
  135. tools.append(assistant_tool)
  136. self.tools = tools
  137. return tools
  138. def get_tool(self, tool_name: str) -> ApiTool:
  139. """
  140. get tool by name
  141. :param tool_name: the name of the tool
  142. :return: the tool
  143. """
  144. if self.tools is None:
  145. self.get_tools(self.tenant_id)
  146. for tool in self.tools:
  147. if tool.identity.name == tool_name:
  148. return tool
  149. raise ValueError(f"tool {tool_name} not found")