keywords.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from collections.abc import Sequence
  2. from typing import Any
  3. from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
  4. class KeywordsModeration(Moderation):
  5. name: str = "keywords"
  6. @classmethod
  7. def validate_config(cls, tenant_id: str, config: dict) -> None:
  8. """
  9. Validate the incoming form config data.
  10. :param tenant_id: the id of workspace
  11. :param config: the form config data
  12. :return:
  13. """
  14. cls._validate_inputs_and_outputs_config(config, True)
  15. if not config.get("keywords"):
  16. raise ValueError("keywords is required")
  17. if len(config.get("keywords", [])) > 10000:
  18. raise ValueError("keywords length must be less than 10000")
  19. keywords_row_len = config["keywords"].split("\n")
  20. if len(keywords_row_len) > 100:
  21. raise ValueError("the number of rows for the keywords must be less than 100")
  22. def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
  23. flagged = False
  24. preset_response = ""
  25. if self.config is None:
  26. raise ValueError("The config is not set.")
  27. if self.config["inputs_config"]["enabled"]:
  28. preset_response = self.config["inputs_config"]["preset_response"]
  29. if query:
  30. inputs["query__"] = query
  31. # Filter out empty values
  32. keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
  33. flagged = self._is_violated(inputs, keywords_list)
  34. return ModerationInputsResult(
  35. flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
  36. )
  37. def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
  38. flagged = False
  39. preset_response = ""
  40. if self.config is None:
  41. raise ValueError("The config is not set.")
  42. if self.config["outputs_config"]["enabled"]:
  43. # Filter out empty values
  44. keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
  45. flagged = self._is_violated({"text": text}, keywords_list)
  46. preset_response = self.config["outputs_config"]["preset_response"]
  47. return ModerationOutputsResult(
  48. flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
  49. )
  50. def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
  51. return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
  52. def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
  53. return any(keyword.lower() in str(value).lower() for keyword in keywords_list)