message.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import enum
  2. from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
  3. from pydantic import BaseModel
  4. class LLMRunResult(BaseModel):
  5. content: str
  6. prompt_tokens: int
  7. completion_tokens: int
  8. class MessageType(enum.Enum):
  9. HUMAN = 'human'
  10. ASSISTANT = 'assistant'
  11. SYSTEM = 'system'
  12. class PromptMessage(BaseModel):
  13. type: MessageType = MessageType.HUMAN
  14. content: str = ''
  15. def to_lc_messages(messages: list[PromptMessage]):
  16. lc_messages = []
  17. for message in messages:
  18. if message.type == MessageType.HUMAN:
  19. lc_messages.append(HumanMessage(content=message.content))
  20. elif message.type == MessageType.ASSISTANT:
  21. lc_messages.append(AIMessage(content=message.content))
  22. elif message.type == MessageType.SYSTEM:
  23. lc_messages.append(SystemMessage(content=message.content))
  24. return lc_messages
  25. def to_prompt_messages(messages: list[BaseMessage]):
  26. prompt_messages = []
  27. for message in messages:
  28. if isinstance(message, HumanMessage):
  29. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
  30. elif isinstance(message, AIMessage):
  31. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
  32. elif isinstance(message, SystemMessage):
  33. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
  34. return prompt_messages
  35. def str_to_prompt_messages(texts: list[str]):
  36. prompt_messages = []
  37. for text in texts:
  38. prompt_messages.append(PromptMessage(content=text))
  39. return prompt_messages