use-config.ts 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import {
  2. useCallback,
  3. useEffect,
  4. useRef,
  5. useState,
  6. } from 'react'
  7. import produce from 'immer'
  8. import { isEqual } from 'lodash-es'
  9. import type { ValueSelector, Var } from '../../types'
  10. import { BlockEnum, VarType } from '../../types'
  11. import {
  12. useIsChatMode, useNodesReadOnly,
  13. useWorkflow,
  14. } from '../../hooks'
  15. import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
  16. import {
  17. getMultipleRetrievalConfig,
  18. getSelectedDatasetsMode,
  19. } from './utils'
  20. import { RETRIEVE_TYPE } from '@/types/app'
  21. import { DATASET_DEFAULT } from '@/config'
  22. import type { DataSet } from '@/models/datasets'
  23. import { fetchDatasets } from '@/service/datasets'
  24. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  25. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  26. import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  27. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  28. const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
  29. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  30. const isChatMode = useIsChatMode()
  31. const { getBeforeNodesInSameBranch } = useWorkflow()
  32. const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
  33. const startNodeId = startNode?.id
  34. const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
  35. const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
  36. const newInputs = produce(s, (draft) => {
  37. if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
  38. delete draft.single_retrieval_config
  39. else
  40. delete draft.multiple_retrieval_config
  41. })
  42. // not work in pass to draft...
  43. doSetInputs(newInputs)
  44. }, [doSetInputs])
  45. const inputRef = useRef(inputs)
  46. useEffect(() => {
  47. inputRef.current = inputs
  48. }, [inputs])
  49. const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
  50. const newInputs = produce(inputs, (draft) => {
  51. draft.query_variable_selector = newVar as ValueSelector
  52. })
  53. setInputs(newInputs)
  54. }, [inputs, setInputs])
  55. const {
  56. currentProvider,
  57. currentModel,
  58. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  59. const {
  60. defaultModel: rerankDefaultModel,
  61. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  62. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  63. const newInputs = produce(inputRef.current, (draft) => {
  64. if (!draft.single_retrieval_config) {
  65. draft.single_retrieval_config = {
  66. model: {
  67. provider: '',
  68. name: '',
  69. mode: '',
  70. completion_params: {},
  71. },
  72. }
  73. }
  74. const draftModel = draft.single_retrieval_config?.model
  75. draftModel.provider = model.provider
  76. draftModel.name = model.modelId
  77. draftModel.mode = model.mode!
  78. })
  79. setInputs(newInputs)
  80. }, [setInputs])
  81. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  82. // inputRef.current.single_retrieval_config?.model is old when change the provider...
  83. if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
  84. return
  85. const newInputs = produce(inputRef.current, (draft) => {
  86. if (!draft.single_retrieval_config) {
  87. draft.single_retrieval_config = {
  88. model: {
  89. provider: '',
  90. name: '',
  91. mode: '',
  92. completion_params: {},
  93. },
  94. }
  95. }
  96. draft.single_retrieval_config.model.completion_params = newParams
  97. })
  98. setInputs(newInputs)
  99. }, [setInputs])
  100. // set defaults models
  101. useEffect(() => {
  102. const inputs = inputRef.current
  103. if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
  104. return
  105. if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
  106. return
  107. const newInput = produce(inputs, (draft) => {
  108. if (currentProvider?.provider && currentModel?.model) {
  109. const hasSetModel = draft.single_retrieval_config?.model?.provider
  110. if (!hasSetModel) {
  111. draft.single_retrieval_config = {
  112. model: {
  113. provider: currentProvider?.provider,
  114. name: currentModel?.model,
  115. mode: currentModel?.model_properties?.mode as string,
  116. completion_params: {},
  117. },
  118. }
  119. }
  120. }
  121. const multipleRetrievalConfig = draft.multiple_retrieval_config
  122. draft.multiple_retrieval_config = {
  123. top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
  124. score_threshold: multipleRetrievalConfig?.score_threshold,
  125. reranking_model: multipleRetrievalConfig?.reranking_model,
  126. reranking_mode: multipleRetrievalConfig?.reranking_mode,
  127. weights: multipleRetrievalConfig?.weights,
  128. }
  129. })
  130. setInputs(newInput)
  131. // eslint-disable-next-line react-hooks/exhaustive-deps
  132. }, [currentProvider?.provider, currentModel, rerankDefaultModel])
  133. const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
  134. const [rerankModelOpen, setRerankModelOpen] = useState(false)
  135. const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
  136. const newInputs = produce(inputs, (draft) => {
  137. draft.retrieval_mode = newMode
  138. if (newMode === RETRIEVE_TYPE.multiWay) {
  139. const multipleRetrievalConfig = draft.multiple_retrieval_config
  140. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
  141. }
  142. else {
  143. const hasSetModel = draft.single_retrieval_config?.model?.provider
  144. if (!hasSetModel) {
  145. draft.single_retrieval_config = {
  146. model: {
  147. provider: currentProvider?.provider || '',
  148. name: currentModel?.model || '',
  149. mode: currentModel?.model_properties?.mode as string,
  150. completion_params: {},
  151. },
  152. }
  153. }
  154. }
  155. })
  156. setInputs(newInputs)
  157. }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
  158. const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
  159. const newInputs = produce(inputs, (draft) => {
  160. draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
  161. })
  162. setInputs(newInputs)
  163. }, [inputs, setInputs, selectedDatasets])
  164. // datasets
  165. useEffect(() => {
  166. (async () => {
  167. const inputs = inputRef.current
  168. const datasetIds = inputs.dataset_ids
  169. if (datasetIds?.length > 0) {
  170. const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
  171. setSelectedDatasets(dataSetsWithDetail)
  172. }
  173. const newInputs = produce(inputs, (draft) => {
  174. draft.dataset_ids = datasetIds
  175. })
  176. setInputs(newInputs)
  177. })()
  178. // eslint-disable-next-line react-hooks/exhaustive-deps
  179. }, [])
  180. useEffect(() => {
  181. let query_variable_selector: ValueSelector = inputs.query_variable_selector
  182. if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
  183. query_variable_selector = [startNodeId, 'sys.query']
  184. setInputs({
  185. ...inputs,
  186. query_variable_selector,
  187. })
  188. // eslint-disable-next-line react-hooks/exhaustive-deps
  189. }, [])
  190. const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
  191. const {
  192. mixtureHighQualityAndEconomic,
  193. mixtureInternalAndExternal,
  194. inconsistentEmbeddingModel,
  195. allInternal,
  196. allExternal,
  197. } = getSelectedDatasetsMode(newDatasets)
  198. const newInputs = produce(inputs, (draft) => {
  199. draft.dataset_ids = newDatasets.map(d => d.id)
  200. if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
  201. const multipleRetrievalConfig = draft.multiple_retrieval_config
  202. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
  203. }
  204. })
  205. setInputs(newInputs)
  206. setSelectedDatasets(newDatasets)
  207. if (
  208. (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
  209. || mixtureInternalAndExternal
  210. || (allExternal && newDatasets.length > 1)
  211. )
  212. setRerankModelOpen(true)
  213. }, [inputs, setInputs, payload.retrieval_mode])
  214. const filterVar = useCallback((varPayload: Var) => {
  215. return varPayload.type === VarType.string
  216. }, [])
  217. // single run
  218. const {
  219. isShowSingleRun,
  220. hideSingleRun,
  221. runningStatus,
  222. handleRun,
  223. handleStop,
  224. runInputData,
  225. setRunInputData,
  226. runResult,
  227. } = useOneStepRun<KnowledgeRetrievalNodeType>({
  228. id,
  229. data: inputs,
  230. defaultRunInputData: {
  231. query: '',
  232. },
  233. })
  234. const query = runInputData.query
  235. const setQuery = useCallback((newQuery: string) => {
  236. setRunInputData({
  237. ...runInputData,
  238. query: newQuery,
  239. })
  240. }, [runInputData, setRunInputData])
  241. return {
  242. readOnly,
  243. inputs,
  244. handleQueryVarChange,
  245. filterVar,
  246. handleRetrievalModeChange,
  247. handleMultipleRetrievalConfigChange,
  248. handleModelChanged,
  249. handleCompletionParamsChange,
  250. selectedDatasets: selectedDatasets.filter(d => d.name),
  251. handleOnDatasetsChange,
  252. isShowSingleRun,
  253. hideSingleRun,
  254. runningStatus,
  255. handleRun,
  256. handleStop,
  257. query,
  258. setQuery,
  259. runResult,
  260. rerankModelOpen,
  261. setRerankModelOpen,
  262. }
  263. }
  264. export default useConfig