llm_callback_handler.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import logging
  2. from typing import Any, Dict, List, Union
  3. from langchain.callbacks.base import BaseCallbackHandler
  4. from langchain.schema import LLMResult, BaseMessage
  5. from core.callback_handler.entity.llm_message import LLMMessage
  6. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  7. from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
  8. from core.model_providers.models.llm.base import BaseLLM
  9. class LLMCallbackHandler(BaseCallbackHandler):
  10. raise_error: bool = True
  11. def __init__(self, model_instance: BaseLLM,
  12. conversation_message_task: ConversationMessageTask):
  13. self.model_instance = model_instance
  14. self.llm_message = LLMMessage()
  15. self.start_at = None
  16. self.conversation_message_task = conversation_message_task
  17. @property
  18. def always_verbose(self) -> bool:
  19. """Whether to call verbose callbacks even if verbose is False."""
  20. return True
  21. def on_chat_model_start(
  22. self,
  23. serialized: Dict[str, Any],
  24. messages: List[List[BaseMessage]],
  25. **kwargs: Any
  26. ) -> Any:
  27. real_prompts = []
  28. for message in messages[0]:
  29. if message.type == 'human':
  30. role = 'user'
  31. elif message.type == 'ai':
  32. role = 'assistant'
  33. else:
  34. role = 'system'
  35. real_prompts.append({
  36. "role": role,
  37. "text": message.content
  38. })
  39. self.llm_message.prompt = real_prompts
  40. self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
  41. def on_llm_start(
  42. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  43. ) -> None:
  44. self.llm_message.prompt = [{
  45. "role": 'user',
  46. "text": prompts[0]
  47. }]
  48. self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
  49. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  50. if not self.conversation_message_task.streaming:
  51. self.conversation_message_task.append_message_text(response.generations[0][0].text)
  52. self.llm_message.completion = response.generations[0][0].text
  53. if response.llm_output and 'token_usage' in response.llm_output:
  54. if 'prompt_tokens' in response.llm_output['token_usage']:
  55. self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
  56. if 'completion_tokens' in response.llm_output['token_usage']:
  57. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  58. else:
  59. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  60. [PromptMessage(content=self.llm_message.completion)])
  61. else:
  62. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  63. [PromptMessage(content=self.llm_message.completion)])
  64. self.conversation_message_task.save_message(self.llm_message)
  65. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  66. try:
  67. self.conversation_message_task.append_message_text(token)
  68. except ConversationTaskStoppedException as ex:
  69. self.on_llm_error(error=ex)
  70. raise ex
  71. self.llm_message.completion += token
  72. def on_llm_error(
  73. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  74. ) -> None:
  75. """Do nothing."""
  76. if isinstance(error, ConversationTaskStoppedException):
  77. if self.conversation_message_task.streaming:
  78. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  79. [PromptMessage(content=self.llm_message.completion)]
  80. )
  81. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  82. else:
  83. logging.debug("on_llm_error: %s", error)