agent-model-trigger.tsx 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import type { FC } from 'react'
  2. import { useMemo, useState } from 'react'
  3. import { useTranslation } from 'react-i18next'
  4. import type {
  5. ModelItem,
  6. ModelProvider,
  7. } from '../declarations'
  8. import {
  9. CustomConfigurationStatusEnum,
  10. ModelTypeEnum,
  11. } from '../declarations'
  12. import { useInvalidateInstalledPluginList } from '@/service/use-plugins'
  13. import ConfigurationButton from './configuration-button'
  14. import Loading from '@/app/components/base/loading'
  15. import {
  16. useModelModalHandler,
  17. useUpdateModelList,
  18. useUpdateModelProviders,
  19. } from '../hooks'
  20. import ModelIcon from '../model-icon'
  21. import ModelDisplay from './model-display'
  22. import { InstallPluginButton } from '@/app/components/workflow/nodes/_base/components/install-plugin-button'
  23. import StatusIndicators from './status-indicators'
  24. import cn from '@/utils/classnames'
  25. import { useProviderContext } from '@/context/provider-context'
  26. import { RiEqualizer2Line } from '@remixicon/react'
  27. import { useModelInList, usePluginInfo } from '@/service/use-plugins'
  28. export type AgentModelTriggerProps = {
  29. open?: boolean
  30. disabled?: boolean
  31. currentProvider?: ModelProvider
  32. currentModel?: ModelItem
  33. providerName?: string
  34. modelId?: string
  35. hasDeprecated?: boolean
  36. scope?: string
  37. }
  38. const AgentModelTrigger: FC<AgentModelTriggerProps> = ({
  39. disabled,
  40. currentProvider,
  41. currentModel,
  42. providerName,
  43. modelId,
  44. hasDeprecated,
  45. scope,
  46. }) => {
  47. const { t } = useTranslation()
  48. const { modelProviders } = useProviderContext()
  49. const updateModelProviders = useUpdateModelProviders()
  50. const updateModelList = useUpdateModelList()
  51. const { modelProvider, needsConfiguration } = useMemo(() => {
  52. const modelProvider = modelProviders.find(item => item.provider === providerName)
  53. const needsConfiguration = modelProvider?.custom_configuration.status === CustomConfigurationStatusEnum.noConfigure && !(
  54. modelProvider.system_configuration.enabled === true
  55. && modelProvider.system_configuration.quota_configurations.find(
  56. item => item.quota_type === modelProvider.system_configuration.current_quota_type,
  57. )
  58. )
  59. return {
  60. modelProvider,
  61. needsConfiguration,
  62. }
  63. }, [modelProviders, providerName])
  64. const [installed, setInstalled] = useState(false)
  65. const invalidateInstalledPluginList = useInvalidateInstalledPluginList()
  66. const handleOpenModal = useModelModalHandler()
  67. const { data: inModelList = false } = useModelInList(currentProvider, modelId)
  68. const { data: pluginInfo, isLoading: isPluginLoading } = usePluginInfo(providerName)
  69. if (modelId && isPluginLoading)
  70. return <Loading />
  71. return (
  72. <div
  73. className={cn(
  74. 'relative group flex items-center p-1 gap-[2px] flex-grow rounded-lg bg-components-input-bg-normal cursor-pointer hover:bg-state-base-hover-alt',
  75. )}
  76. >
  77. {modelId ? (
  78. <>
  79. <ModelIcon
  80. className='p-0.5'
  81. provider={currentProvider || modelProvider}
  82. modelName={currentModel?.model || modelId}
  83. isDeprecated={hasDeprecated}
  84. />
  85. <ModelDisplay
  86. currentModel={currentModel}
  87. modelId={modelId}
  88. />
  89. {needsConfiguration && (
  90. <ConfigurationButton
  91. modelProvider={modelProvider}
  92. handleOpenModal={handleOpenModal}
  93. />
  94. )}
  95. <StatusIndicators
  96. needsConfiguration={needsConfiguration}
  97. modelProvider={!!modelProvider}
  98. inModelList={inModelList}
  99. disabled={!!disabled}
  100. pluginInfo={pluginInfo}
  101. t={t}
  102. />
  103. {!installed && !modelProvider && pluginInfo && (
  104. <InstallPluginButton
  105. onClick={e => e.stopPropagation()}
  106. size={'small'}
  107. uniqueIdentifier={pluginInfo.latest_package_identifier}
  108. onSuccess={() => {
  109. [
  110. ModelTypeEnum.textGeneration,
  111. ModelTypeEnum.textEmbedding,
  112. ModelTypeEnum.rerank,
  113. ModelTypeEnum.moderation,
  114. ModelTypeEnum.speech2text,
  115. ModelTypeEnum.tts,
  116. ].forEach((type: ModelTypeEnum) => {
  117. if (scope?.includes(type))
  118. updateModelList(type)
  119. },
  120. )
  121. updateModelProviders()
  122. invalidateInstalledPluginList()
  123. setInstalled(true)
  124. }}
  125. />
  126. )}
  127. {modelProvider && !disabled && !needsConfiguration && (
  128. <div className="flex pr-1 items-center">
  129. <RiEqualizer2Line className="w-4 h-4 text-text-tertiary group-hover:text-text-secondary" />
  130. </div>
  131. )}
  132. </>
  133. ) : (
  134. <>
  135. <div className="flex p-1 pl-2 items-center gap-1 grow">
  136. <span className="overflow-hidden text-ellipsis whitespace-nowrap system-sm-regular text-components-input-text-placeholder">
  137. {t('workflow.nodes.agent.configureModel')}
  138. </span>
  139. </div>
  140. <div className="flex pr-1 items-center">
  141. <RiEqualizer2Line className="w-4 h-4 text-text-tertiary group-hover:text-text-secondary" />
  142. </div>
  143. </>
  144. )}
  145. </div>
  146. )
  147. }
  148. export default AgentModelTrigger