entities.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from collections.abc import Sequence
  2. from typing import Any, Optional
  3. from pydantic import BaseModel, Field, field_validator
  4. from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
  5. from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
  6. from core.workflow.entities.variable_entities import VariableSelector
  7. from core.workflow.nodes.base import BaseNodeData
  8. class ModelConfig(BaseModel):
  9. provider: str
  10. name: str
  11. mode: LLMMode
  12. completion_params: dict[str, Any] = {}
  13. class ContextConfig(BaseModel):
  14. enabled: bool
  15. variable_selector: Optional[list[str]] = None
  16. class VisionConfigOptions(BaseModel):
  17. variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
  18. detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
  19. class VisionConfig(BaseModel):
  20. enabled: bool = False
  21. configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
  22. @field_validator("configs", mode="before")
  23. @classmethod
  24. def convert_none_configs(cls, v: Any):
  25. if v is None:
  26. return VisionConfigOptions()
  27. return v
  28. class PromptConfig(BaseModel):
  29. jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
  30. @field_validator("jinja2_variables", mode="before")
  31. @classmethod
  32. def convert_none_jinja2_variables(cls, v: Any):
  33. if v is None:
  34. return []
  35. return v
  36. class LLMNodeChatModelMessage(ChatModelMessage):
  37. text: str = ""
  38. jinja2_text: Optional[str] = None
  39. class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
  40. jinja2_text: Optional[str] = None
  41. class LLMNodeData(BaseNodeData):
  42. model: ModelConfig
  43. prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
  44. prompt_config: PromptConfig = Field(default_factory=PromptConfig)
  45. memory: Optional[MemoryConfig] = None
  46. context: ContextConfig
  47. vision: VisionConfig = Field(default_factory=VisionConfig)
  48. @field_validator("prompt_config", mode="before")
  49. @classmethod
  50. def convert_none_prompt_config(cls, v: Any):
  51. if v is None:
  52. return PromptConfig()
  53. return v