| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 | import loggingimport timeimport uuidfrom collections.abc import Generator, Mapping, Sequencefrom typing import Any, Optional, castfrom configs import dify_configfrom core.app.app_config.entities import FileExtraConfigfrom core.app.apps.base_app_queue_manager import GenerateTaskStoppedErrorfrom core.app.entities.app_invoke_entities import InvokeFromfrom core.file.models import File, FileTransferMethod, FileType, ImageConfigfrom core.workflow.callbacks import WorkflowCallbackfrom core.workflow.entities.variable_pool import VariablePoolfrom core.workflow.errors import WorkflowNodeRunFailedErrorfrom core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEventfrom core.workflow.graph_engine.entities.graph import Graphfrom core.workflow.graph_engine.entities.graph_init_params import GraphInitParamsfrom core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeStatefrom core.workflow.graph_engine.graph_engine import GraphEnginefrom core.workflow.nodes import NodeTypefrom core.workflow.nodes.base import BaseNode, BaseNodeDatafrom core.workflow.nodes.event import NodeEventfrom core.workflow.nodes.llm import LLMNodeDatafrom core.workflow.nodes.node_mapping import node_type_classes_mappingfrom models.enums import UserFromfrom models.workflow import (    Workflow,    WorkflowType,)logger = logging.getLogger(__name__)class WorkflowEntry:    def __init__(        self,        tenant_id: str,        app_id: str,        workflow_id: str,        workflow_type: WorkflowType,        graph_config: Mapping[str, Any],        graph: Graph,        user_id: str,        user_from: UserFrom,        invoke_from: InvokeFrom,        call_depth: int,        variable_pool: VariablePool,        thread_pool_id: Optional[str] = None,    ) -> None:        """        Init workflow entry        :param tenant_id: tenant id        :param app_id: app id        :param workflow_id: workflow id        :param workflow_type: workflow type        :param graph_config: workflow graph config        :param graph: workflow graph        :param user_id: user id        :param user_from: user from        :param invoke_from: invoke from        :param call_depth: call depth        :param variable_pool: variable pool        :param thread_pool_id: thread pool id        """        # check call depth        workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH        if call_depth > workflow_call_max_depth:            raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))        # init workflow run state        self.graph_engine = GraphEngine(            tenant_id=tenant_id,            app_id=app_id,            workflow_type=workflow_type,            workflow_id=workflow_id,            user_id=user_id,            user_from=user_from,            invoke_from=invoke_from,            call_depth=call_depth,            graph=graph,            graph_config=graph_config,            variable_pool=variable_pool,            max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,            max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,            thread_pool_id=thread_pool_id,        )    def run(        self,        *,        callbacks: Sequence[WorkflowCallback],    ) -> Generator[GraphEngineEvent, None, None]:        """        :param callbacks: workflow callbacks        """        graph_engine = self.graph_engine        try:            # run workflow            generator = graph_engine.run()            for event in generator:                if callbacks:                    for callback in callbacks:                        callback.on_event(event=event)                yield event        except GenerateTaskStoppedError:            pass        except Exception as e:            logger.exception("Unknown Error when workflow entry running")            if callbacks:                for callback in callbacks:                    callback.on_event(event=GraphRunFailedEvent(error=str(e)))            return    @classmethod    def single_step_run(        cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict    ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:        """        Single step run workflow node        :param workflow: Workflow instance        :param node_id: node id        :param user_id: user id        :param user_inputs: user inputs        :return:        """        # fetch node info from workflow graph        graph = workflow.graph_dict        if not graph:            raise ValueError("workflow graph not found")        nodes = graph.get("nodes")        if not nodes:            raise ValueError("nodes not found in workflow graph")        # fetch node config from node id        node_config = None        for node in nodes:            if node.get("id") == node_id:                node_config = node                break        if not node_config:            raise ValueError("node id not found in workflow graph")        # Get node class        node_type = NodeType(node_config.get("data", {}).get("type"))        node_cls = node_type_classes_mapping.get(node_type)        node_cls = cast(type[BaseNode], node_cls)        if not node_cls:            raise ValueError(f"Node class not found for node type {node_type}")        # init variable pool        variable_pool = VariablePool(            system_variables={},            user_inputs={},            environment_variables=workflow.environment_variables,        )        # init graph        graph = Graph.init(graph_config=workflow.graph_dict)        # init workflow run state        node_instance = node_cls(            id=str(uuid.uuid4()),            config=node_config,            graph_init_params=GraphInitParams(                tenant_id=workflow.tenant_id,                app_id=workflow.app_id,                workflow_type=WorkflowType.value_of(workflow.type),                workflow_id=workflow.id,                graph_config=workflow.graph_dict,                user_id=user_id,                user_from=UserFrom.ACCOUNT,                invoke_from=InvokeFrom.DEBUGGER,                call_depth=0,            ),            graph=graph,            graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),        )        try:            # variable selector to variable mapping            try:                variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(                    graph_config=workflow.graph_dict, config=node_config                )            except NotImplementedError:                variable_mapping = {}            cls.mapping_user_inputs_to_variable_pool(                variable_mapping=variable_mapping,                user_inputs=user_inputs,                variable_pool=variable_pool,                tenant_id=workflow.tenant_id,                node_type=node_type,                node_data=node_instance.node_data,            )            # run node            generator = node_instance.run()            return node_instance, generator        except Exception as e:            raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))    @staticmethod    def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:        return WorkflowEntry._handle_special_values(value)    @staticmethod    def _handle_special_values(value: Any) -> Any:        if value is None:            return value        if isinstance(value, dict):            res = {}            for k, v in value.items():                res[k] = WorkflowEntry._handle_special_values(v)            return res        if isinstance(value, list):            res = []            for item in value:                res.append(WorkflowEntry._handle_special_values(item))            return res        if isinstance(value, File):            return value.to_dict()        return value    @classmethod    def mapping_user_inputs_to_variable_pool(        cls,        variable_mapping: Mapping[str, Sequence[str]],        user_inputs: dict,        variable_pool: VariablePool,        tenant_id: str,        node_type: NodeType,        node_data: BaseNodeData,    ) -> None:        for node_variable, variable_selector in variable_mapping.items():            # fetch node id and variable key from node_variable            node_variable_list = node_variable.split(".")            if len(node_variable_list) < 1:                raise ValueError(f"Invalid node variable {node_variable}")            node_variable_key = ".".join(node_variable_list[1:])            if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(                variable_selector            ):                raise ValueError(f"Variable key {node_variable} not found in user inputs.")            # fetch variable node id from variable selector            variable_node_id = variable_selector[0]            variable_key_list = variable_selector[1:]            variable_key_list = cast(list[str], variable_key_list)            # get input value            input_value = user_inputs.get(node_variable)            if not input_value:                input_value = user_inputs.get(node_variable_key)            # FIXME: temp fix for image type            if node_type == NodeType.LLM:                new_value = []                if isinstance(input_value, list):                    node_data = cast(LLMNodeData, node_data)                    detail = node_data.vision.configs.detail if node_data.vision.configs else None                    for item in input_value:                        if isinstance(item, dict) and "type" in item and item["type"] == "image":                            transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))                            file = File(                                tenant_id=tenant_id,                                type=FileType.IMAGE,                                transfer_method=transfer_method,                                remote_url=item.get("url")                                if transfer_method == FileTransferMethod.REMOTE_URL                                else None,                                related_id=item.get("upload_file_id")                                if transfer_method == FileTransferMethod.LOCAL_FILE                                else None,                                _extra_config=FileExtraConfig(                                    image_config=ImageConfig(detail=detail) if detail else None                                ),                            )                            new_value.append(file)                if new_value:                    input_value = new_value            # append variable and value to variable pool            variable_pool.add([variable_node_id] + variable_key_list, input_value)
 |