tool.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from collections.abc import Generator
  2. from typing import Any, Optional
  3. from core.plugin.manager.tool import PluginToolManager
  4. from core.tools.__base.tool import Tool
  5. from core.tools.__base.tool_runtime import ToolRuntime
  6. from core.tools.entities.file_entities import PluginFileEntity
  7. from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
  8. from models.model import File
  9. class PluginTool(Tool):
  10. tenant_id: str
  11. icon: str
  12. plugin_unique_identifier: str
  13. runtime_parameters: Optional[list[ToolParameter]]
  14. def __init__(
  15. self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
  16. ) -> None:
  17. super().__init__(entity, runtime)
  18. self.tenant_id = tenant_id
  19. self.icon = icon
  20. self.plugin_unique_identifier = plugin_unique_identifier
  21. self.runtime_parameters = None
  22. def tool_provider_type(self) -> ToolProviderType:
  23. return ToolProviderType.PLUGIN
  24. @classmethod
  25. def _transform_image_parameters(cls, parameters: dict[str, Any]) -> dict[str, Any]:
  26. for parameter_name, parameter in parameters.items():
  27. if isinstance(parameter, File):
  28. url = parameter.generate_url()
  29. if url is None:
  30. raise ValueError(f"File {parameter.id} does not have a valid URL")
  31. parameters[parameter_name] = PluginFileEntity(
  32. url=url,
  33. mime_type=parameter.mime_type,
  34. type=parameter.type,
  35. filename=parameter.filename,
  36. extension=parameter.extension,
  37. size=parameter.size,
  38. ).model_dump()
  39. elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter):
  40. parameters[parameter_name] = []
  41. for p in parameter:
  42. assert isinstance(p, File)
  43. url = p.generate_url()
  44. if url is None:
  45. raise ValueError(f"File {p.id} does not have a valid URL")
  46. parameters[parameter_name].append(
  47. PluginFileEntity(
  48. url=url,
  49. mime_type=p.mime_type,
  50. type=p.type,
  51. filename=p.filename,
  52. extension=p.extension,
  53. size=p.size,
  54. ).model_dump()
  55. )
  56. return parameters
  57. def _invoke(
  58. self,
  59. user_id: str,
  60. tool_parameters: dict[str, Any],
  61. conversation_id: Optional[str] = None,
  62. app_id: Optional[str] = None,
  63. message_id: Optional[str] = None,
  64. ) -> Generator[ToolInvokeMessage, None, None]:
  65. manager = PluginToolManager()
  66. # convert tool parameters with File type to PluginFileEntity
  67. tool_parameters = self._transform_image_parameters(tool_parameters)
  68. yield from manager.invoke(
  69. tenant_id=self.tenant_id,
  70. user_id=user_id,
  71. tool_provider=self.entity.identity.provider,
  72. tool_name=self.entity.identity.name,
  73. credentials=self.runtime.credentials,
  74. tool_parameters=tool_parameters,
  75. conversation_id=conversation_id,
  76. app_id=app_id,
  77. message_id=message_id,
  78. )
  79. def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
  80. return PluginTool(
  81. entity=self.entity,
  82. runtime=runtime,
  83. tenant_id=self.tenant_id,
  84. icon=self.icon,
  85. plugin_unique_identifier=self.plugin_unique_identifier,
  86. )
  87. def get_runtime_parameters(self) -> list[ToolParameter]:
  88. """
  89. get the runtime parameters
  90. """
  91. if not self.entity.has_runtime_parameters:
  92. return self.entity.parameters
  93. if self.runtime_parameters is not None:
  94. return self.runtime_parameters
  95. manager = PluginToolManager()
  96. self.runtime_parameters = manager.get_runtime_parameters(
  97. tenant_id=self.tenant_id,
  98. user_id="",
  99. provider=self.entity.identity.provider,
  100. tool=self.entity.identity.name,
  101. credentials=self.runtime.credentials,
  102. )
  103. return self.runtime_parameters