llm.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package model_entities
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "github.com/go-playground/validator/v10"
  6. "github.com/langgenius/dify-plugin-daemon/pkg/validators"
  7. "github.com/shopspring/decimal"
  8. )
  9. type ModelType string
  10. const (
  11. MODEL_TYPE_LLM ModelType = "llm"
  12. MODEL_TYPE_TEXT_EMBEDDING ModelType = "text-embedding"
  13. MODEL_TYPE_RERANKING ModelType = "rerank"
  14. MODEL_TYPE_SPEECH2TEXT ModelType = "speech2text"
  15. MODEL_TYPE_TTS ModelType = "tts"
  16. MODEL_TYPE_MODERATION ModelType = "moderation"
  17. )
  18. type LLMModel string
  19. const (
  20. LLM_MODE_CHAT LLMModel = "chat"
  21. LLM_MODE_COMPLETION LLMModel = "completion"
  22. )
  23. type PromptMessageRole string
  24. const (
  25. PROMPT_MESSAGE_ROLE_SYSTEM = "system"
  26. PROMPT_MESSAGE_ROLE_USER = "user"
  27. PROMPT_MESSAGE_ROLE_ASSISTANT = "assistant"
  28. PROMPT_MESSAGE_ROLE_TOOL = "tool"
  29. )
  30. func isPromptMessageRole(fl validator.FieldLevel) bool {
  31. value := fl.Field().String()
  32. switch value {
  33. case string(PROMPT_MESSAGE_ROLE_SYSTEM),
  34. string(PROMPT_MESSAGE_ROLE_USER),
  35. string(PROMPT_MESSAGE_ROLE_ASSISTANT),
  36. string(PROMPT_MESSAGE_ROLE_TOOL):
  37. return true
  38. }
  39. return false
  40. }
  41. type PromptMessage struct {
  42. Role PromptMessageRole `json:"role" validate:"required,prompt_message_role"`
  43. Content any `json:"content" validate:"required,prompt_message_content"`
  44. Name string `json:"name"`
  45. ToolCalls []PromptMessageToolCall `json:"tool_calls" validate:"dive"`
  46. ToolCallId string `json:"tool_call_id"`
  47. }
  48. func isPromptMessageContent(fl validator.FieldLevel) bool {
  49. // only allow string or []PromptMessageContent
  50. value := fl.Field().Interface()
  51. switch valueType := value.(type) {
  52. case string:
  53. return true
  54. case []PromptMessageContent:
  55. // validate the content
  56. for _, content := range valueType {
  57. if err := validators.GlobalEntitiesValidator.Struct(content); err != nil {
  58. return false
  59. }
  60. }
  61. return true
  62. }
  63. return false
  64. }
  65. type PromptMessageContentType string
  66. const (
  67. PROMPT_MESSAGE_CONTENT_TYPE_TEXT PromptMessageContentType = "text"
  68. PROMPT_MESSAGE_CONTENT_TYPE_IMAGE PromptMessageContentType = "image"
  69. PROMPT_MESSAGE_CONTENT_TYPE_AUDIO PromptMessageContentType = "audio"
  70. PROMPT_MESSAGE_CONTENT_TYPE_VIDEO PromptMessageContentType = "video"
  71. PROMPT_MESSAGE_CONTENT_TYPE_DOCUMENT PromptMessageContentType = "document"
  72. )
  73. func isPromptMessageContentType(fl validator.FieldLevel) bool {
  74. value := fl.Field().String()
  75. switch value {
  76. case string(PROMPT_MESSAGE_CONTENT_TYPE_TEXT),
  77. string(PROMPT_MESSAGE_CONTENT_TYPE_IMAGE),
  78. string(PROMPT_MESSAGE_CONTENT_TYPE_AUDIO),
  79. string(PROMPT_MESSAGE_CONTENT_TYPE_VIDEO),
  80. string(PROMPT_MESSAGE_CONTENT_TYPE_DOCUMENT):
  81. return true
  82. }
  83. return false
  84. }
  85. type PromptMessageContent struct {
  86. Type PromptMessageContentType `json:"type" validate:"required,prompt_message_content_type"`
  87. Base64Data string `json:"base64_data"` // for multi-modal data
  88. URL string `json:"url"` // for multi-modal data
  89. Data string `json:"data"` // for text only
  90. EncodeFormat string `json:"encode_format"`
  91. Format string `json:"format"`
  92. MimeType string `json:"mime_type"`
  93. }
  94. type PromptMessageToolCall struct {
  95. ID string `json:"id"`
  96. Type string `json:"type"`
  97. Function struct {
  98. Name string `json:"name"`
  99. Arguments string `json:"arguments"`
  100. } `json:"function"`
  101. }
  102. func init() {
  103. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_role", isPromptMessageRole)
  104. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content", isPromptMessageContent)
  105. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content_type", isPromptMessageContentType)
  106. }
  107. func (p *PromptMessage) UnmarshalJSON(data []byte) error {
  108. // Unmarshal the JSON data into a map
  109. var raw map[string]json.RawMessage
  110. if err := json.Unmarshal(data, &raw); err != nil {
  111. return err
  112. }
  113. // Check if content is a string or an array which contains type and content
  114. if _, ok := raw["content"]; ok {
  115. var content string
  116. if err := json.Unmarshal(raw["content"], &content); err == nil {
  117. p.Content = content
  118. } else {
  119. var content []PromptMessageContent
  120. if err := json.Unmarshal(raw["content"], &content); err != nil {
  121. return err
  122. }
  123. p.Content = content
  124. }
  125. } else {
  126. return errors.New("content field is required")
  127. }
  128. // Unmarshal the rest of the fields
  129. if role, ok := raw["role"]; ok {
  130. if err := json.Unmarshal(role, &p.Role); err != nil {
  131. return err
  132. }
  133. } else {
  134. return errors.New("role field is required")
  135. }
  136. if name, ok := raw["name"]; ok {
  137. if err := json.Unmarshal(name, &p.Name); err != nil {
  138. return err
  139. }
  140. }
  141. if toolCalls, ok := raw["tool_calls"]; ok {
  142. if err := json.Unmarshal(toolCalls, &p.ToolCalls); err != nil {
  143. return err
  144. }
  145. }
  146. if toolCallId, ok := raw["tool_call_id"]; ok {
  147. if err := json.Unmarshal(toolCallId, &p.ToolCallId); err != nil {
  148. return err
  149. }
  150. }
  151. return nil
  152. }
  153. type PromptMessageTool struct {
  154. Name string `json:"name" validate:"required"`
  155. Description string `json:"description"`
  156. Parameters map[string]any `json:"parameters"`
  157. }
  158. type LLMResultChunk struct {
  159. Model LLMModel `json:"model" validate:"required"`
  160. PromptMessages []PromptMessage `json:"prompt_messages" validate:"required,dive"`
  161. SystemFingerprint string `json:"system_fingerprint" validate:"omitempty"`
  162. Delta LLMResultChunkDelta `json:"delta" validate:"required"`
  163. }
  164. type LLMUsage struct {
  165. PromptTokens *int `json:"prompt_tokens" validate:"required"`
  166. PromptUnitPrice decimal.Decimal `json:"prompt_unit_price" validate:"required"`
  167. PromptPriceUnit decimal.Decimal `json:"prompt_price_unit" validate:"required"`
  168. PromptPrice decimal.Decimal `json:"prompt_price" validate:"required"`
  169. CompletionTokens *int `json:"completion_tokens" validate:"required"`
  170. CompletionUnitPrice decimal.Decimal `json:"completion_unit_price" validate:"required"`
  171. CompletionPriceUnit decimal.Decimal `json:"completion_price_unit" validate:"required"`
  172. CompletionPrice decimal.Decimal `json:"completion_price" validate:"required"`
  173. TotalTokens *int `json:"total_tokens" validate:"required"`
  174. TotalPrice decimal.Decimal `json:"total_price" validate:"required"`
  175. Currency *string `json:"currency" validate:"required"`
  176. Latency *float64 `json:"latency" validate:"required"`
  177. }
  178. type LLMResultChunkDelta struct {
  179. Index *int `json:"index" validate:"required"`
  180. Message PromptMessage `json:"message" validate:"required"`
  181. Usage *LLMUsage `json:"usage" validate:"omitempty"`
  182. FinishReason *string `json:"finish_reason" validate:"omitempty"`
  183. }
  184. type LLMGetNumTokensResponse struct {
  185. NumTokens int `json:"num_tokens"`
  186. }