provider.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from pydantic import Field
  2. from core.entities.provider_entities import ProviderConfig
  3. from core.tools.__base.tool_provider import ToolProviderController
  4. from core.tools.__base.tool_runtime import ToolRuntime
  5. from core.tools.custom_tool.tool import ApiTool
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_bundle import ApiToolBundle
  8. from core.tools.entities.tool_entities import (
  9. ApiProviderAuthType,
  10. ToolDescription,
  11. ToolEntity,
  12. ToolIdentity,
  13. ToolProviderEntity,
  14. ToolProviderIdentity,
  15. ToolProviderType,
  16. )
  17. from extensions.ext_database import db
  18. from models.tools import ApiToolProvider
  19. class ApiToolProviderController(ToolProviderController):
  20. provider_id: str
  21. tenant_id: str
  22. tools: list[ApiTool] = Field(default_factory=list)
  23. def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
  24. super().__init__(entity)
  25. self.provider_id = provider_id
  26. self.tenant_id = tenant_id
  27. self.tools = []
  28. @classmethod
  29. def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
  30. credentials_schema = [
  31. ProviderConfig(
  32. name="auth_type",
  33. required=True,
  34. type=ProviderConfig.Type.SELECT,
  35. options=[
  36. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  37. ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
  38. ],
  39. default="none",
  40. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  41. )
  42. ]
  43. if auth_type == ApiProviderAuthType.API_KEY:
  44. credentials_schema = [
  45. *credentials_schema,
  46. ProviderConfig(
  47. name="api_key_header",
  48. required=False,
  49. default="api_key",
  50. type=ProviderConfig.Type.TEXT_INPUT,
  51. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  52. ),
  53. ProviderConfig(
  54. name="api_key_value",
  55. required=True,
  56. type=ProviderConfig.Type.SECRET_INPUT,
  57. help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
  58. ),
  59. ProviderConfig(
  60. name="api_key_header_prefix",
  61. required=False,
  62. default="basic",
  63. type=ProviderConfig.Type.SELECT,
  64. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  65. options=[
  66. ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  67. ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  68. ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  69. ],
  70. ),
  71. ]
  72. elif auth_type == ApiProviderAuthType.NONE:
  73. pass
  74. user = db_provider.user
  75. user_name = user.name if user else ""
  76. return ApiToolProviderController(
  77. entity=ToolProviderEntity(
  78. identity=ToolProviderIdentity(
  79. author=user_name,
  80. name=db_provider.name,
  81. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  82. description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
  83. icon=db_provider.icon,
  84. ),
  85. credentials_schema=credentials_schema,
  86. plugin_id=None,
  87. ),
  88. provider_id=db_provider.id or "",
  89. tenant_id=db_provider.tenant_id or "",
  90. )
  91. @property
  92. def provider_type(self) -> ToolProviderType:
  93. return ToolProviderType.API
  94. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  95. """
  96. parse tool bundle to tool
  97. :param tool_bundle: the tool bundle
  98. :return: the tool
  99. """
  100. return ApiTool(
  101. api_bundle=tool_bundle,
  102. entity=ToolEntity(
  103. identity=ToolIdentity(
  104. author=tool_bundle.author,
  105. name=tool_bundle.operation_id or "default_tool",
  106. label=I18nObject(
  107. en_US=tool_bundle.operation_id or "default_tool",
  108. zh_Hans=tool_bundle.operation_id or "default_tool",
  109. ),
  110. icon=self.entity.identity.icon,
  111. provider=self.provider_id,
  112. ),
  113. description=ToolDescription(
  114. human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
  115. llm=tool_bundle.summary or "",
  116. ),
  117. parameters=tool_bundle.parameters or [],
  118. ),
  119. runtime=ToolRuntime(tenant_id=self.tenant_id),
  120. )
  121. def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
  122. """
  123. load bundled tools
  124. :param tools: the bundled tools
  125. :return: the tools
  126. """
  127. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  128. return self.tools
  129. def get_tools(self, tenant_id: str) -> list[ApiTool]:
  130. """
  131. fetch tools from database
  132. :param user_id: the user id
  133. :param tenant_id: the tenant id
  134. :return: the tools
  135. """
  136. if self.tools is not None:
  137. return self.tools
  138. tools: list[ApiTool] = []
  139. # get tenant api providers
  140. db_providers: list[ApiToolProvider] = (
  141. db.session.query(ApiToolProvider)
  142. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
  143. .all()
  144. )
  145. if db_providers and len(db_providers) != 0:
  146. for db_provider in db_providers:
  147. for tool in db_provider.tools:
  148. assistant_tool = self._parse_tool_bundle(tool)
  149. tools.append(assistant_tool)
  150. self.tools = tools
  151. return tools
  152. def get_tool(self, tool_name: str) -> ApiTool:
  153. """
  154. get tool by name
  155. :param tool_name: the name of the tool
  156. :return: the tool
  157. """
  158. if self.tools is None:
  159. self.get_tools(self.tenant_id)
  160. for tool in self.tools:
  161. if tool.entity.identity.name == tool_name:
  162. return tool
  163. raise ValueError(f"tool {tool_name} not found")