entities.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from typing import Any, Literal, Union
  2. from pydantic import BaseModel, ValidationInfo, field_validator
  3. from core.tools.entities.tool_entities import ToolSelector
  4. from core.workflow.nodes.base.entities import BaseNodeData
  5. class AgentEntity(BaseModel):
  6. agent_strategy_provider_name: str # redundancy
  7. agent_strategy_name: str
  8. agent_strategy_label: str # redundancy
  9. agent_configurations: dict[str, Any]
  10. plugin_unique_identifier: str
  11. @field_validator("agent_configurations", mode="before")
  12. @classmethod
  13. def validate_agent_configurations(cls, value, values: ValidationInfo):
  14. if not isinstance(value, dict):
  15. raise ValueError("agent_configurations must be a dictionary")
  16. for key in values.data.get("agent_configurations", {}):
  17. value = values.data.get("agent_configurations", {}).get(key)
  18. if isinstance(value, dict):
  19. # convert dict to ToolSelector
  20. return ToolSelector(**value)
  21. elif isinstance(value, ToolSelector):
  22. return value
  23. elif isinstance(value, list):
  24. # convert list[ToolSelector] to ToolSelector
  25. if all(isinstance(val, dict) for val in value):
  26. return [ToolSelector(**val) for val in value]
  27. elif all(isinstance(val, ToolSelector) for val in value):
  28. return value
  29. else:
  30. raise ValueError("value must be a list of ToolSelector")
  31. else:
  32. raise ValueError("value must be a dictionary or ToolSelector")
  33. return value
  34. class AgentNodeData(BaseNodeData, AgentEntity):
  35. class AgentInput(BaseModel):
  36. # TODO: check this type
  37. value: Union[list[str], list[ToolSelector], Any]
  38. type: Literal["mixed", "variable", "constant"]
  39. @field_validator("type", mode="before")
  40. @classmethod
  41. def check_type(cls, value, validation_info: ValidationInfo):
  42. typ = value
  43. value = validation_info.data.get("value")
  44. if typ == "mixed" and not isinstance(value, str):
  45. raise ValueError("value must be a string")
  46. elif typ == "variable":
  47. if not isinstance(value, list):
  48. raise ValueError("value must be a list")
  49. for val in value:
  50. if not isinstance(val, str):
  51. raise ValueError("value must be a list of strings")
  52. elif typ == "constant":
  53. if isinstance(value, list):
  54. # convert dict to ToolSelector
  55. if all(isinstance(val, dict) for val in value) or all(
  56. isinstance(val, ToolSelector) for val in value
  57. ):
  58. return value
  59. else:
  60. raise ValueError("value must be a list of ToolSelector")
  61. elif isinstance(value, dict):
  62. # convert dict to ToolSelector
  63. return ToolSelector(**value)
  64. elif isinstance(value, ToolSelector):
  65. return value
  66. else:
  67. raise ValueError("value must be a list of ToolSelector")
  68. return typ
  69. agent_parameters: dict[str, AgentInput]