|
@@ -1,7 +1,8 @@
|
|
|
package remote_manager
|
|
|
|
|
|
import (
|
|
|
- "encoding/hex"
|
|
|
+ "bytes"
|
|
|
+ "encoding/base64"
|
|
|
"fmt"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
@@ -69,6 +70,9 @@ func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
|
|
callbacks: make(map[string][]func([]byte)),
|
|
|
callbacks_lock: &sync.RWMutex{},
|
|
|
|
|
|
+ assets: make(map[string]*bytes.Buffer),
|
|
|
+ assets_bytes: 0,
|
|
|
+
|
|
|
shutdown_chan: make(chan bool),
|
|
|
wait_launched_chan: make(chan error),
|
|
|
|
|
@@ -186,148 +190,223 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if !runtime.handshake {
|
|
|
- key := string(message)
|
|
|
-
|
|
|
- info, err := GetConnectionInfo(key)
|
|
|
- if err == cache.ErrNotFound {
|
|
|
+ if !runtime.initialized {
|
|
|
+ register_payload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
|
|
|
+ if err != nil {
|
|
|
// close connection if handshake failed
|
|
|
- close_conn([]byte("handshake failed, invalid key\n"))
|
|
|
+ close_conn([]byte("handshake failed, invalid handshake message\n"))
|
|
|
runtime.handshake_failed = true
|
|
|
return
|
|
|
- } else if err != nil {
|
|
|
- // close connection if handshake failed
|
|
|
- close_conn([]byte("internal error\n"))
|
|
|
- return
|
|
|
}
|
|
|
|
|
|
- runtime.tenant_id = info.TenantId
|
|
|
+ if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
|
|
|
+ if runtime.handshake {
|
|
|
+ // handshake already completed
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // handshake completed
|
|
|
- runtime.handshake = true
|
|
|
- } else if !runtime.registration_transferred {
|
|
|
- // process handle shake if not completed
|
|
|
- declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
|
|
|
- if err != nil {
|
|
|
- // close connection if handshake failed
|
|
|
- close_conn([]byte("handshake failed, invalid plugin declaration\n"))
|
|
|
- return
|
|
|
- }
|
|
|
+ key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](register_payload.Data)
|
|
|
+ if err != nil {
|
|
|
+ // close connection if handshake failed
|
|
|
+ close_conn([]byte("handshake failed, invalid key\n"))
|
|
|
+ runtime.handshake_failed = true
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- runtime.Config = declaration
|
|
|
+ info, err := GetConnectionInfo(key.Key)
|
|
|
+ if err == cache.ErrNotFound {
|
|
|
+ // close connection if handshake failed
|
|
|
+ close_conn([]byte("handshake failed, invalid key\n"))
|
|
|
+ runtime.handshake_failed = true
|
|
|
+ return
|
|
|
+ } else if err != nil {
|
|
|
+ // close connection if handshake failed
|
|
|
+ close_conn([]byte("internal error\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // registration transferred
|
|
|
- runtime.registration_transferred = true
|
|
|
- } else if !runtime.tools_registration_transferred {
|
|
|
- tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
|
|
|
- if err != nil {
|
|
|
- log.Error("tools register failed, error: %v", err)
|
|
|
- close_conn([]byte("tools register failed, invalid tools declaration\n"))
|
|
|
- return
|
|
|
- }
|
|
|
+ runtime.tenant_id = info.TenantId
|
|
|
|
|
|
- runtime.tools_registration_transferred = true
|
|
|
+ // handshake completed
|
|
|
+ runtime.handshake = true
|
|
|
+ } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
|
|
|
+ if runtime.assets_transferred {
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- if len(tools) > 0 {
|
|
|
- declaration := runtime.Config
|
|
|
- declaration.Tool = &tools[0]
|
|
|
- runtime.Config = declaration
|
|
|
- }
|
|
|
- } else if !runtime.models_registration_transferred {
|
|
|
- models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
|
|
|
- if err != nil {
|
|
|
- log.Error("models register failed, error: %v", err)
|
|
|
- close_conn([]byte("models register failed, invalid models declaration\n"))
|
|
|
- return
|
|
|
- }
|
|
|
+ asset_chunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](register_payload.Data)
|
|
|
+ if err != nil {
|
|
|
+ log.Error("assets register failed, error: %v", err)
|
|
|
+ close_conn([]byte("assets register failed, invalid assets chunk\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- runtime.models_registration_transferred = true
|
|
|
+ buffer, ok := runtime.assets[asset_chunk.Filename]
|
|
|
+ if !ok {
|
|
|
+ runtime.assets[asset_chunk.Filename] = &bytes.Buffer{}
|
|
|
+ buffer = runtime.assets[asset_chunk.Filename]
|
|
|
+ }
|
|
|
|
|
|
- if len(models) > 0 {
|
|
|
- declaration := runtime.Config
|
|
|
- declaration.Model = &models[0]
|
|
|
- runtime.Config = declaration
|
|
|
- }
|
|
|
- } else if !runtime.endpoints_registration_transferred {
|
|
|
- endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
|
|
|
- if err != nil {
|
|
|
- log.Error("endpoints register failed, error: %v", err)
|
|
|
- close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
|
|
|
- return
|
|
|
- }
|
|
|
+ // allows at most 50MB assets
|
|
|
+ if runtime.assets_bytes+int64(len(asset_chunk.Data)) > 50*1024*1024 {
|
|
|
+ close_conn([]byte("assets too large, at most 50MB\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- runtime.endpoints_registration_transferred = true
|
|
|
+ // decode as base64
|
|
|
+ data, err := base64.StdEncoding.DecodeString(asset_chunk.Data)
|
|
|
+ if err != nil {
|
|
|
+ log.Error("assets decode failed, error: %v", err)
|
|
|
+ close_conn([]byte("assets decode failed, invalid assets data\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- if len(endpoints) > 0 {
|
|
|
- declaration := runtime.Config
|
|
|
- declaration.Endpoint = &endpoints[0]
|
|
|
- runtime.Config = declaration
|
|
|
- }
|
|
|
- } else if !runtime.assets_transferred {
|
|
|
- assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
|
|
|
- if err != nil {
|
|
|
- log.Error("assets register failed, error: %v", err)
|
|
|
- close_conn([]byte(fmt.Sprintf("assets register failed, invalid assets declaration: %v\n", err)))
|
|
|
- return
|
|
|
- }
|
|
|
+ buffer.Write(data)
|
|
|
+
|
|
|
+ // update assets bytes
|
|
|
+ runtime.assets_bytes += int64(len(data))
|
|
|
+ } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
|
|
|
+ if !runtime.models_registration_transferred &&
|
|
|
+ !runtime.endpoints_registration_transferred &&
|
|
|
+ !runtime.tools_registration_transferred {
|
|
|
+ close_conn([]byte("no registration transferred, cannot initialize\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ files := make(map[string][]byte)
|
|
|
+ for filename, buffer := range runtime.assets {
|
|
|
+ files[filename] = buffer.Bytes()
|
|
|
+ }
|
|
|
+
|
|
|
+ // remap assets
|
|
|
+ if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
|
|
|
+ log.Error("assets remap failed, error: %v", err)
|
|
|
+ close_conn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ atomic.AddInt32(&s.current_conn, 1)
|
|
|
+ if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
|
|
|
+ close_conn([]byte("server is busy now, please try again later\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // fill in default values
|
|
|
+ runtime.Config.FillInDefaultValues()
|
|
|
+
|
|
|
+ // mark assets transferred
|
|
|
+ runtime.assets_transferred = true
|
|
|
+
|
|
|
+ runtime.checksum = runtime.calculateChecksum()
|
|
|
+ runtime.InitState()
|
|
|
+ runtime.SetActiveAt(time.Now())
|
|
|
+
|
|
|
+ // trigger registration event
|
|
|
+ if err := runtime.Register(); err != nil {
|
|
|
+ log.Error("register failed, error: %v", err)
|
|
|
+ close_conn([]byte("register failed, cannot register\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- files := make(map[string][]byte)
|
|
|
- for _, asset := range assets {
|
|
|
- files[asset.Filename], err = hex.DecodeString(asset.Data)
|
|
|
+ // send started event
|
|
|
+ runtime.wait_chan_lock.Lock()
|
|
|
+ for _, c := range runtime.wait_started_chan {
|
|
|
+ select {
|
|
|
+ case c <- true:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ }
|
|
|
+ runtime.wait_chan_lock.Unlock()
|
|
|
+
|
|
|
+ // notify launched
|
|
|
+ runtime.wait_launched_chan_once.Do(func() {
|
|
|
+ close(runtime.wait_launched_chan)
|
|
|
+ })
|
|
|
+
|
|
|
+ // mark initialized
|
|
|
+ runtime.initialized = true
|
|
|
+
|
|
|
+ // publish runtime to watcher
|
|
|
+ s.response.Write(runtime)
|
|
|
+ } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
|
|
|
+ if runtime.registration_transferred {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // process handle shake if not completed
|
|
|
+ declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](register_payload.Data)
|
|
|
if err != nil {
|
|
|
- log.Error("assets decode failed, error: %v", err)
|
|
|
- close_conn([]byte(fmt.Sprintf("assets decode failed, invalid assets data, cannot decode file: %v\n", err)))
|
|
|
+ // close connection if handshake failed
|
|
|
+ close_conn([]byte("handshake failed, invalid plugin declaration\n"))
|
|
|
return
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- // remap assets
|
|
|
- if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
|
|
|
- log.Error("assets remap failed, error: %v", err)
|
|
|
- close_conn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
|
|
|
- return
|
|
|
- }
|
|
|
+ runtime.Config = declaration
|
|
|
|
|
|
- atomic.AddInt32(&s.current_conn, 1)
|
|
|
- if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
|
|
|
- close_conn([]byte("server is busy now, please try again later\n"))
|
|
|
- return
|
|
|
- }
|
|
|
+ // registration transferred
|
|
|
+ runtime.registration_transferred = true
|
|
|
+ } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
|
|
|
+ if runtime.tools_registration_transferred {
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // fill in default values
|
|
|
- runtime.Config.FillInDefaultValues()
|
|
|
+ tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](register_payload.Data)
|
|
|
+ if err != nil {
|
|
|
+ log.Error("tools register failed, error: %v", err)
|
|
|
+ close_conn([]byte("tools register failed, invalid tools declaration\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // mark assets transferred
|
|
|
- runtime.assets_transferred = true
|
|
|
+ runtime.tools_registration_transferred = true
|
|
|
|
|
|
- runtime.checksum = runtime.calculateChecksum()
|
|
|
- runtime.InitState()
|
|
|
- runtime.SetActiveAt(time.Now())
|
|
|
+ if len(tools) > 0 {
|
|
|
+ declaration := runtime.Config
|
|
|
+ declaration.Tool = &tools[0]
|
|
|
+ runtime.Config = declaration
|
|
|
+ }
|
|
|
+ } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
|
|
|
+ if runtime.models_registration_transferred {
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // trigger registration event
|
|
|
- if err := runtime.Register(); err != nil {
|
|
|
- log.Error("register failed, error: %v", err)
|
|
|
- close_conn([]byte("register failed, cannot register\n"))
|
|
|
- return
|
|
|
- }
|
|
|
+ models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](register_payload.Data)
|
|
|
+ if err != nil {
|
|
|
+ log.Error("models register failed, error: %v", err)
|
|
|
+ close_conn([]byte("models register failed, invalid models declaration\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ runtime.models_registration_transferred = true
|
|
|
|
|
|
- // send started event
|
|
|
- runtime.wait_chan_lock.Lock()
|
|
|
- for _, c := range runtime.wait_started_chan {
|
|
|
- select {
|
|
|
- case c <- true:
|
|
|
- default:
|
|
|
+ if len(models) > 0 {
|
|
|
+ declaration := runtime.Config
|
|
|
+ declaration.Model = &models[0]
|
|
|
+ runtime.Config = declaration
|
|
|
+ }
|
|
|
+ } else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
|
|
|
+ if runtime.endpoints_registration_transferred {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](register_payload.Data)
|
|
|
+ if err != nil {
|
|
|
+ log.Error("endpoints register failed, error: %v", err)
|
|
|
+ close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
|
|
|
+ return
|
|
|
}
|
|
|
- }
|
|
|
- runtime.wait_chan_lock.Unlock()
|
|
|
|
|
|
- // notify launched
|
|
|
- runtime.wait_launched_chan_once.Do(func() {
|
|
|
- close(runtime.wait_launched_chan)
|
|
|
- })
|
|
|
+ runtime.endpoints_registration_transferred = true
|
|
|
|
|
|
- // publish runtime to watcher
|
|
|
- s.response.Write(runtime)
|
|
|
+ if len(endpoints) > 0 {
|
|
|
+ declaration := runtime.Config
|
|
|
+ declaration.Endpoint = &endpoints[0]
|
|
|
+ runtime.Config = declaration
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // unknown event type
|
|
|
+ close_conn([]byte("unknown initialization event type\n"))
|
|
|
+ return
|
|
|
+ }
|
|
|
} else {
|
|
|
// continue handle messages if handshake completed
|
|
|
runtime.response.Write(message)
|