llm_callback_handler.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import logging
  2. import time
  3. from typing import Any, Dict, List, Union, Optional
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  8. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  9. from core.llm.streamable_open_ai import StreamableOpenAI
  10. class LLMCallbackHandler(BaseCallbackHandler):
  11. def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  12. conversation_message_task: ConversationMessageTask):
  13. self.llm = llm
  14. self.llm_message = LLMMessage()
  15. self.start_at = None
  16. self.conversation_message_task = conversation_message_task
  17. @property
  18. def always_verbose(self) -> bool:
  19. """Whether to call verbose callbacks even if verbose is False."""
  20. return True
  21. def on_llm_start(
  22. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  23. ) -> None:
  24. self.start_at = time.perf_counter()
  25. if 'Chat' in serialized['name']:
  26. real_prompts = []
  27. messages = []
  28. for prompt in prompts:
  29. role, content = prompt.split(': ', maxsplit=1)
  30. if role == 'human':
  31. role = 'user'
  32. message = HumanMessage(content=content)
  33. elif role == 'ai':
  34. role = 'assistant'
  35. message = AIMessage(content=content)
  36. else:
  37. message = SystemMessage(content=content)
  38. real_prompt = {
  39. "role": role,
  40. "text": content
  41. }
  42. real_prompts.append(real_prompt)
  43. messages.append(message)
  44. self.llm_message.prompt = real_prompts
  45. self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
  46. else:
  47. self.llm_message.prompt = [{
  48. "role": 'user',
  49. "text": prompts[0]
  50. }]
  51. self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
  52. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  53. end_at = time.perf_counter()
  54. self.llm_message.latency = end_at - self.start_at
  55. if not self.conversation_message_task.streaming:
  56. self.conversation_message_task.append_message_text(response.generations[0][0].text)
  57. self.llm_message.completion = response.generations[0][0].text
  58. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  59. else:
  60. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  61. self.conversation_message_task.save_message(self.llm_message)
  62. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  63. try:
  64. self.conversation_message_task.append_message_text(token)
  65. except ConversationTaskStoppedException as ex:
  66. self.on_llm_error(error=ex)
  67. self.llm_message.completion += token
  68. def on_llm_error(
  69. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  70. ) -> None:
  71. """Do nothing."""
  72. if isinstance(error, ConversationTaskStoppedException):
  73. if self.conversation_message_task.streaming:
  74. end_at = time.perf_counter()
  75. self.llm_message.latency = end_at - self.start_at
  76. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  77. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  78. else:
  79. logging.error(error)
  80. def on_chain_start(
  81. self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
  82. ) -> None:
  83. pass
  84. def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
  85. pass
  86. def on_chain_error(
  87. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  88. ) -> None:
  89. pass
  90. def on_tool_start(
  91. self,
  92. serialized: Dict[str, Any],
  93. input_str: str,
  94. **kwargs: Any,
  95. ) -> None:
  96. pass
  97. def on_agent_action(
  98. self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
  99. ) -> Any:
  100. pass
  101. def on_tool_end(
  102. self,
  103. output: str,
  104. color: Optional[str] = None,
  105. observation_prefix: Optional[str] = None,
  106. llm_prefix: Optional[str] = None,
  107. **kwargs: Any,
  108. ) -> None:
  109. pass
  110. def on_tool_error(
  111. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  112. ) -> None:
  113. pass
  114. def on_text(
  115. self,
  116. text: str,
  117. color: Optional[str] = None,
  118. end: str = "",
  119. **kwargs: Optional[str],
  120. ) -> None:
  121. pass
  122. def on_agent_finish(
  123. self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
  124. ) -> None:
  125. pass