provider.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from typing import Any
  2. from core.plugin.manager.tool import PluginToolManager
  3. from core.tools.__base.tool_runtime import ToolRuntime
  4. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  5. from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
  6. from core.tools.errors import ToolProviderCredentialValidationError
  7. from core.tools.plugin_tool.tool import PluginTool
  8. class PluginToolProviderController(BuiltinToolProviderController):
  9. entity: ToolProviderEntityWithPlugin
  10. tenant_id: str
  11. plugin_id: str
  12. plugin_unique_identifier: str
  13. def __init__(
  14. self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
  15. ) -> None:
  16. self.entity = entity
  17. self.tenant_id = tenant_id
  18. self.plugin_id = plugin_id
  19. self.plugin_unique_identifier = plugin_unique_identifier
  20. @property
  21. def provider_type(self) -> ToolProviderType:
  22. """
  23. returns the type of the provider
  24. :return: type of the provider
  25. """
  26. return ToolProviderType.PLUGIN
  27. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
  28. """
  29. validate the credentials of the provider
  30. """
  31. manager = PluginToolManager()
  32. if not manager.validate_provider_credentials(
  33. tenant_id=self.tenant_id,
  34. user_id=user_id,
  35. provider=self.entity.identity.name,
  36. credentials=credentials,
  37. ):
  38. raise ToolProviderCredentialValidationError("Invalid credentials")
  39. def get_tool(self, tool_name: str) -> PluginTool:
  40. """
  41. return tool with given name
  42. """
  43. tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name)
  44. if not tool_entity:
  45. raise ValueError(f"Tool with name {tool_name} not found")
  46. return PluginTool(
  47. entity=tool_entity,
  48. runtime=ToolRuntime(tenant_id=self.tenant_id),
  49. tenant_id=self.tenant_id,
  50. icon=self.entity.identity.icon,
  51. plugin_unique_identifier=self.plugin_unique_identifier,
  52. )
  53. def get_tools(self) -> list[PluginTool]:
  54. """
  55. get all tools
  56. """
  57. return [
  58. PluginTool(
  59. entity=tool_entity,
  60. runtime=ToolRuntime(tenant_id=self.tenant_id),
  61. tenant_id=self.tenant_id,
  62. icon=self.entity.identity.icon,
  63. plugin_unique_identifier=self.plugin_unique_identifier,
  64. )
  65. for tool_entity in self.entity.tools
  66. ]