hooks.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import {
  2. useCallback,
  3. useEffect,
  4. useMemo,
  5. useRef,
  6. useState,
  7. } from 'react'
  8. import { useTranslation } from 'react-i18next'
  9. import { produce, setAutoFreeze } from 'immer'
  10. import { uniqBy } from 'lodash-es'
  11. import { useWorkflowRun } from '../../hooks'
  12. import { NodeRunningStatus, WorkflowRunningStatus } from '../../types'
  13. import { useWorkflowStore } from '../../store'
  14. import { DEFAULT_ITER_TIMES } from '../../constants'
  15. import type {
  16. ChatItem,
  17. ChatItemInTree,
  18. Inputs,
  19. } from '@/app/components/base/chat/types'
  20. import type { InputForm } from '@/app/components/base/chat/chat/type'
  21. import {
  22. getProcessedInputs,
  23. processOpeningStatement,
  24. } from '@/app/components/base/chat/chat/utils'
  25. import { useToastContext } from '@/app/components/base/toast'
  26. import { TransferMethod } from '@/types/app'
  27. import {
  28. getProcessedFiles,
  29. getProcessedFilesFromResponse,
  30. } from '@/app/components/base/file-uploader/utils'
  31. import type { FileEntity } from '@/app/components/base/file-uploader/types'
  32. import { getThreadMessages } from '@/app/components/base/chat/utils'
  33. import type { NodeTracing } from '@/types/workflow'
  34. type GetAbortController = (abortController: AbortController) => void
  35. type SendCallback = {
  36. onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<any>
  37. }
  38. export const useChat = (
  39. config: any,
  40. formSettings?: {
  41. inputs: Inputs
  42. inputsForm: InputForm[]
  43. },
  44. prevChatTree?: ChatItemInTree[],
  45. stopChat?: (taskId: string) => void,
  46. ) => {
  47. const { t } = useTranslation()
  48. const { notify } = useToastContext()
  49. const { handleRun } = useWorkflowRun()
  50. const hasStopResponded = useRef(false)
  51. const workflowStore = useWorkflowStore()
  52. const conversationId = useRef('')
  53. const taskIdRef = useRef('')
  54. const [isResponding, setIsResponding] = useState(false)
  55. const isRespondingRef = useRef(false)
  56. const [suggestedQuestions, setSuggestQuestions] = useState<string[]>([])
  57. const suggestedQuestionsAbortControllerRef = useRef<AbortController | null>(null)
  58. const {
  59. setIterTimes,
  60. } = workflowStore.getState()
  61. const handleResponding = useCallback((isResponding: boolean) => {
  62. setIsResponding(isResponding)
  63. isRespondingRef.current = isResponding
  64. }, [])
  65. const [chatTree, setChatTree] = useState<ChatItemInTree[]>(prevChatTree || [])
  66. const chatTreeRef = useRef<ChatItemInTree[]>(chatTree)
  67. const [targetMessageId, setTargetMessageId] = useState<string>()
  68. const threadMessages = useMemo(() => getThreadMessages(chatTree, targetMessageId), [chatTree, targetMessageId])
  69. const getIntroduction = useCallback((str: string) => {
  70. return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
  71. }, [formSettings?.inputs, formSettings?.inputsForm])
  72. /** Final chat list that will be rendered */
  73. const chatList = useMemo(() => {
  74. const ret = [...threadMessages]
  75. if (config?.opening_statement) {
  76. const index = threadMessages.findIndex(item => item.isOpeningStatement)
  77. if (index > -1) {
  78. ret[index] = {
  79. ...ret[index],
  80. content: getIntroduction(config.opening_statement),
  81. suggestedQuestions: config.suggested_questions,
  82. }
  83. }
  84. else {
  85. ret.unshift({
  86. id: `${Date.now()}`,
  87. content: getIntroduction(config.opening_statement),
  88. isAnswer: true,
  89. isOpeningStatement: true,
  90. suggestedQuestions: config.suggested_questions,
  91. })
  92. }
  93. }
  94. return ret
  95. }, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
  96. useEffect(() => {
  97. setAutoFreeze(false)
  98. return () => {
  99. setAutoFreeze(true)
  100. }
  101. }, [])
  102. /** Find the target node by bfs and then operate on it */
  103. const produceChatTreeNode = useCallback((targetId: string, operation: (node: ChatItemInTree) => void) => {
  104. return produce(chatTreeRef.current, (draft) => {
  105. const queue: ChatItemInTree[] = [...draft]
  106. while (queue.length > 0) {
  107. const current = queue.shift()!
  108. if (current.id === targetId) {
  109. operation(current)
  110. break
  111. }
  112. if (current.children)
  113. queue.push(...current.children)
  114. }
  115. })
  116. }, [])
  117. const handleStop = useCallback(() => {
  118. hasStopResponded.current = true
  119. handleResponding(false)
  120. if (stopChat && taskIdRef.current)
  121. stopChat(taskIdRef.current)
  122. setIterTimes(DEFAULT_ITER_TIMES)
  123. if (suggestedQuestionsAbortControllerRef.current)
  124. suggestedQuestionsAbortControllerRef.current.abort()
  125. }, [handleResponding, setIterTimes, stopChat])
  126. const handleRestart = useCallback(() => {
  127. conversationId.current = ''
  128. taskIdRef.current = ''
  129. handleStop()
  130. setIterTimes(DEFAULT_ITER_TIMES)
  131. setChatTree([])
  132. setSuggestQuestions([])
  133. }, [
  134. handleStop,
  135. setIterTimes,
  136. ])
  137. const updateCurrentQAOnTree = useCallback(({
  138. parentId,
  139. responseItem,
  140. placeholderQuestionId,
  141. questionItem,
  142. }: {
  143. parentId?: string
  144. responseItem: ChatItem
  145. placeholderQuestionId: string
  146. questionItem: ChatItem
  147. }) => {
  148. let nextState: ChatItemInTree[]
  149. const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] }
  150. if (!parentId && !chatTree.some(item => [placeholderQuestionId, questionItem.id].includes(item.id))) {
  151. // QA whose parent is not provided is considered as a first message of the conversation,
  152. // and it should be a root node of the chat tree
  153. nextState = produce(chatTree, (draft) => {
  154. draft.push(currentQA)
  155. })
  156. }
  157. else {
  158. // find the target QA in the tree and update it; if not found, insert it to its parent node
  159. nextState = produceChatTreeNode(parentId!, (parentNode) => {
  160. const questionNodeIndex = parentNode.children!.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
  161. if (questionNodeIndex === -1)
  162. parentNode.children!.push(currentQA)
  163. else
  164. parentNode.children![questionNodeIndex] = currentQA
  165. })
  166. }
  167. setChatTree(nextState)
  168. chatTreeRef.current = nextState
  169. }, [chatTree, produceChatTreeNode])
  170. const handleSend = useCallback((
  171. params: {
  172. query: string
  173. files?: FileEntity[]
  174. parent_message_id?: string
  175. [key: string]: any
  176. },
  177. {
  178. onGetSuggestedQuestions,
  179. }: SendCallback,
  180. ) => {
  181. if (isRespondingRef.current) {
  182. notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
  183. return false
  184. }
  185. const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
  186. const placeholderQuestionId = `question-${Date.now()}`
  187. const questionItem = {
  188. id: placeholderQuestionId,
  189. content: params.query,
  190. isAnswer: false,
  191. message_files: params.files,
  192. parentMessageId: params.parent_message_id,
  193. }
  194. const placeholderAnswerId = `answer-placeholder-${Date.now()}`
  195. const placeholderAnswerItem = {
  196. id: placeholderAnswerId,
  197. content: '',
  198. isAnswer: true,
  199. parentMessageId: questionItem.id,
  200. siblingIndex: parentMessage?.children?.length ?? chatTree.length,
  201. }
  202. setTargetMessageId(parentMessage?.id)
  203. updateCurrentQAOnTree({
  204. parentId: params.parent_message_id,
  205. responseItem: placeholderAnswerItem,
  206. placeholderQuestionId,
  207. questionItem,
  208. })
  209. // answer
  210. const responseItem: ChatItem = {
  211. id: placeholderAnswerId,
  212. content: '',
  213. agent_thoughts: [],
  214. message_files: [],
  215. isAnswer: true,
  216. parentMessageId: questionItem.id,
  217. siblingIndex: parentMessage?.children?.length ?? chatTree.length,
  218. }
  219. handleResponding(true)
  220. const { files, inputs, ...restParams } = params
  221. const bodyParams = {
  222. files: getProcessedFiles(files || []),
  223. inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
  224. ...restParams,
  225. }
  226. if (bodyParams?.files?.length) {
  227. bodyParams.files = bodyParams.files.map((item) => {
  228. if (item.transfer_method === TransferMethod.local_file) {
  229. return {
  230. ...item,
  231. url: '',
  232. }
  233. }
  234. return item
  235. })
  236. }
  237. let hasSetResponseId = false
  238. handleRun(
  239. bodyParams,
  240. {
  241. onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }: any) => {
  242. responseItem.content = responseItem.content + message
  243. if (messageId && !hasSetResponseId) {
  244. questionItem.id = `question-${messageId}`
  245. responseItem.id = messageId
  246. responseItem.parentMessageId = questionItem.id
  247. hasSetResponseId = true
  248. }
  249. if (isFirstMessage && newConversationId)
  250. conversationId.current = newConversationId
  251. taskIdRef.current = taskId
  252. if (messageId)
  253. responseItem.id = messageId
  254. updateCurrentQAOnTree({
  255. placeholderQuestionId,
  256. questionItem,
  257. responseItem,
  258. parentId: params.parent_message_id,
  259. })
  260. },
  261. async onCompleted(hasError?: boolean, errorMessage?: string) {
  262. handleResponding(false)
  263. if (hasError) {
  264. if (errorMessage) {
  265. responseItem.content = errorMessage
  266. responseItem.isError = true
  267. updateCurrentQAOnTree({
  268. placeholderQuestionId,
  269. questionItem,
  270. responseItem,
  271. parentId: params.parent_message_id,
  272. })
  273. }
  274. return
  275. }
  276. if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) {
  277. try {
  278. const { data }: any = await onGetSuggestedQuestions(
  279. responseItem.id,
  280. newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController,
  281. )
  282. setSuggestQuestions(data)
  283. }
  284. catch (error) {
  285. setSuggestQuestions([])
  286. }
  287. }
  288. },
  289. onMessageEnd: (messageEnd) => {
  290. responseItem.citation = messageEnd.metadata?.retriever_resources || []
  291. const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
  292. responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
  293. updateCurrentQAOnTree({
  294. placeholderQuestionId,
  295. questionItem,
  296. responseItem,
  297. parentId: params.parent_message_id,
  298. })
  299. },
  300. onMessageReplace: (messageReplace) => {
  301. responseItem.content = messageReplace.answer
  302. },
  303. onError() {
  304. handleResponding(false)
  305. },
  306. onWorkflowStarted: ({ workflow_run_id, task_id }) => {
  307. taskIdRef.current = task_id
  308. responseItem.workflow_run_id = workflow_run_id
  309. responseItem.workflowProcess = {
  310. status: WorkflowRunningStatus.Running,
  311. tracing: [],
  312. }
  313. updateCurrentQAOnTree({
  314. placeholderQuestionId,
  315. questionItem,
  316. responseItem,
  317. parentId: params.parent_message_id,
  318. })
  319. },
  320. onWorkflowFinished: ({ data }) => {
  321. responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
  322. updateCurrentQAOnTree({
  323. placeholderQuestionId,
  324. questionItem,
  325. responseItem,
  326. parentId: params.parent_message_id,
  327. })
  328. },
  329. onIterationStart: ({ data }) => {
  330. responseItem.workflowProcess!.tracing!.push({
  331. ...data,
  332. status: NodeRunningStatus.Running,
  333. details: [],
  334. } as any)
  335. updateCurrentQAOnTree({
  336. placeholderQuestionId,
  337. questionItem,
  338. responseItem,
  339. parentId: params.parent_message_id,
  340. })
  341. },
  342. onIterationNext: ({ data }) => {
  343. const tracing = responseItem.workflowProcess!.tracing!
  344. const iterations = tracing.find(item => item.node_id === data.node_id
  345. && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
  346. iterations.details!.push([])
  347. updateCurrentQAOnTree({
  348. placeholderQuestionId,
  349. questionItem,
  350. responseItem,
  351. parentId: params.parent_message_id,
  352. })
  353. },
  354. onIterationFinish: ({ data }) => {
  355. const tracing = responseItem.workflowProcess!.tracing!
  356. const iterationsIndex = tracing.findIndex(item => item.node_id === data.node_id
  357. && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
  358. tracing[iterationsIndex] = {
  359. ...tracing[iterationsIndex],
  360. ...data,
  361. status: NodeRunningStatus.Succeeded,
  362. } as any
  363. updateCurrentQAOnTree({
  364. placeholderQuestionId,
  365. questionItem,
  366. responseItem,
  367. parentId: params.parent_message_id,
  368. })
  369. },
  370. onNodeStarted: ({ data }) => {
  371. if (data.iteration_id)
  372. return
  373. responseItem.workflowProcess!.tracing!.push({
  374. ...data,
  375. status: NodeRunningStatus.Running,
  376. } as any)
  377. updateCurrentQAOnTree({
  378. placeholderQuestionId,
  379. questionItem,
  380. responseItem,
  381. parentId: params.parent_message_id,
  382. })
  383. },
  384. onNodeRetry: ({ data }) => {
  385. if (data.iteration_id)
  386. return
  387. const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => {
  388. if (!item.execution_metadata?.parallel_id)
  389. return item.node_id === data.node_id
  390. return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id)
  391. })
  392. if (responseItem.workflowProcess!.tracing[currentIndex].retryDetail)
  393. responseItem.workflowProcess!.tracing[currentIndex].retryDetail?.push(data as NodeTracing)
  394. else
  395. responseItem.workflowProcess!.tracing[currentIndex].retryDetail = [data as NodeTracing]
  396. handleUpdateChatList(produce(chatListRef.current, (draft) => {
  397. const currentIndex = draft.findIndex(item => item.id === responseItem.id)
  398. draft[currentIndex] = {
  399. ...draft[currentIndex],
  400. ...responseItem,
  401. }
  402. }))
  403. },
  404. onNodeFinished: ({ data }) => {
  405. if (data.iteration_id)
  406. return
  407. const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => {
  408. if (!item.execution_metadata?.parallel_id)
  409. return item.node_id === data.node_id
  410. return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id)
  411. })
  412. responseItem.workflowProcess!.tracing[currentIndex] = {
  413. ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras
  414. ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras }
  415. : {}),
  416. ...(responseItem.workflowProcess!.tracing[currentIndex]?.retryDetail
  417. ? { retryDetail: responseItem.workflowProcess!.tracing[currentIndex].retryDetail }
  418. : {}),
  419. ...data,
  420. } as any
  421. updateCurrentQAOnTree({
  422. placeholderQuestionId,
  423. questionItem,
  424. responseItem,
  425. parentId: params.parent_message_id,
  426. })
  427. },
  428. },
  429. )
  430. }, [threadMessages, chatTree.length, updateCurrentQAOnTree, handleResponding, formSettings?.inputsForm, handleRun, notify, t, config?.suggested_questions_after_answer?.enabled])
  431. return {
  432. conversationId: conversationId.current,
  433. chatList,
  434. setTargetMessageId,
  435. handleSend,
  436. handleStop,
  437. handleRestart,
  438. isResponding,
  439. suggestedQuestions,
  440. }
  441. }