瀏覽代碼

integrate model provider with plugin daemon

takatost 1 年之前
父節點
當前提交
18edeb8e0a

+ 2 - 2
api/controllers/console/workspace/load_balancing_config.py

@@ -113,10 +113,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
 # Load Balancing Config
 api.add_resource(
     LoadBalancingCredentialsValidateApi,
-    "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
+    "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
 )
 
 api.add_resource(
     LoadBalancingConfigCredentialsValidateApi,
-    "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
+    "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
 )

+ 9 - 9
api/controllers/console/workspace/model_providers.py

@@ -1,6 +1,6 @@
 import io
 
-from flask import request, send_file
+from flask import send_file
 from flask_login import current_user
 from flask_restful import Resource, reqparse
 from werkzeug.exceptions import Forbidden
@@ -126,11 +126,7 @@ class ModelProviderIconApi(Resource):
     Get model provider icon
     """
 
-    def get(self, provider: str, icon_type: str, lang: str):
-        tenant_id = request.args.get("tenant_id")
-        if not tenant_id:
-            return {"content": "Invalid request."}, 400
-
+    def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
         model_provider_service = ModelProviderService()
         icon, mimetype = model_provider_service.get_model_provider_icon(
             tenant_id=tenant_id,
@@ -139,6 +135,9 @@ class ModelProviderIconApi(Resource):
             lang=lang,
         )
 
+        if not icon:
+            return {"message": "Icon not found"}, 404
+
         return send_file(io.BytesIO(icon), mimetype=mimetype)
 
 
@@ -193,11 +192,12 @@ api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
 api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
 api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
 api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
-api.add_resource(
-    ModelProviderIconApi, "/workspaces/current/model-providers/<path:provider>/<string:icon_type>/<string:lang>"
-)
 
 api.add_resource(
     PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
 )
 api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
+api.add_resource(
+    ModelProviderIconApi,
+    "/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
+)

+ 7 - 7
api/controllers/console/workspace/models.py

@@ -320,7 +320,7 @@ class ModelProviderModelValidateApi(Resource):
         response = {"result": "success" if result else "error"}
 
         if not result:
-            response["error"] = error
+            response["error"] = error or ""
 
         return response
 
@@ -357,26 +357,26 @@ class ModelProviderAvailableModelApi(Resource):
         return jsonable_encoder({"data": models})
 
 
-api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
+api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
 api.add_resource(
     ModelProviderModelEnableApi,
-    "/workspaces/current/model-providers/<string:provider>/models/enable",
+    "/workspaces/current/model-providers/<path:provider>/models/enable",
     endpoint="model-provider-model-enable",
 )
 api.add_resource(
     ModelProviderModelDisableApi,
-    "/workspaces/current/model-providers/<string:provider>/models/disable",
+    "/workspaces/current/model-providers/<path:provider>/models/disable",
     endpoint="model-provider-model-disable",
 )
 api.add_resource(
-    ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
+    ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
 )
 api.add_resource(
-    ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
+    ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
 )
 
 api.add_resource(
-    ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
+    ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
 )
 api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
 api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

+ 1 - 0
api/core/entities/__init__.py

@@ -0,0 +1 @@
+DEFAULT_PLUGIN_ID = "langgenius"

+ 2 - 2
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -1,7 +1,7 @@
 import decimal
 from typing import Optional
 
-from pydantic import ConfigDict, Field
+from pydantic import BaseModel, ConfigDict, Field
 
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -15,7 +15,7 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
 from core.plugin.manager.model import PluginModelManager
 
 
-class AIModel:
+class AIModel(BaseModel):
     """
     Base class for all models.
     """

+ 4 - 3
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -5,6 +5,7 @@ from typing import Optional
 
 from pydantic import BaseModel
 
+from core.entities import DEFAULT_PLUGIN_ID
 from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
@@ -132,7 +133,7 @@ class ModelProviderFactory:
             tenant_id=self.tenant_id,
             user_id="unknown",
             plugin_id=plugin_model_provider_entity.plugin_id,
-            provider=provider,
+            provider=plugin_model_provider_entity.provider,
             credentials=filtered_credentials,
         )
 
@@ -167,7 +168,7 @@ class ModelProviderFactory:
             tenant_id=self.tenant_id,
             user_id="unknown",
             plugin_id=plugin_model_provider_entity.plugin_id,
-            provider=provider,
+            provider=plugin_model_provider_entity.provider,
             model_type=model_type.value,
             model=model,
             credentials=filtered_credentials,
@@ -337,7 +338,7 @@ class ModelProviderFactory:
         :param provider: provider name
         :return: plugin id and provider name
         """
-        plugin_id = "langgenius"
+        plugin_id = DEFAULT_PLUGIN_ID
         provider_name = provider
         if "/" in provider:
             # get the plugin_id before provider

+ 11 - 3
api/core/plugin/manager/base.py

@@ -1,4 +1,5 @@
 import json
+import logging
 from collections.abc import Callable, Generator
 from typing import Optional, TypeVar
 
@@ -21,6 +22,8 @@ plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
 
 T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
 
+logger = logging.getLogger(__name__)
+
 
 class BasePluginManager:
     def _request(
@@ -44,9 +47,14 @@ class BasePluginManager:
         if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
             data = json.dumps(data)
 
-        response = requests.request(
-            method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
-        )
+        try:
+            response = requests.request(
+                method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
+            )
+        except requests.exceptions.ConnectionError as e:
+            logger.exception(f"Request to Plugin Daemon Service failed: {e}")
+            raise ValueError("Request to Plugin Daemon Service failed")
+
         return response
 
     def _stream_request(

+ 17 - 5
api/services/entities/model_provider_entities.py

@@ -50,6 +50,7 @@ class ProviderResponse(BaseModel):
     Model class for provider response.
     """
 
+    tenant_id: str
     provider: str
     label: I18nObject
     description: Optional[I18nObject] = None
@@ -71,7 +72,9 @@ class ProviderResponse(BaseModel):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
+        url_prefix = (
+            dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
+        )
         if self.icon_small is not None:
             self.icon_small = I18nObject(
                 en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
@@ -88,6 +91,7 @@ class ProviderWithModelsResponse(BaseModel):
     Model class for provider with models response.
     """
 
+    tenant_id: str
     provider: str
     label: I18nObject
     icon_small: Optional[I18nObject] = None
@@ -98,7 +102,9 @@ class ProviderWithModelsResponse(BaseModel):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
+        url_prefix = (
+            dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
+        )
         if self.icon_small is not None:
             self.icon_small = I18nObject(
                 en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
@@ -115,10 +121,14 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
     Simple provider entity response.
     """
 
+    tenant_id: str
+
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
+        url_prefix = (
+            dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
+        )
         if self.icon_small is not None:
             self.icon_small = I18nObject(
                 en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
@@ -150,5 +160,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
 
     provider: SimpleProviderEntityResponse
 
-    def __init__(self, model: ModelWithProviderEntity) -> None:
-        super().__init__(**model.model_dump())
+    def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
+        dump_model = model.model_dump()
+        dump_model["provider"]["tenant_id"] = tenant_id
+        super().__init__(**dump_model)

+ 5 - 1
api/services/model_provider_service.py

@@ -47,6 +47,7 @@ class ModelProviderService:
                     continue
 
             provider_response = ProviderResponse(
+                tenant_id=tenant_id,
                 provider=provider_configuration.provider.provider,
                 label=provider_configuration.provider.label,
                 description=provider_configuration.provider.description,
@@ -90,7 +91,8 @@ class ModelProviderService:
 
         # Get provider available models
         return [
-            ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
+            ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
+            for model in provider_configurations.get_models(provider=provider)
         ]
 
     def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
@@ -303,6 +305,7 @@ class ModelProviderService:
 
             providers_with_models.append(
                 ProviderWithModelsResponse(
+                    tenant_id=tenant_id,
                     provider=provider,
                     label=first_model.provider.label,
                     icon_small=first_model.provider.icon_small,
@@ -373,6 +376,7 @@ class ModelProviderService:
                     model=result.model,
                     model_type=result.model_type,
                     provider=SimpleProviderEntityResponse(
+                        tenant_id=tenant_id,
                         provider=result.provider.provider,
                         label=result.provider.label,
                         icon_small=result.provider.icon_small,