input_moderation.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import logging
  2. from collections.abc import Mapping
  3. from typing import Any, Optional
  4. from core.app.app_config.entities import AppConfig
  5. from core.moderation.base import ModerationAction, ModerationError
  6. from core.moderation.factory import ModerationFactory
  7. from core.ops.entities.trace_entity import TraceTaskName
  8. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  9. from core.ops.utils import measure_time
  10. logger = logging.getLogger(__name__)
  11. class InputModeration:
  12. def check(
  13. self,
  14. app_id: str,
  15. tenant_id: str,
  16. app_config: AppConfig,
  17. inputs: Mapping[str, Any],
  18. query: str,
  19. message_id: str,
  20. trace_manager: Optional[TraceQueueManager] = None,
  21. ) -> tuple[bool, Mapping[str, Any], str]:
  22. """
  23. Process sensitive_word_avoidance.
  24. :param app_id: app id
  25. :param tenant_id: tenant id
  26. :param app_config: app config
  27. :param inputs: inputs
  28. :param query: query
  29. :param message_id: message id
  30. :param trace_manager: trace manager
  31. :return:
  32. """
  33. inputs = dict(inputs)
  34. if not app_config.sensitive_word_avoidance:
  35. return False, inputs, query
  36. sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
  37. moderation_type = sensitive_word_avoidance_config.type
  38. moderation_factory = ModerationFactory(
  39. name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
  40. )
  41. with measure_time() as timer:
  42. moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
  43. if trace_manager:
  44. trace_manager.add_trace_task(
  45. TraceTask(
  46. TraceTaskName.MODERATION_TRACE,
  47. message_id=message_id,
  48. moderation_result=moderation_result,
  49. inputs=inputs,
  50. timer=timer,
  51. )
  52. )
  53. if not moderation_result.flagged:
  54. return False, inputs, query
  55. if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
  56. raise ModerationError(moderation_result.preset_response)
  57. elif moderation_result.action == ModerationAction.OVERRIDDEN:
  58. inputs = moderation_result.inputs
  59. query = moderation_result.query
  60. return True, inputs, query