hooks.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package remote_manager
  2. import (
  3. "encoding/hex"
  4. "sync"
  5. "time"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. "github.com/panjf2000/gnet/v2"
  13. )
  14. type DifyServer struct {
  15. gnet.BuiltinEventEngine
  16. engine gnet.Engine
  17. mediaManager *media_manager.MediaManager
  18. // listening address
  19. addr string
  20. port uint16
  21. // enabled multicore
  22. multicore bool
  23. // event loop count
  24. num_loops int
  25. // read new connections
  26. response *stream.StreamResponse[*RemotePluginRuntime]
  27. plugins map[int]*RemotePluginRuntime
  28. plugins_lock *sync.RWMutex
  29. shutdown_chan chan bool
  30. }
  31. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  32. s.engine = c
  33. return gnet.None
  34. }
  35. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  36. // new plugin connected
  37. c.SetContext(&codec{})
  38. runtime := &RemotePluginRuntime{
  39. conn: c,
  40. response: stream.NewStreamResponse[[]byte](512),
  41. callbacks: make(map[string][]func([]byte)),
  42. callbacks_lock: &sync.RWMutex{},
  43. shutdown_chan: make(chan bool),
  44. alive: true,
  45. }
  46. // store plugin runtime
  47. s.plugins_lock.Lock()
  48. s.plugins[c.Fd()] = runtime
  49. s.plugins_lock.Unlock()
  50. // start a timer to check if handshake is completed in 10 seconds
  51. time.AfterFunc(time.Second*10, func() {
  52. if !runtime.handshake {
  53. // close connection
  54. c.Close()
  55. }
  56. })
  57. // verified
  58. verified := true
  59. if verified {
  60. return nil, gnet.None
  61. }
  62. return nil, gnet.Close
  63. }
  64. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  65. // plugin disconnected
  66. s.plugins_lock.Lock()
  67. plugin := s.plugins[c.Fd()]
  68. delete(s.plugins, c.Fd())
  69. s.plugins_lock.Unlock()
  70. // close plugin
  71. plugin.onDisconnected()
  72. // uninstall plugin
  73. if plugin.handshake && plugin.registration_transferred &&
  74. plugin.endpoints_registration_transferred &&
  75. plugin.models_registration_transferred &&
  76. plugin.tools_registration_transferred {
  77. if err := plugin.Unregister(); err != nil {
  78. log.Error("unregister plugin failed, error: %v", err)
  79. }
  80. }
  81. return gnet.None
  82. }
  83. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  84. close(s.shutdown_chan)
  85. }
  86. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  87. codec := c.Context().(*codec)
  88. messages, err := codec.Decode(c)
  89. if err != nil {
  90. return gnet.Close
  91. }
  92. // get plugin runtime
  93. s.plugins_lock.RLock()
  94. runtime, ok := s.plugins[c.Fd()]
  95. s.plugins_lock.RUnlock()
  96. if !ok {
  97. return gnet.Close
  98. }
  99. // handle messages
  100. for _, message := range messages {
  101. if len(message) == 0 {
  102. continue
  103. }
  104. s.onMessage(runtime, message)
  105. }
  106. return gnet.None
  107. }
  108. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  109. // handle message
  110. if runtime.handshake_failed {
  111. // do nothing if handshake has failed
  112. return
  113. }
  114. if !runtime.handshake {
  115. key := string(message)
  116. info, err := GetConnectionInfo(key)
  117. if err == cache.ErrNotFound {
  118. // close connection if handshake failed
  119. runtime.conn.Write([]byte("handshake failed, invalid key\n"))
  120. runtime.conn.Close()
  121. runtime.handshake_failed = true
  122. return
  123. } else if err != nil {
  124. // close connection if handshake failed
  125. runtime.conn.Write([]byte("internal error\n"))
  126. runtime.conn.Close()
  127. return
  128. }
  129. runtime.tenant_id = info.TenantId
  130. // handshake completed
  131. runtime.handshake = true
  132. } else if !runtime.registration_transferred {
  133. // process handle shake if not completed
  134. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
  135. if err != nil {
  136. // close connection if handshake failed
  137. runtime.conn.Write([]byte("handshake failed\n"))
  138. runtime.conn.Close()
  139. return
  140. }
  141. runtime.Config = declaration
  142. // registration transferred
  143. runtime.registration_transferred = true
  144. } else if !runtime.tools_registration_transferred {
  145. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
  146. if err != nil {
  147. runtime.conn.Write([]byte("tools register failed\n"))
  148. log.Error("tools register failed, error: %v", err)
  149. runtime.conn.Close()
  150. return
  151. }
  152. runtime.tools_registration_transferred = true
  153. if len(tools) > 0 {
  154. declaration := runtime.Config
  155. declaration.Tool = &tools[0]
  156. runtime.Config = declaration
  157. }
  158. } else if !runtime.models_registration_transferred {
  159. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
  160. if err != nil {
  161. runtime.conn.Write([]byte("models register failed\n"))
  162. log.Error("models register failed, error: %v", err)
  163. runtime.conn.Close()
  164. return
  165. }
  166. runtime.models_registration_transferred = true
  167. if len(models) > 0 {
  168. declaration := runtime.Config
  169. declaration.Model = &models[0]
  170. runtime.Config = declaration
  171. }
  172. } else if !runtime.endpoints_registration_transferred {
  173. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
  174. if err != nil {
  175. runtime.conn.Write([]byte("endpoints register failed\n"))
  176. log.Error("endpoints register failed, error: %v", err)
  177. runtime.conn.Close()
  178. return
  179. }
  180. runtime.endpoints_registration_transferred = true
  181. if len(endpoints) > 0 {
  182. declaration := runtime.Config
  183. declaration.Endpoint = &endpoints[0]
  184. runtime.Config = declaration
  185. }
  186. } else if !runtime.assets_transferred {
  187. assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
  188. if err != nil {
  189. runtime.conn.Write([]byte("assets register failed\n"))
  190. log.Error("assets register failed, error: %v", err)
  191. runtime.conn.Close()
  192. return
  193. }
  194. files := make(map[string][]byte)
  195. for _, asset := range assets {
  196. files[asset.Filename], err = hex.DecodeString(asset.Data)
  197. if err != nil {
  198. runtime.conn.Write([]byte("assets decode failed\n"))
  199. log.Error("assets decode failed, error: %v", err)
  200. runtime.conn.Close()
  201. return
  202. }
  203. }
  204. // remap assets
  205. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  206. runtime.conn.Write([]byte("assets remap failed\n"))
  207. log.Error("assets remap failed, error: %v", err)
  208. runtime.conn.Close()
  209. return
  210. }
  211. runtime.checksum = runtime.calculateChecksum()
  212. runtime.InitState()
  213. runtime.SetActiveAt(time.Now())
  214. // trigger registration event
  215. if err := runtime.Register(); err != nil {
  216. runtime.conn.Write([]byte("register failed\n"))
  217. log.Error("register failed, error: %v", err)
  218. runtime.conn.Close()
  219. return
  220. }
  221. // publish runtime to watcher
  222. s.response.Write(runtime)
  223. } else {
  224. // continue handle messages if handshake completed
  225. runtime.response.Write(message)
  226. }
  227. }
  228. func (s *DifyServer) onAssets(runtime *RemotePluginRuntime, assets []plugin_entities.RemoteAssetPayload) {
  229. }