generic.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  13. )
  14. func genericInvokePlugin[Req any, Rsp any](
  15. session *session_manager.Session,
  16. request *Req,
  17. response_buffer_size int,
  18. typ access_types.PluginAccessType,
  19. action access_types.PluginAccessAction,
  20. ) (*stream.StreamResponse[Rsp], error) {
  21. runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginIdentity())
  22. if runtime == nil {
  23. return nil, errors.New("plugin not found")
  24. }
  25. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  26. listener := runtime.Listen(session.ID())
  27. listener.Listen(func(message []byte) {
  28. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  29. if err != nil {
  30. log.Error("unmarshal json failed: %s", err.Error())
  31. return
  32. }
  33. switch chunk.Type {
  34. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  35. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  36. if err != nil {
  37. log.Error("unmarshal json failed: %s", err.Error())
  38. response.WriteError(err)
  39. } else {
  40. response.Write(chunk)
  41. }
  42. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  43. // check if the request contains a aws_event_id
  44. var writer backwards_invocation.BackwardsInvocationWriter
  45. if chunk.RuntimeType == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
  46. writer = transaction.NewAWSTransactionWriter(chunk.ServerlessEventId)
  47. } else {
  48. writer = transaction.NewFullDuplexEventWriter(session)
  49. }
  50. if err := backwards_invocation.InvokeDify(runtime, typ, session, writer, chunk.Data); err != nil {
  51. log.Error("invoke dify failed: %s", err.Error())
  52. return
  53. }
  54. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  55. response.Close()
  56. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  57. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  58. if err != nil {
  59. break
  60. }
  61. response.WriteError(errors.New(e.Error))
  62. response.Close()
  63. default:
  64. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  65. response.Close()
  66. }
  67. })
  68. response.OnClose(func() {
  69. listener.Close()
  70. })
  71. session.Write(
  72. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  73. getInvokePluginMap(
  74. session,
  75. typ,
  76. action,
  77. request,
  78. ),
  79. )
  80. return response, nil
  81. }
  82. func getInvokePluginMap(
  83. session *session_manager.Session,
  84. typ access_types.PluginAccessType,
  85. action access_types.PluginAccessAction,
  86. request any,
  87. ) map[string]any {
  88. req := getBasicPluginAccessMap(session.UserID(), typ, action)
  89. for k, v := range parser.StructToMap(request) {
  90. req[k] = v
  91. }
  92. return req
  93. }