llm_chain.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from typing import Any, Optional
  2. from langchain import LLMChain as LCLLMChain
  3. from langchain.callbacks.manager import CallbackManagerForChainRun
  4. from langchain.schema import Generation, LLMResult
  5. from langchain.schema.language_model import BaseLanguageModel
  6. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  7. from core.entities.message_entities import lc_messages_to_prompt_messages
  8. from core.model_manager import ModelInstance
  9. from core.rag.retrieval.agent.fake_llm import FakeLLM
  10. class LLMChain(LCLLMChain):
  11. model_config: ModelConfigWithCredentialsEntity
  12. """The language model instance to use."""
  13. llm: BaseLanguageModel = FakeLLM(response="")
  14. parameters: dict[str, Any] = {}
  15. def generate(
  16. self,
  17. input_list: list[dict[str, Any]],
  18. run_manager: Optional[CallbackManagerForChainRun] = None,
  19. ) -> LLMResult:
  20. """Generate LLM result from inputs."""
  21. prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
  22. messages = prompts[0].to_messages()
  23. prompt_messages = lc_messages_to_prompt_messages(messages)
  24. model_instance = ModelInstance(
  25. provider_model_bundle=self.model_config.provider_model_bundle,
  26. model=self.model_config.model,
  27. )
  28. result = model_instance.invoke_llm(
  29. prompt_messages=prompt_messages,
  30. stream=False,
  31. stop=stop,
  32. model_parameters=self.parameters
  33. )
  34. generations = [
  35. [Generation(text=result.message.content)]
  36. ]
  37. return LLMResult(generations=generations)