message.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import enum
  2. from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
  3. from pydantic import BaseModel
  4. class LLMRunResult(BaseModel):
  5. content: str
  6. prompt_tokens: int
  7. completion_tokens: int
  8. source: list = None
  9. function_call: dict = None
  10. class MessageType(enum.Enum):
  11. USER = 'user'
  12. ASSISTANT = 'assistant'
  13. SYSTEM = 'system'
  14. class PromptMessage(BaseModel):
  15. type: MessageType = MessageType.USER
  16. content: str = ''
  17. function_call: dict = None
  18. def to_lc_messages(messages: list[PromptMessage]):
  19. lc_messages = []
  20. for message in messages:
  21. if message.type == MessageType.USER:
  22. lc_messages.append(HumanMessage(content=message.content))
  23. elif message.type == MessageType.ASSISTANT:
  24. additional_kwargs = {}
  25. if message.function_call:
  26. additional_kwargs['function_call'] = message.function_call
  27. lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
  28. elif message.type == MessageType.SYSTEM:
  29. lc_messages.append(SystemMessage(content=message.content))
  30. return lc_messages
  31. def to_prompt_messages(messages: list[BaseMessage]):
  32. prompt_messages = []
  33. for message in messages:
  34. if isinstance(message, HumanMessage):
  35. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
  36. elif isinstance(message, AIMessage):
  37. message_kwargs = {
  38. 'content': message.content,
  39. 'type': MessageType.ASSISTANT
  40. }
  41. if 'function_call' in message.additional_kwargs:
  42. message_kwargs['function_call'] = message.additional_kwargs['function_call']
  43. prompt_messages.append(PromptMessage(**message_kwargs))
  44. elif isinstance(message, SystemMessage):
  45. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
  46. elif isinstance(message, FunctionMessage):
  47. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
  48. return prompt_messages
  49. def str_to_prompt_messages(texts: list[str]):
  50. prompt_messages = []
  51. for text in texts:
  52. prompt_messages.append(PromptMessage(content=text))
  53. return prompt_messages