hooks.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. package remote_manager
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  10. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  15. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  16. "github.com/panjf2000/gnet/v2"
  17. )
  18. var (
  19. // mode is only used for testing
  20. // TODO: simplify this ugly code
  21. _mode pluginRuntimeMode
  22. )
  23. type DifyServer struct {
  24. gnet.BuiltinEventEngine
  25. engine gnet.Engine
  26. mediaManager *media_manager.MediaBucket
  27. // listening address
  28. addr string
  29. port uint16
  30. // enabled multicore
  31. multicore bool
  32. // event loop count
  33. numLoops int
  34. // read new connections
  35. response *stream.Stream[plugin_entities.PluginFullDuplexLifetime]
  36. plugins map[int]*RemotePluginRuntime
  37. pluginsLock *sync.RWMutex
  38. shutdownChan chan bool
  39. maxConn int32
  40. currentConn int32
  41. }
  42. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  43. s.engine = c
  44. return gnet.None
  45. }
  46. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  47. // new plugin connected
  48. c.SetContext(&codec{})
  49. runtime := &RemotePluginRuntime{
  50. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  51. s.mediaManager,
  52. ),
  53. conn: c,
  54. response: stream.NewStream[[]byte](512),
  55. callbacks: make(map[string][]func([]byte)),
  56. callbacksLock: &sync.RWMutex{},
  57. assets: make(map[string]*bytes.Buffer),
  58. assetsBytes: 0,
  59. shutdownChan: make(chan bool),
  60. waitLaunchedChan: make(chan error),
  61. alive: true,
  62. }
  63. // store plugin runtime
  64. s.pluginsLock.Lock()
  65. s.plugins[c.Fd()] = runtime
  66. s.pluginsLock.Unlock()
  67. // start a timer to check if handshake is completed in 10 seconds
  68. time.AfterFunc(time.Second*10, func() {
  69. if !runtime.handshake {
  70. // close connection
  71. c.Close()
  72. }
  73. })
  74. // verified
  75. verified := true
  76. if verified {
  77. return nil, gnet.None
  78. }
  79. return nil, gnet.Close
  80. }
  81. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  82. // plugin disconnected
  83. s.pluginsLock.Lock()
  84. plugin := s.plugins[c.Fd()]
  85. delete(s.plugins, c.Fd())
  86. s.pluginsLock.Unlock()
  87. if plugin == nil {
  88. return gnet.None
  89. }
  90. // close plugin
  91. plugin.onDisconnected()
  92. // uninstall plugin
  93. if plugin.assetsTransferred {
  94. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  95. if plugin.installationId != "" {
  96. if err := plugin.Unregister(); err != nil {
  97. log.Error("unregister plugin failed, error: %v", err)
  98. }
  99. }
  100. // decrease current connection
  101. atomic.AddInt32(&s.currentConn, -1)
  102. }
  103. }
  104. // send stopped event
  105. plugin.waitChanLock.Lock()
  106. for _, c := range plugin.waitStoppedChan {
  107. select {
  108. case c <- true:
  109. default:
  110. }
  111. }
  112. plugin.waitChanLock.Unlock()
  113. // recycle launched chan, avoid memory leak
  114. plugin.waitLaunchedChanOnce.Do(func() {
  115. close(plugin.waitLaunchedChan)
  116. })
  117. return gnet.None
  118. }
  119. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  120. close(s.shutdownChan)
  121. }
  122. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  123. codec := c.Context().(*codec)
  124. messages, err := codec.Decode(c)
  125. if err != nil {
  126. return gnet.Close
  127. }
  128. // get plugin runtime
  129. s.pluginsLock.RLock()
  130. runtime, ok := s.plugins[c.Fd()]
  131. s.pluginsLock.RUnlock()
  132. if !ok {
  133. return gnet.Close
  134. }
  135. // handle messages
  136. for _, message := range messages {
  137. if len(message) == 0 {
  138. continue
  139. }
  140. s.onMessage(runtime, message)
  141. }
  142. return gnet.None
  143. }
  144. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  145. // handle message
  146. if runtime.handshakeFailed {
  147. // do nothing if handshake has failed
  148. return
  149. }
  150. closeConn := func(message []byte) {
  151. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  152. runtime.conn.Write(message)
  153. runtime.conn.Close()
  154. }
  155. }
  156. if !runtime.initialized {
  157. registerPayload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
  158. if err != nil {
  159. // close connection if handshake failed
  160. closeConn([]byte("handshake failed, invalid handshake message\n"))
  161. runtime.handshakeFailed = true
  162. return
  163. }
  164. if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
  165. if runtime.handshake {
  166. // handshake already completed
  167. return
  168. }
  169. key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](registerPayload.Data)
  170. if err != nil {
  171. // close connection if handshake failed
  172. closeConn([]byte("handshake failed, invalid handshake message\n"))
  173. runtime.handshakeFailed = true
  174. return
  175. }
  176. info, err := GetConnectionInfo(key.Key)
  177. if err == cache.ErrNotFound {
  178. // close connection if handshake failed
  179. closeConn([]byte("handshake failed, invalid key\n"))
  180. runtime.handshakeFailed = true
  181. return
  182. } else if err != nil {
  183. // close connection if handshake failed
  184. closeConn([]byte("internal error\n"))
  185. return
  186. }
  187. runtime.tenantId = info.TenantId
  188. // handshake completed
  189. runtime.handshake = true
  190. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
  191. if runtime.assetsTransferred {
  192. return
  193. }
  194. assetChunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](registerPayload.Data)
  195. if err != nil {
  196. log.Error("assets register failed, error: %v", err)
  197. closeConn([]byte("assets register failed, invalid assets chunk\n"))
  198. return
  199. }
  200. buffer, ok := runtime.assets[assetChunk.Filename]
  201. if !ok {
  202. runtime.assets[assetChunk.Filename] = &bytes.Buffer{}
  203. buffer = runtime.assets[assetChunk.Filename]
  204. }
  205. // allows at most 50MB assets
  206. if runtime.assetsBytes+int64(len(assetChunk.Data)) > 50*1024*1024 {
  207. closeConn([]byte("assets too large, at most 50MB\n"))
  208. return
  209. }
  210. // decode as base64
  211. data, err := base64.StdEncoding.DecodeString(assetChunk.Data)
  212. if err != nil {
  213. log.Error("assets decode failed, error: %v", err)
  214. closeConn([]byte("assets decode failed, invalid assets data\n"))
  215. return
  216. }
  217. buffer.Write(data)
  218. // update assets bytes
  219. runtime.assetsBytes += int64(len(data))
  220. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
  221. if !runtime.modelsRegistrationTransferred &&
  222. !runtime.endpointsRegistrationTransferred &&
  223. !runtime.toolsRegistrationTransferred &&
  224. !runtime.agentsRegistrationTransferred {
  225. closeConn([]byte("no registration transferred, cannot initialize\n"))
  226. return
  227. }
  228. files := make(map[string][]byte)
  229. for filename, buffer := range runtime.assets {
  230. files[filename] = buffer.Bytes()
  231. }
  232. // remap assets
  233. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  234. log.Error("assets remap failed, error: %v", err)
  235. closeConn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
  236. return
  237. }
  238. atomic.AddInt32(&s.currentConn, 1)
  239. if atomic.LoadInt32(&s.currentConn) > int32(s.maxConn) {
  240. closeConn([]byte("server is busy now, please try again later\n"))
  241. return
  242. }
  243. // fill in default values
  244. runtime.Config.FillInDefaultValues()
  245. // mark assets transferred
  246. runtime.assetsTransferred = true
  247. runtime.checksum = runtime.calculateChecksum()
  248. runtime.InitState()
  249. runtime.SetActiveAt(time.Now())
  250. // trigger registration event
  251. if err := runtime.Register(); err != nil {
  252. closeConn([]byte(fmt.Sprintf("register failed, cannot register: %v\n", err)))
  253. return
  254. }
  255. // send started event
  256. runtime.waitChanLock.Lock()
  257. for _, c := range runtime.waitStartedChan {
  258. select {
  259. case c <- true:
  260. default:
  261. }
  262. }
  263. runtime.waitChanLock.Unlock()
  264. // notify launched
  265. runtime.waitLaunchedChanOnce.Do(func() {
  266. close(runtime.waitLaunchedChan)
  267. })
  268. // mark initialized
  269. runtime.initialized = true
  270. // publish runtime to watcher
  271. s.response.Write(runtime)
  272. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
  273. if runtime.registrationTransferred {
  274. return
  275. }
  276. // process handle shake if not completed
  277. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](registerPayload.Data)
  278. if err != nil {
  279. // close connection if handshake failed
  280. closeConn([]byte(fmt.Sprintf("handshake failed, invalid plugin declaration: %v\n", err)))
  281. return
  282. }
  283. runtime.Config = declaration
  284. // registration transferred
  285. runtime.registrationTransferred = true
  286. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
  287. if runtime.toolsRegistrationTransferred {
  288. return
  289. }
  290. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](registerPayload.Data)
  291. if err != nil {
  292. closeConn([]byte(fmt.Sprintf("tools register failed, invalid tools declaration: %v\n", err)))
  293. return
  294. }
  295. runtime.toolsRegistrationTransferred = true
  296. if len(tools) > 0 {
  297. declaration := runtime.Config
  298. declaration.Tool = &tools[0]
  299. runtime.Config = declaration
  300. }
  301. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
  302. if runtime.modelsRegistrationTransferred {
  303. return
  304. }
  305. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](registerPayload.Data)
  306. if err != nil {
  307. closeConn([]byte(fmt.Sprintf("models register failed, invalid models declaration: %v\n", err)))
  308. return
  309. }
  310. runtime.modelsRegistrationTransferred = true
  311. if len(models) > 0 {
  312. declaration := runtime.Config
  313. declaration.Model = &models[0]
  314. runtime.Config = declaration
  315. }
  316. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
  317. if runtime.endpointsRegistrationTransferred {
  318. return
  319. }
  320. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](registerPayload.Data)
  321. if err != nil {
  322. closeConn([]byte(fmt.Sprintf("endpoints register failed, invalid endpoints declaration: %v\n", err)))
  323. return
  324. }
  325. runtime.endpointsRegistrationTransferred = true
  326. if len(endpoints) > 0 {
  327. declaration := runtime.Config
  328. declaration.Endpoint = &endpoints[0]
  329. runtime.Config = declaration
  330. }
  331. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_AGENT_DECLARATION {
  332. if runtime.agentsRegistrationTransferred {
  333. return
  334. }
  335. agents, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.AgentProviderDeclaration](registerPayload.Data)
  336. if err != nil {
  337. closeConn([]byte(fmt.Sprintf("agents register failed, invalid agents declaration: %v\n", err)))
  338. return
  339. }
  340. runtime.agentsRegistrationTransferred = true
  341. if len(agents) > 0 {
  342. declaration := runtime.Config
  343. declaration.Agent = &agents[0]
  344. runtime.Config = declaration
  345. }
  346. }
  347. } else {
  348. // continue handle messages if handshake completed
  349. runtime.response.Write(message)
  350. }
  351. }