dataset_retriever_tool.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from collections.abc import Generator
  2. from typing import Any, Optional
  3. from core.app.app_config.entities import DatasetRetrieveConfigEntity
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  6. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  7. from core.tools.__base.tool import Tool
  8. from core.tools.__base.tool_runtime import ToolRuntime
  9. from core.tools.entities.common_entities import I18nObject
  10. from core.tools.entities.tool_entities import (
  11. ToolDescription,
  12. ToolEntity,
  13. ToolIdentity,
  14. ToolInvokeMessage,
  15. ToolParameter,
  16. ToolProviderType,
  17. )
  18. from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  19. class DatasetRetrieverTool(Tool):
  20. retrieval_tool: DatasetRetrieverBaseTool
  21. def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
  22. super().__init__(entity, runtime)
  23. self.retrieval_tool = retrieval_tool
  24. @staticmethod
  25. def get_dataset_tools(
  26. tenant_id: str,
  27. dataset_ids: list[str],
  28. retrieve_config: DatasetRetrieveConfigEntity | None,
  29. return_resource: bool,
  30. invoke_from: InvokeFrom,
  31. hit_callback: DatasetIndexToolCallbackHandler,
  32. ) -> list["DatasetRetrieverTool"]:
  33. """
  34. get dataset tool
  35. """
  36. # check if retrieve_config is valid
  37. if dataset_ids is None or len(dataset_ids) == 0:
  38. return []
  39. if retrieve_config is None:
  40. return []
  41. feature = DatasetRetrieval()
  42. # save original retrieve strategy, and set retrieve strategy to SINGLE
  43. # Agent only support SINGLE mode
  44. original_retriever_mode = retrieve_config.retrieve_strategy
  45. retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
  46. retrieval_tools = feature.to_dataset_retriever_tool(
  47. tenant_id=tenant_id,
  48. dataset_ids=dataset_ids,
  49. retrieve_config=retrieve_config,
  50. return_resource=return_resource,
  51. invoke_from=invoke_from,
  52. hit_callback=hit_callback,
  53. )
  54. if retrieval_tools is None or len(retrieval_tools) == 0:
  55. return []
  56. # restore retrieve strategy
  57. retrieve_config.retrieve_strategy = original_retriever_mode
  58. # convert retrieval tools to Tools
  59. tools = []
  60. for retrieval_tool in retrieval_tools:
  61. tool = DatasetRetrieverTool(
  62. retrieval_tool=retrieval_tool,
  63. entity=ToolEntity(
  64. identity=ToolIdentity(
  65. provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
  66. ),
  67. parameters=[],
  68. description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
  69. ),
  70. runtime=ToolRuntime(tenant_id=tenant_id),
  71. )
  72. tools.append(tool)
  73. return tools
  74. def get_runtime_parameters(
  75. self,
  76. conversation_id: Optional[str] = None,
  77. app_id: Optional[str] = None,
  78. message_id: Optional[str] = None,
  79. ) -> list[ToolParameter]:
  80. return [
  81. ToolParameter(
  82. name="query",
  83. label=I18nObject(en_US="", zh_Hans=""),
  84. human_description=I18nObject(en_US="", zh_Hans=""),
  85. type=ToolParameter.ToolParameterType.STRING,
  86. form=ToolParameter.ToolParameterForm.LLM,
  87. llm_description="Query for the dataset to be used to retrieve the dataset.",
  88. required=True,
  89. default="",
  90. placeholder=I18nObject(en_US="", zh_Hans=""),
  91. ),
  92. ]
  93. def tool_provider_type(self) -> ToolProviderType:
  94. return ToolProviderType.DATASET_RETRIEVAL
  95. def _invoke(
  96. self,
  97. user_id: str,
  98. tool_parameters: dict[str, Any],
  99. conversation_id: Optional[str] = None,
  100. app_id: Optional[str] = None,
  101. message_id: Optional[str] = None,
  102. ) -> Generator[ToolInvokeMessage, None, None]:
  103. """
  104. invoke dataset retriever tool
  105. """
  106. query = tool_parameters.get("query")
  107. if not query:
  108. yield self.create_text_message(text="please input query")
  109. else:
  110. # invoke dataset retriever tool
  111. result = self.retrieval_tool._run(query=query)
  112. yield self.create_text_message(text=result)
  113. def validate_credentials(
  114. self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
  115. ) -> str | None:
  116. """
  117. validate the credentials for dataset retriever tool
  118. """
  119. pass