|
@@ -1,8 +1,11 @@
|
|
|
import decimal
|
|
|
+import hashlib
|
|
|
+from threading import Lock
|
|
|
from typing import Optional
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
|
|
+import contexts
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
|
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
|
|
from core.model_runtime.entities.model_entities import (
|
|
@@ -139,15 +142,35 @@ class AIModel(BaseModel):
|
|
|
:return: model schema
|
|
|
"""
|
|
|
plugin_model_manager = PluginModelManager()
|
|
|
- return plugin_model_manager.get_model_schema(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- user_id="unknown",
|
|
|
- plugin_id=self.plugin_id,
|
|
|
- provider=self.provider_name,
|
|
|
- model_type=self.model_type.value,
|
|
|
- model=model,
|
|
|
- credentials=credentials or {},
|
|
|
- )
|
|
|
+ cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
|
|
+ # sort credentials
|
|
|
+ sorted_credentials = sorted(credentials.items()) if credentials else []
|
|
|
+ cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
|
|
+
|
|
|
+ try:
|
|
|
+ contexts.plugin_model_schemas.get()
|
|
|
+ except LookupError:
|
|
|
+ contexts.plugin_model_schemas.set({})
|
|
|
+ contexts.plugin_model_schema_lock.set(Lock())
|
|
|
+
|
|
|
+ with contexts.plugin_model_schema_lock.get():
|
|
|
+ if cache_key in contexts.plugin_model_schemas.get():
|
|
|
+ return contexts.plugin_model_schemas.get()[cache_key]
|
|
|
+
|
|
|
+ schema = plugin_model_manager.get_model_schema(
|
|
|
+ tenant_id=self.tenant_id,
|
|
|
+ user_id="unknown",
|
|
|
+ plugin_id=self.plugin_id,
|
|
|
+ provider=self.provider_name,
|
|
|
+ model_type=self.model_type.value,
|
|
|
+ model=model,
|
|
|
+ credentials=credentials or {},
|
|
|
+ )
|
|
|
+
|
|
|
+ if schema:
|
|
|
+ contexts.plugin_model_schemas.get()[cache_key] = schema
|
|
|
+
|
|
|
+ return schema
|
|
|
|
|
|
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
|
|
"""
|