provider.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from typing import Any
  2. from core.plugin.manager.tool import PluginToolManager
  3. from core.tools.__base.tool_runtime import ToolRuntime
  4. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  5. from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
  6. from core.tools.errors import ToolProviderCredentialValidationError
  7. from core.tools.plugin_tool.tool import PluginTool
  8. class PluginToolProviderController(BuiltinToolProviderController):
  9. entity: ToolProviderEntityWithPlugin
  10. tenant_id: str
  11. plugin_id: str
  12. def __init__(self, entity: ToolProviderEntityWithPlugin, plugin_id: str, tenant_id: str) -> None:
  13. self.entity = entity
  14. self.tenant_id = tenant_id
  15. self.plugin_id = plugin_id
  16. @property
  17. def provider_type(self) -> ToolProviderType:
  18. """
  19. returns the type of the provider
  20. :return: type of the provider
  21. """
  22. return ToolProviderType.PLUGIN
  23. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
  24. """
  25. validate the credentials of the provider
  26. """
  27. manager = PluginToolManager()
  28. if not manager.validate_provider_credentials(
  29. tenant_id=self.tenant_id,
  30. user_id=user_id,
  31. provider=self.entity.identity.name,
  32. credentials=credentials,
  33. ):
  34. raise ToolProviderCredentialValidationError("Invalid credentials")
  35. def get_tool(self, tool_name: str) -> PluginTool:
  36. """
  37. return tool with given name
  38. """
  39. tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name)
  40. if not tool_entity:
  41. raise ValueError(f"Tool with name {tool_name} not found")
  42. return PluginTool(
  43. entity=tool_entity,
  44. runtime=ToolRuntime(tenant_id=self.tenant_id),
  45. tenant_id=self.tenant_id,
  46. )
  47. def get_tools(self) -> list[PluginTool]:
  48. """
  49. get all tools
  50. """
  51. return [
  52. PluginTool(
  53. entity=tool_entity,
  54. runtime=ToolRuntime(tenant_id=self.tenant_id),
  55. tenant_id=self.tenant_id,
  56. )
  57. for tool_entity in self.entity.tools
  58. ]