llm.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package model_entities
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "github.com/go-playground/validator/v10"
  6. "github.com/langgenius/dify-plugin-daemon/pkg/validators"
  7. "github.com/shopspring/decimal"
  8. )
  9. type ModelType string
  10. const (
  11. MODEL_TYPE_LLM ModelType = "llm"
  12. MODEL_TYPE_TEXT_EMBEDDING ModelType = "text-embedding"
  13. MODEL_TYPE_RERANKING ModelType = "rerank"
  14. MODEL_TYPE_SPEECH2TEXT ModelType = "speech2text"
  15. MODEL_TYPE_TTS ModelType = "tts"
  16. MODEL_TYPE_MODERATION ModelType = "moderation"
  17. )
  18. type LLMModel string
  19. const (
  20. LLM_MODE_CHAT LLMModel = "chat"
  21. LLM_MODE_COMPLETION LLMModel = "completion"
  22. )
  23. type PromptMessageRole string
  24. const (
  25. PROMPT_MESSAGE_ROLE_SYSTEM = "system"
  26. PROMPT_MESSAGE_ROLE_USER = "user"
  27. PROMPT_MESSAGE_ROLE_ASSISTANT = "assistant"
  28. PROMPT_MESSAGE_ROLE_TOOL = "tool"
  29. )
  30. func isPromptMessageRole(fl validator.FieldLevel) bool {
  31. value := fl.Field().String()
  32. switch value {
  33. case string(PROMPT_MESSAGE_ROLE_SYSTEM),
  34. string(PROMPT_MESSAGE_ROLE_USER),
  35. string(PROMPT_MESSAGE_ROLE_ASSISTANT),
  36. string(PROMPT_MESSAGE_ROLE_TOOL):
  37. return true
  38. }
  39. return false
  40. }
  41. type PromptMessage struct {
  42. Role PromptMessageRole `json:"role" validate:"required,prompt_message_role"`
  43. Content any `json:"content" validate:"required,prompt_message_content"`
  44. Name string `json:"name"`
  45. ToolCalls []PromptMessageToolCall `json:"tool_calls" validate:"dive"`
  46. ToolCallId string `json:"tool_call_id"`
  47. }
  48. func isPromptMessageContent(fl validator.FieldLevel) bool {
  49. // only allow string or []PromptMessageContent
  50. value := fl.Field().Interface()
  51. switch valueType := value.(type) {
  52. case string:
  53. return true
  54. case []PromptMessageContent:
  55. // validate the content
  56. for _, content := range valueType {
  57. if err := validators.GlobalEntitiesValidator.Struct(content); err != nil {
  58. return false
  59. }
  60. }
  61. return true
  62. }
  63. return false
  64. }
  65. type PromptMessageContentType string
  66. const (
  67. PROMPT_MESSAGE_CONTENT_TYPE_TEXT PromptMessageContentType = "text"
  68. PROMPT_MESSAGE_CONTENT_TYPE_IMAGE PromptMessageContentType = "image"
  69. PROMPT_MESSAGE_CONTENT_TYPE_AUDIO PromptMessageContentType = "audio"
  70. PROMPT_MESSAGE_CONTENT_TYPE_VIDEO PromptMessageContentType = "video"
  71. PROMPT_MESSAGE_CONTENT_TYPE_DOCUMENT PromptMessageContentType = "document"
  72. )
  73. func isPromptMessageContentType(fl validator.FieldLevel) bool {
  74. value := fl.Field().String()
  75. switch value {
  76. case string(PROMPT_MESSAGE_CONTENT_TYPE_TEXT),
  77. string(PROMPT_MESSAGE_CONTENT_TYPE_IMAGE),
  78. string(PROMPT_MESSAGE_CONTENT_TYPE_AUDIO),
  79. string(PROMPT_MESSAGE_CONTENT_TYPE_VIDEO),
  80. string(PROMPT_MESSAGE_CONTENT_TYPE_DOCUMENT):
  81. return true
  82. }
  83. return false
  84. }
  85. type PromptMessageContent struct {
  86. Type PromptMessageContentType `json:"type" validate:"required,prompt_message_content_type"`
  87. Base64Data string `json:"base64_data"` // for multi-modal data
  88. URL string `json:"url"` // for multi-modal data
  89. Data string `json:"data"` // for text only
  90. EncodeFormat string `json:"encode_format"`
  91. Format string `json:"format"`
  92. MimeType string `json:"mime_type"`
  93. Detail string `json:"detail"` // for multi-modal data
  94. }
  95. type PromptMessageToolCall struct {
  96. ID string `json:"id"`
  97. Type string `json:"type"`
  98. Function struct {
  99. Name string `json:"name"`
  100. Arguments string `json:"arguments"`
  101. } `json:"function"`
  102. }
  103. func init() {
  104. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_role", isPromptMessageRole)
  105. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content", isPromptMessageContent)
  106. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content_type", isPromptMessageContentType)
  107. }
  108. func (p *PromptMessage) UnmarshalJSON(data []byte) error {
  109. // Unmarshal the JSON data into a map
  110. var raw map[string]json.RawMessage
  111. if err := json.Unmarshal(data, &raw); err != nil {
  112. return err
  113. }
  114. // Check if content is a string or an array which contains type and content
  115. if _, ok := raw["content"]; ok {
  116. var content string
  117. if err := json.Unmarshal(raw["content"], &content); err == nil {
  118. p.Content = content
  119. } else {
  120. var content []PromptMessageContent
  121. if err := json.Unmarshal(raw["content"], &content); err != nil {
  122. return err
  123. }
  124. p.Content = content
  125. }
  126. } else {
  127. return errors.New("content field is required")
  128. }
  129. // Unmarshal the rest of the fields
  130. if role, ok := raw["role"]; ok {
  131. if err := json.Unmarshal(role, &p.Role); err != nil {
  132. return err
  133. }
  134. } else {
  135. return errors.New("role field is required")
  136. }
  137. if name, ok := raw["name"]; ok {
  138. if err := json.Unmarshal(name, &p.Name); err != nil {
  139. return err
  140. }
  141. }
  142. if toolCalls, ok := raw["tool_calls"]; ok {
  143. if err := json.Unmarshal(toolCalls, &p.ToolCalls); err != nil {
  144. return err
  145. }
  146. }
  147. if toolCallId, ok := raw["tool_call_id"]; ok {
  148. if err := json.Unmarshal(toolCallId, &p.ToolCallId); err != nil {
  149. return err
  150. }
  151. }
  152. return nil
  153. }
  154. type PromptMessageTool struct {
  155. Name string `json:"name" validate:"required"`
  156. Description string `json:"description"`
  157. Parameters map[string]any `json:"parameters"`
  158. }
  159. type LLMResultChunk struct {
  160. Model LLMModel `json:"model" validate:"required"`
  161. PromptMessages []PromptMessage `json:"prompt_messages" validate:"required,dive"`
  162. SystemFingerprint string `json:"system_fingerprint" validate:"omitempty"`
  163. Delta LLMResultChunkDelta `json:"delta" validate:"required"`
  164. }
  165. type LLMUsage struct {
  166. PromptTokens *int `json:"prompt_tokens" validate:"required"`
  167. PromptUnitPrice decimal.Decimal `json:"prompt_unit_price" validate:"required"`
  168. PromptPriceUnit decimal.Decimal `json:"prompt_price_unit" validate:"required"`
  169. PromptPrice decimal.Decimal `json:"prompt_price" validate:"required"`
  170. CompletionTokens *int `json:"completion_tokens" validate:"required"`
  171. CompletionUnitPrice decimal.Decimal `json:"completion_unit_price" validate:"required"`
  172. CompletionPriceUnit decimal.Decimal `json:"completion_price_unit" validate:"required"`
  173. CompletionPrice decimal.Decimal `json:"completion_price" validate:"required"`
  174. TotalTokens *int `json:"total_tokens" validate:"required"`
  175. TotalPrice decimal.Decimal `json:"total_price" validate:"required"`
  176. Currency *string `json:"currency" validate:"required"`
  177. Latency *float64 `json:"latency" validate:"required"`
  178. }
  179. type LLMResultChunkDelta struct {
  180. Index *int `json:"index" validate:"required"`
  181. Message PromptMessage `json:"message" validate:"required"`
  182. Usage *LLMUsage `json:"usage" validate:"omitempty"`
  183. FinishReason *string `json:"finish_reason" validate:"omitempty"`
  184. }
  185. type LLMGetNumTokensResponse struct {
  186. NumTokens int `json:"num_tokens"`
  187. }