default.ts 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import { BlockEnum } from '../../types'
  2. import type { NodeDefault } from '../../types'
  3. import type { KnowledgeRetrievalNodeType } from './types'
  4. import { RerankingModeEnum } from '@/models/datasets'
  5. import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
  6. import { DATASET_DEFAULT } from '@/config'
  7. import { RETRIEVE_TYPE } from '@/types/app'
  8. const i18nPrefix = 'workflow'
  9. const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
  10. defaultValue: {
  11. query_variable_selector: [],
  12. dataset_ids: [],
  13. retrieval_mode: RETRIEVE_TYPE.multiWay,
  14. multiple_retrieval_config: {
  15. top_k: DATASET_DEFAULT.top_k,
  16. score_threshold: undefined,
  17. reranking_enable: false,
  18. },
  19. },
  20. getAvailablePrevNodes(isChatMode: boolean) {
  21. const nodes = isChatMode
  22. ? ALL_CHAT_AVAILABLE_BLOCKS
  23. : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
  24. return nodes
  25. },
  26. getAvailableNextNodes(isChatMode: boolean) {
  27. const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
  28. return nodes
  29. },
  30. checkValid(payload: KnowledgeRetrievalNodeType, t: any) {
  31. let errorMessages = ''
  32. if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
  33. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.queryVariable`) })
  34. if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
  35. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
  36. if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider)
  37. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
  38. if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
  39. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
  40. return {
  41. isValid: !errorMessages,
  42. errorMessage: errorMessages,
  43. }
  44. },
  45. }
  46. export default nodeDefault