hooks.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. package remote_manager
  2. import (
  3. "encoding/hex"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  14. "github.com/panjf2000/gnet/v2"
  15. )
  16. var (
  17. _mode pluginRuntimeMode
  18. )
  19. type DifyServer struct {
  20. gnet.BuiltinEventEngine
  21. engine gnet.Engine
  22. mediaManager *media_manager.MediaManager
  23. // listening address
  24. addr string
  25. port uint16
  26. // enabled multicore
  27. multicore bool
  28. // event loop count
  29. num_loops int
  30. // read new connections
  31. response *stream.Stream[*RemotePluginRuntime]
  32. plugins map[int]*RemotePluginRuntime
  33. plugins_lock *sync.RWMutex
  34. shutdown_chan chan bool
  35. }
  36. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  37. s.engine = c
  38. return gnet.None
  39. }
  40. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  41. // new plugin connected
  42. c.SetContext(&codec{})
  43. runtime := &RemotePluginRuntime{
  44. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  45. s.mediaManager,
  46. ),
  47. conn: c,
  48. response: stream.NewStream[[]byte](512),
  49. callbacks: make(map[string][]func([]byte)),
  50. callbacks_lock: &sync.RWMutex{},
  51. shutdown_chan: make(chan bool),
  52. alive: true,
  53. }
  54. // store plugin runtime
  55. s.plugins_lock.Lock()
  56. s.plugins[c.Fd()] = runtime
  57. s.plugins_lock.Unlock()
  58. // start a timer to check if handshake is completed in 10 seconds
  59. time.AfterFunc(time.Second*10, func() {
  60. if !runtime.handshake {
  61. // close connection
  62. c.Close()
  63. }
  64. })
  65. // verified
  66. verified := true
  67. if verified {
  68. return nil, gnet.None
  69. }
  70. return nil, gnet.Close
  71. }
  72. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  73. // plugin disconnected
  74. s.plugins_lock.Lock()
  75. plugin := s.plugins[c.Fd()]
  76. delete(s.plugins, c.Fd())
  77. s.plugins_lock.Unlock()
  78. // close plugin
  79. plugin.onDisconnected()
  80. // clear assets
  81. plugin.ClearAssets()
  82. // uninstall plugin
  83. if plugin.assets_transferred {
  84. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  85. if err := plugin.Unregister(); err != nil {
  86. log.Error("unregister plugin failed, error: %v", err)
  87. }
  88. }
  89. }
  90. // send stopped event
  91. plugin.wait_chan_lock.Lock()
  92. for _, c := range plugin.wait_stopped_chan {
  93. select {
  94. case c <- true:
  95. default:
  96. }
  97. }
  98. plugin.wait_chan_lock.Unlock()
  99. return gnet.None
  100. }
  101. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  102. close(s.shutdown_chan)
  103. }
  104. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  105. codec := c.Context().(*codec)
  106. messages, err := codec.Decode(c)
  107. if err != nil {
  108. return gnet.Close
  109. }
  110. // get plugin runtime
  111. s.plugins_lock.RLock()
  112. runtime, ok := s.plugins[c.Fd()]
  113. s.plugins_lock.RUnlock()
  114. if !ok {
  115. return gnet.Close
  116. }
  117. // handle messages
  118. for _, message := range messages {
  119. if len(message) == 0 {
  120. continue
  121. }
  122. s.onMessage(runtime, message)
  123. }
  124. return gnet.None
  125. }
  126. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  127. // handle message
  128. if runtime.handshake_failed {
  129. // do nothing if handshake has failed
  130. return
  131. }
  132. close := func(message []byte) {
  133. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  134. runtime.conn.Write(message)
  135. runtime.conn.Close()
  136. }
  137. }
  138. if !runtime.handshake {
  139. key := string(message)
  140. info, err := GetConnectionInfo(key)
  141. if err == cache.ErrNotFound {
  142. // close connection if handshake failed
  143. close([]byte("handshake failed, invalid key\n"))
  144. runtime.handshake_failed = true
  145. return
  146. } else if err != nil {
  147. // close connection if handshake failed
  148. close([]byte("internal error\n"))
  149. return
  150. }
  151. runtime.tenant_id = info.TenantId
  152. // handshake completed
  153. runtime.handshake = true
  154. } else if !runtime.registration_transferred {
  155. // process handle shake if not completed
  156. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
  157. if err != nil {
  158. // close connection if handshake failed
  159. close([]byte("handshake failed, invalid plugin declaration\n"))
  160. return
  161. }
  162. runtime.Config = declaration
  163. // registration transferred
  164. runtime.registration_transferred = true
  165. } else if !runtime.tools_registration_transferred {
  166. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
  167. if err != nil {
  168. log.Error("tools register failed, error: %v", err)
  169. close([]byte("tools register failed, invalid tools declaration\n"))
  170. return
  171. }
  172. runtime.tools_registration_transferred = true
  173. if len(tools) > 0 {
  174. declaration := runtime.Config
  175. declaration.Tool = &tools[0]
  176. runtime.Config = declaration
  177. }
  178. } else if !runtime.models_registration_transferred {
  179. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
  180. if err != nil {
  181. log.Error("models register failed, error: %v", err)
  182. close([]byte("models register failed, invalid models declaration\n"))
  183. return
  184. }
  185. runtime.models_registration_transferred = true
  186. if len(models) > 0 {
  187. declaration := runtime.Config
  188. declaration.Model = &models[0]
  189. runtime.Config = declaration
  190. }
  191. } else if !runtime.endpoints_registration_transferred {
  192. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
  193. if err != nil {
  194. log.Error("endpoints register failed, error: %v", err)
  195. close([]byte("endpoints register failed, invalid endpoints declaration\n"))
  196. return
  197. }
  198. runtime.endpoints_registration_transferred = true
  199. if len(endpoints) > 0 {
  200. declaration := runtime.Config
  201. declaration.Endpoint = &endpoints[0]
  202. runtime.Config = declaration
  203. }
  204. } else if !runtime.assets_transferred {
  205. assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
  206. if err != nil {
  207. log.Error("assets register failed, error: %v", err)
  208. close([]byte("assets register failed, invalid assets declaration\n"))
  209. return
  210. }
  211. files := make(map[string][]byte)
  212. for _, asset := range assets {
  213. files[asset.Filename], err = hex.DecodeString(asset.Data)
  214. if err != nil {
  215. log.Error("assets decode failed, error: %v", err)
  216. close([]byte("assets decode failed, invalid assets data, cannot decode file\n"))
  217. return
  218. }
  219. }
  220. // remap assets
  221. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  222. log.Error("assets remap failed, error: %v", err)
  223. close([]byte("assets remap failed, invalid assets data, cannot remap\n"))
  224. return
  225. }
  226. runtime.assets_transferred = true
  227. runtime.checksum = runtime.calculateChecksum()
  228. runtime.InitState()
  229. runtime.SetActiveAt(time.Now())
  230. // trigger registration event
  231. if err := runtime.Register(); err != nil {
  232. log.Error("register failed, error: %v", err)
  233. close([]byte("register failed, cannot register\n"))
  234. return
  235. }
  236. // send started event
  237. runtime.wait_chan_lock.Lock()
  238. for _, c := range runtime.wait_started_chan {
  239. select {
  240. case c <- true:
  241. default:
  242. }
  243. }
  244. runtime.wait_chan_lock.Unlock()
  245. // publish runtime to watcher
  246. s.response.Write(runtime)
  247. } else {
  248. // continue handle messages if handshake completed
  249. runtime.response.Write(message)
  250. }
  251. }