provider.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from typing import Any
  2. from core.plugin.manager.tool import PluginToolManager
  3. from core.tools.__base.tool_runtime import ToolRuntime
  4. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  5. from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
  6. from core.tools.errors import ToolProviderCredentialValidationError
  7. from core.tools.plugin_tool.tool import PluginTool
  8. class PluginToolProviderController(BuiltinToolProviderController):
  9. entity: ToolProviderEntityWithPlugin
  10. tenant_id: str
  11. def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str) -> None:
  12. self.entity = entity
  13. self.tenant_id = tenant_id
  14. @property
  15. def provider_type(self) -> ToolProviderType:
  16. """
  17. returns the type of the provider
  18. :return: type of the provider
  19. """
  20. return ToolProviderType.PLUGIN
  21. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
  22. """
  23. validate the credentials of the provider
  24. """
  25. manager = PluginToolManager()
  26. if not manager.validate_provider_credentials(
  27. tenant_id=self.tenant_id,
  28. user_id=user_id,
  29. provider=self.entity.identity.name,
  30. credentials=credentials,
  31. ):
  32. raise ToolProviderCredentialValidationError("Invalid credentials")
  33. def get_tool(self, tool_name: str) -> PluginTool:
  34. """
  35. return tool with given name
  36. """
  37. tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name)
  38. if not tool_entity:
  39. raise ValueError(f"Tool with name {tool_name} not found")
  40. return PluginTool(
  41. entity=tool_entity,
  42. runtime=ToolRuntime(tenant_id=self.tenant_id),
  43. tenant_id=self.tenant_id,
  44. )
  45. def get_tools(self) -> list[PluginTool]:
  46. """
  47. get all tools
  48. """
  49. return [
  50. PluginTool(
  51. entity=tool_entity,
  52. runtime=ToolRuntime(tenant_id=self.tenant_id),
  53. tenant_id=self.tenant_id,
  54. )
  55. for tool_entity in self.entity.tools
  56. ]