Ver código fonte

feat: add plugin_model_providers context

takatost 4 meses atrás
pai
commit
d5c708c62b

+ 5 - 0
api/contexts/__init__.py

@@ -4,11 +4,16 @@ from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
     from core.tools.plugin_tool.provider import PluginToolProviderController
+    from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
     from core.workflow.entities.variable_pool import VariablePool
 
+
 tenant_id: ContextVar[str] = ContextVar("tenant_id")
 
 workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
 
 plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
 plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
+
+plugin_model_providers: ContextVar[list["PluginModelProviderEntity"]] = ContextVar("plugin_model_providers")
+plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

+ 18 - 5
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -1,10 +1,12 @@
 import logging
 import os
 from collections.abc import Sequence
+from threading import Lock
 from typing import Optional
 
 from pydantic import BaseModel
 
+import contexts
 from core.entities import DEFAULT_PLUGIN_ID
 from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -71,13 +73,24 @@ class ModelProviderFactory:
         Get all plugin model providers
         :return: list of plugin model providers
         """
-        # Fetch plugin model providers
-        plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
+        # check if context is set
+        try:
+            contexts.plugin_model_providers.get()
+        except LookupError:
+            contexts.plugin_model_providers.set([])
+            contexts.plugin_model_providers_lock.set(Lock())
 
-        for provider in plugin_providers:
-            provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
+        with contexts.plugin_model_providers_lock.get():
+            plugin_model_providers = contexts.plugin_model_providers.get()
+
+            # Fetch plugin model providers
+            plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
+
+            for provider in plugin_providers:
+                provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
+                plugin_model_providers.append(provider)
 
-        return plugin_providers
+            return plugin_model_providers
 
     def get_provider_schema(self, provider: str) -> ProviderEntity:
         """