model_service.go 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  11. )
  12. func getInvokeModelMap(
  13. session *session_manager.Session,
  14. action PluginAccessAction,
  15. request *requests.RequestInvokeLLM,
  16. ) map[string]any {
  17. req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_MODEL, action)
  18. data := req["data"].(map[string]any)
  19. data["provider"] = request.Provider
  20. data["model"] = request.Model
  21. data["model_type"] = request.ModelType
  22. data["model_parameters"] = request.ModelParameters
  23. data["prompt_messages"] = request.PromptMessages
  24. data["tools"] = request.Tools
  25. data["stop"] = request.Stop
  26. data["stream"] = request.Stream
  27. data["credentials"] = request.Credentials
  28. return req
  29. }
  30. func InvokeLLM(
  31. session *session_manager.Session,
  32. request *requests.RequestInvokeLLM,
  33. ) (
  34. *stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error,
  35. ) {
  36. runtime := plugin_manager.Get(session.PluginIdentity())
  37. if runtime == nil {
  38. return nil, errors.New("plugin not found")
  39. }
  40. response := stream.NewStreamResponse[plugin_entities.InvokeModelResponseChunk](512)
  41. listener := runtime.Listen(session.ID())
  42. listener.AddListener(func(message []byte) {
  43. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  44. if err != nil {
  45. log.Error("unmarshal json failed: %s", err.Error())
  46. return
  47. }
  48. switch chunk.Type {
  49. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  50. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.InvokeModelResponseChunk](chunk.Data)
  51. if err != nil {
  52. log.Error("unmarshal json failed: %s", err.Error())
  53. return
  54. }
  55. response.Write(chunk)
  56. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  57. invokeDify(runtime, session, chunk.Data)
  58. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  59. response.Close()
  60. default:
  61. log.Error("unknown stream message type: %s", chunk.Type)
  62. response.Close()
  63. }
  64. })
  65. response.OnClose(func() {
  66. listener.Close()
  67. })
  68. runtime.Write(session.ID(), []byte(parser.MarshalJson(
  69. getInvokeModelMap(
  70. session,
  71. PLUGIN_ACCESS_ACTION_INVOKE_LLM,
  72. request,
  73. ),
  74. )))
  75. return response, nil
  76. }