index.tsx 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import type { FC } from 'react'
  2. import { Fragment, useState } from 'react'
  3. import { Popover, Transition } from '@headlessui/react'
  4. import { useTranslation } from 'react-i18next'
  5. import _ from 'lodash-es'
  6. import cn from 'classnames'
  7. import type { BackendModel, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
  8. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  9. import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
  10. import { Check, SearchLg } from '@/app/components/base/icons/src/vender/line/general'
  11. import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
  12. import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
  13. import Tooltip from '@/app/components/base/tooltip'
  14. import ModelIcon from '@/app/components/app/configuration/config-model/model-icon'
  15. import ModelName, { supportI18nModelName } from '@/app/components/app/configuration/config-model/model-name'
  16. import ProviderName from '@/app/components/app/configuration/config-model/provider-name'
  17. import { useProviderContext } from '@/context/provider-context'
  18. type Props = {
  19. value: {
  20. providerName: ProviderEnum
  21. modelName: string
  22. } | undefined
  23. modelType: ModelType
  24. supportAgentThought?: boolean
  25. onChange: (value: BackendModel) => void
  26. popClassName?: string
  27. readonly?: boolean
  28. triggerIconSmall?: boolean
  29. }
  30. type ModelOption = {
  31. type: 'model'
  32. value: string
  33. providerName: ProviderEnum
  34. modelDisplayName: string
  35. } | {
  36. type: 'provider'
  37. value: ProviderEnum
  38. }
  39. const ModelSelector: FC<Props> = ({
  40. value,
  41. modelType,
  42. supportAgentThought,
  43. onChange,
  44. popClassName,
  45. readonly,
  46. triggerIconSmall,
  47. }) => {
  48. const { t } = useTranslation()
  49. const { textGenerationModelList, embeddingsModelList, speech2textModelList, agentThoughtModelList } = useProviderContext()
  50. const [search, setSearch] = useState('')
  51. const modelList = supportAgentThought
  52. ? agentThoughtModelList
  53. : ({
  54. [ModelType.textGeneration]: textGenerationModelList,
  55. [ModelType.embeddings]: embeddingsModelList,
  56. [ModelType.speech2text]: speech2textModelList,
  57. })[modelType]
  58. const currModel = modelList.find(item => item.model_name === value?.modelName && item.model_provider.provider_name === value.providerName)
  59. const allModelNames = (() => {
  60. if (!search)
  61. return {}
  62. const res: Record<string, string> = {}
  63. modelList.forEach(({ model_name }) => {
  64. res[model_name] = supportI18nModelName.includes(model_name) ? t(`common.modelName.${model_name}`) : model_name
  65. })
  66. return res
  67. })()
  68. const filteredModelList = search
  69. ? modelList.filter(({ model_name }) => {
  70. if (allModelNames[model_name].includes(search))
  71. return true
  72. return false
  73. })
  74. : modelList
  75. const hasRemoved = value && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName)
  76. const modelOptions: ModelOption[] = (() => {
  77. const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
  78. const res: ModelOption[] = []
  79. providers.forEach((providerName) => {
  80. res.push({
  81. type: 'provider',
  82. value: providerName,
  83. })
  84. const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName)
  85. models.forEach(({ model_name, model_display_name }) => {
  86. res.push({
  87. type: 'model',
  88. providerName,
  89. value: model_name,
  90. modelDisplayName: model_display_name,
  91. })
  92. })
  93. })
  94. return res
  95. })()
  96. return (
  97. <div className=''>
  98. <Popover className='relative'>
  99. <Popover.Button className={cn('flex items-center px-2.5 w-full h-9 rounded-lg', readonly ? '!cursor-auto' : 'bg-gray-100', hasRemoved && '!bg-[#FEF3F2]')}>
  100. {
  101. ({ open }) => (
  102. <>
  103. {
  104. value
  105. ? (
  106. <>
  107. <ModelIcon
  108. className={cn('mr-1.5', !triggerIconSmall && 'w-5 h-5')}
  109. modelId={value.modelName}
  110. providerName={value.providerName}
  111. />
  112. <div className='mr-1.5 grow text-left text-sm text-gray-900 truncate'><ModelName modelId={value.modelName} modelDisplayName={currModel?.model_display_name} /></div>
  113. </>
  114. )
  115. : (
  116. <div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div>
  117. )
  118. }
  119. {
  120. hasRemoved && (
  121. <Tooltip
  122. selector='model-selector-remove-tip'
  123. htmlContent={
  124. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.selector.tip')}</div>
  125. }
  126. >
  127. <AlertCircle className='mr-1 w-4 h-4 text-[#F04438]' />
  128. </Tooltip>
  129. )
  130. }
  131. {!readonly && <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />}
  132. </>
  133. )
  134. }
  135. </Popover.Button>
  136. {!readonly && (
  137. <Transition
  138. as={Fragment}
  139. leave='transition ease-in duration-100'
  140. leaveFrom='opacity-100'
  141. leaveTo='opacity-0'
  142. >
  143. <Popover.Panel className={cn(popClassName, 'absolute top-10 p-1 min-w-[232px] max-w-[260px] max-h-[366px] bg-white border-[0.5px] border-gray-200 rounded-lg shadow-lg overflow-auto z-10')}>
  144. <div className='px-2 pt-2 pb-1'>
  145. <div className='flex items-center px-2 h-8 bg-gray-100 rounded-lg'>
  146. <div className='mr-1.5 p-[1px]'><SearchLg className='w-[14px] h-[14px] text-gray-400' /></div>
  147. <div className='grow px-0.5'>
  148. <input
  149. value={search}
  150. onChange={e => setSearch(e.target.value)}
  151. className={`
  152. block w-full h-8 bg-transparent text-[13px] text-gray-700
  153. outline-none appearance-none border-none
  154. `}
  155. placeholder={t('common.modelProvider.searchModel') || ''}
  156. />
  157. </div>
  158. {
  159. search && (
  160. <div className='ml-1 p-0.5 cursor-pointer' onClick={() => setSearch('')}>
  161. <XCircle className='w-3 h-3 text-gray-400' />
  162. </div>
  163. )
  164. }
  165. </div>
  166. </div>
  167. {
  168. modelOptions.map((model) => {
  169. if (model.type === 'provider') {
  170. return (
  171. <div
  172. className='px-3 pt-2 pb-1 text-xs font-medium text-gray-500'
  173. key={`${model.type}-${model.value}`}
  174. >
  175. <ProviderName provideName={model.value} />
  176. </div>
  177. )
  178. }
  179. if (model.type === 'model') {
  180. return (
  181. <Popover.Button
  182. key={`${model.providerName}-${model.value}`}
  183. className={`
  184. flex items-center px-3 w-full h-8 rounded-lg hover:bg-gray-50
  185. ${!readonly ? 'cursor-pointer' : 'cursor-auto'}
  186. ${(value?.providerName === model.providerName && value?.modelName === model.value) && 'bg-gray-50'}
  187. `}
  188. onClick={() => {
  189. const selectedModel = modelList.find((item) => {
  190. return item.model_name === model.value && item.model_provider.provider_name === model.providerName
  191. })
  192. onChange(selectedModel as BackendModel)
  193. }}
  194. >
  195. <ModelIcon
  196. className='mr-2 shrink-0'
  197. modelId={model.value}
  198. providerName={model.providerName}
  199. />
  200. <div className='grow text-left text-sm text-gray-900 truncate'><ModelName modelId={model.value} modelDisplayName={model.modelDisplayName} /></div>
  201. { (value?.providerName === model.providerName && value?.modelName === model.value) && <Check className='shrink-0 w-4 h-4 text-primary-600' /> }
  202. </Popover.Button>
  203. )
  204. }
  205. return null
  206. })
  207. }
  208. {(search && filteredModelList.length === 0) && (
  209. <div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div>
  210. )}
  211. </Popover.Panel>
  212. </Transition>
  213. )}
  214. </Popover>
  215. </div>
  216. )
  217. }
  218. export default ModelSelector