use-config.ts 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import {
  2. useCallback,
  3. useEffect,
  4. useRef,
  5. useState,
  6. } from 'react'
  7. import produce from 'immer'
  8. import { isEqual } from 'lodash-es'
  9. import type { ValueSelector, Var } from '../../types'
  10. import { BlockEnum, VarType } from '../../types'
  11. import {
  12. useIsChatMode, useNodesReadOnly,
  13. useWorkflow,
  14. } from '../../hooks'
  15. import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
  16. import {
  17. getMultipleRetrievalConfig,
  18. getSelectedDatasetsMode,
  19. } from './utils'
  20. import { RETRIEVE_TYPE } from '@/types/app'
  21. import { DATASET_DEFAULT } from '@/config'
  22. import type { DataSet } from '@/models/datasets'
  23. import { fetchDatasets } from '@/service/datasets'
  24. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  25. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  26. import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  27. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  28. const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
  29. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  30. const isChatMode = useIsChatMode()
  31. const { getBeforeNodesInSameBranch } = useWorkflow()
  32. const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
  33. const startNodeId = startNode?.id
  34. const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
  35. const inputRef = useRef(inputs)
  36. const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
  37. const newInputs = produce(s, (draft) => {
  38. if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
  39. delete draft.single_retrieval_config
  40. else
  41. delete draft.multiple_retrieval_config
  42. })
  43. // not work in pass to draft...
  44. doSetInputs(newInputs)
  45. inputRef.current = newInputs
  46. }, [doSetInputs])
  47. const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
  48. const newInputs = produce(inputs, (draft) => {
  49. draft.query_variable_selector = newVar as ValueSelector
  50. })
  51. setInputs(newInputs)
  52. }, [inputs, setInputs])
  53. const {
  54. currentProvider,
  55. currentModel,
  56. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  57. const {
  58. modelList: rerankModelList,
  59. defaultModel: rerankDefaultModel,
  60. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  61. const {
  62. currentModel: currentRerankModel,
  63. } = useCurrentProviderAndModel(
  64. rerankModelList,
  65. rerankDefaultModel
  66. ? {
  67. ...rerankDefaultModel,
  68. provider: rerankDefaultModel.provider.provider,
  69. }
  70. : undefined,
  71. )
  72. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  73. const newInputs = produce(inputRef.current, (draft) => {
  74. if (!draft.single_retrieval_config) {
  75. draft.single_retrieval_config = {
  76. model: {
  77. provider: '',
  78. name: '',
  79. mode: '',
  80. completion_params: {},
  81. },
  82. }
  83. }
  84. const draftModel = draft.single_retrieval_config?.model
  85. draftModel.provider = model.provider
  86. draftModel.name = model.modelId
  87. draftModel.mode = model.mode!
  88. })
  89. setInputs(newInputs)
  90. }, [setInputs])
  91. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  92. // inputRef.current.single_retrieval_config?.model is old when change the provider...
  93. if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
  94. return
  95. const newInputs = produce(inputRef.current, (draft) => {
  96. if (!draft.single_retrieval_config) {
  97. draft.single_retrieval_config = {
  98. model: {
  99. provider: '',
  100. name: '',
  101. mode: '',
  102. completion_params: {},
  103. },
  104. }
  105. }
  106. draft.single_retrieval_config.model.completion_params = newParams
  107. })
  108. setInputs(newInputs)
  109. }, [setInputs])
  110. // set defaults models
  111. useEffect(() => {
  112. const inputs = inputRef.current
  113. if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
  114. return
  115. if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
  116. return
  117. const newInput = produce(inputs, (draft) => {
  118. if (currentProvider?.provider && currentModel?.model) {
  119. const hasSetModel = draft.single_retrieval_config?.model?.provider
  120. if (!hasSetModel) {
  121. draft.single_retrieval_config = {
  122. model: {
  123. provider: currentProvider?.provider,
  124. name: currentModel?.model,
  125. mode: currentModel?.model_properties?.mode as string,
  126. completion_params: {},
  127. },
  128. }
  129. }
  130. }
  131. const multipleRetrievalConfig = draft.multiple_retrieval_config
  132. draft.multiple_retrieval_config = {
  133. top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
  134. score_threshold: multipleRetrievalConfig?.score_threshold,
  135. reranking_model: multipleRetrievalConfig?.reranking_model,
  136. reranking_mode: multipleRetrievalConfig?.reranking_mode,
  137. weights: multipleRetrievalConfig?.weights,
  138. reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
  139. ? multipleRetrievalConfig.reranking_enable
  140. : Boolean(currentRerankModel && rerankDefaultModel),
  141. }
  142. })
  143. setInputs(newInput)
  144. // eslint-disable-next-line react-hooks/exhaustive-deps
  145. }, [currentProvider?.provider, currentModel, rerankDefaultModel])
  146. const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
  147. const [rerankModelOpen, setRerankModelOpen] = useState(false)
  148. const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
  149. const newInputs = produce(inputs, (draft) => {
  150. draft.retrieval_mode = newMode
  151. if (newMode === RETRIEVE_TYPE.multiWay) {
  152. const multipleRetrievalConfig = draft.multiple_retrieval_config
  153. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
  154. }
  155. else {
  156. const hasSetModel = draft.single_retrieval_config?.model?.provider
  157. if (!hasSetModel) {
  158. draft.single_retrieval_config = {
  159. model: {
  160. provider: currentProvider?.provider || '',
  161. name: currentModel?.model || '',
  162. mode: currentModel?.model_properties?.mode as string,
  163. completion_params: {},
  164. },
  165. }
  166. }
  167. }
  168. })
  169. setInputs(newInputs)
  170. }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
  171. const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
  172. const newInputs = produce(inputs, (draft) => {
  173. draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
  174. })
  175. setInputs(newInputs)
  176. }, [inputs, setInputs, selectedDatasets])
  177. // datasets
  178. useEffect(() => {
  179. (async () => {
  180. const inputs = inputRef.current
  181. const datasetIds = inputs.dataset_ids
  182. if (datasetIds?.length > 0) {
  183. const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
  184. setSelectedDatasets(dataSetsWithDetail)
  185. }
  186. const newInputs = produce(inputs, (draft) => {
  187. draft.dataset_ids = datasetIds
  188. })
  189. setInputs(newInputs)
  190. })()
  191. // eslint-disable-next-line react-hooks/exhaustive-deps
  192. }, [])
  193. useEffect(() => {
  194. const inputs = inputRef.current
  195. let query_variable_selector: ValueSelector = inputs.query_variable_selector
  196. if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
  197. query_variable_selector = [startNodeId, 'sys.query']
  198. setInputs(produce(inputs, (draft) => {
  199. draft.query_variable_selector = query_variable_selector
  200. }))
  201. // eslint-disable-next-line react-hooks/exhaustive-deps
  202. }, [])
  203. const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
  204. const {
  205. mixtureHighQualityAndEconomic,
  206. mixtureInternalAndExternal,
  207. inconsistentEmbeddingModel,
  208. allInternal,
  209. allExternal,
  210. } = getSelectedDatasetsMode(newDatasets)
  211. const newInputs = produce(inputs, (draft) => {
  212. draft.dataset_ids = newDatasets.map(d => d.id)
  213. if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
  214. const multipleRetrievalConfig = draft.multiple_retrieval_config
  215. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
  216. }
  217. })
  218. setInputs(newInputs)
  219. setSelectedDatasets(newDatasets)
  220. if (
  221. (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
  222. || mixtureInternalAndExternal
  223. || (allExternal && newDatasets.length > 1)
  224. )
  225. setRerankModelOpen(true)
  226. }, [inputs, setInputs, payload.retrieval_mode])
  227. const filterVar = useCallback((varPayload: Var) => {
  228. return varPayload.type === VarType.string
  229. }, [])
  230. // single run
  231. const {
  232. isShowSingleRun,
  233. hideSingleRun,
  234. runningStatus,
  235. handleRun,
  236. handleStop,
  237. runInputData,
  238. setRunInputData,
  239. runResult,
  240. } = useOneStepRun<KnowledgeRetrievalNodeType>({
  241. id,
  242. data: inputs,
  243. defaultRunInputData: {
  244. query: '',
  245. },
  246. })
  247. const query = runInputData.query
  248. const setQuery = useCallback((newQuery: string) => {
  249. setRunInputData({
  250. ...runInputData,
  251. query: newQuery,
  252. })
  253. }, [runInputData, setRunInputData])
  254. return {
  255. readOnly,
  256. inputs,
  257. handleQueryVarChange,
  258. filterVar,
  259. handleRetrievalModeChange,
  260. handleMultipleRetrievalConfigChange,
  261. handleModelChanged,
  262. handleCompletionParamsChange,
  263. selectedDatasets: selectedDatasets.filter(d => d.name),
  264. handleOnDatasetsChange,
  265. isShowSingleRun,
  266. hideSingleRun,
  267. runningStatus,
  268. handleRun,
  269. handleStop,
  270. query,
  271. setQuery,
  272. runResult,
  273. rerankModelOpen,
  274. setRerankModelOpen,
  275. }
  276. }
  277. export default useConfig