workflow_tool.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import json
  2. import logging
  3. from copy import deepcopy
  4. from typing import Any, Optional, Union, cast
  5. from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
  6. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
  7. from core.tools.tool.tool import Tool
  8. from extensions.ext_database import db
  9. from factories.file_factory import build_from_mapping
  10. from models.account import Account
  11. from models.model import App, EndUser
  12. from models.workflow import Workflow
  13. logger = logging.getLogger(__name__)
  14. class WorkflowTool(Tool):
  15. workflow_app_id: str
  16. version: str
  17. workflow_entities: dict[str, Any]
  18. workflow_call_depth: int
  19. thread_pool_id: Optional[str] = None
  20. label: str
  21. """
  22. Workflow tool.
  23. """
  24. def tool_provider_type(self) -> ToolProviderType:
  25. """
  26. get the tool provider type
  27. :return: the tool provider type
  28. """
  29. return ToolProviderType.WORKFLOW
  30. def _invoke(
  31. self, user_id: str, tool_parameters: dict[str, Any]
  32. ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
  33. """
  34. invoke the tool
  35. """
  36. app = self._get_app(app_id=self.workflow_app_id)
  37. workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
  38. # transform the tool parameters
  39. tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
  40. from core.app.apps.workflow.app_generator import WorkflowAppGenerator
  41. generator = WorkflowAppGenerator()
  42. assert self.runtime is not None
  43. assert self.runtime.invoke_from is not None
  44. result = generator.generate(
  45. app_model=app,
  46. workflow=workflow,
  47. user=self._get_user(user_id),
  48. args={"inputs": tool_parameters, "files": files},
  49. invoke_from=self.runtime.invoke_from,
  50. streaming=False,
  51. call_depth=self.workflow_call_depth + 1,
  52. workflow_thread_pool_id=self.thread_pool_id,
  53. )
  54. assert isinstance(result, dict)
  55. data = result.get("data", {})
  56. if data.get("error"):
  57. raise Exception(data.get("error"))
  58. r = []
  59. outputs = data.get("outputs")
  60. if outputs == None:
  61. outputs = {}
  62. else:
  63. outputs, extracted_files = self._extract_files(outputs)
  64. for f in extracted_files:
  65. r.append(self.create_file_message(f))
  66. r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
  67. r.append(self.create_json_message(outputs))
  68. return r
  69. def _get_user(self, user_id: str) -> Union[EndUser, Account]:
  70. """
  71. get the user by user id
  72. """
  73. user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
  74. if not user:
  75. user = db.session.query(Account).filter(Account.id == user_id).first()
  76. if not user:
  77. raise ValueError("user not found")
  78. return user
  79. def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool":
  80. """
  81. fork a new tool with meta data
  82. :param meta: the meta data of a tool call processing, tenant_id is required
  83. :return: the new tool
  84. """
  85. return self.__class__(
  86. identity=deepcopy(self.identity),
  87. parameters=deepcopy(self.parameters),
  88. description=deepcopy(self.description),
  89. runtime=Tool.Runtime(**runtime),
  90. workflow_app_id=self.workflow_app_id,
  91. workflow_entities=self.workflow_entities,
  92. workflow_call_depth=self.workflow_call_depth,
  93. version=self.version,
  94. label=self.label,
  95. )
  96. def _get_workflow(self, app_id: str, version: str) -> Workflow:
  97. """
  98. get the workflow by app id and version
  99. """
  100. if not version:
  101. workflow = (
  102. db.session.query(Workflow)
  103. .filter(Workflow.app_id == app_id, Workflow.version != "draft")
  104. .order_by(Workflow.created_at.desc())
  105. .first()
  106. )
  107. else:
  108. workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
  109. if not workflow:
  110. raise ValueError("workflow not found or not published")
  111. return workflow
  112. def _get_app(self, app_id: str) -> App:
  113. """
  114. get the app by app id
  115. """
  116. app = db.session.query(App).filter(App.id == app_id).first()
  117. if not app:
  118. raise ValueError("app not found")
  119. return app
  120. def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
  121. """
  122. transform the tool parameters
  123. :param tool_parameters: the tool parameters
  124. :return: tool_parameters, files
  125. """
  126. parameter_rules = self.get_all_runtime_parameters()
  127. parameters_result = {}
  128. files = []
  129. for parameter in parameter_rules:
  130. if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
  131. file = tool_parameters.get(parameter.name)
  132. if file:
  133. try:
  134. file_var_list = [File.model_validate(f) for f in file]
  135. for file in file_var_list:
  136. file_dict: dict[str, str | None] = {
  137. "transfer_method": file.transfer_method.value,
  138. "type": file.type.value,
  139. }
  140. if file.transfer_method == FileTransferMethod.TOOL_FILE:
  141. file_dict["tool_file_id"] = file.related_id
  142. elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
  143. file_dict["upload_file_id"] = file.related_id
  144. elif file.transfer_method == FileTransferMethod.REMOTE_URL:
  145. file_dict["url"] = file.generate_url()
  146. files.append(file_dict)
  147. except Exception as e:
  148. logger.exception(f"Failed to transform file {file}")
  149. else:
  150. parameters_result[parameter.name] = tool_parameters.get(parameter.name)
  151. return parameters_result, files
  152. def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
  153. """
  154. extract files from the result
  155. :param result: the result
  156. :return: the result, files
  157. """
  158. files = []
  159. result = {}
  160. for key, value in outputs.items():
  161. if isinstance(value, list):
  162. for item in value:
  163. if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
  164. item = self._update_file_mapping(item)
  165. file = build_from_mapping(
  166. mapping=item,
  167. tenant_id=str(cast(Tool.Runtime, self.runtime).tenant_id),
  168. )
  169. files.append(file)
  170. elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
  171. value = self._update_file_mapping(value)
  172. file = build_from_mapping(
  173. mapping=value,
  174. tenant_id=str(cast(Tool.Runtime, self.runtime).tenant_id),
  175. )
  176. files.append(file)
  177. result[key] = value
  178. return result, files
  179. def _update_file_mapping(self, file_dict: dict) -> dict:
  180. transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
  181. if transfer_method == FileTransferMethod.TOOL_FILE:
  182. file_dict["tool_file_id"] = file_dict.get("related_id")
  183. elif transfer_method == FileTransferMethod.LOCAL_FILE:
  184. file_dict["upload_file_id"] = file_dict.get("related_id")
  185. return file_dict