session.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package session_manager
  2. import (
  3. "errors"
  4. "fmt"
  5. "sync"
  6. "time"
  7. "github.com/google/uuid"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  13. "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
  14. )
  15. var (
  16. sessions map[string]*Session = map[string]*Session{}
  17. session_lock sync.RWMutex
  18. )
  19. // session need to implement the backwards_invocation.BackwardsInvocationWriter interface
  20. type Session struct {
  21. ID string `json:"id"`
  22. runtime plugin_entities.PluginLifetime `json:"-"`
  23. backwardsInvocation dify_invocation.BackwardsInvocation `json:"-"`
  24. TenantID string `json:"tenant_id"`
  25. UserID string `json:"user_id"`
  26. PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier"`
  27. ClusterID string `json:"cluster_id"`
  28. InvokeFrom access_types.PluginAccessType `json:"invoke_from"`
  29. Action access_types.PluginAccessAction `json:"action"`
  30. Declaration *plugin_entities.PluginDeclaration `json:"declaration"`
  31. // information about incoming request
  32. ConversationID *string `json:"conversation_id"`
  33. MessageID *string `json:"message_id"`
  34. AppID *string `json:"app_id"`
  35. EndpointID *string `json:"endpoint_id"`
  36. }
  37. func sessionKey(id string) string {
  38. return fmt.Sprintf("session_info:%s", id)
  39. }
  40. type NewSessionPayload struct {
  41. TenantID string `json:"tenant_id"`
  42. UserID string `json:"user_id"`
  43. PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier"`
  44. ClusterID string `json:"cluster_id"`
  45. InvokeFrom access_types.PluginAccessType `json:"invoke_from"`
  46. Action access_types.PluginAccessAction `json:"action"`
  47. Declaration *plugin_entities.PluginDeclaration `json:"declaration"`
  48. BackwardsInvocation dify_invocation.BackwardsInvocation `json:"backwards_invocation"`
  49. IgnoreCache bool `json:"ignore_cache"`
  50. ConversationID *string `json:"conversation_id"`
  51. MessageID *string `json:"message_id"`
  52. AppID *string `json:"app_id"`
  53. EndpointID *string `json:"endpoint_id"`
  54. }
  55. func NewSession(payload NewSessionPayload) *Session {
  56. s := &Session{
  57. ID: uuid.New().String(),
  58. TenantID: payload.TenantID,
  59. UserID: payload.UserID,
  60. PluginUniqueIdentifier: payload.PluginUniqueIdentifier,
  61. ClusterID: payload.ClusterID,
  62. InvokeFrom: payload.InvokeFrom,
  63. Action: payload.Action,
  64. Declaration: payload.Declaration,
  65. backwardsInvocation: payload.BackwardsInvocation,
  66. ConversationID: payload.ConversationID,
  67. MessageID: payload.MessageID,
  68. AppID: payload.AppID,
  69. EndpointID: payload.EndpointID,
  70. }
  71. session_lock.Lock()
  72. sessions[s.ID] = s
  73. session_lock.Unlock()
  74. if !payload.IgnoreCache {
  75. if err := cache.Store(sessionKey(s.ID), s, time.Minute*30); err != nil {
  76. log.Error("set session info to cache failed, %s", err)
  77. }
  78. }
  79. return s
  80. }
  81. type GetSessionPayload struct {
  82. ID string `json:"id"`
  83. IgnoreCache bool `json:"ignore_cache"`
  84. }
  85. func GetSession(payload GetSessionPayload) *Session {
  86. session_lock.RLock()
  87. session := sessions[payload.ID]
  88. session_lock.RUnlock()
  89. if session == nil {
  90. // if session not found, it may be generated by another node, try to get it from cache
  91. session, err := cache.Get[Session](sessionKey(payload.ID))
  92. if err != nil {
  93. log.Error("get session info from cache failed, %s", err)
  94. return nil
  95. }
  96. return session
  97. }
  98. return session
  99. }
  100. type DeleteSessionPayload struct {
  101. ID string `json:"id"`
  102. IgnoreCache bool `json:"ignore_cache"`
  103. }
  104. func DeleteSession(payload DeleteSessionPayload) {
  105. session_lock.Lock()
  106. delete(sessions, payload.ID)
  107. session_lock.Unlock()
  108. if !payload.IgnoreCache {
  109. if err := cache.Del(sessionKey(payload.ID)); err != nil {
  110. log.Error("delete session info from cache failed, %s", err)
  111. }
  112. }
  113. }
  114. type CloseSessionPayload struct {
  115. IgnoreCache bool `json:"ignore_cache"`
  116. }
  117. func (s *Session) Close(payload CloseSessionPayload) {
  118. DeleteSession(DeleteSessionPayload{
  119. ID: s.ID,
  120. IgnoreCache: payload.IgnoreCache,
  121. })
  122. }
  123. func (s *Session) BindRuntime(runtime plugin_entities.PluginLifetime) {
  124. s.runtime = runtime
  125. }
  126. func (s *Session) Runtime() plugin_entities.PluginLifetime {
  127. return s.runtime
  128. }
  129. func (s *Session) BindBackwardsInvocation(backwardsInvocation dify_invocation.BackwardsInvocation) {
  130. s.backwardsInvocation = backwardsInvocation
  131. }
  132. func (s *Session) BackwardsInvocation() dify_invocation.BackwardsInvocation {
  133. return s.backwardsInvocation
  134. }
  135. type PLUGIN_IN_STREAM_EVENT string
  136. const (
  137. PLUGIN_IN_STREAM_EVENT_REQUEST PLUGIN_IN_STREAM_EVENT = "request"
  138. PLUGIN_IN_STREAM_EVENT_RESPONSE PLUGIN_IN_STREAM_EVENT = "backwards_response"
  139. )
  140. func (s *Session) Message(event PLUGIN_IN_STREAM_EVENT, data any) []byte {
  141. return parser.MarshalJsonBytes(map[string]any{
  142. "session_id": s.ID,
  143. "conversation_id": s.ConversationID,
  144. "message_id": s.MessageID,
  145. "app_id": s.AppID,
  146. "endpoint_id": s.EndpointID,
  147. "event": event,
  148. "data": data,
  149. })
  150. }
  151. func (s *Session) Write(event PLUGIN_IN_STREAM_EVENT, action access_types.PluginAccessAction, data any) error {
  152. if s.runtime == nil {
  153. return errors.New("runtime not bound")
  154. }
  155. s.runtime.Write(s.ID, action, s.Message(event, data))
  156. return nil
  157. }