hooks.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. package debugging_runtime
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_runtime"
  10. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_transport"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  15. "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
  16. "github.com/panjf2000/gnet/v2"
  17. )
  18. var (
  19. // mode is only used for testing
  20. // TODO: simplify this ugly code
  21. _mode pluginRuntimeMode
  22. )
  23. type DifyServer struct {
  24. gnet.BuiltinEventEngine
  25. engine gnet.Engine
  26. mediaManager *media_transport.MediaBucket
  27. // listening address
  28. addr string
  29. port uint16
  30. // enabled multicore
  31. multicore bool
  32. // event loop count
  33. numLoops int
  34. // read new connections
  35. response *stream.Stream[plugin_entities.PluginFullDuplexLifetime]
  36. plugins map[int]*RemotePluginRuntime
  37. pluginsLock *sync.RWMutex
  38. shutdownChan chan bool
  39. maxConn int32
  40. currentConn int32
  41. }
  42. func (s *DifyServer) OnBoot(c gnet.Engine) (action gnet.Action) {
  43. s.engine = c
  44. return gnet.None
  45. }
  46. func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
  47. // new plugin connected
  48. c.SetContext(&codec{})
  49. runtime := &RemotePluginRuntime{
  50. MediaTransport: basic_runtime.NewMediaTransport(
  51. s.mediaManager,
  52. ),
  53. conn: c,
  54. response: stream.NewStream[[]byte](512),
  55. messageCallbacks: make(map[string][]func([]byte)),
  56. messageCallbacksLock: &sync.RWMutex{},
  57. sessionMessageClosers: make(map[string][]func()),
  58. sessionMessageClosersLock: &sync.RWMutex{},
  59. assets: make(map[string]*bytes.Buffer),
  60. assetsBytes: 0,
  61. shutdownChan: make(chan bool),
  62. waitLaunchedChan: make(chan error),
  63. alive: true,
  64. }
  65. // store plugin runtime
  66. s.pluginsLock.Lock()
  67. s.plugins[c.Fd()] = runtime
  68. s.pluginsLock.Unlock()
  69. // start a timer to check if handshake is completed in 10 seconds
  70. time.AfterFunc(time.Second*10, func() {
  71. if !runtime.handshake {
  72. // close connection
  73. c.Close()
  74. }
  75. })
  76. // verified
  77. verified := true
  78. if verified {
  79. return nil, gnet.None
  80. }
  81. return nil, gnet.Close
  82. }
  83. func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
  84. // plugin disconnected
  85. s.pluginsLock.Lock()
  86. plugin := s.plugins[c.Fd()]
  87. delete(s.plugins, c.Fd())
  88. s.pluginsLock.Unlock()
  89. if plugin == nil {
  90. return gnet.None
  91. }
  92. // close plugin
  93. plugin.onDisconnected()
  94. // uninstall plugin
  95. if plugin.assetsTransferred {
  96. if _mode != _PLUGIN_RUNTIME_MODE_CI {
  97. if plugin.installationId != "" {
  98. if err := plugin.Unregister(); err != nil {
  99. log.Error("unregister plugin failed, error: %v", err)
  100. }
  101. }
  102. // decrease current connection
  103. atomic.AddInt32(&s.currentConn, -1)
  104. }
  105. }
  106. // send stopped event
  107. plugin.waitChanLock.Lock()
  108. for _, c := range plugin.waitStoppedChan {
  109. select {
  110. case c <- true:
  111. default:
  112. }
  113. }
  114. plugin.waitChanLock.Unlock()
  115. // recycle launched chan, avoid memory leak
  116. plugin.waitLaunchedChanOnce.Do(func() {
  117. close(plugin.waitLaunchedChan)
  118. })
  119. return gnet.None
  120. }
  121. func (s *DifyServer) OnShutdown(c gnet.Engine) {
  122. close(s.shutdownChan)
  123. }
  124. func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
  125. codec := c.Context().(*codec)
  126. messages, err := codec.Decode(c)
  127. if err != nil {
  128. return gnet.Close
  129. }
  130. // get plugin runtime
  131. s.pluginsLock.RLock()
  132. runtime, ok := s.plugins[c.Fd()]
  133. s.pluginsLock.RUnlock()
  134. if !ok {
  135. return gnet.Close
  136. }
  137. // handle messages
  138. for _, message := range messages {
  139. if len(message) == 0 {
  140. continue
  141. }
  142. s.onMessage(runtime, message)
  143. }
  144. return gnet.None
  145. }
  146. func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
  147. // handle message
  148. if runtime.handshakeFailed {
  149. // do nothing if handshake has failed
  150. return
  151. }
  152. closeConn := func(message []byte) {
  153. if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
  154. runtime.conn.Write(message)
  155. runtime.conn.Close()
  156. }
  157. }
  158. if !runtime.initialized {
  159. registerPayload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
  160. if err != nil {
  161. // close connection if handshake failed
  162. closeConn([]byte("handshake failed, invalid handshake message\n"))
  163. runtime.handshakeFailed = true
  164. return
  165. }
  166. if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
  167. if runtime.handshake {
  168. // handshake already completed
  169. return
  170. }
  171. key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](registerPayload.Data)
  172. if err != nil {
  173. // close connection if handshake failed
  174. closeConn([]byte("handshake failed, invalid handshake message\n"))
  175. runtime.handshakeFailed = true
  176. return
  177. }
  178. info, err := GetConnectionInfo(key.Key)
  179. if err == cache.ErrNotFound {
  180. // close connection if handshake failed
  181. closeConn([]byte("handshake failed, invalid key\n"))
  182. runtime.handshakeFailed = true
  183. return
  184. } else if err != nil {
  185. // close connection if handshake failed
  186. closeConn([]byte("internal error\n"))
  187. return
  188. }
  189. runtime.tenantId = info.TenantId
  190. // handshake completed
  191. runtime.handshake = true
  192. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
  193. if runtime.assetsTransferred {
  194. return
  195. }
  196. assetChunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](registerPayload.Data)
  197. if err != nil {
  198. log.Error("assets register failed, error: %v", err)
  199. closeConn([]byte("assets register failed, invalid assets chunk\n"))
  200. return
  201. }
  202. buffer, ok := runtime.assets[assetChunk.Filename]
  203. if !ok {
  204. runtime.assets[assetChunk.Filename] = &bytes.Buffer{}
  205. buffer = runtime.assets[assetChunk.Filename]
  206. }
  207. // allows at most 50MB assets
  208. if runtime.assetsBytes+int64(len(assetChunk.Data)) > 50*1024*1024 {
  209. closeConn([]byte("assets too large, at most 50MB\n"))
  210. return
  211. }
  212. // decode as base64
  213. data, err := base64.StdEncoding.DecodeString(assetChunk.Data)
  214. if err != nil {
  215. log.Error("assets decode failed, error: %v", err)
  216. closeConn([]byte("assets decode failed, invalid assets data\n"))
  217. return
  218. }
  219. buffer.Write(data)
  220. // update assets bytes
  221. runtime.assetsBytes += int64(len(data))
  222. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
  223. if !runtime.modelsRegistrationTransferred &&
  224. !runtime.endpointsRegistrationTransferred &&
  225. !runtime.toolsRegistrationTransferred &&
  226. !runtime.agentStrategyRegistrationTransferred {
  227. closeConn([]byte("no registration transferred, cannot initialize\n"))
  228. return
  229. }
  230. files := make(map[string][]byte)
  231. for filename, buffer := range runtime.assets {
  232. files[filename] = buffer.Bytes()
  233. }
  234. // remap assets
  235. if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
  236. log.Error("assets remap failed, error: %v", err)
  237. closeConn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
  238. return
  239. }
  240. atomic.AddInt32(&s.currentConn, 1)
  241. if atomic.LoadInt32(&s.currentConn) > int32(s.maxConn) {
  242. closeConn([]byte("server is busy now, please try again later\n"))
  243. return
  244. }
  245. // fill in default values
  246. runtime.Config.FillInDefaultValues()
  247. // mark assets transferred
  248. runtime.assetsTransferred = true
  249. runtime.checksum = runtime.calculateChecksum()
  250. runtime.InitState()
  251. runtime.SetActiveAt(time.Now())
  252. // trigger registration event
  253. if err := runtime.Register(); err != nil {
  254. closeConn([]byte(fmt.Sprintf("register failed, cannot register: %v\n", err)))
  255. return
  256. }
  257. // send started event
  258. runtime.waitChanLock.Lock()
  259. for _, c := range runtime.waitStartedChan {
  260. select {
  261. case c <- true:
  262. default:
  263. }
  264. }
  265. runtime.waitChanLock.Unlock()
  266. // notify launched
  267. runtime.waitLaunchedChanOnce.Do(func() {
  268. close(runtime.waitLaunchedChan)
  269. })
  270. // mark initialized
  271. runtime.initialized = true
  272. // publish runtime to watcher
  273. s.response.Write(runtime)
  274. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
  275. if runtime.registrationTransferred {
  276. return
  277. }
  278. // process handle shake if not completed
  279. declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](registerPayload.Data)
  280. if err != nil {
  281. // close connection if handshake failed
  282. closeConn([]byte(fmt.Sprintf("handshake failed, invalid plugin declaration: %v\n", err)))
  283. return
  284. }
  285. runtime.Config = declaration
  286. // registration transferred
  287. runtime.registrationTransferred = true
  288. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
  289. if runtime.toolsRegistrationTransferred {
  290. return
  291. }
  292. tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](registerPayload.Data)
  293. if err != nil {
  294. closeConn([]byte(fmt.Sprintf("tools register failed, invalid tools declaration: %v\n", err)))
  295. return
  296. }
  297. runtime.toolsRegistrationTransferred = true
  298. if len(tools) > 0 {
  299. declaration := runtime.Config
  300. declaration.Tool = &tools[0]
  301. runtime.Config = declaration
  302. }
  303. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
  304. if runtime.modelsRegistrationTransferred {
  305. return
  306. }
  307. models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](registerPayload.Data)
  308. if err != nil {
  309. closeConn([]byte(fmt.Sprintf("models register failed, invalid models declaration: %v\n", err)))
  310. return
  311. }
  312. runtime.modelsRegistrationTransferred = true
  313. if len(models) > 0 {
  314. declaration := runtime.Config
  315. declaration.Model = &models[0]
  316. runtime.Config = declaration
  317. }
  318. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
  319. if runtime.endpointsRegistrationTransferred {
  320. return
  321. }
  322. endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](registerPayload.Data)
  323. if err != nil {
  324. closeConn([]byte(fmt.Sprintf("endpoints register failed, invalid endpoints declaration: %v\n", err)))
  325. return
  326. }
  327. runtime.endpointsRegistrationTransferred = true
  328. if len(endpoints) > 0 {
  329. declaration := runtime.Config
  330. declaration.Endpoint = &endpoints[0]
  331. runtime.Config = declaration
  332. }
  333. } else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_AGENT_STRATEGY_DECLARATION {
  334. if runtime.agentStrategyRegistrationTransferred {
  335. return
  336. }
  337. agents, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.AgentStrategyProviderDeclaration](registerPayload.Data)
  338. if err != nil {
  339. closeConn([]byte(fmt.Sprintf("agent strategies register failed, invalid agent strategies declaration: %v\n", err)))
  340. return
  341. }
  342. runtime.agentStrategyRegistrationTransferred = true
  343. if len(agents) > 0 {
  344. declaration := runtime.Config
  345. declaration.AgentStrategy = &agents[0]
  346. runtime.Config = declaration
  347. }
  348. }
  349. } else {
  350. // continue handle messages if handshake completed
  351. runtime.response.Write(message)
  352. }
  353. }