hooks.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. package remote_manager
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
  10. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  15. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  16. "github.com/panjf2000/gnet/v2"
  17. )
  18. var (
  19. // mode is only used for testing
  20. _mode pluginRuntimeMode
  21. )
  22. type DifyServer struct {
  23. gnet.BuiltinEventEngine
  24. engine gnet.Engine
  25. mediaManager *media_manager.MediaBucket
  26. // listening address
  27. addr string
  28. port uint16
  29. // enabled multicore
  30. multicore bool
  31. // event loop count
  32. numLoops int
  33. // read new connections
  34. response *stream.Stream[plugin_entities.PluginFullDuplexLifetime]
  35. plugins map[int]*RemotePluginRuntime
  36. pluginsLock *sync.RWMutex
  37. shutdownChan chan bool
  38. maxConn int32
  39. currentConn int32
  40. }
  41. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  42. s.engine = c
  43. return gnet.None
  44. }
  45. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  46. // new plugin connected
  47. c.SetContext(&codec{})
  48. runtime := &RemotePluginRuntime{
  49. BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(
  50. s.mediaManager,
  51. ),
  52. conn: c,
  53. response: stream.NewStream[[]byte](512),
  54. callbacks: make(map[string][]func([]byte)),
  55. callbacksLock: &sync.RWMutex{},
  56. assets: make(map[string]*bytes.Buffer),
  57. assetsBytes: 0,
  58. shutdownChan: make(chan bool),
  59. waitLaunchedChan: make(chan error),
  60. alive: true,
  61. }
  62. // store plugin runtime
  63. s.pluginsLock.Lock()
  64. s.plugins[c.Fd()] = runtime
  65. s.pluginsLock.Unlock()
  66. // start a timer to check if handshake is completed in 10 seconds
  67. time.AfterFunc(time.Second*10, func() {
  68. if !runtime.handshake {
  69. // close connection
  70. c.Close()
  71. }
  72. })
  73. // verified
  74. verified := true
  75. if verified {
  76. return nil, gnet.None
  77. }
  78. return nil, gnet.Close
  79. }
  80. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  81. // plugin disconnected
  82. s.pluginsLock.Lock()
  83. plugin := s.plugins[c.Fd()]
  84. delete(s.plugins, c.Fd())
  85. s.pluginsLock.Unlock()
  86. if plugin == nil {
  87. return gnet.None
  88. }
  89. // close plugin
  90. plugin.onDisconnected()
  91. // uninstall plugin
  92. if plugin.assetsTransferred {
  93. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  94. if plugin.installationId != "" {
  95. if err := plugin.Unregister(); err != nil {
  96. log.Error("unregister plugin failed, error: %v", err)
  97. }
  98. }
  99. // decrease current connection
  100. atomic.AddInt32(&s.currentConn, -1)
  101. }
  102. }
  103. // send stopped event
  104. plugin.waitChanLock.Lock()
  105. for _, c := range plugin.waitStoppedChan {
  106. select {
  107. case c <- true:
  108. default:
  109. }
  110. }
  111. plugin.waitChanLock.Unlock()
  112. // recycle launched chan, avoid memory leak
  113. plugin.waitLaunchedChanOnce.Do(func() {
  114. close(plugin.waitLaunchedChan)
  115. })
  116. return gnet.None
  117. }
  118. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  119. close(s.shutdownChan)
  120. }
  121. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  122. codec := c.Context().(*codec)
  123. messages, err := codec.Decode(c)
  124. if err != nil {
  125. return gnet.Close
  126. }
  127. // get plugin runtime
  128. s.pluginsLock.RLock()
  129. runtime, ok := s.plugins[c.Fd()]
  130. s.pluginsLock.RUnlock()
  131. if !ok {
  132. return gnet.Close
  133. }
  134. // handle messages
  135. for _, message := range messages {
  136. if len(message) == 0 {
  137. continue
  138. }
  139. s.onMessage(runtime, message)
  140. }
  141. return gnet.None
  142. }
  143. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  144. // handle message
  145. if runtime.handshakeFailed {
  146. // do nothing if handshake has failed
  147. return
  148. }
  149. closeConn := func(message []byte) {
  150. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  151. runtime.conn.Write(message)
  152. runtime.conn.Close()
  153. }
  154. }
  155. if !runtime.initialized {
  156. registerPayload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
  157. if err != nil {
  158. // close connection if handshake failed
  159. closeConn([]byte("handshake failed, invalid handshake message\n"))
  160. runtime.handshakeFailed = true
  161. return
  162. }
  163. if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
  164. if runtime.handshake {
  165. // handshake already completed
  166. return
  167. }
  168. key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](registerPayload.Data)
  169. if err != nil {
  170. // close connection if handshake failed
  171. closeConn([]byte("handshake failed, invalid handshake message\n"))
  172. runtime.handshakeFailed = true
  173. return
  174. }
  175. info, err := GetConnectionInfo(key.Key)
  176. if err == cache.ErrNotFound {
  177. // close connection if handshake failed
  178. closeConn([]byte("handshake failed, invalid key\n"))
  179. runtime.handshakeFailed = true
  180. return
  181. } else if err != nil {
  182. // close connection if handshake failed
  183. closeConn([]byte("internal error\n"))
  184. return
  185. }
  186. runtime.tenantId = info.TenantId
  187. // handshake completed
  188. runtime.handshake = true
  189. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
  190. if runtime.assetsTransferred {
  191. return
  192. }
  193. assetChunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](registerPayload.Data)
  194. if err != nil {
  195. log.Error("assets register failed, error: %v", err)
  196. closeConn([]byte("assets register failed, invalid assets chunk\n"))
  197. return
  198. }
  199. buffer, ok := runtime.assets[assetChunk.Filename]
  200. if !ok {
  201. runtime.assets[assetChunk.Filename] = &bytes.Buffer{}
  202. buffer = runtime.assets[assetChunk.Filename]
  203. }
  204. // allows at most 50MB assets
  205. if runtime.assetsBytes+int64(len(assetChunk.Data)) > 50*1024*1024 {
  206. closeConn([]byte("assets too large, at most 50MB\n"))
  207. return
  208. }
  209. // decode as base64
  210. data, err := base64.StdEncoding.DecodeString(assetChunk.Data)
  211. if err != nil {
  212. log.Error("assets decode failed, error: %v", err)
  213. closeConn([]byte("assets decode failed, invalid assets data\n"))
  214. return
  215. }
  216. buffer.Write(data)
  217. // update assets bytes
  218. runtime.assetsBytes += int64(len(data))
  219. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
  220. if !runtime.modelsRegistrationTransferred &&
  221. !runtime.endpointsRegistrationTransferred &&
  222. !runtime.toolsRegistrationTransferred {
  223. closeConn([]byte("no registration transferred, cannot initialize\n"))
  224. return
  225. }
  226. files := make(map[string][]byte)
  227. for filename, buffer := range runtime.assets {
  228. files[filename] = buffer.Bytes()
  229. }
  230. // remap assets
  231. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  232. log.Error("assets remap failed, error: %v", err)
  233. closeConn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
  234. return
  235. }
  236. atomic.AddInt32(&s.currentConn, 1)
  237. if atomic.LoadInt32(&s.currentConn) > int32(s.maxConn) {
  238. closeConn([]byte("server is busy now, please try again later\n"))
  239. return
  240. }
  241. // fill in default values
  242. runtime.Config.FillInDefaultValues()
  243. // mark assets transferred
  244. runtime.assetsTransferred = true
  245. runtime.checksum = runtime.calculateChecksum()
  246. runtime.InitState()
  247. runtime.SetActiveAt(time.Now())
  248. // trigger registration event
  249. if err := runtime.Register(); err != nil {
  250. closeConn([]byte(fmt.Sprintf("register failed, cannot register: %v\n", err)))
  251. return
  252. }
  253. // send started event
  254. runtime.waitChanLock.Lock()
  255. for _, c := range runtime.waitStartedChan {
  256. select {
  257. case c <- true:
  258. default:
  259. }
  260. }
  261. runtime.waitChanLock.Unlock()
  262. // notify launched
  263. runtime.waitLaunchedChanOnce.Do(func() {
  264. close(runtime.waitLaunchedChan)
  265. })
  266. // mark initialized
  267. runtime.initialized = true
  268. // publish runtime to watcher
  269. s.response.Write(runtime)
  270. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
  271. if runtime.registrationTransferred {
  272. return
  273. }
  274. // process handle shake if not completed
  275. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](registerPayload.Data)
  276. if err != nil {
  277. // close connection if handshake failed
  278. closeConn([]byte(fmt.Sprintf("handshake failed, invalid plugin declaration: %v\n", err)))
  279. return
  280. }
  281. runtime.Config = declaration
  282. // registration transferred
  283. runtime.registrationTransferred = true
  284. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
  285. if runtime.toolsRegistrationTransferred {
  286. return
  287. }
  288. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](registerPayload.Data)
  289. if err != nil {
  290. closeConn([]byte(fmt.Sprintf("tools register failed, invalid tools declaration: %v\n", err)))
  291. return
  292. }
  293. runtime.toolsRegistrationTransferred = true
  294. if len(tools) > 0 {
  295. declaration := runtime.Config
  296. declaration.Tool = &tools[0]
  297. runtime.Config = declaration
  298. }
  299. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
  300. if runtime.modelsRegistrationTransferred {
  301. return
  302. }
  303. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](registerPayload.Data)
  304. if err != nil {
  305. closeConn([]byte(fmt.Sprintf("models register failed, invalid models declaration: %v\n", err)))
  306. return
  307. }
  308. runtime.modelsRegistrationTransferred = true
  309. if len(models) > 0 {
  310. declaration := runtime.Config
  311. declaration.Model = &models[0]
  312. runtime.Config = declaration
  313. }
  314. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
  315. if runtime.endpointsRegistrationTransferred {
  316. return
  317. }
  318. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](registerPayload.Data)
  319. if err != nil {
  320. closeConn([]byte(fmt.Sprintf("endpoints register failed, invalid endpoints declaration: %v\n", err)))
  321. return
  322. }
  323. runtime.endpointsRegistrationTransferred = true
  324. if len(endpoints) > 0 {
  325. declaration := runtime.Config
  326. declaration.Endpoint = &endpoints[0]
  327. runtime.Config = declaration
  328. }
  329. } else {
  330. // unknown event type
  331. closeConn([]byte("unknown initialization event type\n"))
  332. return
  333. }
  334. } else {
  335. // continue handle messages if handshake completed
  336. runtime.response.Write(message)
  337. }
  338. }