workflow_tool.py 6.9 KB


  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.file.file_obj import FileTransferMethod, FileVar
  7. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
  8. from core.tools.tool.tool import Tool
  9. from extensions.ext_database import db
  10. from models.account import Account
  11. from models.model import App, EndUser
  12. from models.workflow import Workflow
  13. logger = logging.getLogger(__name__)
  14. class WorkflowTool(Tool):
  15. workflow_app_id: str
  16. version: str
  17. workflow_entities: dict[str, Any]
  18. workflow_call_depth: int
  19. thread_pool_id: Optional[str] = None
  20. label: str
  21. """
  22. Workflow tool.
  23. """
  24. def tool_provider_type(self) -> ToolProviderType:
  25. """
  26. get the tool provider type
  27. :return: the tool provider type
  28. """
  29. return ToolProviderType.WORKFLOW
  30. def _invoke(
  31. self, user_id: str, tool_parameters: dict[str, Any]
  32. ) -> Generator[ToolInvokeMessage, None, None]:
  33. """
  34. invoke the tool
  35. """
  36. app = self._get_app(app_id=self.workflow_app_id)
  37. workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
  38. # transform the tool parameters
  39. tool_parameters, files = self._transform_args(tool_parameters)
  40. from core.app.apps.workflow.app_generator import WorkflowAppGenerator
  41. generator = WorkflowAppGenerator()
  42. assert self.runtime and self.runtime.invoke_from
  43. result = generator.generate(
  44. app_model=app,
  45. workflow=workflow,
  46. user=self._get_user(user_id),
  47. args={
  48. 'inputs': tool_parameters,
  49. 'files': files
  50. },
  51. invoke_from=self.runtime.invoke_from,
  52. stream=False,
  53. call_depth=self.workflow_call_depth + 1,
  54. workflow_thread_pool_id=self.thread_pool_id
  55. )
  56. data = result.get('data', {})
  57. if data.get('error'):
  58. raise Exception(data.get('error'))
  59. outputs = data.get('outputs', {})
  60. outputs, files = self._extract_files(outputs)
  61. for file in files:
  62. yield self.create_file_var_message(file)
  63. yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
  64. yield self.create_json_message(outputs)
  65. def _get_user(self, user_id: str) -> Union[EndUser, Account]:
  66. """
  67. get the user by user id
  68. """
  69. user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
  70. if not user:
  71. user = db.session.query(Account).filter(Account.id == user_id).first()
  72. if not user:
  73. raise ValueError('user not found')
  74. return user
  75. def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
  76. """
  77. fork a new tool with meta data
  78. :param meta: the meta data of a tool call processing, tenant_id is required
  79. :return: the new tool
  80. """
  81. return self.__class__(
  82. identity=deepcopy(self.identity),
  83. parameters=deepcopy(self.parameters),
  84. description=deepcopy(self.description),
  85. runtime=Tool.Runtime(**runtime),
  86. workflow_app_id=self.workflow_app_id,
  87. workflow_entities=self.workflow_entities,
  88. workflow_call_depth=self.workflow_call_depth,
  89. version=self.version,
  90. label=self.label
  91. )
  92. def _get_workflow(self, app_id: str, version: str) -> Workflow:
  93. """
  94. get the workflow by app id and version
  95. """
  96. if not version:
  97. workflow = db.session.query(Workflow).filter(
  98. Workflow.app_id == app_id,
  99. Workflow.version != 'draft'
  100. ).order_by(Workflow.created_at.desc()).first()
  101. else:
  102. workflow = db.session.query(Workflow).filter(
  103. Workflow.app_id == app_id,
  104. Workflow.version == version
  105. ).first()
  106. if not workflow:
  107. raise ValueError('workflow not found or not published')
  108. return workflow
  109. def _get_app(self, app_id: str) -> App:
  110. """
  111. get the app by app id
  112. """
  113. app = db.session.query(App).filter(App.id == app_id).first()
  114. if not app:
  115. raise ValueError('app not found')
  116. return app
  117. def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
  118. """
  119. transform the tool parameters
  120. :param tool_parameters: the tool parameters
  121. :return: tool_parameters, files
  122. """
  123. parameter_rules = self.get_all_runtime_parameters()
  124. parameters_result = {}
  125. files = []
  126. for parameter in parameter_rules:
  127. if parameter.type == ToolParameter.ToolParameterType.FILE:
  128. file = tool_parameters.get(parameter.name)
  129. if file:
  130. try:
  131. file_var_list = [FileVar(**f) for f in file]
  132. for file_var in file_var_list:
  133. file_dict: dict[str, Any] = {
  134. 'transfer_method': file_var.transfer_method.value,
  135. 'type': file_var.type.value,
  136. }
  137. if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
  138. file_dict['tool_file_id'] = file_var.related_id
  139. elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
  140. file_dict['upload_file_id'] = file_var.related_id
  141. elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
  142. file_dict['url'] = file_var.preview_url
  143. files.append(file_dict)
  144. except Exception as e:
  145. logger.exception(e)
  146. else:
  147. parameters_result[parameter.name] = tool_parameters.get(parameter.name)
  148. return parameters_result, files
  149. def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
  150. """
  151. extract files from the result
  152. :param result: the result
  153. :return: the result, files
  154. """
  155. files = []
  156. result = {}
  157. for key, value in outputs.items():
  158. if isinstance(value, list):
  159. has_file = False
  160. for item in value:
  161. if isinstance(item, dict) and item.get('__variant') == 'FileVar':
  162. try:
  163. files.append(FileVar(**item))
  164. has_file = True
  165. except Exception as e:
  166. pass
  167. if has_file:
  168. continue
  169. result[key] = value
  170. return result, files