llm_callback_handler.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import logging
  2. import time
  3. from typing import Any, Dict, List, Union, Optional
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  8. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  9. from core.llm.streamable_open_ai import StreamableOpenAI
  10. class LLMCallbackHandler(BaseCallbackHandler):
  11. def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  12. conversation_message_task: ConversationMessageTask):
  13. self.llm = llm
  14. self.llm_message = LLMMessage()
  15. self.start_at = None
  16. self.conversation_message_task = conversation_message_task
  17. @property
  18. def always_verbose(self) -> bool:
  19. """Whether to call verbose callbacks even if verbose is False."""
  20. return True
  21. def on_llm_start(
  22. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  23. ) -> None:
  24. self.start_at = time.perf_counter()
  25. if 'Chat' in serialized['name']:
  26. real_prompts = []
  27. messages = []
  28. for prompt in prompts:
  29. role, content = prompt.split(': ', maxsplit=1)
  30. if role == 'human':
  31. role = 'user'
  32. message = HumanMessage(content=content)
  33. elif role == 'ai':
  34. role = 'assistant'
  35. message = AIMessage(content=content)
  36. else:
  37. message = SystemMessage(content=content)
  38. real_prompt = {
  39. "role": role,
  40. "text": content
  41. }
  42. real_prompts.append(real_prompt)
  43. messages.append(message)
  44. self.llm_message.prompt = real_prompts
  45. self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
  46. else:
  47. self.llm_message.prompt = [{
  48. "role": 'user',
  49. "text": prompts[0]
  50. }]
  51. self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
  52. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  53. end_at = time.perf_counter()
  54. self.llm_message.latency = end_at - self.start_at
  55. if not self.conversation_message_task.streaming:
  56. self.conversation_message_task.append_message_text(response.generations[0][0].text)
  57. self.llm_message.completion = response.generations[0][0].text
  58. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  59. else:
  60. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  61. self.conversation_message_task.save_message(self.llm_message)
  62. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  63. try:
  64. self.conversation_message_task.append_message_text(token)
  65. except ConversationTaskStoppedException as ex:
  66. self.on_llm_error(error=ex)
  67. raise ex
  68. self.llm_message.completion += token
  69. def on_llm_error(
  70. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  71. ) -> None:
  72. """Do nothing."""
  73. if isinstance(error, ConversationTaskStoppedException):
  74. if self.conversation_message_task.streaming:
  75. end_at = time.perf_counter()
  76. self.llm_message.latency = end_at - self.start_at
  77. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  78. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  79. else:
  80. logging.error(error)
  81. def on_chain_start(
  82. self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
  83. ) -> None:
  84. pass
  85. def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
  86. pass
  87. def on_chain_error(
  88. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  89. ) -> None:
  90. pass
  91. def on_tool_start(
  92. self,
  93. serialized: Dict[str, Any],
  94. input_str: str,
  95. **kwargs: Any,
  96. ) -> None:
  97. pass
  98. def on_agent_action(
  99. self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
  100. ) -> Any:
  101. pass
  102. def on_tool_end(
  103. self,
  104. output: str,
  105. color: Optional[str] = None,
  106. observation_prefix: Optional[str] = None,
  107. llm_prefix: Optional[str] = None,
  108. **kwargs: Any,
  109. ) -> None:
  110. pass
  111. def on_tool_error(
  112. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  113. ) -> None:
  114. pass
  115. def on_text(
  116. self,
  117. text: str,
  118. color: Optional[str] = None,
  119. end: str = "",
  120. **kwargs: Optional[str],
  121. ) -> None:
  122. pass
  123. def on_agent_finish(
  124. self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
  125. ) -> None:
  126. pass