base_callback.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from abc import ABC, abstractmethod
  2. from typing import Optional
  3. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  4. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  5. from core.model_runtime.model_providers.__base.ai_model import AIModel
  6. _TEXT_COLOR_MAPPING = {
  7. "blue": "36;1",
  8. "yellow": "33;1",
  9. "pink": "38;5;200",
  10. "green": "32;1",
  11. "red": "31;1",
  12. }
  13. class Callback(ABC):
  14. """
  15. Base class for callbacks.
  16. Only for LLM.
  17. """
  18. raise_error: bool = False
  19. @abstractmethod
  20. def on_before_invoke(
  21. self,
  22. llm_instance: AIModel,
  23. model: str,
  24. credentials: dict,
  25. prompt_messages: list[PromptMessage],
  26. model_parameters: dict,
  27. tools: Optional[list[PromptMessageTool]] = None,
  28. stop: Optional[list[str]] = None,
  29. stream: bool = True,
  30. user: Optional[str] = None,
  31. ) -> None:
  32. """
  33. Before invoke callback
  34. :param llm_instance: LLM instance
  35. :param model: model name
  36. :param credentials: model credentials
  37. :param prompt_messages: prompt messages
  38. :param model_parameters: model parameters
  39. :param tools: tools for tool calling
  40. :param stop: stop words
  41. :param stream: is stream response
  42. :param user: unique user id
  43. """
  44. raise NotImplementedError()
  45. @abstractmethod
  46. def on_new_chunk(
  47. self,
  48. llm_instance: AIModel,
  49. chunk: LLMResultChunk,
  50. model: str,
  51. credentials: dict,
  52. prompt_messages: list[PromptMessage],
  53. model_parameters: dict,
  54. tools: Optional[list[PromptMessageTool]] = None,
  55. stop: Optional[list[str]] = None,
  56. stream: bool = True,
  57. user: Optional[str] = None,
  58. ):
  59. """
  60. On new chunk callback
  61. :param llm_instance: LLM instance
  62. :param chunk: chunk
  63. :param model: model name
  64. :param credentials: model credentials
  65. :param prompt_messages: prompt messages
  66. :param model_parameters: model parameters
  67. :param tools: tools for tool calling
  68. :param stop: stop words
  69. :param stream: is stream response
  70. :param user: unique user id
  71. """
  72. raise NotImplementedError()
  73. @abstractmethod
  74. def on_after_invoke(
  75. self,
  76. llm_instance: AIModel,
  77. result: LLMResult,
  78. model: str,
  79. credentials: dict,
  80. prompt_messages: list[PromptMessage],
  81. model_parameters: dict,
  82. tools: Optional[list[PromptMessageTool]] = None,
  83. stop: Optional[list[str]] = None,
  84. stream: bool = True,
  85. user: Optional[str] = None,
  86. ) -> None:
  87. """
  88. After invoke callback
  89. :param llm_instance: LLM instance
  90. :param result: result
  91. :param model: model name
  92. :param credentials: model credentials
  93. :param prompt_messages: prompt messages
  94. :param model_parameters: model parameters
  95. :param tools: tools for tool calling
  96. :param stop: stop words
  97. :param stream: is stream response
  98. :param user: unique user id
  99. """
  100. raise NotImplementedError()
  101. @abstractmethod
  102. def on_invoke_error(
  103. self,
  104. llm_instance: AIModel,
  105. ex: Exception,
  106. model: str,
  107. credentials: dict,
  108. prompt_messages: list[PromptMessage],
  109. model_parameters: dict,
  110. tools: Optional[list[PromptMessageTool]] = None,
  111. stop: Optional[list[str]] = None,
  112. stream: bool = True,
  113. user: Optional[str] = None,
  114. ) -> None:
  115. """
  116. Invoke error callback
  117. :param llm_instance: LLM instance
  118. :param ex: exception
  119. :param model: model name
  120. :param credentials: model credentials
  121. :param prompt_messages: prompt messages
  122. :param model_parameters: model parameters
  123. :param tools: tools for tool calling
  124. :param stop: stop words
  125. :param stream: is stream response
  126. :param user: unique user id
  127. """
  128. raise NotImplementedError()
  129. def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
  130. """Print text with highlighting and no end characters."""
  131. text_to_print = self._get_colored_text(text, color) if color else text
  132. print(text_to_print, end=end)
  133. def _get_colored_text(self, text: str, color: str) -> str:
  134. """Get colored text."""
  135. color_str = _TEXT_COLOR_MAPPING[color]
  136. return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"