llm_callback_handler.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import logging
  2. import time
  3. from typing import Any, Dict, List, Union, Optional
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  8. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  9. from core.llm.streamable_open_ai import StreamableOpenAI
  10. class LLMCallbackHandler(BaseCallbackHandler):
  11. raise_error: bool = True
  12. def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  13. conversation_message_task: ConversationMessageTask):
  14. self.llm = llm
  15. self.llm_message = LLMMessage()
  16. self.start_at = None
  17. self.conversation_message_task = conversation_message_task
  18. @property
  19. def always_verbose(self) -> bool:
  20. """Whether to call verbose callbacks even if verbose is False."""
  21. return True
  22. def on_chat_model_start(
  23. self,
  24. serialized: Dict[str, Any],
  25. messages: List[List[BaseMessage]],
  26. **kwargs: Any
  27. ) -> Any:
  28. self.start_at = time.perf_counter()
  29. real_prompts = []
  30. for message in messages[0]:
  31. if message.type == 'human':
  32. role = 'user'
  33. elif message.type == 'ai':
  34. role = 'assistant'
  35. else:
  36. role = 'system'
  37. real_prompts.append({
  38. "role": role,
  39. "text": message.content
  40. })
  41. self.llm_message.prompt = real_prompts
  42. self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
  43. def on_llm_start(
  44. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  45. ) -> None:
  46. self.start_at = time.perf_counter()
  47. self.llm_message.prompt = [{
  48. "role": 'user',
  49. "text": prompts[0]
  50. }]
  51. self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
  52. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  53. end_at = time.perf_counter()
  54. self.llm_message.latency = end_at - self.start_at
  55. if not self.conversation_message_task.streaming:
  56. self.conversation_message_task.append_message_text(response.generations[0][0].text)
  57. self.llm_message.completion = response.generations[0][0].text
  58. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  59. else:
  60. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  61. self.conversation_message_task.save_message(self.llm_message)
  62. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  63. try:
  64. self.conversation_message_task.append_message_text(token)
  65. except ConversationTaskStoppedException as ex:
  66. self.on_llm_error(error=ex)
  67. raise ex
  68. self.llm_message.completion += token
  69. def on_llm_error(
  70. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  71. ) -> None:
  72. """Do nothing."""
  73. if isinstance(error, ConversationTaskStoppedException):
  74. if self.conversation_message_task.streaming:
  75. end_at = time.perf_counter()
  76. self.llm_message.latency = end_at - self.start_at
  77. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  78. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  79. else:
  80. logging.error(error)