hooks.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package remote_manager
  2. import (
  3. "encoding/hex"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  14. "github.com/panjf2000/gnet/v2"
  15. )
  16. var (
  17. // mode is only used for testing
  18. _mode pluginRuntimeMode
  19. )
  20. type DifyServer struct {
  21. gnet.BuiltinEventEngine
  22. engine gnet.Engine
  23. mediaManager *media_manager.MediaManager
  24. // listening address
  25. addr string
  26. port uint16
  27. // enabled multicore
  28. multicore bool
  29. // event loop count
  30. num_loops int
  31. // read new connections
  32. response *stream.Stream[*RemotePluginRuntime]
  33. plugins map[int]*RemotePluginRuntime
  34. plugins_lock *sync.RWMutex
  35. shutdown_chan chan bool
  36. max_conn int32
  37. current_conn int32
  38. }
  39. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  40. s.engine = c
  41. return gnet.None
  42. }
  43. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  44. // new plugin connected
  45. c.SetContext(&codec{})
  46. runtime := &RemotePluginRuntime{
  47. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  48. s.mediaManager,
  49. ),
  50. conn: c,
  51. response: stream.NewStream[[]byte](512),
  52. callbacks: make(map[string][]func([]byte)),
  53. callbacks_lock: &sync.RWMutex{},
  54. shutdown_chan: make(chan bool),
  55. alive: true,
  56. }
  57. // store plugin runtime
  58. s.plugins_lock.Lock()
  59. s.plugins[c.Fd()] = runtime
  60. s.plugins_lock.Unlock()
  61. // start a timer to check if handshake is completed in 10 seconds
  62. time.AfterFunc(time.Second*10, func() {
  63. if !runtime.handshake {
  64. // close connection
  65. c.Close()
  66. }
  67. })
  68. // verified
  69. verified := true
  70. if verified {
  71. return nil, gnet.None
  72. }
  73. return nil, gnet.Close
  74. }
  75. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  76. // plugin disconnected
  77. s.plugins_lock.Lock()
  78. plugin := s.plugins[c.Fd()]
  79. delete(s.plugins, c.Fd())
  80. s.plugins_lock.Unlock()
  81. if plugin == nil {
  82. return gnet.None
  83. }
  84. // close plugin
  85. plugin.onDisconnected()
  86. // uninstall plugin
  87. if plugin.assets_transferred {
  88. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  89. if err := plugin.Unregister(); err != nil {
  90. log.Error("unregister plugin failed, error: %v", err)
  91. }
  92. // decrease current connection
  93. atomic.AddInt32(&s.current_conn, -1)
  94. }
  95. }
  96. // send stopped event
  97. plugin.wait_chan_lock.Lock()
  98. for _, c := range plugin.wait_stopped_chan {
  99. select {
  100. case c <- true:
  101. default:
  102. }
  103. }
  104. plugin.wait_chan_lock.Unlock()
  105. return gnet.None
  106. }
  107. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  108. close(s.shutdown_chan)
  109. }
  110. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  111. codec := c.Context().(*codec)
  112. messages, err := codec.Decode(c)
  113. if err != nil {
  114. return gnet.Close
  115. }
  116. // get plugin runtime
  117. s.plugins_lock.RLock()
  118. runtime, ok := s.plugins[c.Fd()]
  119. s.plugins_lock.RUnlock()
  120. if !ok {
  121. return gnet.Close
  122. }
  123. // handle messages
  124. for _, message := range messages {
  125. if len(message) == 0 {
  126. continue
  127. }
  128. s.onMessage(runtime, message)
  129. }
  130. return gnet.None
  131. }
  132. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  133. // handle message
  134. if runtime.handshake_failed {
  135. // do nothing if handshake has failed
  136. return
  137. }
  138. close := func(message []byte) {
  139. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  140. runtime.conn.Write(message)
  141. runtime.conn.Close()
  142. }
  143. }
  144. if !runtime.handshake {
  145. key := string(message)
  146. info, err := GetConnectionInfo(key)
  147. if err == cache.ErrNotFound {
  148. // close connection if handshake failed
  149. close([]byte("handshake failed, invalid key\n"))
  150. runtime.handshake_failed = true
  151. return
  152. } else if err != nil {
  153. // close connection if handshake failed
  154. close([]byte("internal error\n"))
  155. return
  156. }
  157. runtime.tenant_id = info.TenantId
  158. // handshake completed
  159. runtime.handshake = true
  160. } else if !runtime.registration_transferred {
  161. // process handle shake if not completed
  162. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
  163. if err != nil {
  164. // close connection if handshake failed
  165. close([]byte("handshake failed, invalid plugin declaration\n"))
  166. return
  167. }
  168. runtime.Config = declaration
  169. // registration transferred
  170. runtime.registration_transferred = true
  171. } else if !runtime.tools_registration_transferred {
  172. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
  173. if err != nil {
  174. log.Error("tools register failed, error: %v", err)
  175. close([]byte("tools register failed, invalid tools declaration\n"))
  176. return
  177. }
  178. runtime.tools_registration_transferred = true
  179. if len(tools) > 0 {
  180. declaration := runtime.Config
  181. declaration.Tool = &tools[0]
  182. runtime.Config = declaration
  183. }
  184. } else if !runtime.models_registration_transferred {
  185. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
  186. if err != nil {
  187. log.Error("models register failed, error: %v", err)
  188. close([]byte("models register failed, invalid models declaration\n"))
  189. return
  190. }
  191. runtime.models_registration_transferred = true
  192. if len(models) > 0 {
  193. declaration := runtime.Config
  194. declaration.Model = &models[0]
  195. runtime.Config = declaration
  196. }
  197. } else if !runtime.endpoints_registration_transferred {
  198. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
  199. if err != nil {
  200. log.Error("endpoints register failed, error: %v", err)
  201. close([]byte("endpoints register failed, invalid endpoints declaration\n"))
  202. return
  203. }
  204. runtime.endpoints_registration_transferred = true
  205. if len(endpoints) > 0 {
  206. declaration := runtime.Config
  207. declaration.Endpoint = &endpoints[0]
  208. runtime.Config = declaration
  209. }
  210. } else if !runtime.assets_transferred {
  211. assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
  212. if err != nil {
  213. log.Error("assets register failed, error: %v", err)
  214. close([]byte("assets register failed, invalid assets declaration\n"))
  215. return
  216. }
  217. files := make(map[string][]byte)
  218. for _, asset := range assets {
  219. files[asset.Filename], err = hex.DecodeString(asset.Data)
  220. if err != nil {
  221. log.Error("assets decode failed, error: %v", err)
  222. close([]byte("assets decode failed, invalid assets data, cannot decode file\n"))
  223. return
  224. }
  225. }
  226. // remap assets
  227. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  228. log.Error("assets remap failed, error: %v", err)
  229. close([]byte("assets remap failed, invalid assets data, cannot remap\n"))
  230. return
  231. }
  232. atomic.AddInt32(&s.current_conn, 1)
  233. if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
  234. close([]byte("server is busy now, please try again later\n"))
  235. return
  236. }
  237. // fill in default values
  238. runtime.Config.FillInDefaultValues()
  239. // mark assets transferred
  240. runtime.assets_transferred = true
  241. runtime.checksum = runtime.calculateChecksum()
  242. runtime.InitState()
  243. runtime.SetActiveAt(time.Now())
  244. // trigger registration event
  245. if err := runtime.Register(); err != nil {
  246. log.Error("register failed, error: %v", err)
  247. close([]byte("register failed, cannot register\n"))
  248. return
  249. }
  250. // send started event
  251. runtime.wait_chan_lock.Lock()
  252. for _, c := range runtime.wait_started_chan {
  253. select {
  254. case c <- true:
  255. default:
  256. }
  257. }
  258. runtime.wait_chan_lock.Unlock()
  259. // publish runtime to watcher
  260. s.response.Write(runtime)
  261. } else {
  262. // continue handle messages if handshake completed
  263. runtime.response.Write(message)
  264. }
  265. }