浏览代码

Merge branch 'fix/chore-fix' of github.com:langgenius/dify into fix/chore-fix

Yeuoly 6 月之前
父节点
当前提交
21fd58caf9

+ 3 - 1
api/core/app/features/hosting_moderation/hosting_moderation.py

@@ -24,6 +24,8 @@ class HostingModerationFeature:
             if isinstance(prompt_message.content, str):
                 text += prompt_message.content + "\n"
 
-        moderation_result = moderation.check_moderation(model_config, text)
+        moderation_result = moderation.check_moderation(
+            tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
+        )
 
         return moderation_result

+ 18 - 7
api/core/helper/moderation.py

@@ -1,27 +1,32 @@
 import logging
 import random
+from typing import cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
+from core.entities import DEFAULT_PLUGIN_ID
+from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeBadRequestError
-from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
+from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
 from extensions.ext_hosting_provider import hosting_configuration
 from models.provider import ProviderType
 
 logger = logging.getLogger(__name__)
 
 
-def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
+def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
     moderation_config = hosting_configuration.moderation_config
+    openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
     if (
         moderation_config
         and moderation_config.enabled is True
-        and "openai" in hosting_configuration.provider_map
-        and hosting_configuration.provider_map["openai"].enabled is True
+        and openai_provider_name in hosting_configuration.provider_map
+        and hosting_configuration.provider_map[openai_provider_name].enabled is True
     ):
         using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
         provider_name = model_config.provider
         if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
-            hosting_openai_config = hosting_configuration.provider_map["openai"]
+            hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
 
             if hosting_openai_config.credentials is None:
                 return False
@@ -36,9 +41,15 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
             text_chunk = random.choice(text_chunks)
 
             try:
-                model_type_instance = OpenAIModerationModel()
+                model_provider_factory = ModelProviderFactory(tenant_id)
+
+                # Get model instance of LLM
+                model_type_instance = model_provider_factory.get_model_type_instance(
+                    provider=openai_provider_name, model_type=ModelType.MODERATION
+                )
+                model_type_instance = cast(ModerationModel, model_type_instance)
                 moderation_result = model_type_instance.invoke(
-                    model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk
+                    model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
                 )
 
                 if moderation_result is True:

+ 3 - 0
api/events/event_handlers/deduct_quota_when_message_created.py

@@ -22,6 +22,9 @@ def handle(sender, **kwargs):
 
     system_configuration = provider_configuration.system_configuration
 
+    if not system_configuration.current_quota_type:
+        return
+
     quota_unit = None
     for quota_configuration in system_configuration.quota_configurations:
         if quota_configuration.quota_type == system_configuration.current_quota_type: