index.tsx 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import type {
  2. FC,
  3. ReactNode,
  4. } from 'react'
  5. import { useMemo, useState } from 'react'
  6. import { useTranslation } from 'react-i18next'
  7. import type {
  8. DefaultModel,
  9. FormValue,
  10. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  11. import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  12. import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
  13. import {
  14. useModelList,
  15. } from '@/app/components/header/account-setting/model-provider-page/hooks'
  16. import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
  17. import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
  18. import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
  19. import {
  20. PortalToFollowElem,
  21. PortalToFollowElemContent,
  22. PortalToFollowElemTrigger,
  23. } from '@/app/components/base/portal-to-follow-elem'
  24. import LLMParamsPanel from './llm-params-panel'
  25. import TTSParamsPanel from './tts-params-panel'
  26. import { useProviderContext } from '@/context/provider-context'
  27. import cn from '@/utils/classnames'
  28. export type ModelParameterModalProps = {
  29. popupClassName?: string
  30. portalToFollowElemContentClassName?: string
  31. isAdvancedMode: boolean
  32. value: any
  33. setModel: (model: any) => void
  34. renderTrigger?: (v: TriggerProps) => ReactNode
  35. readonly?: boolean
  36. isInWorkflow?: boolean
  37. isAgentStrategy?: boolean
  38. scope?: string
  39. }
  40. const ModelParameterModal: FC<ModelParameterModalProps> = ({
  41. popupClassName,
  42. portalToFollowElemContentClassName,
  43. isAdvancedMode,
  44. value,
  45. setModel,
  46. renderTrigger,
  47. readonly,
  48. isInWorkflow,
  49. isAgentStrategy,
  50. scope = ModelTypeEnum.textGeneration,
  51. }) => {
  52. const { t } = useTranslation()
  53. const { isAPIKeySet } = useProviderContext()
  54. const [open, setOpen] = useState(false)
  55. const scopeArray = scope.split('&')
  56. const scopeFeatures = useMemo(() => {
  57. if (scopeArray.includes('all'))
  58. return []
  59. return scopeArray.filter(item => ![
  60. ModelTypeEnum.textGeneration,
  61. ModelTypeEnum.textEmbedding,
  62. ModelTypeEnum.rerank,
  63. ModelTypeEnum.moderation,
  64. ModelTypeEnum.speech2text,
  65. ModelTypeEnum.tts,
  66. ].includes(item as ModelTypeEnum))
  67. }, [scopeArray])
  68. const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
  69. const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
  70. const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
  71. const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
  72. const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
  73. const { data: ttsList } = useModelList(ModelTypeEnum.tts)
  74. const scopedModelList = useMemo(() => {
  75. const resultList: any[] = []
  76. if (scopeArray.includes('all')) {
  77. return [
  78. ...textGenerationList,
  79. ...textEmbeddingList,
  80. ...rerankList,
  81. ...sttList,
  82. ...ttsList,
  83. ...moderationList,
  84. ]
  85. }
  86. if (scopeArray.includes(ModelTypeEnum.textGeneration))
  87. return textGenerationList
  88. if (scopeArray.includes(ModelTypeEnum.textEmbedding))
  89. return textEmbeddingList
  90. if (scopeArray.includes(ModelTypeEnum.rerank))
  91. return rerankList
  92. if (scopeArray.includes(ModelTypeEnum.moderation))
  93. return moderationList
  94. if (scopeArray.includes(ModelTypeEnum.speech2text))
  95. return sttList
  96. if (scopeArray.includes(ModelTypeEnum.tts))
  97. return ttsList
  98. return resultList
  99. }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
  100. const { currentProvider, currentModel } = useMemo(() => {
  101. const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
  102. const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
  103. return {
  104. currentProvider,
  105. currentModel,
  106. }
  107. }, [scopedModelList, value?.provider, value?.model])
  108. const hasDeprecated = useMemo(() => {
  109. return !currentProvider || !currentModel
  110. }, [currentModel, currentProvider])
  111. const modelDisabled = useMemo(() => {
  112. return currentModel?.status !== ModelStatusEnum.active
  113. }, [currentModel?.status])
  114. const disabled = useMemo(() => {
  115. return !isAPIKeySet || hasDeprecated || modelDisabled
  116. }, [hasDeprecated, isAPIKeySet, modelDisabled])
  117. const handleChangeModel = ({ provider, model }: DefaultModel) => {
  118. const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
  119. const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
  120. const model_type = targetModelItem?.model_type as string
  121. setModel({
  122. provider,
  123. model,
  124. model_type,
  125. ...(model_type === ModelTypeEnum.textGeneration ? {
  126. mode: targetModelItem?.model_properties.mode as string,
  127. completion_params: {},
  128. } : {}),
  129. })
  130. }
  131. const handleLLMParamsChange = (newParams: FormValue) => {
  132. const newValue = {
  133. ...(value?.completionParams || {}),
  134. completion_params: newParams,
  135. }
  136. setModel({
  137. ...value,
  138. ...newValue,
  139. })
  140. }
  141. const handleTTSParamsChange = (language: string, voice: string) => {
  142. setModel({
  143. ...value,
  144. language,
  145. voice,
  146. })
  147. }
  148. return (
  149. <PortalToFollowElem
  150. open={open}
  151. onOpenChange={setOpen}
  152. placement={isInWorkflow ? 'left' : 'bottom-end'}
  153. offset={4}
  154. >
  155. <div className='relative'>
  156. <PortalToFollowElemTrigger
  157. onClick={() => {
  158. if (readonly)
  159. return
  160. setOpen(v => !v)
  161. }}
  162. className='block'
  163. >
  164. {
  165. renderTrigger
  166. ? renderTrigger({
  167. open,
  168. disabled,
  169. modelDisabled,
  170. hasDeprecated,
  171. currentProvider,
  172. currentModel,
  173. providerName: value?.provider,
  174. modelId: value?.model,
  175. })
  176. : (isAgentStrategy
  177. ? <AgentModelTrigger
  178. disabled={disabled}
  179. hasDeprecated={hasDeprecated}
  180. currentProvider={currentProvider}
  181. currentModel={currentModel}
  182. providerName={value?.provider}
  183. modelId={value?.model}
  184. scope={scope}
  185. />
  186. : <Trigger
  187. disabled={disabled}
  188. isInWorkflow={isInWorkflow}
  189. modelDisabled={modelDisabled}
  190. hasDeprecated={hasDeprecated}
  191. currentProvider={currentProvider}
  192. currentModel={currentModel}
  193. providerName={value?.provider}
  194. modelId={value?.model}
  195. />
  196. )
  197. }
  198. </PortalToFollowElemTrigger>
  199. <PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
  200. <div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
  201. <div className={cn('max-h-[420px] p-4 pt-3 overflow-y-auto')}>
  202. <div className='relative'>
  203. <div className={cn('mb-1 h-6 flex items-center text-text-secondary system-sm-semibold')}>
  204. {t('common.modelProvider.model').toLocaleUpperCase()}
  205. </div>
  206. <ModelSelector
  207. defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
  208. modelList={scopedModelList}
  209. scopeFeatures={scopeFeatures}
  210. onSelect={handleChangeModel}
  211. />
  212. </div>
  213. {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
  214. <div className='my-3 h-[1px] bg-divider-subtle' />
  215. )}
  216. {currentModel?.model_type === ModelTypeEnum.textGeneration && (
  217. <LLMParamsPanel
  218. provider={value?.provider}
  219. modelId={value?.model}
  220. completionParams={value?.completion_params || {}}
  221. onCompletionParamsChange={handleLLMParamsChange}
  222. isAdvancedMode={isAdvancedMode}
  223. />
  224. )}
  225. {currentModel?.model_type === ModelTypeEnum.tts && (
  226. <TTSParamsPanel
  227. currentModel={currentModel}
  228. language={value?.language}
  229. voice={value?.voice}
  230. onChange={handleTTSParamsChange}
  231. />
  232. )}
  233. </div>
  234. </div>
  235. </PortalToFollowElemContent>
  236. </div>
  237. </PortalToFollowElem>
  238. )
  239. }
  240. export default ModelParameterModal