llm.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package model_entities
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "github.com/go-playground/validator/v10"
  6. "github.com/langgenius/dify-plugin-daemon/pkg/validators"
  7. "github.com/shopspring/decimal"
  8. )
  9. type ModelType string
  10. const (
  11. MODEL_TYPE_LLM ModelType = "llm"
  12. MODEL_TYPE_TEXT_EMBEDDING ModelType = "text-embedding"
  13. MODEL_TYPE_RERANKING ModelType = "rerank"
  14. MODEL_TYPE_SPEECH2TEXT ModelType = "speech2text"
  15. MODEL_TYPE_TTS ModelType = "tts"
  16. MODEL_TYPE_MODERATION ModelType = "moderation"
  17. )
  18. type LLMModel string
  19. const (
  20. LLM_MODE_CHAT LLMModel = "chat"
  21. LLM_MODE_COMPLETION LLMModel = "completion"
  22. )
  23. type PromptMessageRole string
  24. const (
  25. PROMPT_MESSAGE_ROLE_SYSTEM = "system"
  26. PROMPT_MESSAGE_ROLE_USER = "user"
  27. PROMPT_MESSAGE_ROLE_ASSISTANT = "assistant"
  28. PROMPT_MESSAGE_ROLE_TOOL = "tool"
  29. )
  30. func isPromptMessageRole(fl validator.FieldLevel) bool {
  31. value := fl.Field().String()
  32. switch value {
  33. case string(PROMPT_MESSAGE_ROLE_SYSTEM),
  34. string(PROMPT_MESSAGE_ROLE_USER),
  35. string(PROMPT_MESSAGE_ROLE_ASSISTANT),
  36. string(PROMPT_MESSAGE_ROLE_TOOL):
  37. return true
  38. }
  39. return false
  40. }
  41. type PromptMessage struct {
  42. Role PromptMessageRole `json:"role" validate:"required,prompt_message_role"`
  43. Content any `json:"content" validate:"required,prompt_message_content"`
  44. Name string `json:"name"`
  45. ToolCalls []PromptMessageToolCall `json:"tool_calls" validate:"dive"`
  46. ToolCallId string `json:"tool_call_id"`
  47. }
  48. func isPromptMessageContent(fl validator.FieldLevel) bool {
  49. // only allow string or []PromptMessageContent
  50. value := fl.Field().Interface()
  51. switch valueType := value.(type) {
  52. case string:
  53. return true
  54. case []PromptMessageContent:
  55. // validate the content
  56. for _, content := range valueType {
  57. if err := validators.GlobalEntitiesValidator.Struct(content); err != nil {
  58. return false
  59. }
  60. }
  61. return true
  62. }
  63. return false
  64. }
  65. type PromptMessageContentType string
  66. const (
  67. PROMPT_MESSAGE_CONTENT_TYPE_TEXT PromptMessageContentType = "text"
  68. PROMPT_MESSAGE_CONTENT_TYPE_IMAGE PromptMessageContentType = "image"
  69. PROMPT_MESSAGE_CONTENT_TYPE_AUDIO PromptMessageContentType = "audio"
  70. PROMPT_MESSAGE_CONTENT_TYPE_VIDEO PromptMessageContentType = "video"
  71. PROMPT_MESSAGE_CONTENT_TYPE_DOCUMENT PromptMessageContentType = "document"
  72. )
  73. func isPromptMessageContentType(fl validator.FieldLevel) bool {
  74. value := fl.Field().String()
  75. switch value {
  76. case string(PROMPT_MESSAGE_CONTENT_TYPE_TEXT),
  77. string(PROMPT_MESSAGE_CONTENT_TYPE_IMAGE),
  78. string(PROMPT_MESSAGE_CONTENT_TYPE_AUDIO),
  79. string(PROMPT_MESSAGE_CONTENT_TYPE_VIDEO),
  80. string(PROMPT_MESSAGE_CONTENT_TYPE_DOCUMENT):
  81. return true
  82. }
  83. return false
  84. }
  85. type PromptMessageContent struct {
  86. Type PromptMessageContentType `json:"type" validate:"required,prompt_message_content_type"`
  87. Base64Data string `json:"base64_data"` // for multi-modal data
  88. Data string `json:"data"` // for text only
  89. EncodeFormat string `json:"encode_format"`
  90. Format string `json:"format"`
  91. MimeType string `json:"mime_type"`
  92. }
  93. type PromptMessageToolCall struct {
  94. ID string `json:"id"`
  95. Type string `json:"type"`
  96. Function struct {
  97. Name string `json:"name"`
  98. Arguments string `json:"arguments"`
  99. } `json:"function"`
  100. }
  101. func init() {
  102. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_role", isPromptMessageRole)
  103. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content", isPromptMessageContent)
  104. validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content_type", isPromptMessageContentType)
  105. }
  106. func (p *PromptMessage) UnmarshalJSON(data []byte) error {
  107. // Unmarshal the JSON data into a map
  108. var raw map[string]json.RawMessage
  109. if err := json.Unmarshal(data, &raw); err != nil {
  110. return err
  111. }
  112. // Check if content is a string or an array which contains type and content
  113. if _, ok := raw["content"]; ok {
  114. var content string
  115. if err := json.Unmarshal(raw["content"], &content); err == nil {
  116. p.Content = content
  117. } else {
  118. var content []PromptMessageContent
  119. if err := json.Unmarshal(raw["content"], &content); err != nil {
  120. return err
  121. }
  122. p.Content = content
  123. }
  124. } else {
  125. return errors.New("content field is required")
  126. }
  127. // Unmarshal the rest of the fields
  128. if role, ok := raw["role"]; ok {
  129. if err := json.Unmarshal(role, &p.Role); err != nil {
  130. return err
  131. }
  132. } else {
  133. return errors.New("role field is required")
  134. }
  135. if name, ok := raw["name"]; ok {
  136. if err := json.Unmarshal(name, &p.Name); err != nil {
  137. return err
  138. }
  139. }
  140. if toolCalls, ok := raw["tool_calls"]; ok {
  141. if err := json.Unmarshal(toolCalls, &p.ToolCalls); err != nil {
  142. return err
  143. }
  144. }
  145. if toolCallId, ok := raw["tool_call_id"]; ok {
  146. if err := json.Unmarshal(toolCallId, &p.ToolCallId); err != nil {
  147. return err
  148. }
  149. }
  150. return nil
  151. }
  152. type PromptMessageTool struct {
  153. Name string `json:"name" validate:"required"`
  154. Description string `json:"description" validate:"required"`
  155. Parameters map[string]any `json:"parameters"`
  156. }
  157. type LLMResultChunk struct {
  158. Model LLMModel `json:"model" validate:"required"`
  159. PromptMessages []PromptMessage `json:"prompt_messages" validate:"required,dive"`
  160. SystemFingerprint string `json:"system_fingerprint" validate:"omitempty"`
  161. Delta LLMResultChunkDelta `json:"delta" validate:"required"`
  162. }
  163. type LLMUsage struct {
  164. PromptTokens *int `json:"prompt_tokens" validate:"required"`
  165. PromptUnitPrice decimal.Decimal `json:"prompt_unit_price" validate:"required"`
  166. PromptPriceUnit decimal.Decimal `json:"prompt_price_unit" validate:"required"`
  167. PromptPrice decimal.Decimal `json:"prompt_price" validate:"required"`
  168. CompletionTokens *int `json:"completion_tokens" validate:"required"`
  169. CompletionUnitPrice decimal.Decimal `json:"completion_unit_price" validate:"required"`
  170. CompletionPriceUnit decimal.Decimal `json:"completion_price_unit" validate:"required"`
  171. CompletionPrice decimal.Decimal `json:"completion_price" validate:"required"`
  172. TotalTokens *int `json:"total_tokens" validate:"required"`
  173. TotalPrice decimal.Decimal `json:"total_price" validate:"required"`
  174. Currency *string `json:"currency" validate:"required"`
  175. Latency *float64 `json:"latency" validate:"required"`
  176. }
  177. type LLMResultChunkDelta struct {
  178. Index *int `json:"index" validate:"required"`
  179. Message PromptMessage `json:"message" validate:"required"`
  180. Usage *LLMUsage `json:"usage" validate:"omitempty"`
  181. FinishReason *string `json:"finish_reason" validate:"omitempty"`
  182. }
  183. type LLMGetNumTokensResponse struct {
  184. NumTokens int `json:"num_tokens"`
  185. }