model_service.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  13. )
  14. func genericInvokePlugin[Req any, Rsp any](
  15. session *session_manager.Session,
  16. request *Req,
  17. response_buffer_size int,
  18. typ backwards_invocation.PluginAccessType,
  19. action backwards_invocation.PluginAccessAction,
  20. ) (*stream.StreamResponse[Rsp], error) {
  21. runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginIdentity())
  22. if runtime == nil {
  23. return nil, errors.New("plugin not found")
  24. }
  25. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  26. listener := runtime.Listen(session.ID())
  27. listener.AddListener(func(message []byte) {
  28. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  29. if err != nil {
  30. log.Error("unmarshal json failed: %s", err.Error())
  31. return
  32. }
  33. switch chunk.Type {
  34. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  35. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  36. if err != nil {
  37. log.Error("unmarshal json failed: %s", err.Error())
  38. return
  39. }
  40. response.Write(chunk)
  41. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  42. if err := backwards_invocation.InvokeDify(runtime, typ, session, chunk.Data); err != nil {
  43. log.Error("invoke dify failed: %s", err.Error())
  44. return
  45. }
  46. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  47. response.Close()
  48. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  49. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  50. if err != nil {
  51. break
  52. }
  53. response.WriteError(errors.New(e.Error))
  54. response.Close()
  55. default:
  56. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  57. response.Close()
  58. }
  59. })
  60. response.OnClose(func() {
  61. listener.Close()
  62. })
  63. session.Write(
  64. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  65. getInvokeModelMap(
  66. session,
  67. typ,
  68. action,
  69. request,
  70. ),
  71. )
  72. return response, nil
  73. }
  74. func getInvokeModelMap(
  75. session *session_manager.Session,
  76. typ backwards_invocation.PluginAccessType,
  77. action backwards_invocation.PluginAccessAction,
  78. request any,
  79. ) map[string]any {
  80. req := getBasicPluginAccessMap(session.UserID(), typ, action)
  81. for k, v := range parser.StructToMap(request) {
  82. req[k] = v
  83. }
  84. return req
  85. }
  86. func InvokeLLM(
  87. session *session_manager.Session,
  88. request *requests.RequestInvokeLLM,
  89. ) (
  90. *stream.StreamResponse[model_entities.LLMResultChunk], error,
  91. ) {
  92. return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
  93. session,
  94. request,
  95. 512,
  96. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  97. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  98. )
  99. }
  100. func InvokeTextEmbedding(
  101. session *session_manager.Session,
  102. request *requests.RequestInvokeTextEmbedding,
  103. ) (
  104. *stream.StreamResponse[model_entities.TextEmbeddingResult], error,
  105. ) {
  106. return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
  107. session,
  108. request,
  109. 1,
  110. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  111. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  112. )
  113. }
  114. func InvokeRerank(
  115. session *session_manager.Session,
  116. request *requests.RequestInvokeRerank,
  117. ) (
  118. *stream.StreamResponse[model_entities.RerankResult], error,
  119. ) {
  120. return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
  121. session,
  122. request,
  123. 1,
  124. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  125. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  126. )
  127. }
  128. func InvokeTTS(
  129. session *session_manager.Session,
  130. request *requests.RequestInvokeTTS,
  131. ) (
  132. *stream.StreamResponse[model_entities.TTSResult], error,
  133. ) {
  134. return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult](
  135. session,
  136. request,
  137. 1,
  138. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  139. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  140. )
  141. }
  142. func InvokeSpeech2Text(
  143. session *session_manager.Session,
  144. request *requests.RequestInvokeSpeech2Text,
  145. ) (
  146. *stream.StreamResponse[model_entities.Speech2TextResult], error,
  147. ) {
  148. return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult](
  149. session,
  150. request,
  151. 1,
  152. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  153. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  154. )
  155. }
  156. func InvokeModeration(
  157. session *session_manager.Session,
  158. request *requests.RequestInvokeModeration,
  159. ) (
  160. *stream.StreamResponse[model_entities.ModerationResult], error,
  161. ) {
  162. return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult](
  163. session,
  164. request,
  165. 1,
  166. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  167. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  168. )
  169. }
  170. func ValidateProviderCredentials(
  171. session *session_manager.Session,
  172. request *requests.RequestValidateProviderCredentials,
  173. ) (
  174. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  175. ) {
  176. return genericInvokePlugin[requests.RequestValidateProviderCredentials, model_entities.ValidateCredentialsResult](
  177. session,
  178. request,
  179. 1,
  180. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  181. backwards_invocation.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  182. )
  183. }
  184. func ValidateModelCredentials(
  185. session *session_manager.Session,
  186. request *requests.RequestValidateModelCredentials,
  187. ) (
  188. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  189. ) {
  190. return genericInvokePlugin[requests.RequestValidateModelCredentials, model_entities.ValidateCredentialsResult](
  191. session,
  192. request,
  193. 1,
  194. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  195. backwards_invocation.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  196. )
  197. }