model_service.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. func genericInvokePlugin[Req any, Rsp any](
  14. session *session_manager.Session,
  15. request *Req,
  16. response_buffer_size int,
  17. typ PluginAccessType,
  18. action PluginAccessAction,
  19. ) (
  20. *stream.StreamResponse[Rsp], error,
  21. ) {
  22. runtime := plugin_manager.Get(session.PluginIdentity())
  23. if runtime == nil {
  24. return nil, errors.New("plugin not found")
  25. }
  26. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  27. listener := runtime.Listen(session.ID())
  28. listener.AddListener(func(message []byte) {
  29. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  30. if err != nil {
  31. log.Error("unmarshal json failed: %s", err.Error())
  32. return
  33. }
  34. switch chunk.Type {
  35. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  36. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  37. if err != nil {
  38. log.Error("unmarshal json failed: %s", err.Error())
  39. return
  40. }
  41. response.Write(chunk)
  42. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  43. invokeDify(runtime, session, chunk.Data)
  44. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  45. response.Close()
  46. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  47. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  48. if err != nil {
  49. break
  50. }
  51. response.WriteError(errors.New(e.Error))
  52. response.Close()
  53. default:
  54. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  55. response.Close()
  56. }
  57. })
  58. response.OnClose(func() {
  59. listener.Close()
  60. })
  61. runtime.Write(session.ID(), []byte(parser.MarshalJson(
  62. getInvokeModelMap(
  63. session,
  64. typ,
  65. action,
  66. request,
  67. ),
  68. )))
  69. return response, nil
  70. }
  71. func getInvokeModelMap(
  72. session *session_manager.Session,
  73. typ PluginAccessType,
  74. action PluginAccessAction,
  75. request any,
  76. ) map[string]any {
  77. req := getBasicPluginAccessMap(session.ID(), session.UserID(), typ, action)
  78. data := req["data"].(map[string]any)
  79. for k, v := range parser.StructToMap(request) {
  80. data[k] = v
  81. }
  82. return req
  83. }
  84. func InvokeLLM(
  85. session *session_manager.Session,
  86. request *requests.RequestInvokeLLM,
  87. ) (
  88. *stream.StreamResponse[model_entities.LLMResultChunk], error,
  89. ) {
  90. return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
  91. session,
  92. request,
  93. 512,
  94. PLUGIN_ACCESS_TYPE_MODEL,
  95. PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  96. )
  97. }
  98. func InvokeTextEmbedding(
  99. session *session_manager.Session,
  100. request *requests.RequestInvokeTextEmbedding,
  101. ) (
  102. *stream.StreamResponse[model_entities.TextEmbeddingResult], error,
  103. ) {
  104. return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
  105. session,
  106. request,
  107. 1,
  108. PLUGIN_ACCESS_TYPE_MODEL,
  109. PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
  110. )
  111. }
  112. func InvokeRerank(
  113. session *session_manager.Session,
  114. request *requests.RequestInvokeRerank,
  115. ) (
  116. *stream.StreamResponse[model_entities.RerankResult], error,
  117. ) {
  118. return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
  119. session,
  120. request,
  121. 1,
  122. PLUGIN_ACCESS_TYPE_MODEL,
  123. PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
  124. )
  125. }
  126. func InvokeTTS(
  127. session *session_manager.Session,
  128. request *requests.RequestInvokeTTS,
  129. ) (
  130. *stream.StreamResponse[string], error,
  131. ) {
  132. return genericInvokePlugin[requests.RequestInvokeTTS, string](
  133. session,
  134. request,
  135. 1,
  136. PLUGIN_ACCESS_TYPE_MODEL,
  137. PLUGIN_ACCESS_ACTION_INVOKE_TTS,
  138. )
  139. }
  140. func InvokeSpeech2Text(
  141. session *session_manager.Session,
  142. request *requests.RequestInvokeSpeech2Text,
  143. ) (
  144. *stream.StreamResponse[string], error,
  145. ) {
  146. return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string](
  147. session,
  148. request,
  149. 1,
  150. PLUGIN_ACCESS_TYPE_MODEL,
  151. PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
  152. )
  153. }
  154. func InvokeModeration(
  155. session *session_manager.Session,
  156. request *requests.RequestInvokeModeration,
  157. ) (
  158. *stream.StreamResponse[bool], error,
  159. ) {
  160. return genericInvokePlugin[requests.RequestInvokeModeration, bool](
  161. session,
  162. request,
  163. 1,
  164. PLUGIN_ACCESS_TYPE_MODEL,
  165. PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
  166. )
  167. }