utils.ts 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getIncomers,
  5. getOutgoers,
  6. } from 'reactflow'
  7. import dagre from '@dagrejs/dagre'
  8. import { v4 as uuid4 } from 'uuid'
  9. import {
  10. cloneDeep,
  11. groupBy,
  12. isEqual,
  13. uniqBy,
  14. } from 'lodash-es'
  15. import type {
  16. Edge,
  17. InputVar,
  18. Node,
  19. ToolWithProvider,
  20. ValueSelector,
  21. } from './types'
  22. import {
  23. BlockEnum,
  24. ErrorHandleMode,
  25. NodeRunningStatus,
  26. } from './types'
  27. import {
  28. CUSTOM_NODE,
  29. DEFAULT_RETRY_INTERVAL,
  30. DEFAULT_RETRY_MAX,
  31. ITERATION_CHILDREN_Z_INDEX,
  32. ITERATION_NODE_Z_INDEX,
  33. LOOP_CHILDREN_Z_INDEX,
  34. LOOP_NODE_Z_INDEX,
  35. NODE_WIDTH_X_OFFSET,
  36. START_INITIAL_POSITION,
  37. } from './constants'
  38. import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants'
  39. import { CUSTOM_LOOP_START_NODE } from './nodes/loop-start/constants'
  40. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  41. import type { IfElseNodeType } from './nodes/if-else/types'
  42. import { branchNameCorrect } from './nodes/if-else/utils'
  43. import type { ToolNodeType } from './nodes/tool/types'
  44. import type { IterationNodeType } from './nodes/iteration/types'
  45. import type { LoopNodeType } from './nodes/loop/types'
  46. import { CollectionType } from '@/app/components/tools/types'
  47. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  48. import { canFindTool, correctModelProvider } from '@/utils'
  49. const WHITE = 'WHITE'
  50. const GRAY = 'GRAY'
  51. const BLACK = 'BLACK'
  52. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjList: Record<string, string[]>, stack: string[]) => {
  53. color[nodeId] = GRAY
  54. stack.push(nodeId)
  55. for (let i = 0; i < adjList[nodeId].length; ++i) {
  56. const childId = adjList[nodeId][i]
  57. if (color[childId] === GRAY) {
  58. stack.push(childId)
  59. return true
  60. }
  61. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjList, stack))
  62. return true
  63. }
  64. color[nodeId] = BLACK
  65. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  66. stack.pop()
  67. return false
  68. }
  69. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  70. const adjList: Record<string, string[]> = {}
  71. const color: Record<string, string> = {}
  72. const stack: string[] = []
  73. for (const node of nodes) {
  74. color[node.id] = WHITE
  75. adjList[node.id] = []
  76. }
  77. for (const edge of edges)
  78. adjList[edge.source]?.push(edge.target)
  79. for (let i = 0; i < nodes.length; i++) {
  80. if (color[nodes[i].id] === WHITE)
  81. isCyclicUtil(nodes[i].id, color, adjList, stack)
  82. }
  83. const cycleEdges = []
  84. if (stack.length > 0) {
  85. const cycleNodes = new Set(stack)
  86. for (const edge of edges) {
  87. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  88. cycleEdges.push(edge)
  89. }
  90. }
  91. return cycleEdges
  92. }
  93. export function getIterationStartNode(iterationId: string): Node {
  94. return generateNewNode({
  95. id: `${iterationId}start`,
  96. type: CUSTOM_ITERATION_START_NODE,
  97. data: {
  98. title: '',
  99. desc: '',
  100. type: BlockEnum.IterationStart,
  101. isInIteration: true,
  102. },
  103. position: {
  104. x: 24,
  105. y: 68,
  106. },
  107. zIndex: ITERATION_CHILDREN_Z_INDEX,
  108. parentId: iterationId,
  109. selectable: false,
  110. draggable: false,
  111. }).newNode
  112. }
  113. export function getLoopStartNode(loopId: string): Node {
  114. return generateNewNode({
  115. id: `${loopId}start`,
  116. type: CUSTOM_LOOP_START_NODE,
  117. data: {
  118. title: '',
  119. desc: '',
  120. type: BlockEnum.LoopStart,
  121. isInLoop: true,
  122. },
  123. position: {
  124. x: 24,
  125. y: 68,
  126. },
  127. zIndex: LOOP_CHILDREN_Z_INDEX,
  128. parentId: loopId,
  129. selectable: false,
  130. draggable: false,
  131. }).newNode
  132. }
  133. export function generateNewNode({ data, position, id, zIndex, type, ...rest }: Omit<Node, 'id'> & { id?: string }): {
  134. newNode: Node
  135. newIterationStartNode?: Node
  136. newLoopStartNode?: Node
  137. } {
  138. const newNode = {
  139. id: id || `${Date.now()}`,
  140. type: type || CUSTOM_NODE,
  141. data,
  142. position,
  143. targetPosition: Position.Left,
  144. sourcePosition: Position.Right,
  145. zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : (data.type === BlockEnum.Loop ? LOOP_NODE_Z_INDEX : zIndex),
  146. ...rest,
  147. } as Node
  148. if (data.type === BlockEnum.Iteration) {
  149. const newIterationStartNode = getIterationStartNode(newNode.id);
  150. (newNode.data as IterationNodeType).start_node_id = newIterationStartNode.id;
  151. (newNode.data as IterationNodeType)._children = [newIterationStartNode.id]
  152. return {
  153. newNode,
  154. newIterationStartNode,
  155. }
  156. }
  157. if (data.type === BlockEnum.Loop) {
  158. const newLoopStartNode = getLoopStartNode(newNode.id);
  159. (newNode.data as LoopNodeType).start_node_id = newLoopStartNode.id;
  160. (newNode.data as LoopNodeType)._children = [newLoopStartNode.id]
  161. return {
  162. newNode,
  163. newLoopStartNode,
  164. }
  165. }
  166. return {
  167. newNode,
  168. }
  169. }
  170. export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
  171. const hasIterationNode = nodes.some(node => node.data.type === BlockEnum.Iteration)
  172. const hasLoopNode = nodes.some(node => node.data.type === BlockEnum.Loop)
  173. if (!hasIterationNode) {
  174. return {
  175. nodes,
  176. edges,
  177. }
  178. }
  179. if (!hasLoopNode) {
  180. return {
  181. nodes,
  182. edges,
  183. }
  184. }
  185. const nodesMap = nodes.reduce((prev, next) => {
  186. prev[next.id] = next
  187. return prev
  188. }, {} as Record<string, Node>)
  189. const iterationNodesWithStartNode = []
  190. const iterationNodesWithoutStartNode = []
  191. const loopNodesWithStartNode = []
  192. const loopNodesWithoutStartNode = []
  193. for (let i = 0; i < nodes.length; i++) {
  194. const currentNode = nodes[i] as Node<IterationNodeType | LoopNodeType>
  195. if (currentNode.data.type === BlockEnum.Iteration) {
  196. if (currentNode.data.start_node_id) {
  197. if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE)
  198. iterationNodesWithStartNode.push(currentNode)
  199. }
  200. else {
  201. iterationNodesWithoutStartNode.push(currentNode)
  202. }
  203. }
  204. if (currentNode.data.type === BlockEnum.Loop) {
  205. if (currentNode.data.start_node_id) {
  206. if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_LOOP_START_NODE)
  207. loopNodesWithStartNode.push(currentNode)
  208. }
  209. else {
  210. loopNodesWithoutStartNode.push(currentNode)
  211. }
  212. }
  213. }
  214. const newIterationStartNodesMap = {} as Record<string, Node>
  215. const newIterationStartNodes = [...iterationNodesWithStartNode, ...iterationNodesWithoutStartNode].map((iterationNode, index) => {
  216. const newNode = getIterationStartNode(iterationNode.id)
  217. newNode.id = newNode.id + index
  218. newIterationStartNodesMap[iterationNode.id] = newNode
  219. return newNode
  220. })
  221. const newLoopStartNodesMap = {} as Record<string, Node>
  222. const newLoopStartNodes = [...loopNodesWithStartNode, ...loopNodesWithoutStartNode].map((loopNode, index) => {
  223. const newNode = getLoopStartNode(loopNode.id)
  224. newNode.id = newNode.id + index
  225. newLoopStartNodesMap[loopNode.id] = newNode
  226. return newNode
  227. })
  228. const newEdges = [...iterationNodesWithStartNode, ...loopNodesWithStartNode].map((nodeItem) => {
  229. const isIteration = nodeItem.data.type === BlockEnum.Iteration
  230. const newNode = (isIteration ? newIterationStartNodesMap : newLoopStartNodesMap)[nodeItem.id]
  231. const startNode = nodesMap[nodeItem.data.start_node_id]
  232. const source = newNode.id
  233. const sourceHandle = 'source'
  234. const target = startNode.id
  235. const targetHandle = 'target'
  236. const parentNode = nodes.find(node => node.id === startNode.parentId) || null
  237. const isInIteration = !!parentNode && parentNode.data.type === BlockEnum.Iteration
  238. const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
  239. return {
  240. id: `${source}-${sourceHandle}-${target}-${targetHandle}`,
  241. type: 'custom',
  242. source,
  243. sourceHandle,
  244. target,
  245. targetHandle,
  246. data: {
  247. sourceType: newNode.data.type,
  248. targetType: startNode.data.type,
  249. isInIteration,
  250. iteration_id: isInIteration ? startNode.parentId : undefined,
  251. isInLoop,
  252. loop_id: isInLoop ? startNode.parentId : undefined,
  253. _connectedNodeIsSelected: true,
  254. },
  255. zIndex: isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX,
  256. }
  257. })
  258. nodes.forEach((node) => {
  259. if (node.data.type === BlockEnum.Iteration && newIterationStartNodesMap[node.id])
  260. (node.data as IterationNodeType).start_node_id = newIterationStartNodesMap[node.id].id
  261. if (node.data.type === BlockEnum.Loop && newLoopStartNodesMap[node.id])
  262. (node.data as LoopNodeType).start_node_id = newLoopStartNodesMap[node.id].id
  263. })
  264. return {
  265. nodes: [...nodes, ...newIterationStartNodes, ...newLoopStartNodes],
  266. edges: [...edges, ...newEdges],
  267. }
  268. }
  269. export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
  270. const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges))
  271. const firstNode = nodes[0]
  272. if (!firstNode?.position) {
  273. nodes.forEach((node, index) => {
  274. node.position = {
  275. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  276. y: START_INITIAL_POSITION.y,
  277. }
  278. })
  279. }
  280. const iterationOrLoopNodeMap = nodes.reduce((acc, node) => {
  281. if (node.parentId) {
  282. if (acc[node.parentId])
  283. acc[node.parentId].push(node.id)
  284. else
  285. acc[node.parentId] = [node.id]
  286. }
  287. return acc
  288. }, {} as Record<string, string[]>)
  289. return nodes.map((node) => {
  290. if (!node.type)
  291. node.type = CUSTOM_NODE
  292. const connectedEdges = getConnectedEdges([node], edges)
  293. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  294. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  295. if (node.data.type === BlockEnum.IfElse) {
  296. const nodeData = node.data as IfElseNodeType
  297. if (!nodeData.cases && nodeData.logical_operator && nodeData.conditions) {
  298. (node.data as IfElseNodeType).cases = [
  299. {
  300. case_id: 'true',
  301. logical_operator: nodeData.logical_operator,
  302. conditions: nodeData.conditions,
  303. },
  304. ]
  305. }
  306. node.data._targetBranches = branchNameCorrect([
  307. ...(node.data as IfElseNodeType).cases.map(item => ({ id: item.case_id, name: '' })),
  308. { id: 'false', name: '' },
  309. ])
  310. }
  311. if (node.data.type === BlockEnum.QuestionClassifier) {
  312. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  313. return topic
  314. })
  315. }
  316. if (node.data.type === BlockEnum.Iteration) {
  317. const iterationNodeData = node.data as IterationNodeType
  318. iterationNodeData._children = iterationOrLoopNodeMap[node.id] || []
  319. iterationNodeData.is_parallel = iterationNodeData.is_parallel || false
  320. iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10
  321. iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated
  322. }
  323. // TODO: loop error handle mode
  324. if (node.data.type === BlockEnum.Loop) {
  325. const loopNodeData = node.data as LoopNodeType
  326. loopNodeData._children = iterationOrLoopNodeMap[node.id] || []
  327. loopNodeData.error_handle_mode = loopNodeData.error_handle_mode || ErrorHandleMode.Terminated
  328. }
  329. // legacy provider handle
  330. if (node.data.type === BlockEnum.LLM)
  331. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  332. if (node.data.type === BlockEnum.KnowledgeRetrieval && (node as any).data.multiple_retrieval_config?.reranking_model)
  333. (node as any).data.multiple_retrieval_config.reranking_model.provider = correctModelProvider((node as any).data.multiple_retrieval_config?.reranking_model.provider)
  334. if (node.data.type === BlockEnum.QuestionClassifier)
  335. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  336. if (node.data.type === BlockEnum.ParameterExtractor)
  337. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  338. if (node.data.type === BlockEnum.HttpRequest && !node.data.retry_config) {
  339. node.data.retry_config = {
  340. retry_enabled: true,
  341. max_retries: DEFAULT_RETRY_MAX,
  342. retry_interval: DEFAULT_RETRY_INTERVAL,
  343. }
  344. }
  345. if (node.data.type === BlockEnum.IntentReconTrain && !node.data.retry_config) {
  346. node.data.retry_config = {
  347. retry_enabled: true,
  348. max_retries: DEFAULT_RETRY_MAX,
  349. retry_interval: DEFAULT_RETRY_INTERVAL,
  350. }
  351. }
  352. return node
  353. })
  354. }
  355. export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => {
  356. const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges))
  357. let selectedNode: Node | null = null
  358. const nodesMap = nodes.reduce((acc, node) => {
  359. acc[node.id] = node
  360. if (node.data?.selected)
  361. selectedNode = node
  362. return acc
  363. }, {} as Record<string, Node>)
  364. const cycleEdges = getCycleEdges(nodes, edges)
  365. return edges.filter((edge) => {
  366. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  367. }).map((edge) => {
  368. edge.type = 'custom'
  369. if (!edge.sourceHandle)
  370. edge.sourceHandle = 'source'
  371. if (!edge.targetHandle)
  372. edge.targetHandle = 'target'
  373. if (!edge.data?.sourceType && edge.source && nodesMap[edge.source]) {
  374. edge.data = {
  375. ...edge.data,
  376. sourceType: nodesMap[edge.source].data.type!,
  377. } as any
  378. }
  379. if (!edge.data?.targetType && edge.target && nodesMap[edge.target]) {
  380. edge.data = {
  381. ...edge.data,
  382. targetType: nodesMap[edge.target].data.type!,
  383. } as any
  384. }
  385. if (selectedNode) {
  386. edge.data = {
  387. ...edge.data,
  388. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  389. } as any
  390. }
  391. return edge
  392. })
  393. }
  394. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  395. const dagreGraph = new dagre.graphlib.Graph()
  396. dagreGraph.setDefaultEdgeLabel(() => ({}))
  397. const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE)
  398. const edges = cloneDeep(originEdges).filter(edge => (!edge.data?.isInIteration && !edge.data?.isInLoop))
  399. dagreGraph.setGraph({
  400. rankdir: 'LR',
  401. align: 'UL',
  402. nodesep: 40,
  403. ranksep: 60,
  404. ranker: 'tight-tree',
  405. marginx: 30,
  406. marginy: 200,
  407. })
  408. nodes.forEach((node) => {
  409. dagreGraph.setNode(node.id, {
  410. width: node.width!,
  411. height: node.height!,
  412. })
  413. })
  414. edges.forEach((edge) => {
  415. dagreGraph.setEdge(edge.source, edge.target)
  416. })
  417. dagre.layout(dagreGraph)
  418. return dagreGraph
  419. }
  420. export const canRunBySingle = (nodeType: BlockEnum) => {
  421. return nodeType === BlockEnum.LLM
  422. || nodeType === BlockEnum.KnowledgeRetrieval
  423. || nodeType === BlockEnum.Code
  424. || nodeType === BlockEnum.TemplateTransform
  425. || nodeType === BlockEnum.QuestionClassifier
  426. || nodeType === BlockEnum.HttpRequest
  427. || nodeType === BlockEnum.IntentReconTrain
  428. || nodeType === BlockEnum.Tool
  429. || nodeType === BlockEnum.ParameterExtractor
  430. || nodeType === BlockEnum.Iteration
  431. || nodeType === BlockEnum.Agent
  432. || nodeType === BlockEnum.DocExtractor
  433. || nodeType === BlockEnum.Loop
  434. }
  435. type ConnectedSourceOrTargetNodesChange = {
  436. type: string
  437. edge: Edge
  438. }[]
  439. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  440. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  441. changes.forEach((change) => {
  442. const {
  443. edge,
  444. type,
  445. } = change
  446. const sourceNode = nodes.find(node => node.id === edge.source)!
  447. if (sourceNode) {
  448. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  449. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  450. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  451. }
  452. }
  453. const targetNode = nodes.find(node => node.id === edge.target)!
  454. if (targetNode) {
  455. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  456. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  457. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  458. }
  459. }
  460. if (sourceNode) {
  461. if (type === 'remove') {
  462. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  463. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  464. }
  465. if (type === 'add')
  466. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  467. }
  468. if (targetNode) {
  469. if (type === 'remove') {
  470. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  471. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  472. }
  473. if (type === 'add')
  474. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  475. }
  476. })
  477. return nodesConnectedSourceOrTargetHandleIdsMap
  478. }
  479. export const genNewNodeTitleFromOld = (oldTitle: string) => {
  480. const regex = /^(.+?)\s*\((\d+)\)\s*$/
  481. const match = oldTitle.match(regex)
  482. if (match) {
  483. const title = match[1]
  484. const num = Number.parseInt(match[2], 10)
  485. return `${title} (${num + 1})`
  486. }
  487. else {
  488. return `${oldTitle} (1)`
  489. }
  490. }
  491. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  492. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  493. if (!startNode) {
  494. return {
  495. validNodes: [],
  496. maxDepth: 0,
  497. }
  498. }
  499. const list: Node[] = [startNode]
  500. let maxDepth = 1
  501. const traverse = (root: Node, depth: number) => {
  502. if (depth > maxDepth)
  503. maxDepth = depth
  504. const outgoers = getOutgoers(root, nodes, edges)
  505. if (outgoers.length) {
  506. outgoers.forEach((outgoer) => {
  507. list.push(outgoer)
  508. if (outgoer.data.type === BlockEnum.Iteration)
  509. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  510. if (outgoer.data.type === BlockEnum.Loop)
  511. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  512. traverse(outgoer, depth + 1)
  513. })
  514. }
  515. else {
  516. list.push(root)
  517. if (root.data.type === BlockEnum.Iteration)
  518. list.push(...nodes.filter(node => node.parentId === root.id))
  519. if (root.data.type === BlockEnum.Loop)
  520. list.push(...nodes.filter(node => node.parentId === root.id))
  521. }
  522. }
  523. traverse(startNode, maxDepth)
  524. return {
  525. validNodes: uniqBy(list, 'id'),
  526. maxDepth,
  527. }
  528. }
  529. export const getToolCheckParams = (
  530. toolData: ToolNodeType,
  531. buildInTools: ToolWithProvider[],
  532. customTools: ToolWithProvider[],
  533. workflowTools: ToolWithProvider[],
  534. language: string,
  535. ) => {
  536. const { provider_id, provider_type, tool_name } = toolData
  537. const isBuiltIn = provider_type === CollectionType.builtIn
  538. const currentTools = provider_type === CollectionType.builtIn ? buildInTools : provider_type === CollectionType.custom ? customTools : workflowTools
  539. const currCollection = currentTools.find(item => canFindTool(item.id, provider_id))
  540. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  541. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  542. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  543. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  544. return {
  545. toolInputsSchema: (() => {
  546. const formInputs: InputVar[] = []
  547. toolInputVarSchema.forEach((item: any) => {
  548. formInputs.push({
  549. label: item.label[language] || item.label.en_US,
  550. variable: item.variable,
  551. type: item.type,
  552. required: item.required,
  553. })
  554. })
  555. return formInputs
  556. })(),
  557. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  558. toolSettingSchema,
  559. language,
  560. }
  561. }
  562. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  563. const idMap = nodes.reduce((acc, node) => {
  564. acc[node.id] = uuid4()
  565. return acc
  566. }, {} as Record<string, string>)
  567. const newNodes = nodes.map((node) => {
  568. return {
  569. ...node,
  570. id: idMap[node.id],
  571. }
  572. })
  573. const newEdges = edges.map((edge) => {
  574. return {
  575. ...edge,
  576. source: idMap[edge.source],
  577. target: idMap[edge.target],
  578. }
  579. })
  580. return [newNodes, newEdges] as [Node[], Edge[]]
  581. }
  582. export const isMac = () => {
  583. return navigator.userAgent.toUpperCase().includes('MAC')
  584. }
  585. const specialKeysNameMap: Record<string, string | undefined> = {
  586. ctrl: '⌘',
  587. alt: '⌥',
  588. shift: '⇧',
  589. }
  590. export const getKeyboardKeyNameBySystem = (key: string) => {
  591. if (isMac())
  592. return specialKeysNameMap[key] || key
  593. return key
  594. }
  595. const specialKeysCodeMap: Record<string, string | undefined> = {
  596. ctrl: 'meta',
  597. }
  598. export const getKeyboardKeyCodeBySystem = (key: string) => {
  599. if (isMac())
  600. return specialKeysCodeMap[key] || key
  601. return key
  602. }
  603. export const getTopLeftNodePosition = (nodes: Node[]) => {
  604. let minX = Infinity
  605. let minY = Infinity
  606. nodes.forEach((node) => {
  607. if (node.position.x < minX)
  608. minX = node.position.x
  609. if (node.position.y < minY)
  610. minY = node.position.y
  611. })
  612. return {
  613. x: minX,
  614. y: minY,
  615. }
  616. }
  617. export const isEventTargetInputArea = (target: HTMLElement) => {
  618. if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA')
  619. return true
  620. if (target.contentEditable === 'true')
  621. return true
  622. }
  623. export const variableTransformer = (v: ValueSelector | string) => {
  624. if (typeof v === 'string')
  625. return v.replace(/^{{#|#}}$/g, '').split('.')
  626. return `{{#${v.join('.')}#}}`
  627. }
  628. type ParallelInfoItem = {
  629. parallelNodeId: string
  630. depth: number
  631. isBranch?: boolean
  632. }
  633. type NodeParallelInfo = {
  634. parallelNodeId: string
  635. edgeHandleId: string
  636. depth: number
  637. }
  638. type NodeHandle = {
  639. node: Node
  640. handle: string
  641. }
  642. type NodeStreamInfo = {
  643. upstreamNodes: Set<string>
  644. downstreamEdges: Set<string>
  645. }
  646. export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => {
  647. let startNode
  648. if (parentNodeId) {
  649. const parentNode = nodes.find(node => node.id === parentNodeId)
  650. if (!parentNode)
  651. throw new Error('Parent node not found')
  652. startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id)
  653. }
  654. else {
  655. startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  656. }
  657. if (!startNode)
  658. throw new Error('Start node not found')
  659. const parallelList = [] as ParallelInfoItem[]
  660. const nextNodeHandles = [{ node: startNode, handle: 'source' }]
  661. let hasAbnormalEdges = false
  662. const traverse = (firstNodeHandle: NodeHandle) => {
  663. const nodeEdgesSet = {} as Record<string, Set<string>>
  664. const totalEdgesSet = new Set<string>()
  665. const nextHandles = [firstNodeHandle]
  666. const streamInfo = {} as Record<string, NodeStreamInfo>
  667. const parallelListItem = {
  668. parallelNodeId: '',
  669. depth: 0,
  670. } as ParallelInfoItem
  671. const nodeParallelInfoMap = {} as Record<string, NodeParallelInfo>
  672. nodeParallelInfoMap[firstNodeHandle.node.id] = {
  673. parallelNodeId: '',
  674. edgeHandleId: '',
  675. depth: 0,
  676. }
  677. while (nextHandles.length) {
  678. const currentNodeHandle = nextHandles.shift()!
  679. const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle
  680. const currentNodeHandleKey = currentNode.id
  681. const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle)
  682. const connectedEdgesLength = connectedEdges.length
  683. const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id))
  684. const incomers = getIncomers(currentNode, nodes, edges)
  685. if (!streamInfo[currentNodeHandleKey]) {
  686. streamInfo[currentNodeHandleKey] = {
  687. upstreamNodes: new Set<string>(),
  688. downstreamEdges: new Set<string>(),
  689. }
  690. }
  691. if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) {
  692. const newSet = new Set<string>()
  693. for (const item of totalEdgesSet) {
  694. if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item))
  695. newSet.add(item)
  696. }
  697. if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) {
  698. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  699. nextNodeHandles.push({ node: currentNode, handle: currentHandle })
  700. break
  701. }
  702. }
  703. if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth)
  704. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  705. outgoers.forEach((outgoer) => {
  706. const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
  707. const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
  708. const incomers = getIncomers(outgoer, nodes, edges)
  709. if (outgoers.length > 1 && incomers.length > 1)
  710. hasAbnormalEdges = true
  711. Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
  712. nextHandles.push({ node: outgoer, handle: sourceHandle })
  713. })
  714. if (!outgoerConnectedEdges.length)
  715. nextHandles.push({ node: outgoer, handle: 'source' })
  716. const outgoerKey = outgoer.id
  717. if (!nodeEdgesSet[outgoerKey])
  718. nodeEdgesSet[outgoerKey] = new Set<string>()
  719. if (nodeEdgesSet[currentNodeHandleKey]) {
  720. for (const item of nodeEdgesSet[currentNodeHandleKey])
  721. nodeEdgesSet[outgoerKey].add(item)
  722. }
  723. if (!streamInfo[outgoerKey]) {
  724. streamInfo[outgoerKey] = {
  725. upstreamNodes: new Set<string>(),
  726. downstreamEdges: new Set<string>(),
  727. }
  728. }
  729. if (!nodeParallelInfoMap[outgoer.id]) {
  730. nodeParallelInfoMap[outgoer.id] = {
  731. ...nodeParallelInfoMap[currentNode.id],
  732. }
  733. }
  734. if (connectedEdgesLength > 1) {
  735. const edge = connectedEdges.find(edge => edge.target === outgoer.id)!
  736. nodeEdgesSet[outgoerKey].add(edge.id)
  737. totalEdgesSet.add(edge.id)
  738. streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id)
  739. streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey)
  740. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  741. streamInfo[item].downstreamEdges.add(edge.id)
  742. if (!parallelListItem.parallelNodeId)
  743. parallelListItem.parallelNodeId = currentNode.id
  744. const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1
  745. const currentDepth = nodeParallelInfoMap[outgoer.id].depth
  746. nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth)
  747. }
  748. else {
  749. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  750. streamInfo[outgoerKey].upstreamNodes.add(item)
  751. nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth
  752. }
  753. })
  754. }
  755. parallelList.push(parallelListItem)
  756. }
  757. while (nextNodeHandles.length) {
  758. const nodeHandle = nextNodeHandles.shift()!
  759. traverse(nodeHandle)
  760. }
  761. return {
  762. parallelList,
  763. hasAbnormalEdges,
  764. }
  765. }
  766. export const hasErrorHandleNode = (nodeType?: BlockEnum) => {
  767. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.IntentReconTrain || nodeType === BlockEnum.Code
  768. }
  769. export const getEdgeColor = (nodeRunningStatus?: NodeRunningStatus, isFailBranch?: boolean) => {
  770. if (nodeRunningStatus === NodeRunningStatus.Succeeded)
  771. return 'var(--color-workflow-link-line-success-handle)'
  772. if (nodeRunningStatus === NodeRunningStatus.Failed)
  773. return 'var(--color-workflow-link-line-error-handle)'
  774. if (nodeRunningStatus === NodeRunningStatus.Exception)
  775. return 'var(--color-workflow-link-line-failure-handle)'
  776. if (nodeRunningStatus === NodeRunningStatus.Running) {
  777. if (isFailBranch)
  778. return 'var(--color-workflow-link-line-failure-handle)'
  779. return 'var(--color-workflow-link-line-handle)'
  780. }
  781. return 'var(--color-workflow-link-line-normal)'
  782. }
  783. export const isExceptionVariable = (variable: string, nodeType?: BlockEnum) => {
  784. if ((variable === 'error_message' || variable === 'error_type') && hasErrorHandleNode(nodeType))
  785. return true
  786. return false
  787. }
  788. export const hasRetryNode = (nodeType?: BlockEnum) => {
  789. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.IntentReconTrain || nodeType === BlockEnum.Code
  790. }