moderation_handler.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Optional, Dict
  5. from flask import current_app, Flask
  6. from pydantic import BaseModel
  7. from core.moderation.base import ModerationAction, ModerationOutputsResult
  8. from core.moderation.factory import ModerationFactory
  9. logger = logging.getLogger(__name__)
  10. class ModerationRule(BaseModel):
  11. type: str
  12. config: Dict[str, Any]
  13. class OutputModerationHandler(BaseModel):
  14. DEFAULT_BUFFER_SIZE: int = 300
  15. tenant_id: str
  16. app_id: str
  17. rule: ModerationRule
  18. on_message_replace_func: Any
  19. thread: Optional[threading.Thread] = None
  20. thread_running: bool = True
  21. buffer: str = ''
  22. is_final_chunk: bool = False
  23. final_output: Optional[str] = None
  24. class Config:
  25. arbitrary_types_allowed = True
  26. def should_direct_output(self):
  27. return self.final_output is not None
  28. def get_final_output(self):
  29. return self.final_output
  30. def append_new_token(self, token: str):
  31. self.buffer += token
  32. if not self.thread:
  33. self.thread = self.start_thread()
  34. def moderation_completion(self, completion: str, public_event: bool = False) -> str:
  35. self.buffer = completion
  36. self.is_final_chunk = True
  37. result = self.moderation(
  38. tenant_id=self.tenant_id,
  39. app_id=self.app_id,
  40. moderation_buffer=completion
  41. )
  42. if not result or not result.flagged:
  43. return completion
  44. if result.action == ModerationAction.DIRECT_OUTPUT:
  45. final_output = result.preset_response
  46. else:
  47. final_output = result.text
  48. if public_event:
  49. self.on_message_replace_func(final_output)
  50. return final_output
  51. def start_thread(self) -> threading.Thread:
  52. buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
  53. thread = threading.Thread(target=self.worker, kwargs={
  54. 'flask_app': current_app._get_current_object(),
  55. 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
  56. })
  57. thread.start()
  58. return thread
  59. def stop_thread(self):
  60. if self.thread and self.thread.is_alive():
  61. self.thread_running = False
  62. def worker(self, flask_app: Flask, buffer_size: int):
  63. with flask_app.app_context():
  64. current_length = 0
  65. while self.thread_running:
  66. moderation_buffer = self.buffer
  67. buffer_length = len(moderation_buffer)
  68. if not self.is_final_chunk:
  69. chunk_length = buffer_length - current_length
  70. if 0 <= chunk_length < buffer_size:
  71. time.sleep(1)
  72. continue
  73. current_length = buffer_length
  74. result = self.moderation(
  75. tenant_id=self.tenant_id,
  76. app_id=self.app_id,
  77. moderation_buffer=moderation_buffer
  78. )
  79. if not result or not result.flagged:
  80. continue
  81. if result.action == ModerationAction.DIRECT_OUTPUT:
  82. final_output = result.preset_response
  83. self.final_output = final_output
  84. else:
  85. final_output = result.text + self.buffer[len(moderation_buffer):]
  86. # trigger replace event
  87. if self.thread_running:
  88. self.on_message_replace_func(final_output)
  89. if result.action == ModerationAction.DIRECT_OUTPUT:
  90. break
  91. def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
  92. try:
  93. moderation_factory = ModerationFactory(
  94. name=self.rule.type,
  95. app_id=app_id,
  96. tenant_id=tenant_id,
  97. config=self.rule.config
  98. )
  99. result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
  100. return result
  101. except Exception as e:
  102. logger.error("Moderation Output error: %s", e)
  103. return None