provider.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import Any
  2. from core.plugin.manager.tool import PluginToolManager
  3. from core.tools.__base.tool_runtime import ToolRuntime
  4. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  5. from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
  6. from core.tools.errors import ToolProviderCredentialValidationError
  7. from core.tools.plugin_tool.tool import PluginTool
  8. class PluginToolProviderController(BuiltinToolProviderController):
  9. entity: ToolProviderEntityWithPlugin
  10. tenant_id: str
  11. plugin_id: str
  12. plugin_unique_identifier: str
  13. def __init__(
  14. self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
  15. ) -> None:
  16. self.entity = entity
  17. self.tenant_id = tenant_id
  18. self.plugin_id = plugin_id
  19. self.plugin_unique_identifier = plugin_unique_identifier
  20. @property
  21. def provider_type(self) -> ToolProviderType:
  22. """
  23. returns the type of the provider
  24. :return: type of the provider
  25. """
  26. return ToolProviderType.PLUGIN
  27. def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
  28. """
  29. validate the credentials of the provider
  30. """
  31. manager = PluginToolManager()
  32. if not manager.validate_provider_credentials(
  33. tenant_id=self.tenant_id,
  34. user_id=user_id,
  35. provider=self.entity.identity.name,
  36. credentials=credentials,
  37. ):
  38. raise ToolProviderCredentialValidationError("Invalid credentials")
  39. def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
  40. """
  41. return tool with given name
  42. """
  43. tool_entity = next(
  44. (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
  45. )
  46. if not tool_entity:
  47. raise ValueError(f"Tool with name {tool_name} not found")
  48. return PluginTool(
  49. entity=tool_entity,
  50. runtime=ToolRuntime(tenant_id=self.tenant_id),
  51. tenant_id=self.tenant_id,
  52. icon=self.entity.identity.icon,
  53. plugin_unique_identifier=self.plugin_unique_identifier,
  54. )
  55. def get_tools(self) -> list[PluginTool]: # type: ignore
  56. """
  57. get all tools
  58. """
  59. return [
  60. PluginTool(
  61. entity=tool_entity,
  62. runtime=ToolRuntime(tenant_id=self.tenant_id),
  63. tenant_id=self.tenant_id,
  64. icon=self.entity.identity.icon,
  65. plugin_unique_identifier=self.plugin_unique_identifier,
  66. )
  67. for tool_entity in self.entity.tools
  68. ]