entities.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from collections.abc import Sequence
  2. from typing import Any, Optional
  3. from pydantic import BaseModel, Field, field_validator
  4. from core.model_runtime.entities import ImagePromptMessageContent
  5. from core.model_runtime.entities.llm_entities import LLMMode
  6. from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
  7. from core.workflow.entities.variable_entities import VariableSelector
  8. from core.workflow.nodes.base import BaseNodeData
  9. class ModelConfig(BaseModel):
  10. provider: str
  11. name: str
  12. mode: LLMMode = LLMMode.COMPLETION
  13. completion_params: dict[str, Any] = {}
  14. class ContextConfig(BaseModel):
  15. enabled: bool
  16. variable_selector: Optional[list[str]] = None
  17. class VisionConfigOptions(BaseModel):
  18. variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
  19. detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
  20. class VisionConfig(BaseModel):
  21. enabled: bool = False
  22. configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
  23. @field_validator("configs", mode="before")
  24. @classmethod
  25. def convert_none_configs(cls, v: Any):
  26. if v is None:
  27. return VisionConfigOptions()
  28. return v
  29. class PromptConfig(BaseModel):
  30. jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
  31. @field_validator("jinja2_variables", mode="before")
  32. @classmethod
  33. def convert_none_jinja2_variables(cls, v: Any):
  34. if v is None:
  35. return []
  36. return v
  37. class LLMNodeChatModelMessage(ChatModelMessage):
  38. text: str = ""
  39. jinja2_text: Optional[str] = None
  40. class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
  41. jinja2_text: Optional[str] = None
  42. class LLMNodeData(BaseNodeData):
  43. model: ModelConfig
  44. prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
  45. prompt_config: PromptConfig = Field(default_factory=PromptConfig)
  46. memory: Optional[MemoryConfig] = None
  47. context: ContextConfig
  48. vision: VisionConfig = Field(default_factory=VisionConfig)
  49. @field_validator("prompt_config", mode="before")
  50. @classmethod
  51. def convert_none_prompt_config(cls, v: Any):
  52. if v is None:
  53. return PromptConfig()
  54. return v