base_callback.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Sequence
  3. from typing import Optional
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  5. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  6. from core.model_runtime.model_providers.__base.ai_model import AIModel
  7. _TEXT_COLOR_MAPPING = {
  8. "blue": "36;1",
  9. "yellow": "33;1",
  10. "pink": "38;5;200",
  11. "green": "32;1",
  12. "red": "31;1",
  13. }
  14. class Callback(ABC):
  15. """
  16. Base class for callbacks.
  17. Only for LLM.
  18. """
  19. raise_error: bool = False
  20. @abstractmethod
  21. def on_before_invoke(
  22. self,
  23. llm_instance: AIModel,
  24. model: str,
  25. credentials: dict,
  26. prompt_messages: list[PromptMessage],
  27. model_parameters: dict,
  28. tools: Optional[list[PromptMessageTool]] = None,
  29. stop: Optional[Sequence[str]] = None,
  30. stream: bool = True,
  31. user: Optional[str] = None,
  32. ) -> None:
  33. """
  34. Before invoke callback
  35. :param llm_instance: LLM instance
  36. :param model: model name
  37. :param credentials: model credentials
  38. :param prompt_messages: prompt messages
  39. :param model_parameters: model parameters
  40. :param tools: tools for tool calling
  41. :param stop: stop words
  42. :param stream: is stream response
  43. :param user: unique user id
  44. """
  45. raise NotImplementedError()
  46. @abstractmethod
  47. def on_new_chunk(
  48. self,
  49. llm_instance: AIModel,
  50. chunk: LLMResultChunk,
  51. model: str,
  52. credentials: dict,
  53. prompt_messages: list[PromptMessage],
  54. model_parameters: dict,
  55. tools: Optional[list[PromptMessageTool]] = None,
  56. stop: Optional[Sequence[str]] = None,
  57. stream: bool = True,
  58. user: Optional[str] = None,
  59. ):
  60. """
  61. On new chunk callback
  62. :param llm_instance: LLM instance
  63. :param chunk: chunk
  64. :param model: model name
  65. :param credentials: model credentials
  66. :param prompt_messages: prompt messages
  67. :param model_parameters: model parameters
  68. :param tools: tools for tool calling
  69. :param stop: stop words
  70. :param stream: is stream response
  71. :param user: unique user id
  72. """
  73. raise NotImplementedError()
  74. @abstractmethod
  75. def on_after_invoke(
  76. self,
  77. llm_instance: AIModel,
  78. result: LLMResult,
  79. model: str,
  80. credentials: dict,
  81. prompt_messages: list[PromptMessage],
  82. model_parameters: dict,
  83. tools: Optional[list[PromptMessageTool]] = None,
  84. stop: Optional[Sequence[str]] = None,
  85. stream: bool = True,
  86. user: Optional[str] = None,
  87. ) -> None:
  88. """
  89. After invoke callback
  90. :param llm_instance: LLM instance
  91. :param result: result
  92. :param model: model name
  93. :param credentials: model credentials
  94. :param prompt_messages: prompt messages
  95. :param model_parameters: model parameters
  96. :param tools: tools for tool calling
  97. :param stop: stop words
  98. :param stream: is stream response
  99. :param user: unique user id
  100. """
  101. raise NotImplementedError()
  102. @abstractmethod
  103. def on_invoke_error(
  104. self,
  105. llm_instance: AIModel,
  106. ex: Exception,
  107. model: str,
  108. credentials: dict,
  109. prompt_messages: list[PromptMessage],
  110. model_parameters: dict,
  111. tools: Optional[list[PromptMessageTool]] = None,
  112. stop: Optional[Sequence[str]] = None,
  113. stream: bool = True,
  114. user: Optional[str] = None,
  115. ) -> None:
  116. """
  117. Invoke error callback
  118. :param llm_instance: LLM instance
  119. :param ex: exception
  120. :param model: model name
  121. :param credentials: model credentials
  122. :param prompt_messages: prompt messages
  123. :param model_parameters: model parameters
  124. :param tools: tools for tool calling
  125. :param stop: stop words
  126. :param stream: is stream response
  127. :param user: unique user id
  128. """
  129. raise NotImplementedError()
  130. def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
  131. """Print text with highlighting and no end characters."""
  132. text_to_print = self._get_colored_text(text, color) if color else text
  133. print(text_to_print, end=end)
  134. def _get_colored_text(self, text: str, color: str) -> str:
  135. """Get colored text."""
  136. color_str = _TEXT_COLOR_MAPPING[color]
  137. return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"