hooks.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package remote_manager
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  15. "github.com/panjf2000/gnet/v2"
  16. )
  17. var (
  18. // mode is only used for testing
  19. _mode pluginRuntimeMode
  20. )
  21. type DifyServer struct {
  22. gnet.BuiltinEventEngine
  23. engine gnet.Engine
  24. mediaManager *media_manager.MediaManager
  25. // listening address
  26. addr string
  27. port uint16
  28. // enabled multicore
  29. multicore bool
  30. // event loop count
  31. num_loops int
  32. // read new connections
  33. response *stream.Stream[*RemotePluginRuntime]
  34. plugins map[int]*RemotePluginRuntime
  35. plugins_lock *sync.RWMutex
  36. shutdown_chan chan bool
  37. max_conn int32
  38. current_conn int32
  39. }
  40. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  41. s.engine = c
  42. return gnet.None
  43. }
  44. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  45. // new plugin connected
  46. c.SetContext(&codec{})
  47. runtime := &RemotePluginRuntime{
  48. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  49. s.mediaManager,
  50. ),
  51. conn: c,
  52. response: stream.NewStream[[]byte](512),
  53. callbacks: make(map[string][]func([]byte)),
  54. callbacks_lock: &sync.RWMutex{},
  55. shutdown_chan: make(chan bool),
  56. wait_launched_chan: make(chan error),
  57. alive: true,
  58. }
  59. // store plugin runtime
  60. s.plugins_lock.Lock()
  61. s.plugins[c.Fd()] = runtime
  62. s.plugins_lock.Unlock()
  63. // start a timer to check if handshake is completed in 10 seconds
  64. time.AfterFunc(time.Second*10, func() {
  65. if !runtime.handshake {
  66. // close connection
  67. c.Close()
  68. }
  69. })
  70. // verified
  71. verified := true
  72. if verified {
  73. return nil, gnet.None
  74. }
  75. return nil, gnet.Close
  76. }
  77. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  78. // plugin disconnected
  79. s.plugins_lock.Lock()
  80. plugin := s.plugins[c.Fd()]
  81. delete(s.plugins, c.Fd())
  82. s.plugins_lock.Unlock()
  83. if plugin == nil {
  84. return gnet.None
  85. }
  86. // close plugin
  87. plugin.onDisconnected()
  88. // uninstall plugin
  89. if plugin.assets_transferred {
  90. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  91. if err := plugin.Unregister(); err != nil {
  92. log.Error("unregister plugin failed, error: %v", err)
  93. }
  94. // decrease current connection
  95. atomic.AddInt32(&s.current_conn, -1)
  96. }
  97. }
  98. // send stopped event
  99. plugin.wait_chan_lock.Lock()
  100. for _, c := range plugin.wait_stopped_chan {
  101. select {
  102. case c <- true:
  103. default:
  104. }
  105. }
  106. plugin.wait_chan_lock.Unlock()
  107. // recycle launched chan, avoid memory leak
  108. plugin.wait_launched_chan_once.Do(func() {
  109. close(plugin.wait_launched_chan)
  110. })
  111. return gnet.None
  112. }
  113. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  114. close(s.shutdown_chan)
  115. }
  116. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  117. codec := c.Context().(*codec)
  118. messages, err := codec.Decode(c)
  119. if err != nil {
  120. return gnet.Close
  121. }
  122. // get plugin runtime
  123. s.plugins_lock.RLock()
  124. runtime, ok := s.plugins[c.Fd()]
  125. s.plugins_lock.RUnlock()
  126. if !ok {
  127. return gnet.Close
  128. }
  129. // handle messages
  130. for _, message := range messages {
  131. if len(message) == 0 {
  132. continue
  133. }
  134. s.onMessage(runtime, message)
  135. }
  136. return gnet.None
  137. }
  138. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  139. // handle message
  140. if runtime.handshake_failed {
  141. // do nothing if handshake has failed
  142. return
  143. }
  144. close_conn := func(message []byte) {
  145. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  146. runtime.conn.Write(message)
  147. runtime.conn.Close()
  148. }
  149. }
  150. if !runtime.handshake {
  151. key := string(message)
  152. info, err := GetConnectionInfo(key)
  153. if err == cache.ErrNotFound {
  154. // close connection if handshake failed
  155. close_conn([]byte("handshake failed, invalid key\n"))
  156. runtime.handshake_failed = true
  157. return
  158. } else if err != nil {
  159. // close connection if handshake failed
  160. close_conn([]byte("internal error\n"))
  161. return
  162. }
  163. runtime.tenant_id = info.TenantId
  164. // handshake completed
  165. runtime.handshake = true
  166. } else if !runtime.registration_transferred {
  167. // process handle shake if not completed
  168. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
  169. if err != nil {
  170. // close connection if handshake failed
  171. close_conn([]byte("handshake failed, invalid plugin declaration\n"))
  172. return
  173. }
  174. runtime.Config = declaration
  175. // registration transferred
  176. runtime.registration_transferred = true
  177. } else if !runtime.tools_registration_transferred {
  178. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
  179. if err != nil {
  180. log.Error("tools register failed, error: %v", err)
  181. close_conn([]byte("tools register failed, invalid tools declaration\n"))
  182. return
  183. }
  184. runtime.tools_registration_transferred = true
  185. if len(tools) > 0 {
  186. declaration := runtime.Config
  187. declaration.Tool = &tools[0]
  188. runtime.Config = declaration
  189. }
  190. } else if !runtime.models_registration_transferred {
  191. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
  192. if err != nil {
  193. log.Error("models register failed, error: %v", err)
  194. close_conn([]byte("models register failed, invalid models declaration\n"))
  195. return
  196. }
  197. runtime.models_registration_transferred = true
  198. if len(models) > 0 {
  199. declaration := runtime.Config
  200. declaration.Model = &models[0]
  201. runtime.Config = declaration
  202. }
  203. } else if !runtime.endpoints_registration_transferred {
  204. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
  205. if err != nil {
  206. log.Error("endpoints register failed, error: %v", err)
  207. close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
  208. return
  209. }
  210. runtime.endpoints_registration_transferred = true
  211. if len(endpoints) > 0 {
  212. declaration := runtime.Config
  213. declaration.Endpoint = &endpoints[0]
  214. runtime.Config = declaration
  215. }
  216. } else if !runtime.assets_transferred {
  217. assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
  218. if err != nil {
  219. log.Error("assets register failed, error: %v", err)
  220. close_conn([]byte(fmt.Sprintf("assets register failed, invalid assets declaration: %v\n", err)))
  221. return
  222. }
  223. files := make(map[string][]byte)
  224. for _, asset := range assets {
  225. files[asset.Filename], err = hex.DecodeString(asset.Data)
  226. if err != nil {
  227. log.Error("assets decode failed, error: %v", err)
  228. close_conn([]byte(fmt.Sprintf("assets decode failed, invalid assets data, cannot decode file: %v\n", err)))
  229. return
  230. }
  231. }
  232. // remap assets
  233. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  234. log.Error("assets remap failed, error: %v", err)
  235. close_conn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
  236. return
  237. }
  238. atomic.AddInt32(&s.current_conn, 1)
  239. if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
  240. close_conn([]byte("server is busy now, please try again later\n"))
  241. return
  242. }
  243. // fill in default values
  244. runtime.Config.FillInDefaultValues()
  245. // mark assets transferred
  246. runtime.assets_transferred = true
  247. runtime.checksum = runtime.calculateChecksum()
  248. runtime.InitState()
  249. runtime.SetActiveAt(time.Now())
  250. // trigger registration event
  251. if err := runtime.Register(); err != nil {
  252. log.Error("register failed, error: %v", err)
  253. close_conn([]byte("register failed, cannot register\n"))
  254. return
  255. }
  256. // send started event
  257. runtime.wait_chan_lock.Lock()
  258. for _, c := range runtime.wait_started_chan {
  259. select {
  260. case c <- true:
  261. default:
  262. }
  263. }
  264. runtime.wait_chan_lock.Unlock()
  265. // notify launched
  266. runtime.wait_launched_chan_once.Do(func() {
  267. close(runtime.wait_launched_chan)
  268. })
  269. // publish runtime to watcher
  270. s.response.Write(runtime)
  271. } else {
  272. // continue handle messages if handshake completed
  273. runtime.response.Write(message)
  274. }
  275. }