Explorar o código

feat: invoke all models

Yeuoly hai 1 ano
pai
achega
dfe1d2e74e

+ 68 - 0
internal/types/entities/model_entities/rerank_test.go

@@ -0,0 +1,68 @@
+package model_entities
+
+import (
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+func TestRerankFullFunction(t *testing.T) {
+	const (
+		rerank = `
+		{
+			"model": "rerank",
+			"docs": [
+				{
+					"index": 1,
+					"text": "text",
+					"score": 0.1
+				}
+			]
+		}`
+	)
+
+	_, err := parser.UnmarshalJsonBytes[RerankResult]([]byte(rerank))
+	if err != nil {
+		t.Error(err)
+	}
+}
+
+func TestRerankWrongDocs(t *testing.T) {
+	const (
+		rerank = `
+		{
+			"model": "rerank",
+			"docs": [
+				{
+					"index": 1,
+					"text": "text"
+				}
+			]
+		}`
+	)
+
+	_, err := parser.UnmarshalJsonBytes[RerankResult]([]byte(rerank))
+	if err == nil {
+		t.Error("should have error")
+	}
+}
+
+func TestRerankWrongDocIndex(t *testing.T) {
+	const (
+		rerank = `
+		{
+			"model": "rerank",
+			"docs": [
+				{
+					"text": "text",
+					"score": 0.1
+				}
+			]
+		}`
+	)
+
+	_, err := parser.UnmarshalJsonBytes[RerankResult]([]byte(rerank))
+	if err == nil {
+		t.Error("should have error")
+	}
+}

+ 18 - 0
internal/types/entities/model_entities/text_embedding.go

@@ -1 +1,19 @@
 package model_entities
+
+import "github.com/shopspring/decimal"
+
+type EmbeddingUsage struct {
+	Tokens      *int            `json:"tokens" validate:"required"`
+	TotalTokens *int            `json:"total_tokens" validate:"required"`
+	UnitPrice   decimal.Decimal `json:"unit_price" validate:"required"`
+	PriceUnit   decimal.Decimal `json:"price_unit" validate:"required"`
+	TotalPrice  decimal.Decimal `json:"total_price" validate:"required"`
+	Currency    *string         `json:"currency" validate:"required"`
+	Latency     *float64        `json:"latency" validate:"required"`
+}
+
+type TextEmbeddingResult struct {
+	Model      string         `json:"model" validate:"required"`
+	Embeddings [][]float64    `json:"embeddings" validate:"required,dive"`
+	Usage      EmbeddingUsage `json:"usage" validate:"required"`
+}

+ 84 - 0
internal/types/entities/model_entities/text_embedding_test.go

@@ -0,0 +1,84 @@
+package model_entities
+
+import (
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+func TestTextEmbeddingFullFunction(t *testing.T) {
+	const (
+		text_embedding = `
+		{
+			"model": "text_embedding",
+			"embeddings": [[
+				0.1, 0.2, 0.3
+			]],
+			"usage" : {
+				"tokens": 3,
+				"total_tokens": 100,
+				"unit_price": 0.1,
+				"price_unit": 1,
+				"total_price": 10,
+				"currency": "usd",
+				"latency": 0.1
+			}
+		}`
+	)
+
+	_, err := parser.UnmarshalJsonBytes[TextEmbeddingResult]([]byte(text_embedding))
+	if err != nil {
+		t.Error(err)
+	}
+}
+
+func TestTextEmbeddingWrongUsage(t *testing.T) {
+	const (
+		text_embedding = `
+		{
+			"model": "text_embedding",
+			"embeddings": [[
+				0.1, 0.2, 0.3
+			]],
+			"usage" : {
+				"tokens": 3,
+				"total_tokens": 100,
+				"unit_price": 0.1,
+				"price_unit": 1,
+				"total_price": 10,
+				"currency": "usd"
+			}
+		}`
+	)
+
+	_, err := parser.UnmarshalJsonBytes[TextEmbeddingResult]([]byte(text_embedding))
+	if err == nil {
+		t.Error("should have error")
+	}
+}
+
+func TestTextEmbeddingWrongEmbeddings(t *testing.T) {
+	const (
+		text_embedding = `
+		{
+			"model": "text_embedding",
+			"embeddings": [
+				0.1, 0.2, 0.3
+			],
+			"usage" : {
+				"tokens": 3,
+				"total_tokens": 100,
+				"unit_price": 0.1,
+				"price_unit": 1,
+				"total_price": 10,
+				"currency": "usd",
+				"latency": 0.1
+			}
+		}`
+	)
+
+	_, err := parser.UnmarshalJsonBytes[TextEmbeddingResult]([]byte(text_embedding))
+	if err == nil {
+		t.Error("should have error")
+	}
+}

+ 1 - 1
internal/types/entities/plugin_entities/common.go

@@ -12,7 +12,7 @@ type I18nObject struct {
 	PtBr   string `json:"pt_BR" validate:"lt=1024"`
 }
 
-func isGenericType(fl validator.FieldLevel) bool {
+func isBasicType(fl validator.FieldLevel) bool {
 	// allowed int, string, bool, float64
 	switch fl.Field().Kind() {
 	case reflect.Int, reflect.String, reflect.Bool, reflect.Float64:

+ 1 - 1
internal/types/entities/plugin_entities/model_configuration.go

@@ -274,5 +274,5 @@ func init() {
 
 	validators.GlobalEntitiesValidator.RegisterValidation("parameter_rule", isParameterRule)
 
-	validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
+	validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
 }

+ 1 - 1
internal/types/entities/plugin_entities/tool_configuration.go

@@ -249,7 +249,7 @@ func init() {
 		},
 	)
 
-	validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
+	validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
 }
 
 func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration, error) {

+ 37 - 4
internal/types/entities/requests/model.go

@@ -4,14 +4,47 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 )
 
+type BaseRequestInvokeModel struct {
+	Provider    string                   `json:"provider"`
+	ModelType   model_entities.ModelType `json:"model_type" validate:"required,model_type"`
+	Model       string                   `json:"model"`
+	Credentials map[string]any           `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+}
+
 type RequestInvokeLLM struct {
-	Provider        string                             `json:"provider"`
-	ModelType       model_entities.ModelType           `json:"model_type" validate:"required,model_type"`
-	Model           string                             `json:"model"`
+	BaseRequestInvokeModel
+
 	ModelParameters map[string]any                     `json:"model_parameters" validate:"omitempty,dive,is_basic_type"`
 	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages" validate:"omitempty,dive"`
 	Tools           []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
 	Stop            []string                           `json:"stop" validate:"omitempty"`
 	Stream          bool                               `json:"stream"`
-	Credentials     map[string]any                     `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+}
+
+type RequestInvokeTextEmbedding struct {
+	BaseRequestInvokeModel
+
+	Texts []string `json:"texts" validate:"required,dive"`
+}
+
+type RequestInvokeRerank struct {
+	BaseRequestInvokeModel
+
+	Query          string   `json:"query" validate:"required"`
+	Docs           []string `json:"docs" validate:"required,dive"`
+	ScoreThreshold float64  `json:"score_threshold"`
+	TopN           int      `json:"top_n"`
+}
+
+type RequestInvokeTTS struct {
+	BaseRequestInvokeModel
+
+	ContentText string `json:"content_text" validate:"required"`
+	Voice       string `json:"voice" validate:"required"`
+}
+
+type RequestInvokeSpeech2Text struct {
+	BaseRequestInvokeModel
+
+	File string `json:"file" validate:"required"` // base64 encoded voice file
 }