import type { FC } from 'react' import { Fragment, useState } from 'react' import { Popover, Transition } from '@headlessui/react' import { useTranslation } from 'react-i18next' import _ from 'lodash-es' import cn from 'classnames' import type { BackendModel, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations' import { ModelType } from '@/app/components/header/account-setting/model-page/declarations' import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows' import { Check, SearchLg } from '@/app/components/base/icons/src/vender/line/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import Tooltip from '@/app/components/base/tooltip' import ModelIcon from '@/app/components/app/configuration/config-model/model-icon' import ModelName, { supportI18nModelName } from '@/app/components/app/configuration/config-model/model-name' import ProviderName from '@/app/components/app/configuration/config-model/provider-name' import { useProviderContext } from '@/context/provider-context' type Props = { value: { providerName: ProviderEnum modelName: string } | undefined modelType: ModelType supportAgentThought?: boolean onChange: (value: BackendModel) => void popClassName?: string readonly?: boolean triggerIconSmall?: boolean } type ModelOption = { type: 'model' value: string providerName: ProviderEnum modelDisplayName: string } | { type: 'provider' value: ProviderEnum } const ModelSelector: FC = ({ value, modelType, supportAgentThought, onChange, popClassName, readonly, triggerIconSmall, }) => { const { t } = useTranslation() const { textGenerationModelList, embeddingsModelList, speech2textModelList, agentThoughtModelList } = useProviderContext() const [search, setSearch] = useState('') const modelList = supportAgentThought ? agentThoughtModelList : ({ [ModelType.textGeneration]: textGenerationModelList, [ModelType.embeddings]: embeddingsModelList, [ModelType.speech2text]: speech2textModelList, })[modelType] const currModel = modelList.find(item => item.model_name === value?.modelName && item.model_provider.provider_name === value.providerName) const allModelNames = (() => { if (!search) return {} const res: Record = {} modelList.forEach(({ model_name }) => { res[model_name] = supportI18nModelName.includes(model_name) ? t(`common.modelName.${model_name}`) : model_name }) return res })() const filteredModelList = search ? modelList.filter(({ model_name }) => { if (allModelNames[model_name].includes(search)) return true return false }) : modelList const hasRemoved = value && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName) const modelOptions: ModelOption[] = (() => { const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name)) const res: ModelOption[] = [] providers.forEach((providerName) => { res.push({ type: 'provider', value: providerName, }) const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName) models.forEach(({ model_name, model_display_name }) => { res.push({ type: 'model', providerName, value: model_name, modelDisplayName: model_display_name, }) }) }) return res })() return (
{ ({ open }) => ( <> { value ? ( <>
) : (
{t('common.modelProvider.selectModel')}
) } { hasRemoved && ( {t('common.modelProvider.selector.tip')}
} > ) } {!readonly && } ) } {!readonly && (
setSearch(e.target.value)} className={` block w-full h-8 bg-transparent text-[13px] text-gray-700 outline-none appearance-none border-none `} placeholder={t('common.modelProvider.searchModel') || ''} />
{ search && (
setSearch('')}>
) }
{ modelOptions.map((model) => { if (model.type === 'provider') { return (
) } if (model.type === 'model') { return ( { const selectedModel = modelList.find((item) => { return item.model_name === model.value && item.model_provider.provider_name === model.providerName }) onChange(selectedModel as BackendModel) }} >
{ (value?.providerName === model.providerName && value?.modelName === model.value) && }
) } return null }) } {(search && filteredModelList.length === 0) && (
{t('common.modelProvider.noModelFound', { model: search })}
)}
)} ) } export default ModelSelector