hooks.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package remote_manager
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  10. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  15. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  16. "github.com/panjf2000/gnet/v2"
  17. )
  18. var (
  19. // mode is only used for testing
  20. _mode pluginRuntimeMode
  21. )
  22. type DifyServer struct {
  23. gnet.BuiltinEventEngine
  24. engine gnet.Engine
  25. mediaManager *media_manager.MediaManager
  26. // listening address
  27. addr string
  28. port uint16
  29. // enabled multicore
  30. multicore bool
  31. // event loop count
  32. num_loops int
  33. // read new connections
  34. response *stream.Stream[plugin_entities.PluginFullDuplexLifetime]
  35. plugins map[int]*RemotePluginRuntime
  36. plugins_lock *sync.RWMutex
  37. shutdown_chan chan bool
  38. max_conn int32
  39. current_conn int32
  40. }
  41. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  42. s.engine = c
  43. return gnet.None
  44. }
  45. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  46. // new plugin connected
  47. c.SetContext(&codec{})
  48. runtime := &RemotePluginRuntime{
  49. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  50. s.mediaManager,
  51. ),
  52. conn: c,
  53. response: stream.NewStream[[]byte](512),
  54. callbacks: make(map[string][]func([]byte)),
  55. callbacks_lock: &sync.RWMutex{},
  56. assets: make(map[string]*bytes.Buffer),
  57. assets_bytes: 0,
  58. shutdown_chan: make(chan bool),
  59. wait_launched_chan: make(chan error),
  60. alive: true,
  61. }
  62. // store plugin runtime
  63. s.plugins_lock.Lock()
  64. s.plugins[c.Fd()] = runtime
  65. s.plugins_lock.Unlock()
  66. // start a timer to check if handshake is completed in 10 seconds
  67. time.AfterFunc(time.Second*10, func() {
  68. if !runtime.handshake {
  69. // close connection
  70. c.Close()
  71. }
  72. })
  73. // verified
  74. verified := true
  75. if verified {
  76. return nil, gnet.None
  77. }
  78. return nil, gnet.Close
  79. }
  80. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  81. // plugin disconnected
  82. s.plugins_lock.Lock()
  83. plugin := s.plugins[c.Fd()]
  84. delete(s.plugins, c.Fd())
  85. s.plugins_lock.Unlock()
  86. if plugin == nil {
  87. return gnet.None
  88. }
  89. // close plugin
  90. plugin.onDisconnected()
  91. // uninstall plugin
  92. if plugin.assets_transferred {
  93. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  94. if err := plugin.Unregister(); err != nil {
  95. log.Error("unregister plugin failed, error: %v", err)
  96. }
  97. // decrease current connection
  98. atomic.AddInt32(&s.current_conn, -1)
  99. }
  100. }
  101. // send stopped event
  102. plugin.wait_chan_lock.Lock()
  103. for _, c := range plugin.wait_stopped_chan {
  104. select {
  105. case c <- true:
  106. default:
  107. }
  108. }
  109. plugin.wait_chan_lock.Unlock()
  110. // recycle launched chan, avoid memory leak
  111. plugin.wait_launched_chan_once.Do(func() {
  112. close(plugin.wait_launched_chan)
  113. })
  114. return gnet.None
  115. }
  116. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  117. close(s.shutdown_chan)
  118. }
  119. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  120. codec := c.Context().(*codec)
  121. messages, err := codec.Decode(c)
  122. if err != nil {
  123. return gnet.Close
  124. }
  125. // get plugin runtime
  126. s.plugins_lock.RLock()
  127. runtime, ok := s.plugins[c.Fd()]
  128. s.plugins_lock.RUnlock()
  129. if !ok {
  130. return gnet.Close
  131. }
  132. // handle messages
  133. for _, message := range messages {
  134. if len(message) == 0 {
  135. continue
  136. }
  137. s.onMessage(runtime, message)
  138. }
  139. return gnet.None
  140. }
  141. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  142. // handle message
  143. if runtime.handshake_failed {
  144. // do nothing if handshake has failed
  145. return
  146. }
  147. close_conn := func(message []byte) {
  148. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  149. runtime.conn.Write(message)
  150. runtime.conn.Close()
  151. }
  152. }
  153. if !runtime.initialized {
  154. register_payload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
  155. if err != nil {
  156. // close connection if handshake failed
  157. close_conn([]byte("handshake failed, invalid handshake message\n"))
  158. runtime.handshake_failed = true
  159. return
  160. }
  161. if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
  162. if runtime.handshake {
  163. // handshake already completed
  164. return
  165. }
  166. key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](register_payload.Data)
  167. if err != nil {
  168. // close connection if handshake failed
  169. close_conn([]byte("handshake failed, invalid handshake message\n"))
  170. runtime.handshake_failed = true
  171. return
  172. }
  173. info, err := GetConnectionInfo(key.Key)
  174. if err == cache.ErrNotFound {
  175. // close connection if handshake failed
  176. close_conn([]byte("handshake failed, invalid key\n"))
  177. runtime.handshake_failed = true
  178. return
  179. } else if err != nil {
  180. // close connection if handshake failed
  181. close_conn([]byte("internal error\n"))
  182. return
  183. }
  184. runtime.tenant_id = info.TenantId
  185. // handshake completed
  186. runtime.handshake = true
  187. } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
  188. if runtime.assets_transferred {
  189. return
  190. }
  191. asset_chunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](register_payload.Data)
  192. if err != nil {
  193. log.Error("assets register failed, error: %v", err)
  194. close_conn([]byte("assets register failed, invalid assets chunk\n"))
  195. return
  196. }
  197. buffer, ok := runtime.assets[asset_chunk.Filename]
  198. if !ok {
  199. runtime.assets[asset_chunk.Filename] = &bytes.Buffer{}
  200. buffer = runtime.assets[asset_chunk.Filename]
  201. }
  202. // allows at most 50MB assets
  203. if runtime.assets_bytes+int64(len(asset_chunk.Data)) > 50*1024*1024 {
  204. close_conn([]byte("assets too large, at most 50MB\n"))
  205. return
  206. }
  207. // decode as base64
  208. data, err := base64.StdEncoding.DecodeString(asset_chunk.Data)
  209. if err != nil {
  210. log.Error("assets decode failed, error: %v", err)
  211. close_conn([]byte("assets decode failed, invalid assets data\n"))
  212. return
  213. }
  214. buffer.Write(data)
  215. // update assets bytes
  216. runtime.assets_bytes += int64(len(data))
  217. } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
  218. if !runtime.models_registration_transferred &&
  219. !runtime.endpoints_registration_transferred &&
  220. !runtime.tools_registration_transferred {
  221. close_conn([]byte("no registration transferred, cannot initialize\n"))
  222. return
  223. }
  224. files := make(map[string][]byte)
  225. for filename, buffer := range runtime.assets {
  226. files[filename] = buffer.Bytes()
  227. }
  228. // remap assets
  229. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  230. log.Error("assets remap failed, error: %v", err)
  231. close_conn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
  232. return
  233. }
  234. atomic.AddInt32(&s.current_conn, 1)
  235. if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
  236. close_conn([]byte("server is busy now, please try again later\n"))
  237. return
  238. }
  239. // fill in default values
  240. runtime.Config.FillInDefaultValues()
  241. // mark assets transferred
  242. runtime.assets_transferred = true
  243. runtime.checksum = runtime.calculateChecksum()
  244. runtime.InitState()
  245. runtime.SetActiveAt(time.Now())
  246. // trigger registration event
  247. if err := runtime.Register(); err != nil {
  248. log.Error("register failed, error: %v", err)
  249. close_conn([]byte("register failed, cannot register\n"))
  250. return
  251. }
  252. // send started event
  253. runtime.wait_chan_lock.Lock()
  254. for _, c := range runtime.wait_started_chan {
  255. select {
  256. case c <- true:
  257. default:
  258. }
  259. }
  260. runtime.wait_chan_lock.Unlock()
  261. // notify launched
  262. runtime.wait_launched_chan_once.Do(func() {
  263. close(runtime.wait_launched_chan)
  264. })
  265. // mark initialized
  266. runtime.initialized = true
  267. // publish runtime to watcher
  268. s.response.Write(runtime)
  269. } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
  270. if runtime.registration_transferred {
  271. return
  272. }
  273. // process handle shake if not completed
  274. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](register_payload.Data)
  275. if err != nil {
  276. // close connection if handshake failed
  277. close_conn([]byte("handshake failed, invalid plugin declaration\n"))
  278. return
  279. }
  280. runtime.Config = declaration
  281. // registration transferred
  282. runtime.registration_transferred = true
  283. } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
  284. if runtime.tools_registration_transferred {
  285. return
  286. }
  287. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](register_payload.Data)
  288. if err != nil {
  289. log.Error("tools register failed, error: %v", err)
  290. close_conn([]byte("tools register failed, invalid tools declaration\n"))
  291. return
  292. }
  293. runtime.tools_registration_transferred = true
  294. if len(tools) > 0 {
  295. declaration := runtime.Config
  296. declaration.Tool = &tools[0]
  297. runtime.Config = declaration
  298. }
  299. } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
  300. if runtime.models_registration_transferred {
  301. return
  302. }
  303. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](register_payload.Data)
  304. if err != nil {
  305. log.Error("models register failed, error: %v", err)
  306. close_conn([]byte("models register failed, invalid models declaration\n"))
  307. return
  308. }
  309. runtime.models_registration_transferred = true
  310. if len(models) > 0 {
  311. declaration := runtime.Config
  312. declaration.Model = &models[0]
  313. runtime.Config = declaration
  314. }
  315. } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
  316. if runtime.endpoints_registration_transferred {
  317. return
  318. }
  319. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](register_payload.Data)
  320. if err != nil {
  321. log.Error("endpoints register failed, error: %v", err)
  322. close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
  323. return
  324. }
  325. runtime.endpoints_registration_transferred = true
  326. if len(endpoints) > 0 {
  327. declaration := runtime.Config
  328. declaration.Endpoint = &endpoints[0]
  329. runtime.Config = declaration
  330. }
  331. } else {
  332. // unknown event type
  333. close_conn([]byte("unknown initialization event type\n"))
  334. return
  335. }
  336. } else {
  337. // continue handle messages if handshake completed
  338. runtime.response.Write(message)
  339. }
  340. }