model_declaration.go 27 KB


  1. package plugin_entities
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/go-playground/locales/en"
  6. ut "github.com/go-playground/universal-translator"
  7. "github.com/go-playground/validator/v10"
  8. en_translations "github.com/go-playground/validator/v10/translations/en"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/pkg/validators"
  11. "github.com/shopspring/decimal"
  12. "gopkg.in/yaml.v3"
  13. )
  14. type ModelType string
  15. const (
  16. MODEL_TYPE_LLM ModelType = "llm"
  17. MODEL_TYPE_TEXT_EMBEDDING ModelType = "text-embedding"
  18. MODEL_TYPE_RERANKING ModelType = "rerank"
  19. MODEL_TYPE_SPEECH2TEXT ModelType = "speech2text"
  20. MODEL_TYPE_MODERATION ModelType = "moderation"
  21. MODEL_TYPE_TTS ModelType = "tts"
  22. MODEL_TYPE_TEXT2IMG ModelType = "text2img"
  23. )
  24. func isModelType(fl validator.FieldLevel) bool {
  25. value := fl.Field().String()
  26. switch value {
  27. case string(MODEL_TYPE_LLM),
  28. string(MODEL_TYPE_TEXT_EMBEDDING),
  29. string(MODEL_TYPE_RERANKING),
  30. string(MODEL_TYPE_SPEECH2TEXT),
  31. string(MODEL_TYPE_MODERATION),
  32. string(MODEL_TYPE_TTS),
  33. string(MODEL_TYPE_TEXT2IMG):
  34. return true
  35. }
  36. return false
  37. }
  38. type ModelProviderConfigurateMethod string
  39. const (
  40. CONFIGURATE_METHOD_PREDEFINED_MODEL ModelProviderConfigurateMethod = "predefined-model"
  41. CONFIGURATE_METHOD_CUSTOMIZABLE_MODEL ModelProviderConfigurateMethod = "customizable-model"
  42. )
  43. func isModelProviderConfigurateMethod(fl validator.FieldLevel) bool {
  44. value := fl.Field().String()
  45. switch value {
  46. case string(CONFIGURATE_METHOD_PREDEFINED_MODEL),
  47. string(CONFIGURATE_METHOD_CUSTOMIZABLE_MODEL):
  48. return true
  49. }
  50. return false
  51. }
  52. type ModelParameterType string
  53. const (
  54. PARAMETER_TYPE_FLOAT ModelParameterType = "float"
  55. PARAMETER_TYPE_INT ModelParameterType = "int"
  56. PARAMETER_TYPE_STRING ModelParameterType = "string"
  57. PARAMETER_TYPE_BOOLEAN ModelParameterType = "boolean"
  58. PARAMETER_TYPE_TEXT ModelParameterType = "text"
  59. )
  60. func isModelParameterType(fl validator.FieldLevel) bool {
  61. value := fl.Field().String()
  62. switch value {
  63. case string(PARAMETER_TYPE_FLOAT),
  64. string(PARAMETER_TYPE_INT),
  65. string(PARAMETER_TYPE_STRING),
  66. string(PARAMETER_TYPE_BOOLEAN),
  67. string(PARAMETER_TYPE_TEXT):
  68. return true
  69. }
  70. return false
  71. }
  72. type ModelParameterRule struct {
  73. Name string `json:"name" yaml:"name" validate:"required,lt=256"`
  74. UseTemplate *string `json:"use_template" yaml:"use_template" validate:"omitempty,lt=256"`
  75. Label *I18nObject `json:"label" yaml:"label" validate:"omitempty"`
  76. Type *ModelParameterType `json:"type" yaml:"type" validate:"omitempty,model_parameter_type"`
  77. Help *I18nObject `json:"help" yaml:"help" validate:"omitempty"`
  78. Required bool `json:"required" yaml:"required"`
  79. Default *any `json:"default" yaml:"default" validate:"omitempty,is_basic_type"`
  80. Min *float64 `json:"min" yaml:"min" validate:"omitempty"`
  81. Max *float64 `json:"max" yaml:"max" validate:"omitempty"`
  82. Precision *int `json:"precision" yaml:"precision" validate:"omitempty"`
  83. Options []string `json:"options" yaml:"options" validate:"omitempty,dive,lt=256"`
  84. }
  85. type DefaultParameterName string
  86. const (
  87. TEMPERATURE DefaultParameterName = "temperature"
  88. TOP_P DefaultParameterName = "top_p"
  89. TOP_K DefaultParameterName = "top_k"
  90. PRESENCE_PENALTY DefaultParameterName = "presence_penalty"
  91. FREQUENCY_PENALTY DefaultParameterName = "frequency_penalty"
  92. MAX_TOKENS DefaultParameterName = "max_tokens"
  93. RESPONSE_FORMAT DefaultParameterName = "response_format"
  94. JSON_SCHEMA DefaultParameterName = "json_schema"
  95. )
  96. var PARAMETER_RULE_TEMPLATE = map[DefaultParameterName]ModelParameterRule{
  97. TEMPERATURE: {
  98. Label: &I18nObject{
  99. EnUS: "Temperature",
  100. ZhHans: "温度",
  101. JaJp: "温度",
  102. PtBr: "Temperatura",
  103. },
  104. Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
  105. Help: &I18nObject{
  106. EnUS: "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
  107. ZhHans: "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
  108. JaJp: "温度はランダム性を制御します。温度が低いほどランダムな完成が少なくなります。温度がゼロに近づくと、モデルは決定論的で繰り返しになります。温度が高いほどランダムな完成が多くなります。",
  109. PtBr: "A temperatura controla a aleatoriedade. Menores temperaturas resultam em menos conclusões aleatórias. À medida que a temperatura se aproxima de zero, o modelo se tornará determinístico e repetitivo. Temperaturas mais altas resultam em mais conclusões aleatórias.",
  110. },
  111. Required: false,
  112. Default: parser.ToPtr(any(0.0)),
  113. Min: parser.ToPtr(0.0),
  114. Max: parser.ToPtr(1.0),
  115. Precision: parser.ToPtr(2),
  116. },
  117. TOP_P: {
  118. Label: &I18nObject{
  119. EnUS: "Top P",
  120. ZhHans: "Top P",
  121. JaJp: "Top P",
  122. PtBr: "Top P",
  123. },
  124. Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
  125. Help: &I18nObject{
  126. EnUS: "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
  127. ZhHans: "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
  128. JaJp: "核サンプリングを通じて多様性を制御します:0.5は、すべての可能性加权オプションの半分を考慮します。",
  129. PtBr: "Controla a diversidade via amostragem de núcleo: 0.5 significa que metade das opções com maior probabilidade são consideradas.",
  130. },
  131. Required: false,
  132. Default: parser.ToPtr(any(1.0)),
  133. Min: parser.ToPtr(0.0),
  134. Max: parser.ToPtr(1.0),
  135. Precision: parser.ToPtr(2),
  136. },
  137. TOP_K: {
  138. Label: &I18nObject{
  139. EnUS: "Top K",
  140. ZhHans: "Top K",
  141. },
  142. Type: parser.ToPtr(PARAMETER_TYPE_INT),
  143. Help: &I18nObject{
  144. EnUS: "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
  145. ZhHans: "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
  146. },
  147. Required: false,
  148. Default: parser.ToPtr(any(50)),
  149. Min: parser.ToPtr(1.0),
  150. Max: parser.ToPtr(100.0),
  151. Precision: parser.ToPtr(0),
  152. },
  153. PRESENCE_PENALTY: {
  154. Label: &I18nObject{
  155. EnUS: "Presence Penalty",
  156. ZhHans: "存在惩罚",
  157. JaJp: "存在ペナルティ",
  158. PtBr: "Penalidade de presença",
  159. },
  160. Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
  161. Help: &I18nObject{
  162. EnUS: "Applies a penalty to the log-probability of tokens already in the text.",
  163. ZhHans: "对文本中已有的标记的对数概率施加惩罚。",
  164. JaJp: "テキストに既に存在するトークンの対数確率にペナルティを適用します。",
  165. PtBr: "Aplica uma penalidade à probabilidade logarítmica de tokens já presentes no texto.",
  166. },
  167. Required: false,
  168. Default: parser.ToPtr(any(0.0)),
  169. Min: parser.ToPtr(0.0),
  170. Max: parser.ToPtr(1.0),
  171. Precision: parser.ToPtr(2),
  172. },
  173. FREQUENCY_PENALTY: {
  174. Label: &I18nObject{
  175. EnUS: "Frequency Penalty",
  176. ZhHans: "频率惩罚",
  177. JaJp: "頻度ペナルティ",
  178. PtBr: "Penalidade de frequência",
  179. },
  180. Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
  181. Help: &I18nObject{
  182. EnUS: "Applies a penalty to the log-probability of tokens that appear in the text.",
  183. ZhHans: "对文本中出现的标记的对数概率施加惩罚。",
  184. JaJp: "テキストに出現するトークンの対数確率にペナルティを適用します。",
  185. PtBr: "Aplica uma penalidade à probabilidade logarítmica de tokens que aparecem no texto.",
  186. },
  187. Required: false,
  188. Default: parser.ToPtr(any(0.0)),
  189. Min: parser.ToPtr(0.0),
  190. Max: parser.ToPtr(1.0),
  191. Precision: parser.ToPtr(2),
  192. },
  193. MAX_TOKENS: {
  194. Label: &I18nObject{
  195. EnUS: "Max Tokens",
  196. ZhHans: "最大标记",
  197. JaJp: "最大トークン",
  198. PtBr: "Máximo de tokens",
  199. },
  200. Type: parser.ToPtr(PARAMETER_TYPE_INT),
  201. Help: &I18nObject{
  202. EnUS: "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
  203. ZhHans: "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
  204. JaJp: "生成結果の長さの上限を指定します。生成結果が切り捨てられた場合は、このパラメータを大きくすることができます。",
  205. PtBr: "Especifica o limite superior para o comprimento dos resultados gerados. Se os resultados gerados forem truncados, você pode aumentar este parâmetro.",
  206. },
  207. Required: false,
  208. Default: parser.ToPtr(any(64)),
  209. Min: parser.ToPtr(1.0),
  210. Max: parser.ToPtr(2048.0),
  211. Precision: parser.ToPtr(0),
  212. },
  213. RESPONSE_FORMAT: {
  214. Label: &I18nObject{
  215. EnUS: "Response Format",
  216. ZhHans: "回复格式",
  217. JaJp: "応答形式",
  218. PtBr: "Formato de resposta",
  219. },
  220. Type: parser.ToPtr(PARAMETER_TYPE_STRING),
  221. Help: &I18nObject{
  222. EnUS: "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
  223. ZhHans: "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
  224. JaJp: "応答形式を設定します。llmの出力が可能な限り有効なコードブロックであることを確認します。",
  225. PtBr: "Defina um formato de resposta para garantir que a saída do llm seja um bloco de código válido o mais possível, como JSON, XML, etc.",
  226. },
  227. Required: false,
  228. Options: []string{"JSON", "XML"},
  229. },
  230. JSON_SCHEMA: {
  231. Label: &I18nObject{
  232. EnUS: "JSON Schema",
  233. },
  234. Type: parser.ToPtr(PARAMETER_TYPE_STRING),
  235. Help: &I18nObject{
  236. EnUS: "Set a response json schema will ensure LLM to adhere it.",
  237. ZhHans: "设置返回的json schema,llm将按照它返回",
  238. },
  239. Required: false,
  240. },
  241. }
  242. func (m *ModelParameterRule) TransformTemplate() error {
  243. if m.Label == nil || m.Label.EnUS == "" {
  244. m.Label = &I18nObject{
  245. EnUS: m.Name,
  246. }
  247. }
  248. // if use_template is not empty, transform to use default value
  249. if m.UseTemplate != nil && *m.UseTemplate != "" {
  250. // get the value of use_template
  251. useTemplateValue := m.UseTemplate
  252. // get the template
  253. template, ok := PARAMETER_RULE_TEMPLATE[DefaultParameterName(*useTemplateValue)]
  254. if !ok {
  255. return fmt.Errorf("use_template %s not found", *useTemplateValue)
  256. }
  257. // transform to default value
  258. if m.Label == nil {
  259. m.Label = template.Label
  260. }
  261. if m.Type == nil {
  262. m.Type = template.Type
  263. }
  264. if m.Help == nil {
  265. m.Help = template.Help
  266. }
  267. if m.Default == nil {
  268. m.Default = template.Default
  269. }
  270. if m.Min == nil {
  271. m.Min = template.Min
  272. }
  273. if m.Max == nil {
  274. m.Max = template.Max
  275. }
  276. if m.Precision == nil {
  277. m.Precision = template.Precision
  278. }
  279. if m.Options == nil {
  280. m.Options = template.Options
  281. }
  282. }
  283. if m.Options == nil {
  284. m.Options = []string{}
  285. }
  286. return nil
  287. }
  288. func (m *ModelParameterRule) UnmarshalJSON(data []byte) error {
  289. type alias ModelParameterRule
  290. temp := &struct {
  291. *alias
  292. }{
  293. alias: (*alias)(m),
  294. }
  295. if err := json.Unmarshal(data, &temp); err != nil {
  296. return err
  297. }
  298. if err := m.TransformTemplate(); err != nil {
  299. return err
  300. }
  301. return nil
  302. }
  303. func (m *ModelParameterRule) UnmarshalYAML(value *yaml.Node) error {
  304. type alias ModelParameterRule
  305. temp := &struct {
  306. *alias `yaml:",inline"`
  307. }{
  308. alias: (*alias)(m),
  309. }
  310. if err := value.Decode(&temp); err != nil {
  311. return err
  312. }
  313. if err := m.TransformTemplate(); err != nil {
  314. return err
  315. }
  316. return nil
  317. }
  318. func isParameterRule(fl validator.FieldLevel) bool {
  319. // if use_template is empty, then label, type should be required
  320. // try get the value of use_template
  321. useTemplateHandle := fl.Field().FieldByName("UseTemplate")
  322. // check if use_template is null pointer
  323. if useTemplateHandle.IsNil() {
  324. // label and type should be required
  325. // try get the value of label
  326. if fl.Field().FieldByName("Label").IsNil() {
  327. return false
  328. }
  329. // try get the value of type
  330. if fl.Field().FieldByName("Type").IsNil() {
  331. return false
  332. }
  333. }
  334. return true
  335. }
  336. type ModelPriceConfig struct {
  337. Input decimal.Decimal `json:"input" yaml:"input" validate:"required"`
  338. Output *decimal.Decimal `json:"output" yaml:"output" validate:"omitempty"`
  339. Unit decimal.Decimal `json:"unit" yaml:"unit" validate:"required"`
  340. Currency string `json:"currency" yaml:"currency" validate:"required"`
  341. }
  342. type ModelDeclaration struct {
  343. Model string `json:"model" yaml:"model" validate:"required,lt=256"`
  344. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  345. ModelType ModelType `json:"model_type" yaml:"model_type" validate:"required,model_type"`
  346. Features []string `json:"features" yaml:"features" validate:"omitempty,lte=256,dive,lt=256"`
  347. FetchFrom ModelProviderConfigurateMethod `json:"fetch_from" yaml:"fetch_from" validate:"omitempty,model_provider_configurate_method"`
  348. ModelProperties map[string]any `json:"model_properties" yaml:"model_properties" validate:"omitempty"`
  349. Deprecated bool `json:"deprecated" yaml:"deprecated"`
  350. ParameterRules []ModelParameterRule `json:"parameter_rules" yaml:"parameter_rules" validate:"omitempty,lte=128,dive,parameter_rule"`
  351. PriceConfig *ModelPriceConfig `json:"pricing" yaml:"pricing" validate:"omitempty"`
  352. }
  353. func (m *ModelDeclaration) UnmarshalJSON(data []byte) error {
  354. type alias ModelDeclaration
  355. temp := &struct {
  356. *alias
  357. }{
  358. alias: (*alias)(m),
  359. }
  360. if err := json.Unmarshal(data, &temp); err != nil {
  361. return err
  362. }
  363. if m.FetchFrom == "" {
  364. m.FetchFrom = CONFIGURATE_METHOD_PREDEFINED_MODEL
  365. }
  366. if m.ParameterRules == nil {
  367. m.ParameterRules = []ModelParameterRule{}
  368. }
  369. return nil
  370. }
  371. func (m *ModelDeclaration) MarshalJSON() ([]byte, error) {
  372. type alias ModelDeclaration
  373. temp := &struct {
  374. *alias `json:",inline"`
  375. }{
  376. alias: (*alias)(m),
  377. }
  378. if temp.Label.EnUS == "" {
  379. temp.Label.EnUS = temp.Model
  380. }
  381. return json.Marshal(temp)
  382. }
  383. func (m *ModelDeclaration) UnmarshalYAML(value *yaml.Node) error {
  384. type alias ModelDeclaration
  385. temp := &struct {
  386. *alias `yaml:",inline"`
  387. }{
  388. alias: (*alias)(m),
  389. }
  390. if err := value.Decode(&temp); err != nil {
  391. return err
  392. }
  393. if m.FetchFrom == "" {
  394. m.FetchFrom = CONFIGURATE_METHOD_PREDEFINED_MODEL
  395. }
  396. if m.ParameterRules == nil {
  397. m.ParameterRules = []ModelParameterRule{}
  398. }
  399. return nil
  400. }
  401. type ModelProviderFormType string
  402. const (
  403. FORM_TYPE_TEXT_INPUT ModelProviderFormType = "text-input"
  404. FORM_TYPE_SECRET_INPUT ModelProviderFormType = "secret-input"
  405. FORM_TYPE_SELECT ModelProviderFormType = "select"
  406. FORM_TYPE_RADIO ModelProviderFormType = "radio"
  407. FORM_TYPE_SWITCH ModelProviderFormType = "switch"
  408. )
  409. func isModelProviderFormType(fl validator.FieldLevel) bool {
  410. value := fl.Field().String()
  411. switch value {
  412. case string(FORM_TYPE_TEXT_INPUT),
  413. string(FORM_TYPE_SECRET_INPUT),
  414. string(FORM_TYPE_SELECT),
  415. string(FORM_TYPE_RADIO),
  416. string(FORM_TYPE_SWITCH):
  417. return true
  418. }
  419. return false
  420. }
  421. type ModelProviderFormShowOnObject struct {
  422. Variable string `json:"variable" yaml:"variable" validate:"required,lt=256"`
  423. Value string `json:"value" yaml:"value" validate:"required,lt=256"`
  424. }
  425. type ModelProviderFormOption struct {
  426. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  427. Value string `json:"value" yaml:"value" validate:"required,lt=256"`
  428. ShowOn []ModelProviderFormShowOnObject `json:"show_on" yaml:"show_on" validate:"omitempty,lte=16,dive"`
  429. }
  430. func (m *ModelProviderFormOption) UnmarshalJSON(data []byte) error {
  431. // avoid show_on to be nil
  432. type Alias ModelProviderFormOption
  433. aux := &struct {
  434. *Alias
  435. }{
  436. Alias: (*Alias)(m),
  437. }
  438. if err := json.Unmarshal(data, aux); err != nil {
  439. return err
  440. }
  441. if m.ShowOn == nil {
  442. m.ShowOn = []ModelProviderFormShowOnObject{}
  443. }
  444. return nil
  445. }
  446. func (m *ModelProviderFormOption) UnmarshalYAML(value *yaml.Node) error {
  447. // avoid show_on to be nil
  448. type Alias ModelProviderFormOption
  449. aux := &struct {
  450. *Alias `yaml:",inline"`
  451. }{
  452. Alias: (*Alias)(m),
  453. }
  454. if err := value.Decode(&aux); err != nil {
  455. return err
  456. }
  457. if m.ShowOn == nil {
  458. m.ShowOn = []ModelProviderFormShowOnObject{}
  459. }
  460. return nil
  461. }
  462. type ModelProviderCredentialFormSchema struct {
  463. Variable string `json:"variable" yaml:"variable" validate:"required,lt=256"`
  464. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  465. Type ModelProviderFormType `json:"type" yaml:"type" validate:"required,model_provider_form_type"`
  466. Required bool `json:"required" yaml:"required"`
  467. Default *string `json:"default" yaml:"default" validate:"omitempty,lt=256"`
  468. Options []ModelProviderFormOption `json:"options" yaml:"options" validate:"omitempty,lte=128,dive"`
  469. Placeholder *I18nObject `json:"placeholder" yaml:"placeholder" validate:"omitempty"`
  470. MaxLength int `json:"max_length" yaml:"max_length"`
  471. ShowOn []ModelProviderFormShowOnObject `json:"show_on" yaml:"show_on" validate:"omitempty,lte=16,dive"`
  472. }
  473. func (m *ModelProviderCredentialFormSchema) UnmarshalJSON(data []byte) error {
  474. type Alias ModelProviderCredentialFormSchema
  475. temp := &struct {
  476. *Alias
  477. }{
  478. Alias: (*Alias)(m),
  479. }
  480. if err := json.Unmarshal(data, &temp); err != nil {
  481. return err
  482. }
  483. if m.ShowOn == nil {
  484. m.ShowOn = []ModelProviderFormShowOnObject{}
  485. }
  486. if m.Options == nil {
  487. m.Options = []ModelProviderFormOption{}
  488. }
  489. return nil
  490. }
  491. func (m *ModelProviderCredentialFormSchema) UnmarshalYAML(value *yaml.Node) error {
  492. type Alias ModelProviderCredentialFormSchema
  493. temp := &struct {
  494. *Alias `yaml:",inline"`
  495. }{
  496. Alias: (*Alias)(m),
  497. }
  498. if err := value.Decode(&temp); err != nil {
  499. return err
  500. }
  501. if m.ShowOn == nil {
  502. m.ShowOn = []ModelProviderFormShowOnObject{}
  503. }
  504. if m.Options == nil {
  505. m.Options = []ModelProviderFormOption{}
  506. }
  507. return nil
  508. }
  509. type ModelProviderCredentialSchema struct {
  510. CredentialFormSchemas []ModelProviderCredentialFormSchema `json:"credential_form_schemas" yaml:"credential_form_schemas" validate:"omitempty,lte=32,dive"`
  511. }
  512. type FieldModelSchema struct {
  513. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  514. Placeholder *I18nObject `json:"placeholder" yaml:"placeholder" validate:"omitempty"`
  515. }
  516. type ModelCredentialSchema struct {
  517. Model FieldModelSchema `json:"model" yaml:"model" validate:"required"`
  518. CredentialFormSchemas []ModelProviderCredentialFormSchema `json:"credential_form_schemas" yaml:"credential_form_schemas" validate:"omitempty,lte=32,dive"`
  519. }
  520. type ModelProviderHelpEntity struct {
  521. Title I18nObject `json:"title" yaml:"title" validate:"required"`
  522. URL I18nObject `json:"url" yaml:"url" validate:"required"`
  523. }
  524. type ModelPosition struct {
  525. LLM *[]string `json:"llm,omitempty" yaml:"llm,omitempty"`
  526. TextEmbedding *[]string `json:"text_embedding,omitempty" yaml:"text_embedding,omitempty"`
  527. Rerank *[]string `json:"rerank,omitempty" yaml:"rerank,omitempty"`
  528. TTS *[]string `json:"tts,omitempty" yaml:"tts,omitempty"`
  529. Speech2text *[]string `json:"speech2text,omitempty" yaml:"speech2text,omitempty"`
  530. Moderation *[]string `json:"moderation,omitempty" yaml:"moderation,omitempty"`
  531. }
  532. type ModelProviderDeclaration struct {
  533. Provider string `json:"provider" yaml:"provider" validate:"required,lt=256"`
  534. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  535. Description *I18nObject `json:"description" yaml:"description,omitempty" validate:"omitempty"`
  536. IconSmall *I18nObject `json:"icon_small" yaml:"icon_small,omitempty" validate:"omitempty"`
  537. IconLarge *I18nObject `json:"icon_large" yaml:"icon_large,omitempty" validate:"omitempty"`
  538. Background *string `json:"background" yaml:"background,omitempty" validate:"omitempty"`
  539. Help *ModelProviderHelpEntity `json:"help" yaml:"help,omitempty" validate:"omitempty"`
  540. SupportedModelTypes []ModelType `json:"supported_model_types" yaml:"supported_model_types" validate:"required,lte=16,dive,model_type"`
  541. ConfigurateMethods []ModelProviderConfigurateMethod `json:"configurate_methods" yaml:"configurate_methods" validate:"required,lte=16,dive,model_provider_configurate_method"`
  542. ProviderCredentialSchema *ModelProviderCredentialSchema `json:"provider_credential_schema" yaml:"provider_credential_schema,omitempty" validate:"omitempty"`
  543. ModelCredentialSchema *ModelCredentialSchema `json:"model_credential_schema" yaml:"model_credential_schema,omitempty" validate:"omitempty"`
  544. Position *ModelPosition `json:"position,omitempty" yaml:"position,omitempty"`
  545. Models []ModelDeclaration `json:"models" yaml:"model_declarations,omitempty"`
  546. ModelFiles []string `json:"-" yaml:"-"`
  547. PositionFiles map[string]string `json:"-" yaml:"-"`
  548. }
  549. func (m *ModelProviderDeclaration) UnmarshalJSON(data []byte) error {
  550. type alias ModelProviderDeclaration
  551. var temp struct {
  552. alias
  553. Models json.RawMessage `json:"models"`
  554. }
  555. if err := json.Unmarshal(data, &temp); err != nil {
  556. return err
  557. }
  558. *m = ModelProviderDeclaration(temp.alias)
  559. if m.ModelCredentialSchema != nil && m.ModelCredentialSchema.CredentialFormSchemas == nil {
  560. m.ModelCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
  561. }
  562. if m.ProviderCredentialSchema != nil && m.ProviderCredentialSchema.CredentialFormSchemas == nil {
  563. m.ProviderCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
  564. }
  565. // unmarshal models into map[string]any
  566. var models map[string]any
  567. if err := json.Unmarshal(temp.Models, &models); err != nil {
  568. // can not unmarshal it into map, so it's a list
  569. if err := json.Unmarshal(temp.Models, &m.Models); err != nil {
  570. return err
  571. }
  572. return nil
  573. }
  574. m.PositionFiles = make(map[string]string)
  575. types := []string{
  576. "llm",
  577. "text_embedding",
  578. "tts",
  579. "speech2text",
  580. "moderation",
  581. "rerank",
  582. }
  583. for _, model_type := range types {
  584. modelTypeMap, ok := models[model_type].(map[string]any)
  585. if ok {
  586. modelTypePositionFile, ok := modelTypeMap["position"]
  587. if ok {
  588. modelTypePositionFilePath, ok := modelTypePositionFile.(string)
  589. if ok {
  590. m.PositionFiles[model_type] = modelTypePositionFilePath
  591. }
  592. }
  593. modelTypePredefinedFiles, ok := modelTypeMap["predefined"].([]string)
  594. if ok {
  595. m.ModelFiles = append(m.ModelFiles, modelTypePredefinedFiles...)
  596. }
  597. }
  598. }
  599. if m.Models == nil {
  600. m.Models = []ModelDeclaration{}
  601. }
  602. return nil
  603. }
  604. func (m *ModelProviderDeclaration) MarshalJSON() ([]byte, error) {
  605. type alias ModelProviderDeclaration
  606. temp := &struct {
  607. *alias `json:",inline"`
  608. }{
  609. alias: (*alias)(m),
  610. }
  611. if temp.Models == nil {
  612. temp.Models = []ModelDeclaration{}
  613. }
  614. return json.Marshal(temp)
  615. }
  616. func (m *ModelProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
  617. type alias ModelProviderDeclaration
  618. var temp struct {
  619. alias `yaml:",inline"`
  620. Models yaml.Node `yaml:"models"`
  621. }
  622. if err := value.Decode(&temp); err != nil {
  623. return err
  624. }
  625. *m = ModelProviderDeclaration(temp.alias)
  626. if m.ModelCredentialSchema != nil && m.ModelCredentialSchema.CredentialFormSchemas == nil {
  627. m.ModelCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
  628. }
  629. if m.ProviderCredentialSchema != nil && m.ProviderCredentialSchema.CredentialFormSchemas == nil {
  630. m.ProviderCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
  631. }
  632. // Check if Models is a mapping node
  633. if temp.Models.Kind == yaml.MappingNode {
  634. m.PositionFiles = make(map[string]string)
  635. types := []string{
  636. "llm",
  637. "text_embedding",
  638. "tts",
  639. "speech2text",
  640. "moderation",
  641. "rerank",
  642. }
  643. for i := 0; i < len(temp.Models.Content); i += 2 {
  644. key := temp.Models.Content[i].Value
  645. value := temp.Models.Content[i+1]
  646. for _, model_type := range types {
  647. if key == model_type {
  648. if value.Kind == yaml.MappingNode {
  649. for j := 0; j < len(value.Content); j += 2 {
  650. if value.Content[j].Value == "position" {
  651. m.PositionFiles[model_type] = value.Content[j+1].Value
  652. } else if value.Content[j].Value == "predefined" {
  653. // get content of predefined
  654. if value.Content[j+1].Kind == yaml.SequenceNode {
  655. for _, file := range value.Content[j+1].Content {
  656. m.ModelFiles = append(m.ModelFiles, file.Value)
  657. }
  658. }
  659. }
  660. }
  661. }
  662. }
  663. }
  664. }
  665. } else if temp.Models.Kind == yaml.SequenceNode {
  666. if err := temp.Models.Decode(&m.Models); err != nil {
  667. return err
  668. }
  669. }
  670. if m.Models == nil {
  671. m.Models = []ModelDeclaration{}
  672. }
  673. return nil
  674. }
  675. func init() {
  676. // init validator
  677. en := en.New()
  678. uni := ut.New(en, en)
  679. translator, _ := uni.GetTranslator("en")
  680. // register translations for default validators
  681. en_translations.RegisterDefaultTranslations(validators.GlobalEntitiesValidator, translator)
  682. validators.GlobalEntitiesValidator.RegisterValidation("model_type", isModelType)
  683. validators.GlobalEntitiesValidator.RegisterTranslation(
  684. "model_type",
  685. translator,
  686. func(ut ut.Translator) error {
  687. return ut.Add("model_type", "{0} is not a valid model type", true)
  688. },
  689. func(ut ut.Translator, fe validator.FieldError) string {
  690. t, _ := ut.T("model_type", fe.Field())
  691. return t
  692. },
  693. )
  694. validators.GlobalEntitiesValidator.RegisterValidation("model_provider_configurate_method", isModelProviderConfigurateMethod)
  695. validators.GlobalEntitiesValidator.RegisterTranslation(
  696. "model_provider_configurate_method",
  697. translator,
  698. func(ut ut.Translator) error {
  699. return ut.Add("model_provider_configurate_method", "{0} is not a valid model provider configurate method", true)
  700. },
  701. func(ut ut.Translator, fe validator.FieldError) string {
  702. t, _ := ut.T("model_provider_configurate_method", fe.Field())
  703. return t
  704. },
  705. )
  706. validators.GlobalEntitiesValidator.RegisterValidation("model_provider_form_type", isModelProviderFormType)
  707. validators.GlobalEntitiesValidator.RegisterTranslation(
  708. "model_provider_form_type",
  709. translator,
  710. func(ut ut.Translator) error {
  711. return ut.Add("model_provider_form_type", "{0} is not a valid model provider form type", true)
  712. },
  713. func(ut ut.Translator, fe validator.FieldError) string {
  714. t, _ := ut.T("model_provider_form_type", fe.Field())
  715. return t
  716. },
  717. )
  718. validators.GlobalEntitiesValidator.RegisterValidation("model_parameter_type", isModelParameterType)
  719. validators.GlobalEntitiesValidator.RegisterTranslation(
  720. "model_parameter_type",
  721. translator,
  722. func(ut ut.Translator) error {
  723. return ut.Add("model_parameter_type", "{0} is not a valid model parameter type", true)
  724. },
  725. func(ut ut.Translator, fe validator.FieldError) string {
  726. t, _ := ut.T("model_parameter_type", fe.Field())
  727. return t
  728. },
  729. )
  730. validators.GlobalEntitiesValidator.RegisterValidation("parameter_rule", isParameterRule)
  731. validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
  732. }