utils.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getOutgoers,
  5. } from 'reactflow'
  6. import dagre from '@dagrejs/dagre'
  7. import { v4 as uuid4 } from 'uuid'
  8. import {
  9. cloneDeep,
  10. uniqBy,
  11. } from 'lodash-es'
  12. import type {
  13. Edge,
  14. InputVar,
  15. Node,
  16. ToolWithProvider,
  17. } from './types'
  18. import { BlockEnum } from './types'
  19. import {
  20. CUSTOM_NODE,
  21. ITERATION_NODE_Z_INDEX,
  22. NODE_WIDTH_X_OFFSET,
  23. START_INITIAL_POSITION,
  24. } from './constants'
  25. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  26. import type { ToolNodeType } from './nodes/tool/types'
  27. import { CollectionType } from '@/app/components/tools/types'
  28. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  29. const WHITE = 'WHITE'
  30. const GRAY = 'GRAY'
  31. const BLACK = 'BLACK'
  32. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjaList: Record<string, string[]>, stack: string[]) => {
  33. color[nodeId] = GRAY
  34. stack.push(nodeId)
  35. for (let i = 0; i < adjaList[nodeId].length; ++i) {
  36. const childId = adjaList[nodeId][i]
  37. if (color[childId] === GRAY) {
  38. stack.push(childId)
  39. return true
  40. }
  41. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjaList, stack))
  42. return true
  43. }
  44. color[nodeId] = BLACK
  45. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  46. stack.pop()
  47. return false
  48. }
  49. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  50. const adjaList: Record<string, string[]> = {}
  51. const color: Record<string, string> = {}
  52. const stack: string[] = []
  53. for (const node of nodes) {
  54. color[node.id] = WHITE
  55. adjaList[node.id] = []
  56. }
  57. for (const edge of edges)
  58. adjaList[edge.source]?.push(edge.target)
  59. for (let i = 0; i < nodes.length; i++) {
  60. if (color[nodes[i].id] === WHITE)
  61. isCyclicUtil(nodes[i].id, color, adjaList, stack)
  62. }
  63. const cycleEdges = []
  64. if (stack.length > 0) {
  65. const cycleNodes = new Set(stack)
  66. for (const edge of edges) {
  67. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  68. cycleEdges.push(edge)
  69. }
  70. }
  71. return cycleEdges
  72. }
  73. export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
  74. const nodes = cloneDeep(originNodes)
  75. const edges = cloneDeep(originEdges)
  76. const firstNode = nodes[0]
  77. if (!firstNode?.position) {
  78. nodes.forEach((node, index) => {
  79. node.position = {
  80. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  81. y: START_INITIAL_POSITION.y,
  82. }
  83. })
  84. }
  85. const iterationNodeMap = nodes.reduce((acc, node) => {
  86. if (node.parentId) {
  87. if (acc[node.parentId])
  88. acc[node.parentId].push(node.id)
  89. else
  90. acc[node.parentId] = [node.id]
  91. }
  92. return acc
  93. }, {} as Record<string, string[]>)
  94. return nodes.map((node) => {
  95. if (!node.type)
  96. node.type = CUSTOM_NODE
  97. const connectedEdges = getConnectedEdges([node], edges)
  98. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  99. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  100. if (node.data.type === BlockEnum.IfElse) {
  101. node.data._targetBranches = [
  102. {
  103. id: 'true',
  104. name: 'IS TRUE',
  105. },
  106. {
  107. id: 'false',
  108. name: 'IS FALSE',
  109. },
  110. ]
  111. }
  112. if (node.data.type === BlockEnum.QuestionClassifier) {
  113. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  114. return topic
  115. })
  116. }
  117. if (node.data.type === BlockEnum.Iteration)
  118. node.data._children = iterationNodeMap[node.id] || []
  119. return node
  120. })
  121. }
  122. export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => {
  123. const nodes = cloneDeep(originNodes)
  124. const edges = cloneDeep(originEdges)
  125. let selectedNode: Node | null = null
  126. const nodesMap = nodes.reduce((acc, node) => {
  127. acc[node.id] = node
  128. if (node.data?.selected)
  129. selectedNode = node
  130. return acc
  131. }, {} as Record<string, Node>)
  132. const cycleEdges = getCycleEdges(nodes, edges)
  133. return edges.filter((edge) => {
  134. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  135. }).map((edge) => {
  136. edge.type = 'custom'
  137. if (!edge.sourceHandle)
  138. edge.sourceHandle = 'source'
  139. if (!edge.targetHandle)
  140. edge.targetHandle = 'target'
  141. if (!edge.data?.sourceType && edge.source && nodesMap[edge.source]) {
  142. edge.data = {
  143. ...edge.data,
  144. sourceType: nodesMap[edge.source].data.type!,
  145. } as any
  146. }
  147. if (!edge.data?.targetType && edge.target && nodesMap[edge.target]) {
  148. edge.data = {
  149. ...edge.data,
  150. targetType: nodesMap[edge.target].data.type!,
  151. } as any
  152. }
  153. if (selectedNode) {
  154. edge.data = {
  155. ...edge.data,
  156. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  157. } as any
  158. }
  159. return edge
  160. })
  161. }
  162. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  163. const dagreGraph = new dagre.graphlib.Graph()
  164. dagreGraph.setDefaultEdgeLabel(() => ({}))
  165. const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE)
  166. const edges = cloneDeep(originEdges).filter(edge => !edge.data?.isInIteration)
  167. dagreGraph.setGraph({
  168. rankdir: 'LR',
  169. align: 'UL',
  170. nodesep: 40,
  171. ranksep: 60,
  172. ranker: 'tight-tree',
  173. marginx: 30,
  174. marginy: 200,
  175. })
  176. nodes.forEach((node) => {
  177. dagreGraph.setNode(node.id, {
  178. width: node.width!,
  179. height: node.height!,
  180. })
  181. })
  182. edges.forEach((edge) => {
  183. dagreGraph.setEdge(edge.source, edge.target)
  184. })
  185. dagre.layout(dagreGraph)
  186. return dagreGraph
  187. }
  188. export const canRunBySingle = (nodeType: BlockEnum) => {
  189. return nodeType === BlockEnum.LLM
  190. || nodeType === BlockEnum.KnowledgeRetrieval
  191. || nodeType === BlockEnum.Code
  192. || nodeType === BlockEnum.TemplateTransform
  193. || nodeType === BlockEnum.QuestionClassifier
  194. || nodeType === BlockEnum.HttpRequest
  195. || nodeType === BlockEnum.Tool
  196. || nodeType === BlockEnum.ParameterExtractor
  197. || nodeType === BlockEnum.Iteration
  198. }
  199. type ConnectedSourceOrTargetNodesChange = {
  200. type: string
  201. edge: Edge
  202. }[]
  203. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  204. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  205. changes.forEach((change) => {
  206. const {
  207. edge,
  208. type,
  209. } = change
  210. const sourceNode = nodes.find(node => node.id === edge.source)!
  211. if (sourceNode) {
  212. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  213. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  214. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  215. }
  216. }
  217. const targetNode = nodes.find(node => node.id === edge.target)!
  218. if (targetNode) {
  219. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  220. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  221. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  222. }
  223. }
  224. if (sourceNode) {
  225. if (type === 'remove') {
  226. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  227. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  228. }
  229. if (type === 'add')
  230. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  231. }
  232. if (targetNode) {
  233. if (type === 'remove') {
  234. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  235. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  236. }
  237. if (type === 'add')
  238. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  239. }
  240. })
  241. return nodesConnectedSourceOrTargetHandleIdsMap
  242. }
  243. export const generateNewNode = ({ data, position, id, zIndex, type, ...rest }: Omit<Node, 'id'> & { id?: string }) => {
  244. return {
  245. id: id || `${Date.now()}`,
  246. type: type || CUSTOM_NODE,
  247. data,
  248. position,
  249. targetPosition: Position.Left,
  250. sourcePosition: Position.Right,
  251. zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex,
  252. ...rest,
  253. } as Node
  254. }
  255. export const genNewNodeTitleFromOld = (oldTitle: string) => {
  256. const regex = /^(.+?)\s*\((\d+)\)\s*$/
  257. const match = oldTitle.match(regex)
  258. if (match) {
  259. const title = match[1]
  260. const num = parseInt(match[2], 10)
  261. return `${title} (${num + 1})`
  262. }
  263. else {
  264. return `${oldTitle} (1)`
  265. }
  266. }
  267. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  268. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  269. if (!startNode) {
  270. return {
  271. validNodes: [],
  272. maxDepth: 0,
  273. }
  274. }
  275. const list: Node[] = [startNode]
  276. let maxDepth = 1
  277. const traverse = (root: Node, depth: number) => {
  278. if (depth > maxDepth)
  279. maxDepth = depth
  280. const outgoers = getOutgoers(root, nodes, edges)
  281. if (outgoers.length) {
  282. outgoers.forEach((outgoer) => {
  283. list.push(outgoer)
  284. if (outgoer.data.type === BlockEnum.Iteration)
  285. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  286. traverse(outgoer, depth + 1)
  287. })
  288. }
  289. else {
  290. list.push(root)
  291. if (root.data.type === BlockEnum.Iteration)
  292. list.push(...nodes.filter(node => node.parentId === root.id))
  293. }
  294. }
  295. traverse(startNode, maxDepth)
  296. return {
  297. validNodes: uniqBy(list, 'id'),
  298. maxDepth,
  299. }
  300. }
  301. export const getToolCheckParams = (
  302. toolData: ToolNodeType,
  303. buildInTools: ToolWithProvider[],
  304. customTools: ToolWithProvider[],
  305. workflowTools: ToolWithProvider[],
  306. language: string,
  307. ) => {
  308. const { provider_id, provider_type, tool_name } = toolData
  309. const isBuiltIn = provider_type === CollectionType.builtIn
  310. const currentTools = provider_type === CollectionType.builtIn ? buildInTools : provider_type === CollectionType.custom ? customTools : workflowTools
  311. const currCollection = currentTools.find(item => item.id === provider_id)
  312. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  313. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  314. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  315. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  316. return {
  317. toolInputsSchema: (() => {
  318. const formInputs: InputVar[] = []
  319. toolInputVarSchema.forEach((item: any) => {
  320. formInputs.push({
  321. label: item.label[language] || item.label.en_US,
  322. variable: item.variable,
  323. type: item.type,
  324. required: item.required,
  325. })
  326. })
  327. return formInputs
  328. })(),
  329. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  330. toolSettingSchema,
  331. language,
  332. }
  333. }
  334. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  335. const idMap = nodes.reduce((acc, node) => {
  336. acc[node.id] = uuid4()
  337. return acc
  338. }, {} as Record<string, string>)
  339. const newNodes = nodes.map((node) => {
  340. return {
  341. ...node,
  342. id: idMap[node.id],
  343. }
  344. })
  345. const newEdges = edges.map((edge) => {
  346. return {
  347. ...edge,
  348. source: idMap[edge.source],
  349. target: idMap[edge.target],
  350. }
  351. })
  352. return [newNodes, newEdges] as [Node[], Edge[]]
  353. }
  354. export const isMac = () => {
  355. return navigator.userAgent.toUpperCase().includes('MAC')
  356. }
  357. const specialKeysNameMap: Record<string, string | undefined> = {
  358. ctrl: '⌘',
  359. alt: '⌥',
  360. }
  361. export const getKeyboardKeyNameBySystem = (key: string) => {
  362. if (isMac())
  363. return specialKeysNameMap[key] || key
  364. return key
  365. }
  366. const specialKeysCodeMap: Record<string, string | undefined> = {
  367. ctrl: 'meta',
  368. }
  369. export const getKeyboardKeyCodeBySystem = (key: string) => {
  370. if (isMac())
  371. return specialKeysCodeMap[key] || key
  372. return key
  373. }
  374. export const getTopLeftNodePosition = (nodes: Node[]) => {
  375. let minX = Infinity
  376. let minY = Infinity
  377. nodes.forEach((node) => {
  378. if (node.position.x < minX)
  379. minX = node.position.x
  380. if (node.position.y < minY)
  381. minY = node.position.y
  382. })
  383. return {
  384. x: minX,
  385. y: minY,
  386. }
  387. }
  388. export const isEventTargetInputArea = (target: HTMLElement) => {
  389. if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA')
  390. return true
  391. if (target.contentEditable === 'true')
  392. return true
  393. }