provider-context.tsx 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. 'use client'
  2. import { createContext, useContext } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { useEffect, useState } from 'react'
  5. import { fetchDefaultModal, fetchModelList, fetchSupportRetrievalMethods } from '@/service/common'
  6. import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  7. import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
  8. import type { RETRIEVE_METHOD } from '@/types/app'
  9. import { Plan, type UsagePlanInfo } from '@/app/components/billing/type'
  10. import { fetchCurrentPlanInfo } from '@/service/billing'
  11. import { parseCurrentPlan } from '@/app/components/billing/utils'
  12. import { defaultPlan } from '@/app/components/billing/config'
  13. const ProviderContext = createContext<{
  14. textGenerationModelList: BackendModel[]
  15. embeddingsModelList: BackendModel[]
  16. speech2textModelList: BackendModel[]
  17. rerankModelList: BackendModel[]
  18. agentThoughtModelList: BackendModel[]
  19. updateModelList: (type: ModelType) => void
  20. textGenerationDefaultModel?: BackendModel
  21. mutateTextGenerationDefaultModel: () => void
  22. embeddingsDefaultModel?: BackendModel
  23. isEmbeddingsDefaultModelValid: boolean
  24. mutateEmbeddingsDefaultModel: () => void
  25. speech2textDefaultModel?: BackendModel
  26. mutateSpeech2textDefaultModel: () => void
  27. rerankDefaultModel?: BackendModel
  28. isRerankDefaultModelVaild: boolean
  29. mutateRerankDefaultModel: () => void
  30. supportRetrievalMethods: RETRIEVE_METHOD[]
  31. plan: {
  32. type: Plan
  33. usage: UsagePlanInfo
  34. total: UsagePlanInfo
  35. }
  36. isFetchedPlan: boolean
  37. enableBilling: boolean
  38. enableReplaceWebAppLogo: boolean
  39. }>({
  40. textGenerationModelList: [],
  41. embeddingsModelList: [],
  42. speech2textModelList: [],
  43. rerankModelList: [],
  44. agentThoughtModelList: [],
  45. updateModelList: () => { },
  46. textGenerationDefaultModel: undefined,
  47. mutateTextGenerationDefaultModel: () => { },
  48. speech2textDefaultModel: undefined,
  49. mutateSpeech2textDefaultModel: () => { },
  50. embeddingsDefaultModel: undefined,
  51. isEmbeddingsDefaultModelValid: false,
  52. mutateEmbeddingsDefaultModel: () => { },
  53. rerankDefaultModel: undefined,
  54. isRerankDefaultModelVaild: false,
  55. mutateRerankDefaultModel: () => { },
  56. supportRetrievalMethods: [],
  57. plan: {
  58. type: Plan.sandbox,
  59. usage: {
  60. vectorSpace: 32,
  61. buildApps: 12,
  62. teamMembers: 1,
  63. annotatedResponse: 1,
  64. },
  65. total: {
  66. vectorSpace: 200,
  67. buildApps: 50,
  68. teamMembers: 1,
  69. annotatedResponse: 10,
  70. },
  71. },
  72. isFetchedPlan: false,
  73. enableBilling: false,
  74. enableReplaceWebAppLogo: false,
  75. })
  76. export const useProviderContext = () => useContext(ProviderContext)
  77. type ProviderContextProviderProps = {
  78. children: React.ReactNode
  79. }
  80. export const ProviderContextProvider = ({
  81. children,
  82. }: ProviderContextProviderProps) => {
  83. const { data: textGenerationDefaultModel, mutate: mutateTextGenerationDefaultModel } = useSWR('/workspaces/current/default-model?model_type=text-generation', fetchDefaultModal)
  84. const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
  85. const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
  86. const { data: rerankDefaultModel, mutate: mutateRerankDefaultModel } = useSWR('/workspaces/current/default-model?model_type=reranking', fetchDefaultModal)
  87. const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
  88. const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
  89. const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
  90. const { data: speech2textModelList, mutate: mutateSpeech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
  91. const { data: rerankModelList, mutate: mutateRerankModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.reranking}`, fetchModelList)
  92. const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
  93. const agentThoughtModelList = textGenerationModelList?.filter((item) => {
  94. return item.features?.includes(ModelFeature.agentThought)
  95. })
  96. const isRerankDefaultModelVaild = !!rerankModelList?.find(
  97. item => item.model_name === rerankDefaultModel?.model_name && item.model_provider.provider_name === rerankDefaultModel?.model_provider.provider_name,
  98. )
  99. const isEmbeddingsDefaultModelValid = !!embeddingsModelList?.find(
  100. item => item.model_name === embeddingsDefaultModel?.model_name && item.model_provider.provider_name === embeddingsDefaultModel?.model_provider.provider_name,
  101. )
  102. const updateModelList = (type: ModelType) => {
  103. if (type === ModelType.textGeneration)
  104. mutateTextGenerationModelList()
  105. if (type === ModelType.embeddings)
  106. mutateEmbeddingsModelList()
  107. if (type === ModelType.speech2text)
  108. mutateSpeech2textModelList()
  109. if (type === ModelType.reranking)
  110. mutateRerankModelList()
  111. }
  112. const [plan, setPlan] = useState(defaultPlan)
  113. const [isFetchedPlan, setIsFetchedPlan] = useState(false)
  114. const [enableBilling, setEnableBilling] = useState(true)
  115. const [enableReplaceWebAppLogo, setEnableReplaceWebAppLogo] = useState(false)
  116. useEffect(() => {
  117. (async () => {
  118. const data = await fetchCurrentPlanInfo()
  119. const enabled = data.billing.enabled
  120. setEnableBilling(enabled)
  121. setEnableReplaceWebAppLogo(data.can_replace_logo)
  122. if (enabled) {
  123. setPlan(parseCurrentPlan(data))
  124. // setPlan(parseCurrentPlan({
  125. // ...data,
  126. // annotation_quota_limit: {
  127. // ...data.annotation_quota_limit,
  128. // limit: 10,
  129. // },
  130. // }))
  131. setIsFetchedPlan(true)
  132. }
  133. })()
  134. }, [])
  135. return (
  136. <ProviderContext.Provider value={{
  137. textGenerationModelList: textGenerationModelList || [],
  138. embeddingsModelList: embeddingsModelList || [],
  139. speech2textModelList: speech2textModelList || [],
  140. rerankModelList: rerankModelList || [],
  141. agentThoughtModelList: agentThoughtModelList || [],
  142. updateModelList,
  143. textGenerationDefaultModel,
  144. mutateTextGenerationDefaultModel,
  145. embeddingsDefaultModel,
  146. mutateEmbeddingsDefaultModel,
  147. speech2textDefaultModel,
  148. mutateSpeech2textDefaultModel,
  149. rerankDefaultModel,
  150. isRerankDefaultModelVaild,
  151. isEmbeddingsDefaultModelValid,
  152. mutateRerankDefaultModel,
  153. supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
  154. plan,
  155. isFetchedPlan,
  156. enableBilling,
  157. enableReplaceWebAppLogo,
  158. }}>
  159. {children}
  160. </ProviderContext.Provider>
  161. )
  162. }
  163. export default ProviderContext