base_callback.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from typing import Optional
  2. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  3. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  4. from core.model_runtime.model_providers.__base.ai_model import AIModel
  5. _TEXT_COLOR_MAPPING = {
  6. "blue": "36;1",
  7. "yellow": "33;1",
  8. "pink": "38;5;200",
  9. "green": "32;1",
  10. "red": "31;1",
  11. }
  12. class Callback:
  13. """
  14. Base class for callbacks.
  15. Only for LLM.
  16. """
  17. raise_error: bool = False
  18. def on_before_invoke(
  19. self,
  20. llm_instance: AIModel,
  21. model: str,
  22. credentials: dict,
  23. prompt_messages: list[PromptMessage],
  24. model_parameters: dict,
  25. tools: Optional[list[PromptMessageTool]] = None,
  26. stop: Optional[list[str]] = None,
  27. stream: bool = True,
  28. user: Optional[str] = None,
  29. ) -> None:
  30. """
  31. Before invoke callback
  32. :param llm_instance: LLM instance
  33. :param model: model name
  34. :param credentials: model credentials
  35. :param prompt_messages: prompt messages
  36. :param model_parameters: model parameters
  37. :param tools: tools for tool calling
  38. :param stop: stop words
  39. :param stream: is stream response
  40. :param user: unique user id
  41. """
  42. raise NotImplementedError()
  43. def on_new_chunk(
  44. self,
  45. llm_instance: AIModel,
  46. chunk: LLMResultChunk,
  47. model: str,
  48. credentials: dict,
  49. prompt_messages: list[PromptMessage],
  50. model_parameters: dict,
  51. tools: Optional[list[PromptMessageTool]] = None,
  52. stop: Optional[list[str]] = None,
  53. stream: bool = True,
  54. user: Optional[str] = None,
  55. ):
  56. """
  57. On new chunk callback
  58. :param llm_instance: LLM instance
  59. :param chunk: chunk
  60. :param model: model name
  61. :param credentials: model credentials
  62. :param prompt_messages: prompt messages
  63. :param model_parameters: model parameters
  64. :param tools: tools for tool calling
  65. :param stop: stop words
  66. :param stream: is stream response
  67. :param user: unique user id
  68. """
  69. raise NotImplementedError()
  70. def on_after_invoke(
  71. self,
  72. llm_instance: AIModel,
  73. result: LLMResult,
  74. model: str,
  75. credentials: dict,
  76. prompt_messages: list[PromptMessage],
  77. model_parameters: dict,
  78. tools: Optional[list[PromptMessageTool]] = None,
  79. stop: Optional[list[str]] = None,
  80. stream: bool = True,
  81. user: Optional[str] = None,
  82. ) -> None:
  83. """
  84. After invoke callback
  85. :param llm_instance: LLM instance
  86. :param result: result
  87. :param model: model name
  88. :param credentials: model credentials
  89. :param prompt_messages: prompt messages
  90. :param model_parameters: model parameters
  91. :param tools: tools for tool calling
  92. :param stop: stop words
  93. :param stream: is stream response
  94. :param user: unique user id
  95. """
  96. raise NotImplementedError()
  97. def on_invoke_error(
  98. self,
  99. llm_instance: AIModel,
  100. ex: Exception,
  101. model: str,
  102. credentials: dict,
  103. prompt_messages: list[PromptMessage],
  104. model_parameters: dict,
  105. tools: Optional[list[PromptMessageTool]] = None,
  106. stop: Optional[list[str]] = None,
  107. stream: bool = True,
  108. user: Optional[str] = None,
  109. ) -> None:
  110. """
  111. Invoke error callback
  112. :param llm_instance: LLM instance
  113. :param ex: exception
  114. :param model: model name
  115. :param credentials: model credentials
  116. :param prompt_messages: prompt messages
  117. :param model_parameters: model parameters
  118. :param tools: tools for tool calling
  119. :param stop: stop words
  120. :param stream: is stream response
  121. :param user: unique user id
  122. """
  123. raise NotImplementedError()
  124. def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
  125. """Print text with highlighting and no end characters."""
  126. text_to_print = self._get_colored_text(text, color) if color else text
  127. print(text_to_print, end=end)
  128. def _get_colored_text(self, text: str, color: str) -> str:
  129. """Get colored text."""
  130. color_str = _TEXT_COLOR_MAPPING[color]
  131. return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"