hooks.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package remote_manager
  2. import (
  3. "encoding/hex"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  14. "github.com/panjf2000/gnet/v2"
  15. )
  16. var (
  17. // mode is only used for testing
  18. _mode pluginRuntimeMode
  19. )
  20. type DifyServer struct {
  21. gnet.BuiltinEventEngine
  22. engine gnet.Engine
  23. mediaManager *media_manager.MediaManager
  24. // listening address
  25. addr string
  26. port uint16
  27. // enabled multicore
  28. multicore bool
  29. // event loop count
  30. num_loops int
  31. // read new connections
  32. response *stream.Stream[*RemotePluginRuntime]
  33. plugins map[int]*RemotePluginRuntime
  34. plugins_lock *sync.RWMutex
  35. shutdown_chan chan bool
  36. max_conn int32
  37. current_conn int32
  38. }
  39. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  40. s.engine = c
  41. return gnet.None
  42. }
  43. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  44. // new plugin connected
  45. c.SetContext(&codec{})
  46. runtime := &RemotePluginRuntime{
  47. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  48. s.mediaManager,
  49. ),
  50. conn: c,
  51. response: stream.NewStream[[]byte](512),
  52. callbacks: make(map[string][]func([]byte)),
  53. callbacks_lock: &sync.RWMutex{},
  54. shutdown_chan: make(chan bool),
  55. wait_launched_chan: make(chan error),
  56. alive: true,
  57. }
  58. // store plugin runtime
  59. s.plugins_lock.Lock()
  60. s.plugins[c.Fd()] = runtime
  61. s.plugins_lock.Unlock()
  62. // start a timer to check if handshake is completed in 10 seconds
  63. time.AfterFunc(time.Second*10, func() {
  64. if !runtime.handshake {
  65. // close connection
  66. c.Close()
  67. }
  68. })
  69. // verified
  70. verified := true
  71. if verified {
  72. return nil, gnet.None
  73. }
  74. return nil, gnet.Close
  75. }
  76. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  77. // plugin disconnected
  78. s.plugins_lock.Lock()
  79. plugin := s.plugins[c.Fd()]
  80. delete(s.plugins, c.Fd())
  81. s.plugins_lock.Unlock()
  82. if plugin == nil {
  83. return gnet.None
  84. }
  85. // close plugin
  86. plugin.onDisconnected()
  87. // uninstall plugin
  88. if plugin.assets_transferred {
  89. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  90. if err := plugin.Unregister(); err != nil {
  91. log.Error("unregister plugin failed, error: %v", err)
  92. }
  93. // decrease current connection
  94. atomic.AddInt32(&s.current_conn, -1)
  95. }
  96. }
  97. // send stopped event
  98. plugin.wait_chan_lock.Lock()
  99. for _, c := range plugin.wait_stopped_chan {
  100. select {
  101. case c <- true:
  102. default:
  103. }
  104. }
  105. plugin.wait_chan_lock.Unlock()
  106. // recycle launched chan, avoid memory leak
  107. plugin.wait_launched_chan_once.Do(func() {
  108. close(plugin.wait_launched_chan)
  109. })
  110. return gnet.None
  111. }
  112. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  113. close(s.shutdown_chan)
  114. }
  115. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  116. codec := c.Context().(*codec)
  117. messages, err := codec.Decode(c)
  118. if err != nil {
  119. return gnet.Close
  120. }
  121. // get plugin runtime
  122. s.plugins_lock.RLock()
  123. runtime, ok := s.plugins[c.Fd()]
  124. s.plugins_lock.RUnlock()
  125. if !ok {
  126. return gnet.Close
  127. }
  128. // handle messages
  129. for _, message := range messages {
  130. if len(message) == 0 {
  131. continue
  132. }
  133. s.onMessage(runtime, message)
  134. }
  135. return gnet.None
  136. }
  137. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  138. // handle message
  139. if runtime.handshake_failed {
  140. // do nothing if handshake has failed
  141. return
  142. }
  143. close_conn := func(message []byte) {
  144. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  145. runtime.conn.Write(message)
  146. runtime.conn.Close()
  147. }
  148. }
  149. if !runtime.handshake {
  150. key := string(message)
  151. info, err := GetConnectionInfo(key)
  152. if err == cache.ErrNotFound {
  153. // close connection if handshake failed
  154. close_conn([]byte("handshake failed, invalid key\n"))
  155. runtime.handshake_failed = true
  156. return
  157. } else if err != nil {
  158. // close connection if handshake failed
  159. close_conn([]byte("internal error\n"))
  160. return
  161. }
  162. runtime.tenant_id = info.TenantId
  163. // handshake completed
  164. runtime.handshake = true
  165. } else if !runtime.registration_transferred {
  166. // process handle shake if not completed
  167. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
  168. if err != nil {
  169. // close connection if handshake failed
  170. close_conn([]byte("handshake failed, invalid plugin declaration\n"))
  171. return
  172. }
  173. runtime.Config = declaration
  174. // registration transferred
  175. runtime.registration_transferred = true
  176. } else if !runtime.tools_registration_transferred {
  177. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
  178. if err != nil {
  179. log.Error("tools register failed, error: %v", err)
  180. close_conn([]byte("tools register failed, invalid tools declaration\n"))
  181. return
  182. }
  183. runtime.tools_registration_transferred = true
  184. if len(tools) > 0 {
  185. declaration := runtime.Config
  186. declaration.Tool = &tools[0]
  187. runtime.Config = declaration
  188. }
  189. } else if !runtime.models_registration_transferred {
  190. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
  191. if err != nil {
  192. log.Error("models register failed, error: %v", err)
  193. close_conn([]byte("models register failed, invalid models declaration\n"))
  194. return
  195. }
  196. runtime.models_registration_transferred = true
  197. if len(models) > 0 {
  198. declaration := runtime.Config
  199. declaration.Model = &models[0]
  200. runtime.Config = declaration
  201. }
  202. } else if !runtime.endpoints_registration_transferred {
  203. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
  204. if err != nil {
  205. log.Error("endpoints register failed, error: %v", err)
  206. close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
  207. return
  208. }
  209. runtime.endpoints_registration_transferred = true
  210. if len(endpoints) > 0 {
  211. declaration := runtime.Config
  212. declaration.Endpoint = &endpoints[0]
  213. runtime.Config = declaration
  214. }
  215. } else if !runtime.assets_transferred {
  216. assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
  217. if err != nil {
  218. log.Error("assets register failed, error: %v", err)
  219. close_conn([]byte("assets register failed, invalid assets declaration\n"))
  220. return
  221. }
  222. files := make(map[string][]byte)
  223. for _, asset := range assets {
  224. files[asset.Filename], err = hex.DecodeString(asset.Data)
  225. if err != nil {
  226. log.Error("assets decode failed, error: %v", err)
  227. close_conn([]byte("assets decode failed, invalid assets data, cannot decode file\n"))
  228. return
  229. }
  230. }
  231. // remap assets
  232. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  233. log.Error("assets remap failed, error: %v", err)
  234. close_conn([]byte("assets remap failed, invalid assets data, cannot remap\n"))
  235. return
  236. }
  237. atomic.AddInt32(&s.current_conn, 1)
  238. if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
  239. close_conn([]byte("server is busy now, please try again later\n"))
  240. return
  241. }
  242. // fill in default values
  243. runtime.Config.FillInDefaultValues()
  244. // mark assets transferred
  245. runtime.assets_transferred = true
  246. runtime.checksum = runtime.calculateChecksum()
  247. runtime.InitState()
  248. runtime.SetActiveAt(time.Now())
  249. // trigger registration event
  250. if err := runtime.Register(); err != nil {
  251. log.Error("register failed, error: %v", err)
  252. close_conn([]byte("register failed, cannot register\n"))
  253. return
  254. }
  255. // send started event
  256. runtime.wait_chan_lock.Lock()
  257. for _, c := range runtime.wait_started_chan {
  258. select {
  259. case c <- true:
  260. default:
  261. }
  262. }
  263. runtime.wait_chan_lock.Unlock()
  264. // notify launched
  265. runtime.wait_launched_chan_once.Do(func() {
  266. close(runtime.wait_launched_chan)
  267. })
  268. // publish runtime to watcher
  269. s.response.Write(runtime)
  270. } else {
  271. // continue handle messages if handshake completed
  272. runtime.response.Write(message)
  273. }
  274. }