123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- import type {
- FC,
- ReactNode,
- } from 'react'
- import { useMemo, useState } from 'react'
- import { useTranslation } from 'react-i18next'
- import type {
- DefaultModel,
- FormValue,
- } from '@/app/components/header/account-setting/model-provider-page/declarations'
- import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
- import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
- import {
- useModelList,
- } from '@/app/components/header/account-setting/model-provider-page/hooks'
- import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
- import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
- import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
- import {
- PortalToFollowElem,
- PortalToFollowElemContent,
- PortalToFollowElemTrigger,
- } from '@/app/components/base/portal-to-follow-elem'
- import LLMParamsPanel from './llm-params-panel'
- import TTSParamsPanel from './tts-params-panel'
- import { useProviderContext } from '@/context/provider-context'
- import cn from '@/utils/classnames'
- export type ModelParameterModalProps = {
- popupClassName?: string
- portalToFollowElemContentClassName?: string
- isAdvancedMode: boolean
- value: any
- setModel: (model: any) => void
- renderTrigger?: (v: TriggerProps) => ReactNode
- readonly?: boolean
- isInWorkflow?: boolean
- isAgentStrategy?: boolean
- scope?: string
- }
- const ModelParameterModal: FC<ModelParameterModalProps> = ({
- popupClassName,
- portalToFollowElemContentClassName,
- isAdvancedMode,
- value,
- setModel,
- renderTrigger,
- readonly,
- isInWorkflow,
- isAgentStrategy,
- scope = ModelTypeEnum.textGeneration,
- }) => {
- const { t } = useTranslation()
- const { isAPIKeySet } = useProviderContext()
- const [open, setOpen] = useState(false)
- const scopeArray = scope.split('&')
- const scopeFeatures = useMemo(() => {
- if (scopeArray.includes('all'))
- return []
- return scopeArray.filter(item => ![
- ModelTypeEnum.textGeneration,
- ModelTypeEnum.textEmbedding,
- ModelTypeEnum.rerank,
- ModelTypeEnum.moderation,
- ModelTypeEnum.speech2text,
- ModelTypeEnum.tts,
- ].includes(item as ModelTypeEnum))
- }, [scopeArray])
- const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
- const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
- const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
- const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
- const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
- const { data: ttsList } = useModelList(ModelTypeEnum.tts)
- const scopedModelList = useMemo(() => {
- const resultList: any[] = []
- if (scopeArray.includes('all')) {
- return [
- ...textGenerationList,
- ...textEmbeddingList,
- ...rerankList,
- ...sttList,
- ...ttsList,
- ...moderationList,
- ]
- }
- if (scopeArray.includes(ModelTypeEnum.textGeneration))
- return textGenerationList
- if (scopeArray.includes(ModelTypeEnum.textEmbedding))
- return textEmbeddingList
- if (scopeArray.includes(ModelTypeEnum.rerank))
- return rerankList
- if (scopeArray.includes(ModelTypeEnum.moderation))
- return moderationList
- if (scopeArray.includes(ModelTypeEnum.speech2text))
- return sttList
- if (scopeArray.includes(ModelTypeEnum.tts))
- return ttsList
- return resultList
- }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
- const { currentProvider, currentModel } = useMemo(() => {
- const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
- const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
- return {
- currentProvider,
- currentModel,
- }
- }, [scopedModelList, value?.provider, value?.model])
- const hasDeprecated = useMemo(() => {
- return !currentProvider || !currentModel
- }, [currentModel, currentProvider])
- const modelDisabled = useMemo(() => {
- return currentModel?.status !== ModelStatusEnum.active
- }, [currentModel?.status])
- const disabled = useMemo(() => {
- return !isAPIKeySet || hasDeprecated || modelDisabled
- }, [hasDeprecated, isAPIKeySet, modelDisabled])
- const handleChangeModel = ({ provider, model }: DefaultModel) => {
- const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
- const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
- const model_type = targetModelItem?.model_type as string
- setModel({
- provider,
- model,
- model_type,
- ...(model_type === ModelTypeEnum.textGeneration ? {
- mode: targetModelItem?.model_properties.mode as string,
- completion_params: {},
- } : {}),
- })
- }
- const handleLLMParamsChange = (newParams: FormValue) => {
- const newValue = {
- ...(value?.completionParams || {}),
- completion_params: newParams,
- }
- setModel({
- ...value,
- ...newValue,
- })
- }
- const handleTTSParamsChange = (language: string, voice: string) => {
- setModel({
- ...value,
- language,
- voice,
- })
- }
- return (
- <PortalToFollowElem
- open={open}
- onOpenChange={setOpen}
- placement={isInWorkflow ? 'left' : 'bottom-end'}
- offset={4}
- >
- <div className='relative'>
- <PortalToFollowElemTrigger
- onClick={() => {
- if (readonly)
- return
- setOpen(v => !v)
- }}
- className='block'
- >
- {
- renderTrigger
- ? renderTrigger({
- open,
- disabled,
- modelDisabled,
- hasDeprecated,
- currentProvider,
- currentModel,
- providerName: value?.provider,
- modelId: value?.model,
- })
- : (isAgentStrategy
- ? <AgentModelTrigger
- disabled={disabled}
- hasDeprecated={hasDeprecated}
- currentProvider={currentProvider}
- currentModel={currentModel}
- providerName={value?.provider}
- modelId={value?.model}
- scope={scope}
- />
- : <Trigger
- disabled={disabled}
- isInWorkflow={isInWorkflow}
- modelDisabled={modelDisabled}
- hasDeprecated={hasDeprecated}
- currentProvider={currentProvider}
- currentModel={currentModel}
- providerName={value?.provider}
- modelId={value?.model}
- />
- )
- }
- </PortalToFollowElemTrigger>
- <PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
- <div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
- <div className={cn('max-h-[420px] p-4 pt-3 overflow-y-auto')}>
- <div className='relative'>
- <div className={cn('mb-1 h-6 flex items-center text-text-secondary system-sm-semibold')}>
- {t('common.modelProvider.model').toLocaleUpperCase()}
- </div>
- <ModelSelector
- defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
- modelList={scopedModelList}
- scopeFeatures={scopeFeatures}
- onSelect={handleChangeModel}
- />
- </div>
- {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
- <div className='my-3 h-[1px] bg-divider-subtle' />
- )}
- {currentModel?.model_type === ModelTypeEnum.textGeneration && (
- <LLMParamsPanel
- provider={value?.provider}
- modelId={value?.model}
- completionParams={value?.completion_params || {}}
- onCompletionParamsChange={handleLLMParamsChange}
- isAdvancedMode={isAdvancedMode}
- />
- )}
- {currentModel?.model_type === ModelTypeEnum.tts && (
- <TTSParamsPanel
- currentModel={currentModel}
- language={value?.language}
- voice={value?.voice}
- onChange={handleTTSParamsChange}
- />
- )}
- </div>
- </div>
- </PortalToFollowElemContent>
- </div>
- </PortalToFollowElem>
- )
- }
- export default ModelParameterModal
|