openai_moderation.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from core.model_manager import ModelManager
  2. from core.model_runtime.entities.model_entities import ModelType
  3. from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
  4. class OpenAIModeration(Moderation):
  5. name: str = "openai_moderation"
  6. @classmethod
  7. def validate_config(cls, tenant_id: str, config: dict) -> None:
  8. """
  9. Validate the incoming form config data.
  10. :param tenant_id: the id of workspace
  11. :param config: the form config data
  12. :return:
  13. """
  14. cls._validate_inputs_and_outputs_config(config, True)
  15. def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
  16. flagged = False
  17. preset_response = ""
  18. if self.config["inputs_config"]["enabled"]:
  19. preset_response = self.config["inputs_config"]["preset_response"]
  20. if query:
  21. inputs["query__"] = query
  22. flagged = self._is_violated(inputs)
  23. return ModerationInputsResult(
  24. flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
  25. )
  26. def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
  27. flagged = False
  28. preset_response = ""
  29. if self.config["outputs_config"]["enabled"]:
  30. flagged = self._is_violated({"text": text})
  31. preset_response = self.config["outputs_config"]["preset_response"]
  32. return ModerationOutputsResult(
  33. flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
  34. )
  35. def _is_violated(self, inputs: dict):
  36. text = "\n".join(str(inputs.values()))
  37. model_manager = ModelManager()
  38. model_instance = model_manager.get_model_instance(
  39. tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
  40. )
  41. openai_moderation = model_instance.invoke_moderation(text=text)
  42. return openai_moderation