瀏覽代碼

feat: invoke model

Yeuoly 1 年之前
父節點
當前提交
d6ab1f6a14

+ 1 - 0
go.mod

@@ -31,6 +31,7 @@ require (
 	github.com/klauspost/cpuid/v2 v2.2.8 // indirect
 	github.com/leodido/go-urn v1.4.0 // indirect
 	github.com/mattn/go-isatty v0.0.20 // indirect
+	github.com/mitchellh/mapstructure v1.5.0
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/panjf2000/ants v1.3.0

+ 3 - 0
go.sum

@@ -50,6 +50,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
 github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
+github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -86,6 +88,7 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
 github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
 github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
 github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
+go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
 go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
 go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
 go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=

+ 8 - 0
internal/server/validate.go

@@ -3,6 +3,7 @@ package server
 import (
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 )
 
 func BindRequest[T any](r *gin.Context, success func(T)) {
@@ -16,6 +17,13 @@ func BindRequest[T any](r *gin.Context, success func(T)) {
 		err = r.ShouldBind(&request)
 	}
 
+	// validate
+	if err := validators.GlobalEntitiesValidator.Struct(request); err != nil {
+		resp := entities.NewErrorResponse(-400, "Invalid request")
+		r.JSON(400, resp)
+		return
+	}
+
 	if err != nil {
 		resp := entities.NewErrorResponse(-400, "Invalid request")
 		r.JSON(400, resp)

+ 1 - 1
internal/types/entities/model_entities/llm.go

@@ -13,7 +13,7 @@ type ModelType string
 
 const (
 	MODEL_TYPE_LLM            ModelType = "llm"
-	MODEL_TYPE_TEXT_EMBEDDING ModelType = "text_embedding"
+	MODEL_TYPE_TEXT_EMBEDDING ModelType = "text-embedding"
 	MODEL_TYPE_RERANKING      ModelType = "rerank"
 	MODEL_TYPE_SPEECH2TEXT    ModelType = "speech2text"
 	MODEL_TYPE_TTS            ModelType = "tts"

+ 1 - 1
internal/types/entities/model_entities/tts.go

@@ -1,5 +1,5 @@
 package model_entities
 
 type TTSResult struct {
-	Result string `json:"result"`
+	Result string `json:"result"` // in hex
 }

+ 18 - 22
internal/types/entities/requests/model.go

@@ -5,24 +5,15 @@ import (
 )
 
 type BaseRequestInvokeModel struct {
-	Provider    string                   `json:"provider" validate:"required"`
-	ModelType   model_entities.ModelType `json:"model_type"  validate:"required,model_type"`
-	Model       string                   `json:"model" validate:"required"`
-	Credentials map[string]any           `json:"credentials" validate:"omitempty,dive,is_basic_type"`
-}
-
-func (r *BaseRequestInvokeModel) ToCallerArguments() map[string]any {
-	return map[string]any{
-		"provider":    r.Provider,
-		"model":       r.Model,
-		"model_type":  r.ModelType,
-		"credentials": r.Credentials,
-	}
+	Provider    string         `json:"provider" validate:"required"`
+	Model       string         `json:"model" validate:"required"`
+	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
 
 type RequestInvokeLLM struct {
 	BaseRequestInvokeModel
 
+	ModelType       model_entities.ModelType           `json:"model_type"  validate:"required,model_type,eq=llm"`
 	ModelParameters map[string]any                     `json:"model_parameters"  validate:"omitempty,dive,is_basic_type"`
 	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages"  validate:"omitempty,dive"`
 	Tools           []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
@@ -33,33 +24,38 @@ type RequestInvokeLLM struct {
 type RequestInvokeTextEmbedding struct {
 	BaseRequestInvokeModel
 
-	Texts []string `json:"texts" validate:"required,dive"`
+	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=text-embedding"`
+	Texts     []string                 `json:"texts" validate:"required,dive"`
 }
 
 type RequestInvokeRerank struct {
 	BaseRequestInvokeModel
 
-	Query          string   `json:"query" validate:"required"`
-	Docs           []string `json:"docs" validate:"required,dive"`
-	ScoreThreshold float64  `json:"score_threshold" `
-	TopN           int      `json:"top_n" `
+	ModelType      model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=rerank"`
+	Query          string                   `json:"query" validate:"required"`
+	Docs           []string                 `json:"docs" validate:"required,dive"`
+	ScoreThreshold float64                  `json:"score_threshold" `
+	TopN           int                      `json:"top_n" `
 }
 
 type RequestInvokeTTS struct {
 	BaseRequestInvokeModel
 
-	ContentText string `json:"content_text"  validate:"required"`
-	Voice       string `json:"voice" validate:"required"`
+	ModelType   model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=tts"`
+	ContentText string                   `json:"content_text"  validate:"required"`
+	Voice       string                   `json:"voice" validate:"required"`
 }
 
 type RequestInvokeSpeech2Text struct {
 	BaseRequestInvokeModel
 
-	File string `json:"file" validate:"required"` // base64 encoded voice file
+	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=speech2text"`
+	File      string                   `json:"file" validate:"required"` // hexing encoded voice file
 }
 
 type RequestInvokeModeration struct {
 	BaseRequestInvokeModel
 
-	Text string `json:"text" validate:"required"`
+	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=moderation"`
+	Text      string                   `json:"text" validate:"required"`
 }

+ 35 - 0
tests/benchmark/encoding/ascii85_test.go

@@ -0,0 +1,35 @@
+package encoding
+
+import (
+	"encoding/ascii85"
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/tests"
+)
+
+func BenchmarkAscii85(b *testing.B) {
+	var data = []byte("hello world")
+	var dst = make([]byte, ascii85.MaxEncodedLen(len(data)))
+	bytes := 0
+
+	b.Run("Encode", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			ascii85.Encode(dst, data)
+			bytes += len(data)
+		}
+	})
+
+	b.Log("Bytes encoded:", tests.ReadableBytes(bytes))
+
+	encoded := make([]byte, ascii85.MaxEncodedLen(len(data)))
+	bytes = 0
+	ascii85.Encode(encoded, data)
+	b.Run("Decode", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			ascii85.Decode(dst, encoded, true)
+			bytes += len(encoded)
+		}
+	})
+
+	b.Log("Bytes decoded:", tests.ReadableBytes(bytes))
+}

+ 36 - 0
tests/benchmark/encoding/b64_test.go

@@ -0,0 +1,36 @@
+package encoding
+
+import (
+	"encoding/base64"
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/tests"
+)
+
+func BenchmarkBase64(b *testing.B) {
+	var data = []byte("hello world")
+	bytes := 0
+	var dst = make([]byte, base64.StdEncoding.EncodedLen(len(data)))
+
+	b.Run("Encode", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			base64.StdEncoding.Encode(dst, data)
+			bytes += len(data)
+		}
+	})
+
+	b.Log("Bytes encoded:", tests.ReadableBytes(bytes))
+
+	encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
+	bytes = 0
+	base64.StdEncoding.Encode(encoded, data)
+
+	b.Run("Decode", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			base64.StdEncoding.Decode(dst, encoded)
+			bytes += len(encoded)
+		}
+	})
+
+	b.Log("Bytes decoded:", tests.ReadableBytes(bytes))
+}

+ 35 - 0
tests/benchmark/encoding/hex_test.go

@@ -0,0 +1,35 @@
+package encoding
+
+import (
+	"encoding/hex"
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/tests"
+)
+
+func BenchmarkHex(b *testing.B) {
+	var data = []byte("hello world")
+	var dst = make([]byte, len(data)*2)
+	bytes := 0
+
+	b.Run("Encode", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			hex.Encode(dst, data)
+			bytes += len(data)
+		}
+	})
+
+	b.Log("Bytes encoded:", tests.ReadableBytes(bytes))
+
+	encoded := make([]byte, len(data)*2)
+	bytes = 0
+	hex.Encode(encoded, data)
+	b.Run("Decode", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			hex.Decode(dst, encoded)
+			bytes += len(encoded)
+		}
+	})
+
+	b.Log("Bytes decoded:", tests.ReadableBytes(bytes))
+}

+ 24 - 0
tests/benchmark/stdio/rw_test.go

@@ -0,0 +1,24 @@
+package stdio
+
+import (
+	"os"
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/tests"
+)
+
+func BenchmarkStdioBandWidth(b *testing.B) {
+	// open /dev/zero for reading
+	buf := make([]byte, 1024)
+	zero := os.NewFile(0, "/dev/zero")
+	bytes := 0
+
+	b.Run("Read", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			zero.Read(buf)
+			bytes += len(buf)
+		}
+	})
+
+	b.Log("Bytes read:", tests.ReadableBytes(bytes))
+}

+ 20 - 0
tests/visable.go

@@ -0,0 +1,20 @@
+package tests
+
+import "fmt"
+
+func ReadableBytes(l int) string {
+	// convert l bytes to a readable string
+	if l < 1024 {
+		return fmt.Sprintf("%d B", l)
+	}
+
+	if l < 1024*1024 {
+		return fmt.Sprintf("%.2f KB", float64(l)/1024)
+	}
+
+	if l < 1024*1024*1024 {
+		return fmt.Sprintf("%.2f MB", float64(l)/(1024*1024))
+	}
+
+	return fmt.Sprintf("%.2f GB", float64(l)/(1024*1024*1024))
+}