dataset_retriever_tool.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from collections.abc import Generator
  2. from typing import Any
  3. from core.app.app_config.entities import DatasetRetrieveConfigEntity
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  6. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  7. from core.tools.entities.common_entities import I18nObject
  8. from core.tools.entities.tool_entities import (
  9. ToolDescription,
  10. ToolIdentity,
  11. ToolInvokeMessage,
  12. ToolParameter,
  13. ToolProviderType,
  14. )
  15. from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  16. from core.tools.tool.tool import Tool
  17. class DatasetRetrieverTool(Tool):
  18. retrieval_tool: DatasetRetrieverBaseTool
  19. @staticmethod
  20. def get_dataset_tools(
  21. tenant_id: str,
  22. dataset_ids: list[str],
  23. retrieve_config: DatasetRetrieveConfigEntity,
  24. return_resource: bool,
  25. invoke_from: InvokeFrom,
  26. hit_callback: DatasetIndexToolCallbackHandler,
  27. ) -> list["DatasetRetrieverTool"]:
  28. """
  29. get dataset tool
  30. """
  31. # check if retrieve_config is valid
  32. if dataset_ids is None or len(dataset_ids) == 0:
  33. return []
  34. if retrieve_config is None:
  35. return []
  36. feature = DatasetRetrieval()
  37. # save original retrieve strategy, and set retrieve strategy to SINGLE
  38. # Agent only support SINGLE mode
  39. original_retriever_mode = retrieve_config.retrieve_strategy
  40. retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
  41. retrieval_tools = feature.to_dataset_retriever_tool(
  42. tenant_id=tenant_id,
  43. dataset_ids=dataset_ids,
  44. retrieve_config=retrieve_config,
  45. return_resource=return_resource,
  46. invoke_from=invoke_from,
  47. hit_callback=hit_callback,
  48. )
  49. if retrieval_tools is None or len(retrieval_tools) == 0:
  50. return []
  51. # restore retrieve strategy
  52. retrieve_config.retrieve_strategy = original_retriever_mode
  53. # convert retrieval tools to Tools
  54. tools = []
  55. for retrieval_tool in retrieval_tools:
  56. tool = DatasetRetrieverTool(
  57. retrieval_tool=retrieval_tool,
  58. identity=ToolIdentity(
  59. provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
  60. ),
  61. parameters=[],
  62. is_team_authorization=True,
  63. description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
  64. runtime=DatasetRetrieverTool.Runtime(),
  65. )
  66. tools.append(tool)
  67. return tools
  68. def get_runtime_parameters(self) -> list[ToolParameter]:
  69. return [
  70. ToolParameter(
  71. name="query",
  72. label=I18nObject(en_US="", zh_Hans=""),
  73. human_description=I18nObject(en_US="", zh_Hans=""),
  74. type=ToolParameter.ToolParameterType.STRING,
  75. form=ToolParameter.ToolParameterForm.LLM,
  76. llm_description="Query for the dataset to be used to retrieve the dataset.",
  77. required=True,
  78. default="",
  79. ),
  80. ]
  81. def tool_provider_type(self) -> ToolProviderType:
  82. return ToolProviderType.DATASET_RETRIEVAL
  83. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
  84. """
  85. invoke dataset retriever tool
  86. """
  87. query = tool_parameters.get("query")
  88. if not query:
  89. yield self.create_text_message(text='please input query')
  90. else:
  91. # invoke dataset retriever tool
  92. result = self.retrieval_tool._run(query=query)
  93. yield self.create_text_message(text=result)
  94. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  95. """
  96. validate the credentials for dataset retriever tool
  97. """
  98. pass