model_service.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. func genericInvokePlugin[Req any, Rsp any](
  14. session *session_manager.Session,
  15. request *Req,
  16. response_buffer_size int,
  17. typ PluginAccessType,
  18. action PluginAccessAction,
  19. ) (
  20. *stream.StreamResponse[Rsp], error,
  21. ) {
  22. runtime := plugin_manager.Get(session.PluginIdentity())
  23. if runtime == nil {
  24. return nil, errors.New("plugin not found")
  25. }
  26. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  27. listener := runtime.Listen(session.ID())
  28. listener.AddListener(func(message []byte) {
  29. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  30. if err != nil {
  31. log.Error("unmarshal json failed: %s", err.Error())
  32. return
  33. }
  34. switch chunk.Type {
  35. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  36. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  37. if err != nil {
  38. log.Error("unmarshal json failed: %s", err.Error())
  39. return
  40. }
  41. response.Write(chunk)
  42. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  43. invokeDify(runtime, typ, session, chunk.Data)
  44. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  45. response.Close()
  46. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  47. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  48. if err != nil {
  49. break
  50. }
  51. response.WriteError(errors.New(e.Error))
  52. response.Close()
  53. default:
  54. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  55. response.Close()
  56. }
  57. })
  58. response.OnClose(func() {
  59. listener.Close()
  60. })
  61. session.Write(
  62. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  63. getInvokeModelMap(
  64. session,
  65. typ,
  66. action,
  67. request,
  68. ),
  69. )
  70. return response, nil
  71. }
  72. func getInvokeModelMap(
  73. session *session_manager.Session,
  74. typ PluginAccessType,
  75. action PluginAccessAction,
  76. request any,
  77. ) map[string]any {
  78. req := getBasicPluginAccessMap(session.UserID(), typ, action)
  79. for k, v := range parser.StructToMap(request) {
  80. req[k] = v
  81. }
  82. return req
  83. }
  84. func InvokeLLM(
  85. session *session_manager.Session,
  86. request *requests.RequestInvokeLLM,
  87. ) (
  88. *stream.StreamResponse[model_entities.LLMResultChunk], error,
  89. ) {
  90. return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
  91. session,
  92. request,
  93. 512,
  94. PLUGIN_ACCESS_TYPE_MODEL,
  95. PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  96. )
  97. }
  98. func InvokeTextEmbedding(
  99. session *session_manager.Session,
  100. request *requests.RequestInvokeTextEmbedding,
  101. ) (
  102. *stream.StreamResponse[model_entities.TextEmbeddingResult], error,
  103. ) {
  104. return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
  105. session,
  106. request,
  107. 1,
  108. PLUGIN_ACCESS_TYPE_MODEL,
  109. PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  110. )
  111. }
  112. func InvokeRerank(
  113. session *session_manager.Session,
  114. request *requests.RequestInvokeRerank,
  115. ) (
  116. *stream.StreamResponse[model_entities.RerankResult], error,
  117. ) {
  118. return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
  119. session,
  120. request,
  121. 1,
  122. PLUGIN_ACCESS_TYPE_MODEL,
  123. PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  124. )
  125. }
  126. func InvokeTTS(
  127. session *session_manager.Session,
  128. request *requests.RequestInvokeTTS,
  129. ) (
  130. *stream.StreamResponse[model_entities.TTSResult], error,
  131. ) {
  132. return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult](
  133. session,
  134. request,
  135. 1,
  136. PLUGIN_ACCESS_TYPE_MODEL,
  137. PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  138. )
  139. }
  140. func InvokeSpeech2Text(
  141. session *session_manager.Session,
  142. request *requests.RequestInvokeSpeech2Text,
  143. ) (
  144. *stream.StreamResponse[model_entities.Speech2TextResult], error,
  145. ) {
  146. return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult](
  147. session,
  148. request,
  149. 1,
  150. PLUGIN_ACCESS_TYPE_MODEL,
  151. PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  152. )
  153. }
  154. func InvokeModeration(
  155. session *session_manager.Session,
  156. request *requests.RequestInvokeModeration,
  157. ) (
  158. *stream.StreamResponse[model_entities.ModerationResult], error,
  159. ) {
  160. return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult](
  161. session,
  162. request,
  163. 1,
  164. PLUGIN_ACCESS_TYPE_MODEL,
  165. PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  166. )
  167. }
  168. func ValidateProviderCredentials(
  169. session *session_manager.Session,
  170. request *requests.RequestValidateProviderCredentials,
  171. ) (
  172. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  173. ) {
  174. return genericInvokePlugin[requests.RequestValidateProviderCredentials, model_entities.ValidateCredentialsResult](
  175. session,
  176. request,
  177. 1,
  178. PLUGIN_ACCESS_TYPE_MODEL,
  179. PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
  180. )
  181. }
  182. func ValidateModelCredentials(
  183. session *session_manager.Session,
  184. request *requests.RequestValidateModelCredentials,
  185. ) (
  186. *stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
  187. ) {
  188. return genericInvokePlugin[requests.RequestValidateModelCredentials, model_entities.ValidateCredentialsResult](
  189. session,
  190. request,
  191. 1,
  192. PLUGIN_ACCESS_TYPE_MODEL,
  193. PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
  194. )
  195. }