瀏覽代碼

feat: add new interfaces

Yeuoly 10 月之前
父節點
當前提交
5d35741a6c

+ 2 - 2
internal/core/plugin_daemon/model_service.go

@@ -115,9 +115,9 @@ func GetTTSModelVoices(
 	session *session_manager.Session,
 	request *requests.RequestGetTTSModelVoices,
 ) (
-	*stream.Stream[model_entities.TTSModelVoice], error,
+	*stream.Stream[model_entities.GetTTSVoicesResponse], error,
 ) {
-	return genericInvokePlugin[requests.RequestGetTTSModelVoices, model_entities.TTSModelVoice](
+	return genericInvokePlugin[requests.RequestGetTTSModelVoices, model_entities.GetTTSVoicesResponse](
 		session,
 		request,
 		1,

+ 1 - 1
internal/service/invoke_model.go

@@ -243,7 +243,7 @@ func GetTTSModelVoices(
 	defer session.Close()
 
 	baseSSEService(
-		func() (*stream.Stream[model_entities.TTSModelVoice], error) {
+		func() (*stream.Stream[model_entities.GetTTSVoicesResponse], error) {
 			return plugin_daemon.GetTTSModelVoices(session, &r.Data)
 		},
 		ctx,

+ 3 - 0
internal/types/entities/model_entities/ai_model.go

@@ -1,4 +1,7 @@
 package model_entities
 
+import "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+
 type GetModelSchemasResponse struct {
+	AIModels []plugin_entities.ModelDeclaration `json:"ai_models" validate:"required,dive"`
 }

+ 1 - 0
internal/types/entities/model_entities/llm.go

@@ -200,4 +200,5 @@ type LLMResultChunkDelta struct {
 }
 
 type LLMGetNumTokensResponse struct {
+	NumTokens int `json:"num_tokens" validate:"required"`
 }

+ 1 - 0
internal/types/entities/model_entities/text_embedding.go

@@ -19,4 +19,5 @@ type TextEmbeddingResult struct {
 }
 
 type GetTextEmbeddingNumTokensResponse struct {
+	NumTokens int `json:"num_tokens" validate:"required"`
 }

+ 6 - 0
internal/types/entities/model_entities/tts.go

@@ -5,4 +5,10 @@ type TTSResult struct {
 }
 
 type TTSModelVoice struct {
+	Name  string `json:"name" validate:"required"`
+	Value string `json:"value" validate:"required"`
+}
+
+type GetTTSVoicesResponse struct {
+	Voices []TTSModelVoice `json:"voices" validate:"required,dive"`
 }

+ 1 - 5
internal/types/entities/plugin_entities/model_declaration.go

@@ -116,10 +116,6 @@ type ModelPriceConfig struct {
 	Currency string           `json:"currency" validate:"required"`
 }
 
-type ModelPricing struct {
-	PricePerUnit float64 `json:"price_per_unit" validate:"required"`
-}
-
 type ModelDeclaration struct {
 	Model           string                         `json:"model" validate:"required,lt=256"`
 	Label           I18nObject                     `json:"label" validate:"required"`
@@ -129,7 +125,7 @@ type ModelDeclaration struct {
 	ModelProperties map[string]any                 `json:"model_properties" validate:"omitempty,dive,is_basic_type"`
 	Deprecated      bool                           `json:"deprecated"`
 	ParameterRules  []ModelParameterRule           `json:"parameter_rules" validate:"omitempty,lte=128,dive,parameter_rule"`
-	PriceConfig     *ModelPriceConfig              `json:"price_config" validate:"omitempty"`
+	PriceConfig     *ModelPriceConfig              `json:"pricing" validate:"omitempty"`
 }
 
 type ModelProviderFormType string

+ 12 - 0
internal/types/entities/requests/model.go

@@ -107,13 +107,25 @@ type RequestValidateModelCredentials struct {
 }
 
 type RequestGetTTSModelVoices struct {
+	Model       string         `json:"model" validate:"required"`
+	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+	Language    string         `json:"language" validate:"omitempty"`
 }
 
 type RequestGetTextEmbeddingNumTokens struct {
+	Model       string         `json:"model" validate:"required"`
+	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+	Texts       []string       `json:"texts" validate:"required,dive"`
 }
 
 type RequestGetLLMNumTokens struct {
+	Model          string                             `json:"model" validate:"required"`
+	Credentials    map[string]any                     `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+	PromptMessages []model_entities.PromptMessage     `json:"prompt_messages"  validate:"omitempty,dive"`
+	Tools          []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
 }
 
 type RequestGetAIModelSchema struct {
+	Model       string         `json:"model" validate:"required"`
+	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }