provider.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from pydantic import Field
  2. from core.entities.provider_entities import ProviderConfig
  3. from core.tools.__base.tool_provider import ToolProviderController
  4. from core.tools.custom_tool.tool import ApiTool
  5. from core.tools.entities.common_entities import I18nObject
  6. from core.tools.entities.tool_bundle import ApiToolBundle
  7. from core.tools.entities.tool_entities import (
  8. ApiProviderAuthType,
  9. ToolProviderEntity,
  10. ToolProviderIdentity,
  11. ToolProviderType,
  12. )
  13. from extensions.ext_database import db
  14. from models.tools import ApiToolProvider
  15. class ApiToolProviderController(ToolProviderController):
  16. provider_id: str
  17. tenant_id: str
  18. tools: list[ApiTool] = Field(default_factory=list)
  19. def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
  20. super().__init__(entity)
  21. self.provider_id = provider_id
  22. self.tenant_id = tenant_id
  23. self.tools = []
  24. @classmethod
  25. def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
  26. credentials_schema = {
  27. "auth_type": ProviderConfig(
  28. name="auth_type",
  29. required=True,
  30. type=ProviderConfig.Type.SELECT,
  31. options=[
  32. ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
  33. ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
  34. ],
  35. default="none",
  36. help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
  37. )
  38. }
  39. if auth_type == ApiProviderAuthType.API_KEY:
  40. credentials_schema = {
  41. **credentials_schema,
  42. "api_key_header": ProviderConfig(
  43. name="api_key_header",
  44. required=False,
  45. default="api_key",
  46. type=ProviderConfig.Type.TEXT_INPUT,
  47. help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
  48. ),
  49. "api_key_value": ProviderConfig(
  50. name="api_key_value",
  51. required=True,
  52. type=ProviderConfig.Type.SECRET_INPUT,
  53. help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
  54. ),
  55. "api_key_header_prefix": ProviderConfig(
  56. name="api_key_header_prefix",
  57. required=False,
  58. default="basic",
  59. type=ProviderConfig.Type.SELECT,
  60. help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
  61. options=[
  62. ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
  63. ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
  64. ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
  65. ],
  66. ),
  67. }
  68. elif auth_type == ApiProviderAuthType.NONE:
  69. pass
  70. user = db_provider.user
  71. user_name = user.name if user else ""
  72. return ApiToolProviderController(
  73. entity=ToolProviderEntity(
  74. identity=ToolProviderIdentity(
  75. author=user_name,
  76. name=db_provider.name,
  77. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  78. description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
  79. icon=db_provider.icon,
  80. ),
  81. credentials_schema=credentials_schema,
  82. ),
  83. provider_id=db_provider.id or "",
  84. tenant_id=db_provider.tenant_id or "",
  85. )
  86. @property
  87. def provider_type(self) -> ToolProviderType:
  88. return ToolProviderType.API
  89. def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
  90. """
  91. parse tool bundle to tool
  92. :param tool_bundle: the tool bundle
  93. :return: the tool
  94. """
  95. return ApiTool(
  96. **{
  97. "api_bundle": tool_bundle,
  98. "identity": {
  99. "author": tool_bundle.author,
  100. "name": tool_bundle.operation_id,
  101. "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
  102. "icon": self.entity.identity.icon,
  103. "provider": self.provider_id,
  104. },
  105. "description": {
  106. "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
  107. "llm": tool_bundle.summary or "",
  108. },
  109. "parameters": tool_bundle.parameters or [],
  110. }
  111. )
  112. def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
  113. """
  114. load bundled tools
  115. :param tools: the bundled tools
  116. :return: the tools
  117. """
  118. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  119. return self.tools
  120. def get_tools(self, tenant_id: str) -> list[ApiTool]:
  121. """
  122. fetch tools from database
  123. :param user_id: the user id
  124. :param tenant_id: the tenant id
  125. :return: the tools
  126. """
  127. if self.tools is not None:
  128. return self.tools
  129. tools: list[ApiTool] = []
  130. # get tenant api providers
  131. db_providers: list[ApiToolProvider] = (
  132. db.session.query(ApiToolProvider)
  133. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
  134. .all()
  135. )
  136. if db_providers and len(db_providers) != 0:
  137. for db_provider in db_providers:
  138. for tool in db_provider.tools:
  139. assistant_tool = self._parse_tool_bundle(tool)
  140. tools.append(assistant_tool)
  141. self.tools = tools
  142. return tools
  143. def get_tool(self, tool_name: str) -> ApiTool:
  144. """
  145. get tool by name
  146. :param tool_name: the name of the tool
  147. :return: the tool
  148. """
  149. if self.tools is None:
  150. self.get_tools(self.tenant_id)
  151. for tool in self.tools:
  152. if tool.entity.identity.name == tool_name:
  153. return tool
  154. raise ValueError(f"tool {tool_name} not found")