Browse Source

feat: aws lambda init environment impl

Yeuoly 11 months ago
parent
commit
366aefcf06

+ 72 - 0
internal/core/plugin_manager/aws_manager/environment.go

@@ -1,6 +1,78 @@
 package aws_manager
 package aws_manager
 
 
+import (
+	"fmt"
+	"os"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+)
+
+var (
+	AWS_LAUNCH_LOCK_PREFIX = "aws_launch_lock_"
+)
+
 func (r *AWSPluginRuntime) InitEnvironment() error {
 func (r *AWSPluginRuntime) InitEnvironment() error {
+	r.Log("Starting to initialize environment")
+	// check if the plugin has already been initialized, at most 300s
+	if err := cache.Lock(AWS_LAUNCH_LOCK_PREFIX+r.Checksum(), 300*time.Second, 300*time.Second); err != nil {
+		return err
+	}
+	defer cache.Unlock(AWS_LAUNCH_LOCK_PREFIX + r.Checksum())
+	r.Log("Started to initialize environment")
+
+	identity, err := r.Identity()
+	if err != nil {
+		return err
+	}
+	function, err := fetchLambda(identity, r.Checksum())
+	if err != nil {
+		if err != ErrNoLambdaFunction {
+			return err
+		}
+	} else {
+		// found, return directly
+		r.lambda_url = function.FunctionURL
+		r.lambda_name = function.FunctionName
+		r.Log(fmt.Sprintf("Found existing lambda function: %s", r.lambda_name))
+		return nil
+	}
+
+	// create it if not found
+	r.Log("Creating new lambda function")
+
+	// create lambda function
+	packager := NewPackager(r, r.decoder)
+	context, err := packager.Pack()
+	if err != nil {
+		return err
+	}
+	defer os.Remove(context.Name())
+	defer context.Close()
+
+	response, err := launchLambda(identity, r.Checksum(), context)
+	if err != nil {
+		return err
+	}
+
+	for response.Next() {
+		response, err := response.Read()
+		if err != nil {
+			return err
+		}
+
+		switch response.Event {
+		case Error:
+			return fmt.Errorf("error: %s", response.Message)
+		case LambdaUrl:
+			r.lambda_url = response.Message
+		case Lambda:
+			r.lambda_name = response.Message
+		case Info:
+			r.Log(fmt.Sprintf("installing: %s", response.Message))
+		}
+	}
+
 	return nil
 	return nil
 }
 }
 
 

+ 2 - 10
internal/core/plugin_manager/aws_manager/packager_test.go

@@ -116,19 +116,11 @@ func TestPackager_Pack(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	defer func() {
 	defer func() {
+		f.Close()
 		os.Remove(f.Name())
 		os.Remove(f.Name())
 	}()
 	}()
 
 
-	// read tar file and check if there is a dockerfile
-	// Open the tar file
-	tar_gz_file, err := os.Open(f.Name())
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer tar_gz_file.Close()
-
-	// Create a new gzip reader
-	gzip_reader, err := gzip.NewReader(tar_gz_file)
+	gzip_reader, err := gzip.NewReader(f)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 0 - 1
internal/core/plugin_manager/aws_manager/run.go

@@ -3,7 +3,6 @@ package aws_manager
 import "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 import "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 
 
 func (r *AWSPluginRuntime) StartPlugin() error {
 func (r *AWSPluginRuntime) StartPlugin() error {
-
 	return nil
 	return nil
 }
 }
 
 

+ 122 - 0
internal/core/plugin_manager/aws_manager/serverless_connector.go

@@ -0,0 +1,122 @@
+package aws_manager
+
+import (
+	"errors"
+	"fmt"
+	"io"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+)
+
+var ()
+
+type LambdaFunction struct {
+	FunctionName string `json:"function_name" validate:"required"`
+	FunctionARN  string `json:"function_arn" validate:"required"`
+	FunctionURL  string `json:"function_url" validate:"required"`
+}
+
+// Ping the serverless connector, return error if failed
+func ping() error {
+	response, err := http_requests.PostAndParse[entities.GenericResponse[string]](
+		client,
+		"/ping",
+		http_requests.HttpHeader(map[string]string{
+			"Authorization": SERVERLESS_CONNECTOR_API_KEY,
+		}),
+	)
+	if err != nil {
+		return err
+	}
+
+	if response.Code != 0 {
+		return fmt.Errorf("unexpected response from serverless connector: %s", response.Message)
+	}
+
+	if response.Data != "pong" {
+		return fmt.Errorf("unexpected response from serverless connector: %s", response.Data)
+	}
+	return nil
+}
+
+var (
+	ErrNoLambdaFunction = errors.New("no lambda function found")
+)
+
+// Fetch the lambda function from serverless connector, return error if failed
+func fetchLambda(identity string, checksum string) (*LambdaFunction, error) {
+	request := map[string]any{
+		"config": map[string]any{
+			"identity": identity,
+			"checksum": checksum,
+		},
+	}
+
+	response, err := http_requests.PostAndParse[entities.GenericResponse[LambdaFunction]](
+		client,
+		"/v1/lambda/fetch",
+		http_requests.HttpHeader(map[string]string{
+			"Authorization": SERVERLESS_CONNECTOR_API_KEY,
+		}),
+		http_requests.HttpPayloadJson(request),
+	)
+	if err != nil {
+		return nil, err
+	}
+
+	if response.Code != 0 {
+		if response.Code == -404 {
+			return nil, ErrNoLambdaFunction
+		}
+		return nil, fmt.Errorf("unexpected response from serverless connector: %s", response.Message)
+	}
+
+	return &response.Data, nil
+}
+
+type LaunchAWSLambdaFunctionEvent string
+
+const (
+	Error     LaunchAWSLambdaFunctionEvent = "error"
+	Info      LaunchAWSLambdaFunctionEvent = "info"
+	Lambda    LaunchAWSLambdaFunctionEvent = "lambda"
+	LambdaUrl LaunchAWSLambdaFunctionEvent = "lambda_url"
+	Done      LaunchAWSLambdaFunctionEvent = "done"
+)
+
+type LaunchAWSLambdaFunctionResponse struct {
+	Event   LaunchAWSLambdaFunctionEvent `json:"event"`
+	Message string                       `json:"message"`
+}
+
+// Launch the lambda function from serverless connector, it will receive the context_tar as the input
+// and build it a docker image, then run it on serverless platform like AWS Lambda
+// it returns a event stream, the caller should consider it as a async operation
+func launchLambda(identity string, checksum string, context_tar io.Reader) (*stream.StreamResponse[LaunchAWSLambdaFunctionResponse], error) {
+	response, err := http_requests.PostAndParseStream[LaunchAWSLambdaFunctionResponse](
+		client,
+		"/v1/lambda/launch",
+		http_requests.HttpHeader(map[string]string{
+			"Authorization": SERVERLESS_CONNECTOR_API_KEY,
+		}),
+		http_requests.HttpReadTimeout(300),
+		http_requests.HttpWriteTimeout(300),
+		http_requests.HttpPayloadMultipart(
+			map[string]string{
+				"identity": identity,
+				"checksum": checksum,
+			},
+			map[string]io.Reader{
+				"context": context_tar,
+			},
+		),
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return response, nil
+}

+ 41 - 0
internal/core/plugin_manager/aws_manager/serverless_connector_client.go

@@ -0,0 +1,41 @@
+package aws_manager
+
+import (
+	"net"
+	"net/http"
+	"net/url"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+)
+
+var (
+	SERVERLESS_CONNECTOR_API_KEY string
+	baseurl                      *url.URL
+	client                       *http.Client
+)
+
+func Init(config *app.Config) {
+	var err error
+	baseurl, err = url.Parse(*config.DifyPluginServerlessConnectorURL)
+	if err != nil {
+		log.Panic("Failed to parse serverless connector url", err)
+	}
+
+	client = &http.Client{
+		Transport: &http.Transport{
+			Dial: (&net.Dialer{
+				Timeout:   5 * time.Second,
+				KeepAlive: 15 * time.Second,
+			}).Dial,
+			IdleConnTimeout: 120 * time.Second,
+		},
+	}
+
+	SERVERLESS_CONNECTOR_API_KEY = *config.DifyPluginServerlessConnectorAPIKey
+
+	if err := ping(); err != nil {
+		log.Panic("Failed to ping serverless connector", err)
+	}
+}

+ 8 - 0
internal/core/plugin_manager/aws_manager/type.go

@@ -2,10 +2,18 @@ package aws_manager
 
 
 import (
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/positive_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/positive_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 )
 )
 
 
 type AWSPluginRuntime struct {
 type AWSPluginRuntime struct {
 	positive_manager.PositivePluginRuntime
 	positive_manager.PositivePluginRuntime
 	entities.PluginRuntime
 	entities.PluginRuntime
+
+	// access url for the lambda function
+	lambda_url  string
+	lambda_name string
+
+	// plugin decoder used to manage the plugin
+	decoder decoder.PluginDecoder
 }
 }

+ 2 - 2
internal/core/plugin_manager/watcher.go

@@ -18,7 +18,7 @@ import (
 
 
 func (p *PluginManager) startLocalWatcher(config *app.Config) {
 func (p *PluginManager) startLocalWatcher(config *app.Config) {
 	go func() {
 	go func() {
-		log.Info("start to handle new plugins in path: %s", config.StoragePath)
+		log.Info("start to handle new plugins in path: %s", config.PluginStoragePath)
 		p.handleNewPlugins(config)
 		p.handleNewPlugins(config)
 		for range time.NewTicker(time.Second * 30).C {
 		for range time.NewTicker(time.Second * 30).C {
 			p.handleNewPlugins(config)
 			p.handleNewPlugins(config)
@@ -46,7 +46,7 @@ func (p *PluginManager) startRemoteWatcher(config *app.Config) {
 
 
 func (p *PluginManager) handleNewPlugins(config *app.Config) {
 func (p *PluginManager) handleNewPlugins(config *app.Config) {
 	// load local plugins firstly
 	// load local plugins firstly
-	for plugin := range p.loadNewPlugins(config.StoragePath) {
+	for plugin := range p.loadNewPlugins(config.PluginStoragePath) {
 		var plugin_interface entities.PluginRuntimeInterface
 		var plugin_interface entities.PluginRuntimeInterface
 
 
 		if config.Platform == app.PLATFORM_AWS_LAMBDA {
 		if config.Platform == app.PLATFORM_AWS_LAMBDA {

+ 26 - 2
internal/types/app/config.go

@@ -20,8 +20,9 @@ type Config struct {
 
 
 	PluginWebhookEnabled bool `envconfig:"PLUGIN_WEBHOOK_ENABLED"`
 	PluginWebhookEnabled bool `envconfig:"PLUGIN_WEBHOOK_ENABLED"`
 
 
-	StoragePath        string `envconfig:"STORAGE_PLUGIN_PATH"  validate:"required"`
-	ProcessCachingPath string `envconfig:"PROCESS_CACHING_PATH"  validate:"required"`
+	PluginStoragePath  string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"`
+	PluginWorkingPath  string `envconfig:"PLUGIN_WORKING_PATH"`
+	ProcessCachingPath string `envconfig:"PROCESS_CACHING_PATH"`
 
 
 	Platform PlatformType `envconfig:"PLATFORM" validate:"required"`
 	Platform PlatformType `envconfig:"PLATFORM" validate:"required"`
 
 
@@ -43,6 +44,9 @@ type Config struct {
 	LifetimeStateGCInterval             int `envconfig:"LIFETIME_STATE_GC_INTERVAL" validate:"required"`
 	LifetimeStateGCInterval             int `envconfig:"LIFETIME_STATE_GC_INTERVAL" validate:"required"`
 
 
 	DifyInvocationConnectionIdleTimeout int `envconfig:"DIFY_INVOCATION_CONNECTION_IDLE_TIMEOUT" validate:"required"`
 	DifyInvocationConnectionIdleTimeout int `envconfig:"DIFY_INVOCATION_CONNECTION_IDLE_TIMEOUT" validate:"required"`
+
+	DifyPluginServerlessConnectorURL    *string `envconfig:"DIFY_PLUGIN_SERVERLESS_CONNECTOR_URL"`
+	DifyPluginServerlessConnectorAPIKey *string `envconfig:"DIFY_PLUGIN_SERVERLESS_CONNECTOR_API_KEY"`
 }
 }
 
 
 func (c *Config) Validate() error {
 func (c *Config) Validate() error {
@@ -67,6 +71,26 @@ func (c *Config) Validate() error {
 		}
 		}
 	}
 	}
 
 
+	if c.Platform == PLATFORM_AWS_LAMBDA {
+		if c.DifyPluginServerlessConnectorURL == nil {
+			return fmt.Errorf("dify plugin serverless connector url is empty")
+		}
+
+		if c.DifyPluginServerlessConnectorAPIKey == nil {
+			return fmt.Errorf("dify plugin serverless connector api key is empty")
+		}
+	} else if c.Platform == PLATFORM_LOCAL {
+		if c.PluginWorkingPath == "" {
+			return fmt.Errorf("plugin working path is empty")
+		}
+
+		if c.ProcessCachingPath == "" {
+			return fmt.Errorf("process caching path is empty")
+		}
+	} else {
+		return fmt.Errorf("invalid platform")
+	}
+
 	return nil
 	return nil
 }
 }
 
 

+ 6 - 0
internal/types/entities/response.go

@@ -21,3 +21,9 @@ func NewErrorResponse(code int, message string) *Response {
 		Data:    nil,
 		Data:    nil,
 	}
 	}
 }
 }
+
+type GenericResponse[T any] struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Data    T      `json:"data"`
+}

+ 23 - 0
internal/types/entities/runtime.go

@@ -5,6 +5,7 @@ import (
 	"crypto/sha256"
 	"crypto/sha256"
 	"encoding/gob"
 	"encoding/gob"
 	"encoding/hex"
 	"encoding/hex"
+	"fmt"
 	"hash/fnv"
 	"hash/fnv"
 	"time"
 	"time"
 
 
@@ -22,6 +23,7 @@ type (
 		PluginRuntimeTimeLifeInterface
 		PluginRuntimeTimeLifeInterface
 		PluginRuntimeSessionIOInterface
 		PluginRuntimeSessionIOInterface
 		PluginRuntimeDockerInterface
 		PluginRuntimeDockerInterface
+		PluginRuntimeLogInterface
 	}
 	}
 
 
 	PluginRuntimeTimeLifeInterface interface {
 	PluginRuntimeTimeLifeInterface interface {
@@ -70,6 +72,15 @@ type (
 		AddRestarts()
 		AddRestarts()
 	}
 	}
 
 
+	PluginRuntimeLogInterface interface {
+		// Log adds a log to the plugin runtime state
+		Log(string)
+		// Warn adds a warning to the plugin runtime state
+		Warn(string)
+		// Error adds an error to the plugin runtime state
+		Error(string)
+	}
+
 	PluginRuntimeSessionIOInterface interface {
 	PluginRuntimeSessionIOInterface interface {
 		Listen(session_id string) *BytesIOListener
 		Listen(session_id string) *BytesIOListener
 		Write(session_id string, data []byte)
 		Write(session_id string, data []byte)
@@ -164,6 +175,18 @@ func (r *PluginRuntime) TriggerStop() {
 	}
 	}
 }
 }
 
 
+func (s *PluginRuntime) Log(log string) {
+	s.State.Logs = append(s.State.Logs, fmt.Sprintf("[Info] %s: %s", time.Now().Format(time.RFC3339), log))
+}
+
+func (s *PluginRuntime) Warn(log string) {
+	s.State.Logs = append(s.State.Logs, fmt.Sprintf("[Warn] %s: %s", time.Now().Format(time.RFC3339), log))
+}
+
+func (s *PluginRuntime) Error(log string) {
+	s.State.Logs = append(s.State.Logs, fmt.Sprintf("[Error] %s: %s", time.Now().Format(time.RFC3339), log))
+}
+
 type PluginRuntimeType string
 type PluginRuntimeType string
 
 
 const (
 const (

+ 11 - 0
internal/utils/http_requests/http_options.go

@@ -1,5 +1,7 @@
 package http_requests
 package http_requests
 
 
+import "io"
+
 type HttpOptions struct {
 type HttpOptions struct {
 	Type  string
 	Type  string
 	Value interface{}
 	Value interface{}
@@ -39,6 +41,15 @@ func HttpPayloadJson(payload interface{}) HttpOptions {
 	return HttpOptions{"payloadJson", payload}
 	return HttpOptions{"payloadJson", payload}
 }
 }
 
 
+// which is used for POST method only
+// payload follows the form data format, and files is a map from filename to file
+func HttpPayloadMultipart(payload map[string]string, files map[string]io.Reader) HttpOptions {
+	return HttpOptions{"payloadMultipart", map[string]interface{}{
+		"payload": payload,
+		"files":   files,
+	}}
+}
+
 func HttpWithDirectReferer() HttpOptions {
 func HttpWithDirectReferer() HttpOptions {
 	return HttpOptions{"directReferer", true}
 	return HttpOptions{"directReferer", true}
 }
 }

+ 33 - 0
internal/utils/http_requests/http_request.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
 	"io"
 	"io"
+	"mime/multipart"
 	"net/http"
 	"net/http"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -39,6 +40,38 @@ func buildHttpRequest(method string, url string, options ...HttpOptions) (*http.
 				q.Add(k, v)
 				q.Add(k, v)
 			}
 			}
 			req.Body = io.NopCloser(strings.NewReader(q.Encode()))
 			req.Body = io.NopCloser(strings.NewReader(q.Encode()))
+		case "payloadMultipart":
+			buffer := new(bytes.Buffer)
+			writer := multipart.NewWriter(buffer)
+
+			files := option.Value.(map[string]any)["files"].(map[string]io.Reader)
+			for filename, reader := range files {
+				part, err := writer.CreateFormFile(filename, filename)
+				if err != nil {
+					writer.Close()
+					return nil, err
+				}
+				_, err = io.Copy(part, reader)
+				if err != nil {
+					writer.Close()
+					return nil, err
+				}
+			}
+
+			payload := option.Value.(map[string]any)["payload"].(map[string]string)
+			for k, v := range payload {
+				if err := writer.WriteField(k, v); err != nil {
+					writer.Close()
+					return nil, err
+				}
+			}
+
+			if err := writer.Close(); err != nil {
+				return nil, err
+			}
+
+			req.Body = io.NopCloser(buffer)
+			req.Header.Set("Content-Type", writer.FormDataContentType())
 		case "payloadText":
 		case "payloadText":
 			req.Body = io.NopCloser(strings.NewReader(option.Value.(string)))
 			req.Body = io.NopCloser(strings.NewReader(option.Value.(string)))
 			req.Header.Set("Content-Type", "text/plain")
 			req.Header.Set("Content-Type", "text/plain")

+ 4 - 1
internal/utils/http_requests/http_warpper.go

@@ -98,6 +98,8 @@ func RequestAndParseStream[T any](client *http.Client, url string, method string
 
 
 	routine.Submit(func() {
 	routine.Submit(func() {
 		scanner := bufio.NewScanner(resp.Body)
 		scanner := bufio.NewScanner(resp.Body)
+		defer resp.Body.Close()
+
 		for scanner.Scan() {
 		for scanner.Scan() {
 			data := scanner.Bytes()
 			data := scanner.Bytes()
 			if len(data) == 0 {
 			if len(data) == 0 {
@@ -112,7 +114,8 @@ func RequestAndParseStream[T any](client *http.Client, url string, method string
 			// unmarshal
 			// unmarshal
 			t, err := parser.UnmarshalJsonBytes[T](data)
 			t, err := parser.UnmarshalJsonBytes[T](data)
 			if err != nil {
 			if err != nil {
-				continue
+				ch.WriteError(err)
+				break
 			}
 			}
 
 
 			ch.Write(t)
 			ch.Write(t)