model_declaration_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package plugin_entities
  2. import (
  3. "encoding/json"
  4. "testing"
  5. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  6. "gopkg.in/yaml.v3"
  7. )
  8. func parse_yaml_to_json(data []byte) ([]byte, error) {
  9. var obj interface{}
  10. err := yaml.Unmarshal(data, &obj)
  11. if err != nil {
  12. return nil, err
  13. }
  14. jsonData, err := json.Marshal(obj)
  15. if err != nil {
  16. return nil, err
  17. }
  18. return jsonData, nil
  19. }
  20. func TestFullFunctionModelProvider_Validate(t *testing.T) {
  21. const (
  22. model_provider_template = `
  23. provider: openai
  24. label:
  25. en_US: OpenAI
  26. description:
  27. en_US: Models provided by OpenAI, such as GPT-3.5-Turbo and GPT-4.
  28. zh_Hans: OpenAI 提供的模型,例如 GPT-3.5-Turbo 和 GPT-4。
  29. icon_small:
  30. en_US: icon_s_en.svg
  31. icon_large:
  32. en_US: icon_l_en.svg
  33. background: "#E5E7EB"
  34. help:
  35. title:
  36. en_US: Get your API Key from OpenAI
  37. zh_Hans: 从 OpenAI 获取 API Key
  38. url:
  39. en_US: https://platform.openai.com/account/api-keys
  40. supported_model_types:
  41. - llm
  42. - text-embedding
  43. - speech2text
  44. - moderation
  45. - tts
  46. configurate_methods:
  47. - predefined-model
  48. - customizable-model
  49. model_credential_schema:
  50. model:
  51. label:
  52. en_US: Model Name
  53. zh_Hans: 模型名称
  54. placeholder:
  55. en_US: Enter your model name
  56. zh_Hans: 输入模型名称
  57. credential_form_schemas:
  58. - variable: openai_api_key
  59. label:
  60. en_US: API Key
  61. type: secret-input
  62. required: true
  63. placeholder:
  64. zh_Hans: 在此输入您的 API Key
  65. en_US: Enter your API Key
  66. - variable: openai_organization
  67. label:
  68. zh_Hans: 组织 ID
  69. en_US: Organization
  70. type: text-input
  71. required: false
  72. placeholder:
  73. zh_Hans: 在此输入您的组织 ID
  74. en_US: Enter your Organization ID
  75. - variable: openai_api_base
  76. label:
  77. zh_Hans: API Base
  78. en_US: API Base
  79. type: text-input
  80. required: false
  81. placeholder:
  82. zh_Hans: 在此输入您的 API Base
  83. en_US: Enter your API Base
  84. provider_credential_schema:
  85. credential_form_schemas:
  86. - variable: openai_api_key
  87. label:
  88. en_US: API Key
  89. type: secret-input
  90. required: true
  91. placeholder:
  92. zh_Hans: 在此输入您的 API Key
  93. en_US: Enter your API Key
  94. - variable: openai_organization
  95. label:
  96. zh_Hans: 组织 ID
  97. en_US: Organization
  98. type: text-input
  99. required: false
  100. placeholder:
  101. zh_Hans: 在此输入您的组织 ID
  102. en_US: Enter your Organization ID
  103. - variable: openai_api_base
  104. label:
  105. zh_Hans: API Base
  106. en_US: API Base
  107. type: text-input
  108. required: false
  109. placeholder:
  110. zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com
  111. en_US: Enter your API Base, e.g. https://api.openai.com
  112. `
  113. )
  114. jsonData, err := parse_yaml_to_json([]byte(model_provider_template))
  115. if err != nil {
  116. t.Error(err)
  117. }
  118. _, err = parser.UnmarshalYamlBytes[ModelProviderDeclaration](jsonData)
  119. if err != nil {
  120. t.Errorf("UnmarshalModelProviderConfiguration() error = %v", err)
  121. }
  122. }
  123. func TestModelParameterRule_UseTemplateYAML(t *testing.T) {
  124. const (
  125. model_parameter_rule_template = `
  126. name: temperature
  127. use_template: temperature
  128. `
  129. )
  130. yamlData := []byte(model_parameter_rule_template)
  131. model, err := parser.UnmarshalYamlBytes[ModelParameterRule](yamlData)
  132. if err != nil {
  133. t.Errorf("UnmarshalModelParameterRule() error = %v", err)
  134. return
  135. }
  136. if model.Type == nil {
  137. t.Errorf("UnmarshalModelParameterRule() error = %v", err)
  138. return
  139. }
  140. if *model.Type != PARAMETER_TYPE_FLOAT {
  141. t.Errorf("UnmarshalModelParameterRule() error = %v", err)
  142. }
  143. if model.Min == nil || model.Max == nil || model.Precision == nil {
  144. t.Errorf("Missing default value")
  145. }
  146. }
  147. func TestModelParameterRule_UseTemplateJSON(t *testing.T) {
  148. const (
  149. model_parameter_rule_template = `{"name": "temperature", "use_template": "temperature"}`
  150. )
  151. jsonData := []byte(model_parameter_rule_template)
  152. model, err := parser.UnmarshalJsonBytes[ModelParameterRule](jsonData)
  153. if err != nil {
  154. t.Errorf("UnmarshalModelParameterRule() error = %v", err)
  155. }
  156. if model.Type == nil {
  157. t.Errorf("UnmarshalModelParameterRule() error = %v", err)
  158. }
  159. if *model.Type != PARAMETER_TYPE_FLOAT {
  160. t.Errorf("UnmarshalModelParameterRule() error = %v", err)
  161. }
  162. if model.Min == nil || model.Max == nil || model.Precision == nil {
  163. t.Errorf("Missing default value")
  164. }
  165. }