tool.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from collections.abc import Generator
  2. from typing import Any, Optional
  3. from core.plugin.manager.tool import PluginToolManager
  4. from core.tools.__base.tool import Tool
  5. from core.tools.__base.tool_runtime import ToolRuntime
  6. from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
  7. class PluginTool(Tool):
  8. tenant_id: str
  9. runtime_parameters: Optional[list[ToolParameter]]
  10. def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None:
  11. super().__init__(entity, runtime)
  12. self.tenant_id = tenant_id
  13. self.runtime_parameters = None
  14. @property
  15. def tool_provider_type(self) -> ToolProviderType:
  16. return ToolProviderType.PLUGIN
  17. def _invoke(
  18. self,
  19. user_id: str,
  20. tool_parameters: dict[str, Any],
  21. conversation_id: Optional[str] = None,
  22. app_id: Optional[str] = None,
  23. message_id: Optional[str] = None,
  24. ) -> Generator[ToolInvokeMessage, None, None]:
  25. manager = PluginToolManager()
  26. return manager.invoke(
  27. tenant_id=self.tenant_id,
  28. user_id=user_id,
  29. tool_provider=self.entity.identity.provider,
  30. tool_name=self.entity.identity.name,
  31. credentials=self.runtime.credentials,
  32. tool_parameters=tool_parameters,
  33. )
  34. def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
  35. return PluginTool(
  36. entity=self.entity,
  37. runtime=runtime,
  38. tenant_id=self.tenant_id,
  39. )
  40. def get_runtime_parameters(self) -> list[ToolParameter]:
  41. """
  42. get the runtime parameters
  43. """
  44. if not self.entity.has_runtime_parameters:
  45. return self.entity.parameters
  46. if self.runtime_parameters is not None:
  47. return self.runtime_parameters
  48. manager = PluginToolManager()
  49. self.runtime_parameters = manager.get_runtime_parameters(
  50. tenant_id=self.tenant_id,
  51. user_id="",
  52. provider=self.entity.identity.provider,
  53. tool=self.entity.identity.name,
  54. credentials=self.runtime.credentials,
  55. )
  56. return self.runtime_parameters