utils.ts 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getOutgoers,
  5. } from 'reactflow'
  6. import dagre from 'dagre'
  7. import { v4 as uuid4 } from 'uuid'
  8. import {
  9. cloneDeep,
  10. uniqBy,
  11. } from 'lodash-es'
  12. import type {
  13. Edge,
  14. InputVar,
  15. Node,
  16. ToolWithProvider,
  17. } from './types'
  18. import { BlockEnum } from './types'
  19. import {
  20. NODE_WIDTH_X_OFFSET,
  21. START_INITIAL_POSITION,
  22. } from './constants'
  23. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  24. import type { ToolNodeType } from './nodes/tool/types'
  25. import { CollectionType } from '@/app/components/tools/types'
  26. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  27. const WHITE = 'WHITE'
  28. const GRAY = 'GRAY'
  29. const BLACK = 'BLACK'
  30. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjaList: Record<string, string[]>, stack: string[]) => {
  31. color[nodeId] = GRAY
  32. stack.push(nodeId)
  33. for (let i = 0; i < adjaList[nodeId].length; ++i) {
  34. const childId = adjaList[nodeId][i]
  35. if (color[childId] === GRAY) {
  36. stack.push(childId)
  37. return true
  38. }
  39. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjaList, stack))
  40. return true
  41. }
  42. color[nodeId] = BLACK
  43. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  44. stack.pop()
  45. return false
  46. }
  47. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  48. const adjaList: Record<string, string[]> = {}
  49. const color: Record<string, string> = {}
  50. const stack: string[] = []
  51. for (const node of nodes) {
  52. color[node.id] = WHITE
  53. adjaList[node.id] = []
  54. }
  55. for (const edge of edges)
  56. adjaList[edge.source]?.push(edge.target)
  57. for (let i = 0; i < nodes.length; i++) {
  58. if (color[nodes[i].id] === WHITE)
  59. isCyclicUtil(nodes[i].id, color, adjaList, stack)
  60. }
  61. const cycleEdges = []
  62. if (stack.length > 0) {
  63. const cycleNodes = new Set(stack)
  64. for (const edge of edges) {
  65. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  66. cycleEdges.push(edge)
  67. }
  68. }
  69. return cycleEdges
  70. }
  71. export const initialNodes = (nodes: Node[], edges: Edge[]) => {
  72. const firstNode = nodes[0]
  73. if (!firstNode?.position) {
  74. nodes.forEach((node, index) => {
  75. node.position = {
  76. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  77. y: START_INITIAL_POSITION.y,
  78. }
  79. })
  80. }
  81. return nodes.map((node) => {
  82. node.type = 'custom'
  83. const connectedEdges = getConnectedEdges([node], edges)
  84. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  85. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  86. if (node.data.type === BlockEnum.IfElse) {
  87. node.data._targetBranches = [
  88. {
  89. id: 'true',
  90. name: 'IS TRUE',
  91. },
  92. {
  93. id: 'false',
  94. name: 'IS FALSE',
  95. },
  96. ]
  97. }
  98. if (node.data.type === BlockEnum.QuestionClassifier) {
  99. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  100. return topic
  101. })
  102. }
  103. return node
  104. })
  105. }
  106. export const initialEdges = (edges: Edge[], nodes: Node[]) => {
  107. let selectedNode: Node | null = null
  108. const nodesMap = nodes.reduce((acc, node) => {
  109. acc[node.id] = node
  110. if (node.data?.selected)
  111. selectedNode = node
  112. return acc
  113. }, {} as Record<string, Node>)
  114. const cycleEdges = getCycleEdges(nodes, edges)
  115. return edges.filter((edge) => {
  116. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  117. }).map((edge) => {
  118. edge.type = 'custom'
  119. if (!edge.sourceHandle)
  120. edge.sourceHandle = 'source'
  121. if (!edge.targetHandle)
  122. edge.targetHandle = 'target'
  123. if (!edge.data?.sourceType && edge.source) {
  124. edge.data = {
  125. ...edge.data,
  126. sourceType: nodesMap[edge.source].data.type!,
  127. } as any
  128. }
  129. if (!edge.data?.targetType && edge.target) {
  130. edge.data = {
  131. ...edge.data,
  132. targetType: nodesMap[edge.target].data.type!,
  133. } as any
  134. }
  135. if (selectedNode) {
  136. edge.data = {
  137. ...edge.data,
  138. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  139. } as any
  140. }
  141. return edge
  142. })
  143. }
  144. const dagreGraph = new dagre.graphlib.Graph()
  145. dagreGraph.setDefaultEdgeLabel(() => ({}))
  146. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  147. const nodes = cloneDeep(originNodes)
  148. const edges = cloneDeep(originEdges)
  149. dagreGraph.setGraph({
  150. rankdir: 'LR',
  151. align: 'UL',
  152. nodesep: 40,
  153. ranksep: 60,
  154. })
  155. nodes.forEach((node) => {
  156. dagreGraph.setNode(node.id, { width: node.width, height: node.height })
  157. })
  158. edges.forEach((edge) => {
  159. dagreGraph.setEdge(edge.source, edge.target)
  160. })
  161. dagre.layout(dagreGraph)
  162. return dagreGraph
  163. }
  164. export const canRunBySingle = (nodeType: BlockEnum) => {
  165. return nodeType === BlockEnum.LLM
  166. || nodeType === BlockEnum.KnowledgeRetrieval
  167. || nodeType === BlockEnum.Code
  168. || nodeType === BlockEnum.TemplateTransform
  169. || nodeType === BlockEnum.QuestionClassifier
  170. || nodeType === BlockEnum.HttpRequest
  171. || nodeType === BlockEnum.Tool
  172. }
  173. type ConnectedSourceOrTargetNodesChange = {
  174. type: string
  175. edge: Edge
  176. }[]
  177. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  178. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  179. changes.forEach((change) => {
  180. const {
  181. edge,
  182. type,
  183. } = change
  184. const sourceNode = nodes.find(node => node.id === edge.source)!
  185. if (sourceNode) {
  186. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  187. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  188. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  189. }
  190. }
  191. const targetNode = nodes.find(node => node.id === edge.target)!
  192. if (targetNode) {
  193. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  194. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  195. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  196. }
  197. }
  198. if (sourceNode) {
  199. if (type === 'remove')
  200. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.filter((handleId: string) => handleId !== edge.sourceHandle)
  201. if (type === 'add')
  202. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  203. }
  204. if (targetNode) {
  205. if (type === 'remove')
  206. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.filter((handleId: string) => handleId !== edge.targetHandle)
  207. if (type === 'add')
  208. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  209. }
  210. })
  211. return nodesConnectedSourceOrTargetHandleIdsMap
  212. }
  213. export const generateNewNode = ({ data, position, id }: Pick<Node, 'data' | 'position'> & { id?: string }) => {
  214. return {
  215. id: id || `${Date.now()}`,
  216. type: 'custom',
  217. data,
  218. position,
  219. targetPosition: Position.Left,
  220. sourcePosition: Position.Right,
  221. } as Node
  222. }
  223. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  224. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  225. if (!startNode) {
  226. return {
  227. validNodes: [],
  228. maxDepth: 0,
  229. }
  230. }
  231. const list: Node[] = [startNode]
  232. let maxDepth = 1
  233. const traverse = (root: Node, depth: number) => {
  234. if (depth > maxDepth)
  235. maxDepth = depth
  236. const outgoers = getOutgoers(root, nodes, edges)
  237. if (outgoers.length) {
  238. outgoers.forEach((outgoer) => {
  239. list.push(outgoer)
  240. traverse(outgoer, depth + 1)
  241. })
  242. }
  243. else {
  244. list.push(root)
  245. }
  246. }
  247. traverse(startNode, maxDepth)
  248. return {
  249. validNodes: uniqBy(list, 'id'),
  250. maxDepth,
  251. }
  252. }
  253. export const getToolCheckParams = (
  254. toolData: ToolNodeType,
  255. buildInTools: ToolWithProvider[],
  256. customTools: ToolWithProvider[],
  257. language: string,
  258. ) => {
  259. const { provider_id, provider_type, tool_name } = toolData
  260. const isBuiltIn = provider_type === CollectionType.builtIn
  261. const currentTools = isBuiltIn ? buildInTools : customTools
  262. const currCollection = currentTools.find(item => item.id === provider_id)
  263. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  264. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  265. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  266. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  267. return {
  268. toolInputsSchema: (() => {
  269. const formInputs: InputVar[] = []
  270. toolInputVarSchema.forEach((item: any) => {
  271. formInputs.push({
  272. label: item.label[language] || item.label.en_US,
  273. variable: item.variable,
  274. type: item.type,
  275. required: item.required,
  276. })
  277. })
  278. return formInputs
  279. })(),
  280. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  281. toolSettingSchema,
  282. language,
  283. }
  284. }
  285. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  286. const idMap = nodes.reduce((acc, node) => {
  287. acc[node.id] = uuid4()
  288. return acc
  289. }, {} as Record<string, string>)
  290. const newNodes = nodes.map((node) => {
  291. return {
  292. ...node,
  293. id: idMap[node.id],
  294. }
  295. })
  296. const newEdges = edges.map((edge) => {
  297. return {
  298. ...edge,
  299. source: idMap[edge.source],
  300. target: idMap[edge.target],
  301. }
  302. })
  303. return [newNodes, newEdges] as [Node[], Edge[]]
  304. }