|
@@ -1,10 +1,12 @@
|
|
|
import logging
|
|
|
import os
|
|
|
from collections.abc import Sequence
|
|
|
+from threading import Lock
|
|
|
from typing import Optional
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
+import contexts
|
|
|
from core.entities import DEFAULT_PLUGIN_ID
|
|
|
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
|
@@ -71,13 +73,24 @@ class ModelProviderFactory:
|
|
|
Get all plugin model providers
|
|
|
:return: list of plugin model providers
|
|
|
"""
|
|
|
- # Fetch plugin model providers
|
|
|
- plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
|
|
+ # check if context is set
|
|
|
+ try:
|
|
|
+ contexts.plugin_model_providers.get()
|
|
|
+ except LookupError:
|
|
|
+ contexts.plugin_model_providers.set([])
|
|
|
+ contexts.plugin_model_providers_lock.set(Lock())
|
|
|
|
|
|
- for provider in plugin_providers:
|
|
|
- provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
|
|
+ with contexts.plugin_model_providers_lock.get():
|
|
|
+ plugin_model_providers = contexts.plugin_model_providers.get()
|
|
|
+
|
|
|
+ # Fetch plugin model providers
|
|
|
+ plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
|
|
+
|
|
|
+ for provider in plugin_providers:
|
|
|
+ provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
|
|
+ plugin_model_providers.append(provider)
|
|
|
|
|
|
- return plugin_providers
|
|
|
+ return plugin_model_providers
|
|
|
|
|
|
def get_provider_schema(self, provider: str) -> ProviderEntity:
|
|
|
"""
|