llm_chain.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from typing import List, Dict, Any, Optional
  2. from langchain import LLMChain as LCLLMChain
  3. from langchain.callbacks.manager import CallbackManagerForChainRun
  4. from langchain.schema import LLMResult, Generation
  5. from langchain.schema.language_model import BaseLanguageModel
  6. from core.model_providers.models.entity.message import to_prompt_messages
  7. from core.model_providers.models.llm.base import BaseLLM
  8. from core.third_party.langchain.llms.fake import FakeLLM
  9. class LLMChain(LCLLMChain):
  10. model_instance: BaseLLM
  11. """The language model instance to use."""
  12. llm: BaseLanguageModel = FakeLLM(response="")
  13. def generate(
  14. self,
  15. input_list: List[Dict[str, Any]],
  16. run_manager: Optional[CallbackManagerForChainRun] = None,
  17. ) -> LLMResult:
  18. """Generate LLM result from inputs."""
  19. prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
  20. messages = prompts[0].to_messages()
  21. prompt_messages = to_prompt_messages(messages)
  22. result = self.model_instance.run(
  23. messages=prompt_messages,
  24. stop=stop
  25. )
  26. generations = [
  27. [Generation(text=result.content)]
  28. ]
  29. return LLMResult(generations=generations)