hooks.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package remote_manager
  2. import (
  3. "sync"
  4. "time"
  5. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  6. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  7. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  10. "github.com/panjf2000/gnet/v2"
  11. )
  12. type DifyServer struct {
  13. gnet.BuiltinEventEngine
  14. engine gnet.Engine
  15. // listening address
  16. addr string
  17. port uint16
  18. // enabled multicore
  19. multicore bool
  20. // event loop count
  21. num_loops int
  22. // read new connections
  23. response *stream.StreamResponse[*RemotePluginRuntime]
  24. plugins map[int]*RemotePluginRuntime
  25. plugins_lock *sync.RWMutex
  26. shutdown_chan chan bool
  27. }
  28. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  29. s.engine = c
  30. return gnet.None
  31. }
  32. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  33. // new plugin connected
  34. c.SetContext(&codec{})
  35. runtime := &RemotePluginRuntime{
  36. conn: c,
  37. response: stream.NewStreamResponse[[]byte](512),
  38. callbacks: make(map[string][]func([]byte)),
  39. callbacks_lock: &sync.RWMutex{},
  40. shutdown_chan: make(chan bool),
  41. alive: true,
  42. }
  43. // store plugin runtime
  44. s.plugins_lock.Lock()
  45. s.plugins[c.Fd()] = runtime
  46. s.plugins_lock.Unlock()
  47. // start a timer to check if handshake is completed in 10 seconds
  48. time.AfterFunc(time.Second*10, func() {
  49. if !runtime.handshake {
  50. // close connection
  51. c.Close()
  52. }
  53. })
  54. // verified
  55. verified := true
  56. if verified {
  57. return nil, gnet.None
  58. }
  59. return nil, gnet.Close
  60. }
  61. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  62. // plugin disconnected
  63. s.plugins_lock.Lock()
  64. plugin := s.plugins[c.Fd()]
  65. delete(s.plugins, c.Fd())
  66. s.plugins_lock.Unlock()
  67. // close plugin
  68. plugin.onDisconnected()
  69. // uninstall plugin
  70. if plugin.handshake && plugin.registration_transferred &&
  71. plugin.endpoints_registration_transferred &&
  72. plugin.models_registration_transferred &&
  73. plugin.tools_registration_transferred {
  74. if err := plugin.Unregister(); err != nil {
  75. log.Error("unregister plugin failed, error: %v", err)
  76. }
  77. }
  78. return gnet.None
  79. }
  80. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  81. close(s.shutdown_chan)
  82. }
  83. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  84. codec := c.Context().(*codec)
  85. messages, err := codec.Decode(c)
  86. if err != nil {
  87. return gnet.Close
  88. }
  89. // get plugin runtime
  90. s.plugins_lock.RLock()
  91. runtime, ok := s.plugins[c.Fd()]
  92. s.plugins_lock.RUnlock()
  93. if !ok {
  94. return gnet.Close
  95. }
  96. // handle messages
  97. for _, message := range messages {
  98. if len(message) == 0 {
  99. continue
  100. }
  101. s.onMessage(runtime, message)
  102. }
  103. return gnet.None
  104. }
  105. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  106. // handle message
  107. if runtime.handshake_failed {
  108. // do nothing if handshake has failed
  109. return
  110. }
  111. if !runtime.handshake {
  112. key := string(message)
  113. info, err := GetConnectionInfo(key)
  114. if err == cache.ErrNotFound {
  115. // close connection if handshake failed
  116. runtime.conn.Write([]byte("handshake failed, invalid key\n"))
  117. runtime.conn.Close()
  118. runtime.handshake_failed = true
  119. return
  120. } else if err != nil {
  121. // close connection if handshake failed
  122. runtime.conn.Write([]byte("internal error\n"))
  123. runtime.conn.Close()
  124. return
  125. }
  126. runtime.tenant_id = info.TenantId
  127. // handshake completed
  128. runtime.handshake = true
  129. } else if !runtime.registration_transferred {
  130. // process handle shake if not completed
  131. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
  132. if err != nil {
  133. // close connection if handshake failed
  134. runtime.conn.Write([]byte("handshake failed\n"))
  135. runtime.conn.Close()
  136. return
  137. }
  138. runtime.Config = declaration
  139. // registration transferred
  140. runtime.registration_transferred = true
  141. } else if !runtime.tools_registration_transferred {
  142. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderConfiguration](message)
  143. if err != nil {
  144. runtime.conn.Write([]byte("tools register failed\n"))
  145. log.Error("tools register failed, error: %v", err)
  146. runtime.conn.Close()
  147. return
  148. }
  149. runtime.tools_registration_transferred = true
  150. if len(tools) > 0 {
  151. declaration := runtime.Config
  152. declaration.Tool = &tools[0]
  153. runtime.Config = declaration
  154. }
  155. } else if !runtime.models_registration_transferred {
  156. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderConfiguration](message)
  157. if err != nil {
  158. runtime.conn.Write([]byte("models register failed\n"))
  159. log.Error("models register failed, error: %v", err)
  160. runtime.conn.Close()
  161. return
  162. }
  163. runtime.models_registration_transferred = true
  164. if len(models) > 0 {
  165. declaration := runtime.Config
  166. declaration.Model = &models[0]
  167. runtime.Config = declaration
  168. }
  169. } else if !runtime.endpoints_registration_transferred {
  170. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
  171. if err != nil {
  172. runtime.conn.Write([]byte("endpoints register failed\n"))
  173. log.Error("endpoints register failed, error: %v", err)
  174. runtime.conn.Close()
  175. return
  176. }
  177. runtime.endpoints_registration_transferred = true
  178. if len(endpoints) > 0 {
  179. declaration := runtime.Config
  180. declaration.Endpoint = &endpoints[0]
  181. runtime.Config = declaration
  182. }
  183. runtime.checksum = runtime.calculateChecksum()
  184. runtime.InitState()
  185. runtime.SetActiveAt(time.Now())
  186. // trigger registration event
  187. if err := runtime.Register(); err != nil {
  188. runtime.conn.Write([]byte("register failed\n"))
  189. log.Error("register failed, error: %v", err)
  190. runtime.conn.Close()
  191. return
  192. }
  193. // publish runtime to watcher
  194. s.response.Write(runtime)
  195. } else {
  196. // continue handle messages if handshake completed
  197. runtime.response.Write(message)
  198. }
  199. }