瀏覽代碼

feat: text-embedding and moderation

Yeuoly 1 年之前
父節點
當前提交
184fb1efa0

+ 6 - 6
internal/core/plugin_daemon/model_service.go

@@ -143,9 +143,9 @@ func InvokeTTS(
 	session *session_manager.Session,
 	request *requests.RequestInvokeTTS,
 ) (
-	*stream.StreamResponse[string], error,
+	*stream.StreamResponse[model_entities.TTSResult], error,
 ) {
-	return genericInvokePlugin[requests.RequestInvokeTTS, string](
+	return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult](
 		session,
 		request,
 		1,
@@ -158,9 +158,9 @@ func InvokeSpeech2Text(
 	session *session_manager.Session,
 	request *requests.RequestInvokeSpeech2Text,
 ) (
-	*stream.StreamResponse[string], error,
+	*stream.StreamResponse[model_entities.Speech2TextResult], error,
 ) {
-	return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string](
+	return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult](
 		session,
 		request,
 		1,
@@ -173,9 +173,9 @@ func InvokeModeration(
 	session *session_manager.Session,
 	request *requests.RequestInvokeModeration,
 ) (
-	*stream.StreamResponse[bool], error,
+	*stream.StreamResponse[model_entities.ModerationResult], error,
 ) {
-	return genericInvokePlugin[requests.RequestInvokeModeration, bool](
+	return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult](
 		session,
 		request,
 		1,

+ 55 - 0
internal/server/controller.go

@@ -32,3 +32,58 @@ func InvokeLLM(c *gin.Context) {
 		},
 	)
 }
+
+func InvokeTextEmbedding(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.InvokeTextEmbedding(&itr, c)
+		},
+	)
+}
+
+func InvokeRerank(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.InvokeRerank(&itr, c)
+		},
+	)
+}
+
+func InvokeTTS(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.InvokeTTS(&itr, c)
+		},
+	)
+}
+
+func InvokeSpeech2Text(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.InvokeSpeech2Text(&itr, c)
+		},
+	)
+}
+
+func InvokeModeration(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.InvokeModeration(&itr, c)
+		},
+	)
+}

+ 5 - 0
internal/server/http.go

@@ -13,6 +13,11 @@ func server(config *app.Config) {
 	engine.GET("/health/check", HealthCheck)
 	engine.POST("/plugin/tool/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTool)
 	engine.POST("/plugin/llm/invoke", CheckingKey(config.PluginInnerApiKey), InvokeLLM)
+	engine.POST("/plugin/text_embedding/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTextEmbedding)
+	engine.POST("/plugin/rerank/invoke", CheckingKey(config.PluginInnerApiKey), InvokeRerank)
+	engine.POST("/plugin/tts/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTTS)
+	engine.POST("/plugin/speech2text/invoke", CheckingKey(config.PluginInnerApiKey), InvokeSpeech2Text)
+	engine.POST("/plugin/moderation/invoke", CheckingKey(config.PluginInnerApiKey), InvokeModeration)
 
 	engine.Run(fmt.Sprintf(":%d", config.SERVER_PORT))
 }

+ 50 - 0
internal/service/invoke.go

@@ -30,3 +30,53 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
 		return plugin_daemon.InvokeLLM(session, &r.Data)
 	}, ctx)
 }
+
+func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
+		return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
+	}, ctx)
+}
+
+func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.RerankResult], error) {
+		return plugin_daemon.InvokeRerank(session, &r.Data)
+	}, ctx)
+}
+
+func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.TTSResult], error) {
+		return plugin_daemon.InvokeTTS(session, &r.Data)
+	}, ctx)
+}
+
+func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
+		return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
+	}, ctx)
+}
+
+func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
+		return plugin_daemon.InvokeModeration(session, &r.Data)
+	}, ctx)
+}

+ 4 - 0
internal/types/entities/model_entities/moderation.go

@@ -1 +1,5 @@
 package model_entities
+
+type ModerationResult struct {
+	Result bool `json:"result"`
+}

+ 4 - 0
internal/types/entities/model_entities/speech2text.go

@@ -1 +1,5 @@
 package model_entities
+
+type Speech2TextResult struct {
+	Result string `json:"result"`
+}

+ 4 - 0
internal/types/entities/model_entities/tts.go

@@ -1 +1,5 @@
 package model_entities
+
+type TTSResult struct {
+	Result string `json:"result"`
+}

+ 7 - 7
internal/types/entities/requests/model.go

@@ -6,7 +6,7 @@ import (
 
 type BaseRequestInvokeModel struct {
 	Provider    string                   `json:"provider" validate:"required"`
-	ModelType   model_entities.ModelType `json:"model_type" mapstructure:"model_type" validate:"required,model_type"`
+	ModelType   model_entities.ModelType `json:"model_type"  validate:"required,model_type"`
 	Model       string                   `json:"model" validate:"required"`
 	Credentials map[string]any           `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
@@ -23,11 +23,11 @@ func (r *BaseRequestInvokeModel) ToCallerArguments() map[string]any {
 type RequestInvokeLLM struct {
 	BaseRequestInvokeModel
 
-	ModelParameters map[string]any                     `json:"model_parameters" mapstructure:"model_parameters" validate:"omitempty,dive,is_basic_type"`
-	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages" mapstructure:"prompt_messages" validate:"omitempty,dive"`
+	ModelParameters map[string]any                     `json:"model_parameters"  validate:"omitempty,dive,is_basic_type"`
+	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages"  validate:"omitempty,dive"`
 	Tools           []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
 	Stop            []string                           `json:"stop" validate:"omitempty"`
-	Stream          bool                               `json:"stream" mapstructure:"stream"`
+	Stream          bool                               `json:"stream" `
 }
 
 type RequestInvokeTextEmbedding struct {
@@ -41,14 +41,14 @@ type RequestInvokeRerank struct {
 
 	Query          string   `json:"query" validate:"required"`
 	Docs           []string `json:"docs" validate:"required,dive"`
-	ScoreThreshold float64  `json:"score_threshold" mapstructure:"score_threshold"`
-	TopN           int      `json:"top_n" mapstructure:"top_n"`
+	ScoreThreshold float64  `json:"score_threshold" `
+	TopN           int      `json:"top_n" `
 }
 
 type RequestInvokeTTS struct {
 	BaseRequestInvokeModel
 
-	ContentText string `json:"content_text" mapstructure:"content_text" validate:"required"`
+	ContentText string `json:"content_text"  validate:"required"`
 	Voice       string `json:"voice" validate:"required"`
 }