浏览代码

feat: add get_tts_model_voices/get_num_tokens/get_model_schema

Yeuoly 10 月之前
父节点
当前提交
f179f0f663

+ 4 - 0
internal/core/plugin_daemon/access_types/access.go

@@ -22,4 +22,8 @@ const (
 	PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS PluginAccessAction = "validate_provider_credentials"
 	PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS    PluginAccessAction = "validate_model_credentials"
 	PLUGIN_ACCESS_ACTION_INVOKE_ENDPOINT               PluginAccessAction = "invoke_endpoint"
+	PLUGIN_ACCESS_ACTION_GET_TTS_MODEL_VOICES          PluginAccessAction = "get_tts_model_voices"
+	PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS PluginAccessAction = "get_text_embedding_num_tokens"
+	PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMA           PluginAccessAction = "get_ai_model_schema"
+	PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS            PluginAccessAction = "get_llm_num_tokens"
 )

+ 52 - 0
internal/core/plugin_daemon/model_service.go

@@ -110,3 +110,55 @@ func ValidateModelCredentials(
 		1,
 	)
 }
+
+func GetTTSModelVoices(
+	session *session_manager.Session,
+	request *requests.RequestGetTTSModelVoices,
+) (
+	*stream.Stream[model_entities.TTSModelVoice], error,
+) {
+	return genericInvokePlugin[requests.RequestGetTTSModelVoices, model_entities.TTSModelVoice](
+		session,
+		request,
+		1,
+	)
+}
+
+func GetTextEmbeddingNumTokens(
+	session *session_manager.Session,
+	request *requests.RequestGetTextEmbeddingNumTokens,
+) (
+	*stream.Stream[model_entities.GetTextEmbeddingNumTokensResponse], error,
+) {
+	return genericInvokePlugin[requests.RequestGetTextEmbeddingNumTokens, model_entities.GetTextEmbeddingNumTokensResponse](
+		session,
+		request,
+		1,
+	)
+}
+
+func GetLLMNumTokens(
+	session *session_manager.Session,
+	request *requests.RequestGetLLMNumTokens,
+) (
+	*stream.Stream[model_entities.LLMGetNumTokensResponse], error,
+) {
+	return genericInvokePlugin[requests.RequestGetLLMNumTokens, model_entities.LLMGetNumTokensResponse](
+		session,
+		request,
+		1,
+	)
+}
+
+func GetAIModelSchema(
+	session *session_manager.Session,
+	request *requests.RequestGetAIModelSchema,
+) (
+	*stream.Stream[model_entities.GetModelSchemasResponse], error,
+) {
+	return genericInvokePlugin[requests.RequestGetAIModelSchema, model_entities.GetModelSchemasResponse](
+		session,
+		request,
+		1,
+	)
+}

+ 3 - 1
internal/server/controllers/base.go

@@ -35,7 +35,9 @@ func BindRequest[T any](r *gin.Context, success func(T)) {
 	success(request)
 }
 
-func BindRequestWithPluginUniqueIdentifier[T any](r *gin.Context, success func(T, plugin_entities.PluginUniqueIdentifier)) {
+func BindRequestWithPluginUniqueIdentifier[T any](r *gin.Context, success func(
+	T, plugin_entities.PluginUniqueIdentifier,
+)) {
 	BindRequest(r, func(req T) {
 		plugin_unique_identifier := r.GetHeader(constants.X_PLUGIN_IDENTIFIER)
 		if plugin_unique_identifier == "" {

+ 52 - 0
internal/server/controllers/model.go

@@ -111,3 +111,55 @@ func ValidateModelCredentials(config *app.Config) gin.HandlerFunc {
 		)
 	}
 }
+
+func GetTTSModelVoices(config *app.Config) gin.HandlerFunc {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestGetTTSModelVoices]
+
+	return func(c *gin.Context) {
+		BindRequest(
+			c,
+			func(itr request) {
+				service.GetTTSModelVoices(&itr, c, config.PluginMaxExecutionTimeout)
+			},
+		)
+	}
+}
+
+func GetTextEmbeddingNumTokens(config *app.Config) gin.HandlerFunc {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestGetTextEmbeddingNumTokens]
+
+	return func(c *gin.Context) {
+		BindRequest(
+			c,
+			func(itr request) {
+				service.GetTextEmbeddingNumTokens(&itr, c, config.PluginMaxExecutionTimeout)
+			},
+		)
+	}
+}
+
+func GetLLMNumTokens(config *app.Config) gin.HandlerFunc {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestGetLLMNumTokens]
+
+	return func(c *gin.Context) {
+		BindRequest(
+			c,
+			func(itr request) {
+				service.GetLLMNumTokens(&itr, c, config.PluginMaxExecutionTimeout)
+			},
+		)
+	}
+}
+
+func GetAIModelSchema(config *app.Config) gin.HandlerFunc {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestGetAIModelSchema]
+
+	return func(c *gin.Context) {
+		BindRequest(
+			c,
+			func(itr request) {
+				service.GetAIModelSchema(&itr, c, config.PluginMaxExecutionTimeout)
+			},
+		)
+	}
+}

+ 2 - 2
internal/server/controllers/tool.go

@@ -12,7 +12,7 @@ func InvokeTool(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTool]
 
 	return func(c *gin.Context) {
-		BindRequest[request](
+		BindRequest(
 			c,
 			func(itr request) {
 				service.InvokeTool(&itr, c, config.PluginMaxExecutionTimeout)
@@ -25,7 +25,7 @@ func ValidateToolCredentials(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials]
 
 	return func(c *gin.Context) {
-		BindRequest[request](
+		BindRequest(
 			c,
 			func(itr request) {
 				service.ValidateToolCredentials(&itr, c, config.PluginMaxExecutionTimeout)

+ 2 - 1
internal/server/http_server.go

@@ -101,7 +101,8 @@ func (app *App) pluginGroup(group *gin.RouterGroup, config *app.Config) {
 	group.Use(CheckingKey(config.PluginInnerApiKey))
 
 	group.GET("/asset/:id", controllers.GetAsset)
-	group.POST("/install", controllers.InstallPluginFromPkg(config))
+	group.POST("/install/pkg", controllers.InstallPluginFromPkg(config))
+	group.POST("/install/id", controllers.InstallPluginFromIdentifier(config))
 	group.POST("/uninstall", controllers.UninstallPlugin)
 	group.GET("/list", controllers.ListPlugins)
 }

+ 104 - 0
internal/service/invoke_model.go

@@ -224,3 +224,107 @@ func ValidateModelCredentials(
 		max_timeout_seconds,
 	)
 }
+
+func GetTTSModelVoices(
+	r *plugin_entities.InvokePluginRequest[requests.RequestGetTTSModelVoices],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_GET_TTS_MODEL_VOICES,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.TTSModelVoice], error) {
+			return plugin_daemon.GetTTSModelVoices(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func GetTextEmbeddingNumTokens(
+	r *plugin_entities.InvokePluginRequest[requests.RequestGetTextEmbeddingNumTokens],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.GetTextEmbeddingNumTokensResponse], error) {
+			return plugin_daemon.GetTextEmbeddingNumTokens(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func GetAIModelSchema(
+	r *plugin_entities.InvokePluginRequest[requests.RequestGetAIModelSchema],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMA,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.GetModelSchemasResponse], error) {
+			return plugin_daemon.GetAIModelSchema(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func GetLLMNumTokens(
+	r *plugin_entities.InvokePluginRequest[requests.RequestGetLLMNumTokens],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.LLMGetNumTokensResponse], error) {
+			return plugin_daemon.GetLLMNumTokens(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}

+ 4 - 0
internal/types/entities/model_entities/ai_model.go

@@ -0,0 +1,4 @@
+package model_entities
+
+type GetModelSchemasResponse struct {
+}

+ 3 - 0
internal/types/entities/model_entities/llm.go

@@ -198,3 +198,6 @@ type LLMResultChunkDelta struct {
 	Usage        *LLMUsage     `json:"usage" validate:"omitempty"`
 	FinishReason *string       `json:"finish_reason" validate:"omitempty"`
 }
+
+type LLMGetNumTokensResponse struct {
+}

+ 3 - 0
internal/types/entities/model_entities/text_embedding.go

@@ -17,3 +17,6 @@ type TextEmbeddingResult struct {
 	Embeddings [][]float64    `json:"embeddings" validate:"required,dive"`
 	Usage      EmbeddingUsage `json:"usage" validate:"required"`
 }
+
+type GetTextEmbeddingNumTokensResponse struct {
+}

+ 3 - 0
internal/types/entities/model_entities/tts.go

@@ -3,3 +3,6 @@ package model_entities
 type TTSResult struct {
 	Result string `json:"result"` // in hex
 }
+
+type TTSModelVoice struct {
+}

+ 1 - 1
internal/types/entities/plugin_entities/request.go

@@ -6,7 +6,7 @@ type InvokePluginUserIdentity struct {
 }
 
 type BasePluginIdentifier struct {
-	PluginUniqueIdentifier PluginUniqueIdentifier `json:"plugin_unique_identifier" binding:"required"`
+	PluginUniqueIdentifier PluginUniqueIdentifier `json:"plugin_unique_identifier"`
 }
 
 type InvokePluginRequest[T any] struct {

+ 12 - 0
internal/types/entities/requests/model.go

@@ -105,3 +105,15 @@ type RequestValidateModelCredentials struct {
 	Model       string                   `json:"model" validate:"required"`
 	Credentials map[string]any           `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
+
+type RequestGetTTSModelVoices struct {
+}
+
+type RequestGetTextEmbeddingNumTokens struct {
+}
+
+type RequestGetLLMNumTokens struct {
+}
+
+type RequestGetAIModelSchema struct {
+}