model_service.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  13. )
  14. func genericInvokePlugin[Req any, Rsp any](
  15. session *session_manager.Session,
  16. request *Req,
  17. response_buffer_size int,
  18. typ backwards_invocation.PluginAccessType,
  19. action backwards_invocation.PluginAccessAction,
  20. ) (
  21. *stream.StreamResponse[Rsp], error,
  22. ) {
  23. runtime := plugin_manager.Get(session.PluginIdentity())
  24. if runtime == nil {
  25. return nil, errors.New("plugin not found")
  26. }
  27. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  28. listener := runtime.Listen(session.ID())
  29. listener.AddListener(func(message []byte) {
  30. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  31. if err != nil {
  32. log.Error("unmarshal json failed: %s", err.Error())
  33. return
  34. }
  35. switch chunk.Type {
  36. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  37. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  38. if err != nil {
  39. log.Error("unmarshal json failed: %s", err.Error())
  40. return
  41. }
  42. response.Write(chunk)
  43. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  44. if err := backwards_invocation.InvokeDify(runtime, typ, session, chunk.Data); err != nil {
  45. log.Error("invoke dify failed: %s", err.Error())
  46. return
  47. }
  48. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  49. response.Close()
  50. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  51. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  52. if err != nil {
  53. break
  54. }
  55. response.WriteError(errors.New(e.Error))
  56. response.Close()
  57. default:
  58. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  59. response.Close()
  60. }
  61. })
  62. response.OnClose(func() {
  63. listener.Close()
  64. })
  65. session.Write(
  66. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  67. getInvokeModelMap(
  68. session,
  69. typ,
  70. action,
  71. request,
  72. ),
  73. )
  74. return response, nil
  75. }
  76. func getInvokeModelMap(
  77. session *session_manager.Session,
  78. typ backwards_invocation.PluginAccessType,
  79. action backwards_invocation.PluginAccessAction,
  80. request any,
  81. ) map[string]any {
  82. req := getBasicPluginAccessMap(session.UserID(), typ, action)
  83. for k, v := range parser.StructToMap(request) {
  84. req[k] = v
  85. }
  86. return req
  87. }
  88. func InvokeLLM(
  89. session *session_manager.Session,
  90. request *requests.RequestInvokeLLM,
  91. ) (
  92. *stream.StreamResponse[model_entities.LLMResultChunk], error,
  93. ) {
  94. return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
  95. session,
  96. request,
  97. 512,
  98. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  99. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  100. )
  101. }
  102. func InvokeTextEmbedding(
  103. session *session_manager.Session,
  104. request *requests.RequestInvokeTextEmbedding,
  105. ) (
  106. *stream.StreamResponse[model_entities.TextEmbeddingResult], error,
  107. ) {
  108. return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
  109. session,
  110. request,
  111. 1,
  112. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  113. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  114. )
  115. }
  116. func InvokeRerank(
  117. session *session_manager.Session,
  118. request *requests.RequestInvokeRerank,
  119. ) (
  120. *stream.StreamResponse[model_entities.RerankResult], error,
  121. ) {
  122. return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
  123. session,
  124. request,
  125. 1,
  126. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  127. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  128. )
  129. }
  130. func InvokeTTS(
  131. session *session_manager.Session,
  132. request *requests.RequestInvokeTTS,
  133. ) (
  134. *stream.StreamResponse[model_entities.TTSResult], error,
  135. ) {
  136. return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult](
  137. session,
  138. request,
  139. 1,
  140. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  141. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  142. )
  143. }
  144. func InvokeSpeech2Text(
  145. session *session_manager.Session,
  146. request *requests.RequestInvokeSpeech2Text,
  147. ) (
  148. *stream.StreamResponse[model_entities.Speech2TextResult], error,
  149. ) {
  150. return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult](
  151. session,
  152. request,
  153. 1,
  154. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  155. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  156. )
  157. }
  158. func InvokeModeration(
  159. session *session_manager.Session,
  160. request *requests.RequestInvokeModeration,
  161. ) (
  162. *stream.StreamResponse[model_entities.ModerationResult], error,
  163. ) {
  164. return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult](
  165. session,
  166. request,
  167. 1,
  168. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  169. backwards_invocation.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  170. )
  171. }
  172. func ValidateProviderCredentials(
  173. session *session_manager.Session,
  174. request *requests.RequestValidateProviderCredentials,
  175. ) (
  176. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  177. ) {
  178. return genericInvokePlugin[requests.RequestValidateProviderCredentials, model_entities.ValidateCredentialsResult](
  179. session,
  180. request,
  181. 1,
  182. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  183. backwards_invocation.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  184. )
  185. }
  186. func ValidateModelCredentials(
  187. session *session_manager.Session,
  188. request *requests.RequestValidateModelCredentials,
  189. ) (
  190. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  191. ) {
  192. return genericInvokePlugin[requests.RequestValidateModelCredentials, model_entities.ValidateCredentialsResult](
  193. session,
  194. request,
  195. 1,
  196. backwards_invocation.PLUGIN_ACCESS_TYPE_MODEL,
  197. backwards_invocation.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  198. )
  199. }