hooks.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. package debugging_runtime
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_runtime"
  10. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_transport"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  15. "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
  16. "github.com/panjf2000/gnet/v2"
  17. )
  18. var (
  19. // mode is only used for testing
  20. // TODO: simplify this ugly code
  21. _mode pluginRuntimeMode
  22. )
  23. type DifyServer struct {
  24. gnet.BuiltinEventEngine
  25. engine gnet.Engine
  26. mediaManager *media_transport.MediaBucket
  27. // listening address
  28. addr string
  29. port uint16
  30. // enabled multicore
  31. multicore bool
  32. // event loop count
  33. numLoops int
  34. // read new connections
  35. response *stream.Stream[plugin_entities.PluginFullDuplexLifetime]
  36. plugins map[int]*RemotePluginRuntime
  37. pluginsLock *sync.RWMutex
  38. shutdownChan chan bool
  39. maxConn int32
  40. currentConn int32
  41. }
  42. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  43. s.engine = c
  44. return gnet.None
  45. }
  46. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  47. // new plugin connected
  48. c.SetContext(&codec{})
  49. runtime := &RemotePluginRuntime{
  50. MediaTransport: basic_runtime.NewMediaTransport(
  51. s.mediaManager,
  52. ),
  53. conn: c,
  54. response: stream.NewStream[[]byte](512),
  55. messageCallbacks: make(map[string][]func([]byte)),
  56. messageCallbacksLock: &sync.RWMutex{},
  57. sessionMessageClosers: make(map[string][]func()),
  58. sessionMessageClosersLock: &sync.RWMutex{},
  59. assets: make(map[string]*bytes.Buffer),
  60. assetsBytes: 0,
  61. shutdownChan: make(chan bool),
  62. waitLaunchedChan: make(chan error),
  63. alive: true,
  64. }
  65. // store plugin runtime
  66. s.pluginsLock.Lock()
  67. s.plugins[c.Fd()] = runtime
  68. s.pluginsLock.Unlock()
  69. // start a timer to check if handshake is completed in 10 seconds
  70. time.AfterFunc(time.Second*10, func() {
  71. if !runtime.handshake {
  72. // close connection
  73. c.Close()
  74. }
  75. })
  76. // verified
  77. verified := true
  78. if verified {
  79. return nil, gnet.None
  80. }
  81. return nil, gnet.Close
  82. }
  83. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  84. // plugin disconnected
  85. s.pluginsLock.Lock()
  86. plugin := s.plugins[c.Fd()]
  87. delete(s.plugins, c.Fd())
  88. s.pluginsLock.Unlock()
  89. if plugin == nil {
  90. return gnet.None
  91. }
  92. // close plugin
  93. plugin.onDisconnected()
  94. // uninstall plugin
  95. if plugin.assetsTransferred {
  96. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  97. if plugin.installationId != "" {
  98. if err := plugin.Unregister(); err != nil {
  99. log.Error("unregister plugin failed, error: %v", err)
  100. }
  101. }
  102. // decrease current connection
  103. atomic.AddInt32(&s.currentConn, -1)
  104. }
  105. }
  106. // send stopped event
  107. plugin.waitChanLock.Lock()
  108. for _, c := range plugin.waitStoppedChan {
  109. select {
  110. case c <- true:
  111. default:
  112. }
  113. }
  114. plugin.waitChanLock.Unlock()
  115. // recycle launched chan, avoid memory leak
  116. plugin.waitLaunchedChanOnce.Do(func() {
  117. close(plugin.waitLaunchedChan)
  118. })
  119. return gnet.None
  120. }
  121. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  122. close(s.shutdownChan)
  123. }
  124. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  125. codec := c.Context().(*codec)
  126. messages, err := codec.Decode(c)
  127. if err != nil {
  128. return gnet.Close
  129. }
  130. // get plugin runtime
  131. s.pluginsLock.RLock()
  132. runtime, ok := s.plugins[c.Fd()]
  133. s.pluginsLock.RUnlock()
  134. if !ok {
  135. return gnet.Close
  136. }
  137. // handle messages
  138. for _, message := range messages {
  139. if len(message) == 0 {
  140. continue
  141. }
  142. s.onMessage(runtime, message)
  143. }
  144. return gnet.None
  145. }
  146. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  147. // handle message
  148. if runtime.handshakeFailed {
  149. // do nothing if handshake has failed
  150. return
  151. }
  152. closeConn := func(message []byte) {
  153. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  154. runtime.conn.Write(message)
  155. runtime.conn.Close()
  156. }
  157. }
  158. if !runtime.initialized {
  159. registerPayload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
  160. if err != nil {
  161. // close connection if handshake failed
  162. closeConn([]byte("handshake failed, invalid handshake message\n"))
  163. runtime.handshakeFailed = true
  164. return
  165. }
  166. if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
  167. if runtime.handshake {
  168. // handshake already completed
  169. return
  170. }
  171. key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](registerPayload.Data)
  172. if err != nil {
  173. // close connection if handshake failed
  174. closeConn([]byte("handshake failed, invalid handshake message\n"))
  175. runtime.handshakeFailed = true
  176. return
  177. }
  178. info, err := GetConnectionInfo(key.Key)
  179. if err == cache.ErrNotFound {
  180. // close connection if handshake failed
  181. closeConn([]byte("handshake failed, invalid key\n"))
  182. runtime.handshakeFailed = true
  183. return
  184. } else if err != nil {
  185. // close connection if handshake failed
  186. log.Error("failed to get connection info: %v", err)
  187. closeConn([]byte("internal error\n"))
  188. return
  189. }
  190. runtime.tenantId = info.TenantId
  191. // handshake completed
  192. runtime.handshake = true
  193. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
  194. if runtime.assetsTransferred {
  195. return
  196. }
  197. assetChunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](registerPayload.Data)
  198. if err != nil {
  199. log.Error("assets register failed, error: %v", err)
  200. closeConn([]byte("assets register failed, invalid assets chunk\n"))
  201. return
  202. }
  203. buffer, ok := runtime.assets[assetChunk.Filename]
  204. if !ok {
  205. runtime.assets[assetChunk.Filename] = &bytes.Buffer{}
  206. buffer = runtime.assets[assetChunk.Filename]
  207. }
  208. // allows at most 50MB assets
  209. if runtime.assetsBytes+int64(len(assetChunk.Data)) > 50*1024*1024 {
  210. closeConn([]byte("assets too large, at most 50MB\n"))
  211. return
  212. }
  213. // decode as base64
  214. data, err := base64.StdEncoding.DecodeString(assetChunk.Data)
  215. if err != nil {
  216. log.Error("assets decode failed, error: %v", err)
  217. closeConn([]byte("assets decode failed, invalid assets data\n"))
  218. return
  219. }
  220. buffer.Write(data)
  221. // update assets bytes
  222. runtime.assetsBytes += int64(len(data))
  223. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
  224. if !runtime.modelsRegistrationTransferred &&
  225. !runtime.endpointsRegistrationTransferred &&
  226. !runtime.toolsRegistrationTransferred &&
  227. !runtime.agentStrategyRegistrationTransferred {
  228. closeConn([]byte("no registration transferred, cannot initialize\n"))
  229. return
  230. }
  231. files := make(map[string][]byte)
  232. for filename, buffer := range runtime.assets {
  233. files[filename] = buffer.Bytes()
  234. }
  235. // remap assets
  236. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  237. log.Error("assets remap failed, error: %v", err)
  238. closeConn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
  239. return
  240. }
  241. atomic.AddInt32(&s.currentConn, 1)
  242. if atomic.LoadInt32(&s.currentConn) > int32(s.maxConn) {
  243. closeConn([]byte("server is busy now, please try again later\n"))
  244. return
  245. }
  246. // fill in default values
  247. runtime.Config.FillInDefaultValues()
  248. // mark assets transferred
  249. runtime.assetsTransferred = true
  250. runtime.checksum = runtime.calculateChecksum()
  251. runtime.InitState()
  252. runtime.SetActiveAt(time.Now())
  253. // trigger registration event
  254. if err := runtime.Register(); err != nil {
  255. closeConn([]byte(fmt.Sprintf("register failed, cannot register: %v\n", err)))
  256. return
  257. }
  258. // send started event
  259. runtime.waitChanLock.Lock()
  260. for _, c := range runtime.waitStartedChan {
  261. select {
  262. case c <- true:
  263. default:
  264. }
  265. }
  266. runtime.waitChanLock.Unlock()
  267. // notify launched
  268. runtime.waitLaunchedChanOnce.Do(func() {
  269. close(runtime.waitLaunchedChan)
  270. })
  271. // mark initialized
  272. runtime.initialized = true
  273. // publish runtime to watcher
  274. s.response.Write(runtime)
  275. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
  276. if runtime.registrationTransferred {
  277. return
  278. }
  279. // process handle shake if not completed
  280. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](registerPayload.Data)
  281. if err != nil {
  282. // close connection if handshake failed
  283. closeConn([]byte(fmt.Sprintf("handshake failed, invalid plugin declaration: %v\n", err)))
  284. return
  285. }
  286. runtime.Config = declaration
  287. // registration transferred
  288. runtime.registrationTransferred = true
  289. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
  290. if runtime.toolsRegistrationTransferred {
  291. return
  292. }
  293. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](registerPayload.Data)
  294. if err != nil {
  295. closeConn([]byte(fmt.Sprintf("tools register failed, invalid tools declaration: %v\n", err)))
  296. return
  297. }
  298. runtime.toolsRegistrationTransferred = true
  299. if len(tools) > 0 {
  300. declaration := runtime.Config
  301. declaration.Tool = &tools[0]
  302. runtime.Config = declaration
  303. }
  304. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
  305. if runtime.modelsRegistrationTransferred {
  306. return
  307. }
  308. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](registerPayload.Data)
  309. if err != nil {
  310. closeConn([]byte(fmt.Sprintf("models register failed, invalid models declaration: %v\n", err)))
  311. return
  312. }
  313. runtime.modelsRegistrationTransferred = true
  314. if len(models) > 0 {
  315. declaration := runtime.Config
  316. declaration.Model = &models[0]
  317. runtime.Config = declaration
  318. }
  319. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
  320. if runtime.endpointsRegistrationTransferred {
  321. return
  322. }
  323. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](registerPayload.Data)
  324. if err != nil {
  325. closeConn([]byte(fmt.Sprintf("endpoints register failed, invalid endpoints declaration: %v\n", err)))
  326. return
  327. }
  328. runtime.endpointsRegistrationTransferred = true
  329. if len(endpoints) > 0 {
  330. declaration := runtime.Config
  331. declaration.Endpoint = &endpoints[0]
  332. runtime.Config = declaration
  333. }
  334. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_AGENT_STRATEGY_DECLARATION {
  335. if runtime.agentStrategyRegistrationTransferred {
  336. return
  337. }
  338. agents, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.AgentStrategyProviderDeclaration](registerPayload.Data)
  339. if err != nil {
  340. closeConn([]byte(fmt.Sprintf("agent strategies register failed, invalid agent strategies declaration: %v\n", err)))
  341. return
  342. }
  343. runtime.agentStrategyRegistrationTransferred = true
  344. if len(agents) > 0 {
  345. declaration := runtime.Config
  346. declaration.AgentStrategy = &agents[0]
  347. runtime.Config = declaration
  348. }
  349. }
  350. } else {
  351. // continue handle messages if handshake completed
  352. runtime.response.Write(message)
  353. }
  354. }