openai_moderation.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
  2. from core.model_providers.model_factory import ModelFactory
  3. class OpenAIModeration(Moderation):
  4. name: str = "openai_moderation"
  5. @classmethod
  6. def validate_config(cls, tenant_id: str, config: dict) -> None:
  7. """
  8. Validate the incoming form config data.
  9. :param tenant_id: the id of workspace
  10. :param config: the form config data
  11. :return:
  12. """
  13. cls._validate_inputs_and_outputs_config(config, True)
  14. def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
  15. flagged = False
  16. preset_response = ""
  17. if self.config['inputs_config']['enabled']:
  18. preset_response = self.config['inputs_config']['preset_response']
  19. if query:
  20. inputs['query__'] = query
  21. flagged = self._is_violated(inputs)
  22. return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
  23. def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
  24. flagged = False
  25. preset_response = ""
  26. if self.config['outputs_config']['enabled']:
  27. flagged = self._is_violated({'text': text})
  28. preset_response = self.config['outputs_config']['preset_response']
  29. return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
  30. def _is_violated(self, inputs: dict):
  31. text = '\n'.join(inputs.values())
  32. openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation")
  33. is_not_invalid = openai_moderation.run(text)
  34. return not is_not_invalid