model_service.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. func genericInvokePlugin[Req any, Rsp any](
  14. session *session_manager.Session,
  15. request *Req,
  16. response_buffer_size int,
  17. typ PluginAccessType,
  18. action PluginAccessAction,
  19. ) (
  20. *stream.StreamResponse[Rsp], error,
  21. ) {
  22. runtime := plugin_manager.Get(session.PluginIdentity())
  23. if runtime == nil {
  24. return nil, errors.New("plugin not found")
  25. }
  26. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  27. listener := runtime.Listen(session.ID())
  28. listener.AddListener(func(message []byte) {
  29. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  30. if err != nil {
  31. log.Error("unmarshal json failed: %s", err.Error())
  32. return
  33. }
  34. switch chunk.Type {
  35. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  36. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  37. if err != nil {
  38. log.Error("unmarshal json failed: %s", err.Error())
  39. return
  40. }
  41. response.Write(chunk)
  42. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  43. if err := invokeDify(runtime, typ, session, chunk.Data); err != nil {
  44. log.Error("invoke dify failed: %s", err.Error())
  45. return
  46. }
  47. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  48. response.Close()
  49. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  50. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  51. if err != nil {
  52. break
  53. }
  54. response.WriteError(errors.New(e.Error))
  55. response.Close()
  56. default:
  57. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  58. response.Close()
  59. }
  60. })
  61. response.OnClose(func() {
  62. listener.Close()
  63. })
  64. session.Write(
  65. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  66. getInvokeModelMap(
  67. session,
  68. typ,
  69. action,
  70. request,
  71. ),
  72. )
  73. return response, nil
  74. }
  75. func getInvokeModelMap(
  76. session *session_manager.Session,
  77. typ PluginAccessType,
  78. action PluginAccessAction,
  79. request any,
  80. ) map[string]any {
  81. req := getBasicPluginAccessMap(session.UserID(), typ, action)
  82. for k, v := range parser.StructToMap(request) {
  83. req[k] = v
  84. }
  85. return req
  86. }
  87. func InvokeLLM(
  88. session *session_manager.Session,
  89. request *requests.RequestInvokeLLM,
  90. ) (
  91. *stream.StreamResponse[model_entities.LLMResultChunk], error,
  92. ) {
  93. return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
  94. session,
  95. request,
  96. 512,
  97. PLUGIN_ACCESS_TYPE_MODEL,
  98. PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  99. )
  100. }
  101. func InvokeTextEmbedding(
  102. session *session_manager.Session,
  103. request *requests.RequestInvokeTextEmbedding,
  104. ) (
  105. *stream.StreamResponse[model_entities.TextEmbeddingResult], error,
  106. ) {
  107. return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
  108. session,
  109. request,
  110. 1,
  111. PLUGIN_ACCESS_TYPE_MODEL,
  112. PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  113. )
  114. }
  115. func InvokeRerank(
  116. session *session_manager.Session,
  117. request *requests.RequestInvokeRerank,
  118. ) (
  119. *stream.StreamResponse[model_entities.RerankResult], error,
  120. ) {
  121. return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
  122. session,
  123. request,
  124. 1,
  125. PLUGIN_ACCESS_TYPE_MODEL,
  126. PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  127. )
  128. }
  129. func InvokeTTS(
  130. session *session_manager.Session,
  131. request *requests.RequestInvokeTTS,
  132. ) (
  133. *stream.StreamResponse[model_entities.TTSResult], error,
  134. ) {
  135. return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult](
  136. session,
  137. request,
  138. 1,
  139. PLUGIN_ACCESS_TYPE_MODEL,
  140. PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  141. )
  142. }
  143. func InvokeSpeech2Text(
  144. session *session_manager.Session,
  145. request *requests.RequestInvokeSpeech2Text,
  146. ) (
  147. *stream.StreamResponse[model_entities.Speech2TextResult], error,
  148. ) {
  149. return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult](
  150. session,
  151. request,
  152. 1,
  153. PLUGIN_ACCESS_TYPE_MODEL,
  154. PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  155. )
  156. }
  157. func InvokeModeration(
  158. session *session_manager.Session,
  159. request *requests.RequestInvokeModeration,
  160. ) (
  161. *stream.StreamResponse[model_entities.ModerationResult], error,
  162. ) {
  163. return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult](
  164. session,
  165. request,
  166. 1,
  167. PLUGIN_ACCESS_TYPE_MODEL,
  168. PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  169. )
  170. }
  171. func ValidateProviderCredentials(
  172. session *session_manager.Session,
  173. request *requests.RequestValidateProviderCredentials,
  174. ) (
  175. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  176. ) {
  177. return genericInvokePlugin[requests.RequestValidateProviderCredentials, model_entities.ValidateCredentialsResult](
  178. session,
  179. request,
  180. 1,
  181. PLUGIN_ACCESS_TYPE_MODEL,
  182. PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  183. )
  184. }
  185. func ValidateModelCredentials(
  186. session *session_manager.Session,
  187. request *requests.RequestValidateModelCredentials,
  188. ) (
  189. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  190. ) {
  191. return genericInvokePlugin[requests.RequestValidateModelCredentials, model_entities.ValidateCredentialsResult](
  192. session,
  193. request,
  194. 1,
  195. PLUGIN_ACCESS_TYPE_MODEL,
  196. PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  197. )
  198. }