浏览代码

feat: Add caching mechanism for plugin model schemas (#14898)

Yeuoly 1 月之前
父节点
当前提交
4668c4996a

+ 9 - 0
api/contexts/__init__.py

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
 from contexts.wrapper import RecyclableContextVar
 
 if TYPE_CHECKING:
+    from core.model_runtime.entities.model_entities import AIModelEntity
     from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
     from core.tools.plugin_tool.provider import PluginToolProviderController
     from core.workflow.entities.variable_pool import VariablePool
@@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo
 plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
     ContextVar("plugin_tool_providers")
 )
+
 plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
 
 plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
     ContextVar("plugin_model_providers")
 )
+
 plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
     ContextVar("plugin_model_providers_lock")
 )
+
+plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
+
+plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
+    ContextVar("plugin_model_schemas")
+)

+ 32 - 9
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -1,8 +1,11 @@
 import decimal
+import hashlib
+from threading import Lock
 from typing import Optional
 
 from pydantic import BaseModel, ConfigDict, Field
 
+import contexts
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
 from core.model_runtime.entities.model_entities import (
@@ -139,15 +142,35 @@ class AIModel(BaseModel):
         :return: model schema
         """
         plugin_model_manager = PluginModelManager()
-        return plugin_model_manager.get_model_schema(
-            tenant_id=self.tenant_id,
-            user_id="unknown",
-            plugin_id=self.plugin_id,
-            provider=self.provider_name,
-            model_type=self.model_type.value,
-            model=model,
-            credentials=credentials or {},
-        )
+        cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
+        # sort credentials
+        sorted_credentials = sorted(credentials.items()) if credentials else []
+        cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
+
+        try:
+            contexts.plugin_model_schemas.get()
+        except LookupError:
+            contexts.plugin_model_schemas.set({})
+            contexts.plugin_model_schema_lock.set(Lock())
+
+        with contexts.plugin_model_schema_lock.get():
+            if cache_key in contexts.plugin_model_schemas.get():
+                return contexts.plugin_model_schemas.get()[cache_key]
+
+            schema = plugin_model_manager.get_model_schema(
+                tenant_id=self.tenant_id,
+                user_id="unknown",
+                plugin_id=self.plugin_id,
+                provider=self.provider_name,
+                model_type=self.model_type.value,
+                model=model,
+                credentials=credentials or {},
+            )
+
+            if schema:
+                contexts.plugin_model_schemas.get()[cache_key] = schema
+
+            return schema
 
     def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
         """

+ 29 - 10
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -1,3 +1,4 @@
+import hashlib
 import logging
 import os
 from collections.abc import Sequence
@@ -206,17 +207,35 @@ class ModelProviderFactory:
         Get model schema
         """
         plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
-        model_schema = self.plugin_model_manager.get_model_schema(
-            tenant_id=self.tenant_id,
-            user_id="unknown",
-            plugin_id=plugin_id,
-            provider=provider_name,
-            model_type=model_type.value,
-            model=model,
-            credentials=credentials,
-        )
+        cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
+        # sort credentials
+        sorted_credentials = sorted(credentials.items()) if credentials else []
+        cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
 
-        return model_schema
+        try:
+            contexts.plugin_model_schemas.get()
+        except LookupError:
+            contexts.plugin_model_schemas.set({})
+            contexts.plugin_model_schema_lock.set(Lock())
+
+        with contexts.plugin_model_schema_lock.get():
+            if cache_key in contexts.plugin_model_schemas.get():
+                return contexts.plugin_model_schemas.get()[cache_key]
+
+            schema = self.plugin_model_manager.get_model_schema(
+                tenant_id=self.tenant_id,
+                user_id="unknown",
+                plugin_id=plugin_id,
+                provider=provider_name,
+                model_type=model_type.value,
+                model=model,
+                credentials=credentials or {},
+            )
+
+            if schema:
+                contexts.plugin_model_schemas.get()[cache_key] = schema
+
+            return schema
 
     def get_models(
         self,