Ver código fonte

refactor: using initialization event to take over of init processes

Yeuoly 8 meses atrás
pai
commit
f544291c4b

+ 194 - 115
internal/core/plugin_manager/remote_manager/hooks.go

@@ -1,7 +1,8 @@
 package remote_manager
 
 import (
-	"encoding/hex"
+	"bytes"
+	"encoding/base64"
 	"fmt"
 	"sync"
 	"sync/atomic"
@@ -69,6 +70,9 @@ func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
 		callbacks:      make(map[string][]func([]byte)),
 		callbacks_lock: &sync.RWMutex{},
 
+		assets:       make(map[string]*bytes.Buffer),
+		assets_bytes: 0,
+
 		shutdown_chan:      make(chan bool),
 		wait_launched_chan: make(chan error),
 
@@ -186,148 +190,223 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		}
 	}
 
-	if !runtime.handshake {
-		key := string(message)
-
-		info, err := GetConnectionInfo(key)
-		if err == cache.ErrNotFound {
+	if !runtime.initialized {
+		register_payload, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterPayload](message)
+		if err != nil {
 			// close connection if handshake failed
-			close_conn([]byte("handshake failed, invalid key\n"))
+			close_conn([]byte("handshake failed, invalid handshake message\n"))
 			runtime.handshake_failed = true
 			return
-		} else if err != nil {
-			// close connection if handshake failed
-			close_conn([]byte("internal error\n"))
-			return
 		}
 
-		runtime.tenant_id = info.TenantId
+		if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_HAND_SHAKE {
+			if runtime.handshake {
+				// handshake already completed
+				return
+			}
 
-		// handshake completed
-		runtime.handshake = true
-	} else if !runtime.registration_transferred {
-		// process handle shake if not completed
-		declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
-		if err != nil {
-			// close connection if handshake failed
-			close_conn([]byte("handshake failed, invalid plugin declaration\n"))
-			return
-		}
+			key, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterHandshake](register_payload.Data)
+			if err != nil {
+				// close connection if handshake failed
+				close_conn([]byte("handshake failed, invalid key\n"))
+				runtime.handshake_failed = true
+				return
+			}
 
-		runtime.Config = declaration
+			info, err := GetConnectionInfo(key.Key)
+			if err == cache.ErrNotFound {
+				// close connection if handshake failed
+				close_conn([]byte("handshake failed, invalid key\n"))
+				runtime.handshake_failed = true
+				return
+			} else if err != nil {
+				// close connection if handshake failed
+				close_conn([]byte("internal error\n"))
+				return
+			}
 
-		// registration transferred
-		runtime.registration_transferred = true
-	} else if !runtime.tools_registration_transferred {
-		tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
-		if err != nil {
-			log.Error("tools register failed, error: %v", err)
-			close_conn([]byte("tools register failed, invalid tools declaration\n"))
-			return
-		}
+			runtime.tenant_id = info.TenantId
 
-		runtime.tools_registration_transferred = true
+			// handshake completed
+			runtime.handshake = true
+		} else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_ASSET_CHUNK {
+			if runtime.assets_transferred {
+				return
+			}
 
-		if len(tools) > 0 {
-			declaration := runtime.Config
-			declaration.Tool = &tools[0]
-			runtime.Config = declaration
-		}
-	} else if !runtime.models_registration_transferred {
-		models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
-		if err != nil {
-			log.Error("models register failed, error: %v", err)
-			close_conn([]byte("models register failed, invalid models declaration\n"))
-			return
-		}
+			asset_chunk, err := parser.UnmarshalJsonBytes[plugin_entities.RemotePluginRegisterAssetChunk](register_payload.Data)
+			if err != nil {
+				log.Error("assets register failed, error: %v", err)
+				close_conn([]byte("assets register failed, invalid assets chunk\n"))
+				return
+			}
 
-		runtime.models_registration_transferred = true
+			buffer, ok := runtime.assets[asset_chunk.Filename]
+			if !ok {
+				runtime.assets[asset_chunk.Filename] = &bytes.Buffer{}
+				buffer = runtime.assets[asset_chunk.Filename]
+			}
 
-		if len(models) > 0 {
-			declaration := runtime.Config
-			declaration.Model = &models[0]
-			runtime.Config = declaration
-		}
-	} else if !runtime.endpoints_registration_transferred {
-		endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
-		if err != nil {
-			log.Error("endpoints register failed, error: %v", err)
-			close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
-			return
-		}
+			// allows at most 50MB assets
+			if runtime.assets_bytes+int64(len(asset_chunk.Data)) > 50*1024*1024 {
+				close_conn([]byte("assets too large, at most 50MB\n"))
+				return
+			}
 
-		runtime.endpoints_registration_transferred = true
+			// decode as base64
+			data, err := base64.StdEncoding.DecodeString(asset_chunk.Data)
+			if err != nil {
+				log.Error("assets decode failed, error: %v", err)
+				close_conn([]byte("assets decode failed, invalid assets data\n"))
+				return
+			}
 
-		if len(endpoints) > 0 {
-			declaration := runtime.Config
-			declaration.Endpoint = &endpoints[0]
-			runtime.Config = declaration
-		}
-	} else if !runtime.assets_transferred {
-		assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
-		if err != nil {
-			log.Error("assets register failed, error: %v", err)
-			close_conn([]byte(fmt.Sprintf("assets register failed, invalid assets declaration: %v\n", err)))
-			return
-		}
+			buffer.Write(data)
+
+			// update assets bytes
+			runtime.assets_bytes += int64(len(data))
+		} else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
+			if !runtime.models_registration_transferred &&
+				!runtime.endpoints_registration_transferred &&
+				!runtime.tools_registration_transferred {
+				close_conn([]byte("no registration transferred, cannot initialize\n"))
+				return
+			}
+
+			files := make(map[string][]byte)
+			for filename, buffer := range runtime.assets {
+				files[filename] = buffer.Bytes()
+			}
+
+			// remap assets
+			if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
+				log.Error("assets remap failed, error: %v", err)
+				close_conn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
+				return
+			}
+
+			atomic.AddInt32(&s.current_conn, 1)
+			if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
+				close_conn([]byte("server is busy now, please try again later\n"))
+				return
+			}
+
+			// fill in default values
+			runtime.Config.FillInDefaultValues()
+
+			// mark assets transferred
+			runtime.assets_transferred = true
+
+			runtime.checksum = runtime.calculateChecksum()
+			runtime.InitState()
+			runtime.SetActiveAt(time.Now())
+
+			// trigger registration event
+			if err := runtime.Register(); err != nil {
+				log.Error("register failed, error: %v", err)
+				close_conn([]byte("register failed, cannot register\n"))
+				return
+			}
 
-		files := make(map[string][]byte)
-		for _, asset := range assets {
-			files[asset.Filename], err = hex.DecodeString(asset.Data)
+			// send started event
+			runtime.wait_chan_lock.Lock()
+			for _, c := range runtime.wait_started_chan {
+				select {
+				case c <- true:
+				default:
+				}
+			}
+			runtime.wait_chan_lock.Unlock()
+
+			// notify launched
+			runtime.wait_launched_chan_once.Do(func() {
+				close(runtime.wait_launched_chan)
+			})
+
+			// mark initialized
+			runtime.initialized = true
+
+			// publish runtime to watcher
+			s.response.Write(runtime)
+		} else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_MANIFEST_DECLARATION {
+			if runtime.registration_transferred {
+				return
+			}
+
+			// process handle shake if not completed
+			declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](register_payload.Data)
 			if err != nil {
-				log.Error("assets decode failed, error: %v", err)
-				close_conn([]byte(fmt.Sprintf("assets decode failed, invalid assets data, cannot decode file: %v\n", err)))
+				// close connection if handshake failed
+				close_conn([]byte("handshake failed, invalid plugin declaration\n"))
 				return
 			}
-		}
 
-		// remap assets
-		if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
-			log.Error("assets remap failed, error: %v", err)
-			close_conn([]byte(fmt.Sprintf("assets remap failed, invalid assets data, cannot remap: %v\n", err)))
-			return
-		}
+			runtime.Config = declaration
 
-		atomic.AddInt32(&s.current_conn, 1)
-		if atomic.LoadInt32(&s.current_conn) > int32(s.max_conn) {
-			close_conn([]byte("server is busy now, please try again later\n"))
-			return
-		}
+			// registration transferred
+			runtime.registration_transferred = true
+		} else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_TOOL_DECLARATION {
+			if runtime.tools_registration_transferred {
+				return
+			}
 
-		// fill in default values
-		runtime.Config.FillInDefaultValues()
+			tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](register_payload.Data)
+			if err != nil {
+				log.Error("tools register failed, error: %v", err)
+				close_conn([]byte("tools register failed, invalid tools declaration\n"))
+				return
+			}
 
-		// mark assets transferred
-		runtime.assets_transferred = true
+			runtime.tools_registration_transferred = true
 
-		runtime.checksum = runtime.calculateChecksum()
-		runtime.InitState()
-		runtime.SetActiveAt(time.Now())
+			if len(tools) > 0 {
+				declaration := runtime.Config
+				declaration.Tool = &tools[0]
+				runtime.Config = declaration
+			}
+		} else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_MODEL_DECLARATION {
+			if runtime.models_registration_transferred {
+				return
+			}
 
-		// trigger registration event
-		if err := runtime.Register(); err != nil {
-			log.Error("register failed, error: %v", err)
-			close_conn([]byte("register failed, cannot register\n"))
-			return
-		}
+			models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](register_payload.Data)
+			if err != nil {
+				log.Error("models register failed, error: %v", err)
+				close_conn([]byte("models register failed, invalid models declaration\n"))
+				return
+			}
+
+			runtime.models_registration_transferred = true
 
-		// send started event
-		runtime.wait_chan_lock.Lock()
-		for _, c := range runtime.wait_started_chan {
-			select {
-			case c <- true:
-			default:
+			if len(models) > 0 {
+				declaration := runtime.Config
+				declaration.Model = &models[0]
+				runtime.Config = declaration
+			}
+		} else if register_payload.Type == plugin_entities.REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION {
+			if runtime.endpoints_registration_transferred {
+				return
+			}
+
+			endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](register_payload.Data)
+			if err != nil {
+				log.Error("endpoints register failed, error: %v", err)
+				close_conn([]byte("endpoints register failed, invalid endpoints declaration\n"))
+				return
 			}
-		}
-		runtime.wait_chan_lock.Unlock()
 
-		// notify launched
-		runtime.wait_launched_chan_once.Do(func() {
-			close(runtime.wait_launched_chan)
-		})
+			runtime.endpoints_registration_transferred = true
 
-		// publish runtime to watcher
-		s.response.Write(runtime)
+			if len(endpoints) > 0 {
+				declaration := runtime.Config
+				declaration.Endpoint = &endpoints[0]
+				runtime.Config = declaration
+			}
+		} else {
+			// unknown event type
+			close_conn([]byte("unknown initialization event type\n"))
+			return
+		}
 	} else {
 		// continue handle messages if handshake completed
 		runtime.response.Write(message)

+ 7 - 0
internal/core/plugin_manager/remote_manager/type.go

@@ -1,6 +1,7 @@
 package remote_manager
 
 import (
+	"bytes"
 	"sync"
 	"time"
 
@@ -35,10 +36,16 @@ type RemotePluginRuntime struct {
 	// heartbeat
 	last_active_at time.Time
 
+	assets       map[string]*bytes.Buffer
+	assets_bytes int64
+
 	// hand shake process completed
 	handshake        bool
 	handshake_failed bool
 
+	// initialized, wether registration transferred
+	initialized bool
+
 	// registration transferred
 	registration_transferred bool
 

+ 33 - 7
internal/types/entities/plugin_entities/plugin_declaration.go

@@ -198,18 +198,17 @@ type PluginDeclarationWithoutAdvancedFields struct {
 }
 
 func (p *PluginDeclarationWithoutAdvancedFields) UnmarshalJSON(data []byte) error {
-	type alias PluginDeclarationWithoutAdvancedFields
-
-	var temp struct {
-		alias
+	type Alias PluginDeclarationWithoutAdvancedFields
+	aux := &struct {
+		*Alias
+	}{
+		Alias: (*Alias)(p),
 	}
 
-	if err := json.Unmarshal(data, &temp); err != nil {
+	if err := json.Unmarshal(data, aux); err != nil {
 		return err
 	}
 
-	*p = PluginDeclarationWithoutAdvancedFields(temp.alias)
-
 	if p.Tags == nil {
 		p.Tags = []PluginTag{}
 	}
@@ -225,6 +224,33 @@ type PluginDeclaration struct {
 	Tool                                   *ToolProviderDeclaration     `json:"tool,omitempty" yaml:"tool,omitempty" validate:"omitempty"`
 }
 
+func (p *PluginDeclaration) UnmarshalJSON(data []byte) error {
+	// First unmarshal the embedded struct
+	if err := json.Unmarshal(data, &p.PluginDeclarationWithoutAdvancedFields); err != nil {
+		return err
+	}
+
+	// Then unmarshal the remaining fields
+	type PluginExtra struct {
+		Verified bool                         `json:"verified"`
+		Endpoint *EndpointProviderDeclaration `json:"endpoint,omitempty"`
+		Model    *ModelProviderDeclaration    `json:"model,omitempty"`
+		Tool     *ToolProviderDeclaration     `json:"tool,omitempty"`
+	}
+
+	var extra PluginExtra
+	if err := json.Unmarshal(data, &extra); err != nil {
+		return err
+	}
+
+	p.Verified = extra.Verified
+	p.Endpoint = extra.Endpoint
+	p.Model = extra.Model
+	p.Tool = extra.Tool
+
+	return nil
+}
+
 func (p *PluginDeclaration) MarshalJSON() ([]byte, error) {
 	// TODO: performance issue, need a better way to do this
 	c := *p

+ 29 - 0
internal/types/entities/plugin_entities/remote_entities.go

@@ -1,6 +1,35 @@
 package plugin_entities
 
+import "encoding/json"
+
 type RemoteAssetPayload struct {
 	Filename string `json:"filename" validate:"required"`
 	Data     string `json:"data" validate:"required"`
 }
+
+type RemotePluginRegisterEventType string
+
+const (
+	REGISTER_EVENT_TYPE_HAND_SHAKE           RemotePluginRegisterEventType = "handshake"
+	REGISTER_EVENT_TYPE_ASSET_CHUNK          RemotePluginRegisterEventType = "asset_chunk"
+	REGISTER_EVENT_TYPE_MANIFEST_DECLARATION RemotePluginRegisterEventType = "manifest_declaration"
+	REGISTER_EVENT_TYPE_TOOL_DECLARATION     RemotePluginRegisterEventType = "tool_declaration"
+	REGISTER_EVENT_TYPE_MODEL_DECLARATION    RemotePluginRegisterEventType = "model_declaration"
+	REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION RemotePluginRegisterEventType = "endpoint_declaration"
+	REGISTER_EVENT_TYPE_END                  RemotePluginRegisterEventType = "end"
+)
+
+type RemotePluginRegisterAssetChunk struct {
+	Filename string `json:"filename" validate:"required"`
+	Data     string `json:"data" validate:"required"`
+	End      bool   `json:"end"` // if true, it's the last chunk of the file
+}
+
+type RemotePluginRegisterHandshake struct {
+	Key string `json:"key" validate:"required"`
+}
+
+type RemotePluginRegisterPayload struct {
+	Type RemotePluginRegisterEventType `json:"type" validate:"required"`
+	Data json.RawMessage               `json:"data" validate:"required"`
+}

+ 2 - 0
internal/utils/parser/json.go

@@ -22,6 +22,8 @@ func UnmarshalJsonBytes[T any](data []byte) (T, error) {
 	typ := reflect.TypeOf(result)
 	if typ.Kind() == reflect.Map {
 		return result, nil
+	} else if typ.Kind() == reflect.String {
+		return result, nil
 	}
 
 	if err := validators.GlobalEntitiesValidator.Struct(&result); err != nil {