provider.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from typing import Any
  2. from core.plugin.manager.tool import PluginToolManager
  3. from core.tools.__base.tool_runtime import ToolRuntime
  4. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  5. from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
  6. from core.tools.errors import ToolProviderCredentialValidationError
  7. from core.tools.plugin_tool.tool import PluginTool
  8. class PluginToolProviderController(BuiltinToolProviderController):
  9. entity: ToolProviderEntityWithPlugin
  10. tenant_id: str
  11. plugin_unique_identifier: str
  12. def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_unique_identifier: str) -> None:
  13. self.entity = entity
  14. self.tenant_id = tenant_id
  15. self.plugin_unique_identifier = plugin_unique_identifier
  16. @property
  17. def provider_type(self) -> ToolProviderType:
  18. """
  19. returns the type of the provider
  20. :return: type of the provider
  21. """
  22. return ToolProviderType.PLUGIN
  23. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
  24. """
  25. validate the credentials of the provider
  26. """
  27. manager = PluginToolManager()
  28. if not manager.validate_provider_credentials(
  29. tenant_id=self.tenant_id,
  30. user_id=user_id,
  31. plugin_unique_identifier=self.plugin_unique_identifier,
  32. provider=self.entity.identity.name,
  33. credentials=credentials,
  34. ):
  35. raise ToolProviderCredentialValidationError("Invalid credentials")
  36. def get_tool(self, tool_name: str) -> PluginTool:
  37. """
  38. return tool with given name
  39. """
  40. tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name)
  41. if not tool_entity:
  42. raise ValueError(f"Tool with name {tool_name} not found")
  43. return PluginTool(
  44. entity=tool_entity,
  45. runtime=ToolRuntime(tenant_id=self.tenant_id),
  46. tenant_id=self.tenant_id,
  47. plugin_unique_identifier=self.plugin_unique_identifier,
  48. )
  49. def get_tools(self) -> list[PluginTool]:
  50. """
  51. get all tools
  52. """
  53. return [
  54. PluginTool(
  55. entity=tool_entity,
  56. runtime=ToolRuntime(tenant_id=self.tenant_id),
  57. tenant_id=self.tenant_id,
  58. plugin_unique_identifier=self.plugin_unique_identifier,
  59. )
  60. for tool_entity in self.entity.tools
  61. ]