tool_entities.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import base64
  2. import enum
  3. from enum import Enum
  4. from typing import Any, Mapping, Optional, Union
  5. from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
  6. from core.entities.provider_entities import ProviderConfig
  7. from core.plugin.entities.parameters import (
  8. PluginParameter,
  9. PluginParameterOption,
  10. PluginParameterType,
  11. as_normal_type,
  12. cast_parameter_value,
  13. init_frontend_parameter,
  14. )
  15. from core.tools.entities.common_entities import I18nObject
  16. from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
  17. class ToolLabelEnum(Enum):
  18. SEARCH = "search"
  19. IMAGE = "image"
  20. VIDEOS = "videos"
  21. WEATHER = "weather"
  22. FINANCE = "finance"
  23. DESIGN = "design"
  24. TRAVEL = "travel"
  25. SOCIAL = "social"
  26. NEWS = "news"
  27. MEDICAL = "medical"
  28. PRODUCTIVITY = "productivity"
  29. EDUCATION = "education"
  30. BUSINESS = "business"
  31. ENTERTAINMENT = "entertainment"
  32. UTILITIES = "utilities"
  33. OTHER = "other"
  34. class ToolProviderType(enum.StrEnum):
  35. """
  36. Enum class for tool provider
  37. """
  38. PLUGIN = "plugin"
  39. BUILT_IN = "builtin"
  40. WORKFLOW = "workflow"
  41. API = "api"
  42. APP = "app"
  43. DATASET_RETRIEVAL = "dataset-retrieval"
  44. @classmethod
  45. def value_of(cls, value: str) -> "ToolProviderType":
  46. """
  47. Get value of given mode.
  48. :param value: mode value
  49. :return: mode
  50. """
  51. for mode in cls:
  52. if mode.value == value:
  53. return mode
  54. raise ValueError(f"invalid mode value {value}")
  55. class ApiProviderSchemaType(Enum):
  56. """
  57. Enum class for api provider schema type.
  58. """
  59. OPENAPI = "openapi"
  60. SWAGGER = "swagger"
  61. OPENAI_PLUGIN = "openai_plugin"
  62. OPENAI_ACTIONS = "openai_actions"
  63. @classmethod
  64. def value_of(cls, value: str) -> "ApiProviderSchemaType":
  65. """
  66. Get value of given mode.
  67. :param value: mode value
  68. :return: mode
  69. """
  70. for mode in cls:
  71. if mode.value == value:
  72. return mode
  73. raise ValueError(f"invalid mode value {value}")
  74. class ApiProviderAuthType(Enum):
  75. """
  76. Enum class for api provider auth type.
  77. """
  78. NONE = "none"
  79. API_KEY = "api_key"
  80. @classmethod
  81. def value_of(cls, value: str) -> "ApiProviderAuthType":
  82. """
  83. Get value of given mode.
  84. :param value: mode value
  85. :return: mode
  86. """
  87. for mode in cls:
  88. if mode.value == value:
  89. return mode
  90. raise ValueError(f"invalid mode value {value}")
  91. class ToolInvokeMessage(BaseModel):
  92. class TextMessage(BaseModel):
  93. text: str
  94. class JsonMessage(BaseModel):
  95. json_object: dict
  96. class BlobMessage(BaseModel):
  97. blob: bytes
  98. class FileMessage(BaseModel):
  99. pass
  100. class VariableMessage(BaseModel):
  101. variable_name: str = Field(..., description="The name of the variable")
  102. variable_value: str = Field(..., description="The value of the variable")
  103. stream: bool = Field(default=False, description="Whether the variable is streamed")
  104. @field_validator("variable_value", mode="before")
  105. @classmethod
  106. def transform_variable_value(cls, value, values) -> Any:
  107. """
  108. Only basic types and lists are allowed.
  109. """
  110. if not isinstance(value, dict | list | str | int | float | bool):
  111. raise ValueError("Only basic types and lists are allowed.")
  112. # if stream is true, the value must be a string
  113. if values.get("stream"):
  114. if not isinstance(value, str):
  115. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  116. return value
  117. @field_validator("variable_name", mode="before")
  118. @classmethod
  119. def transform_variable_name(cls, value) -> str:
  120. """
  121. The variable name must be a string.
  122. """
  123. if value in {"json", "text", "files"}:
  124. raise ValueError(f"The variable name '{value}' is reserved.")
  125. return value
  126. class LogMessage(BaseModel):
  127. class LogStatus(Enum):
  128. START = "start"
  129. ERROR = "error"
  130. SUCCESS = "success"
  131. id: str
  132. label: str = Field(..., description="The label of the log")
  133. parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
  134. error: Optional[str] = Field(default=None, description="The error message")
  135. status: LogStatus = Field(..., description="The status of the log")
  136. data: Mapping[str, Any] = Field(..., description="Detailed log data")
  137. class MessageType(Enum):
  138. TEXT = "text"
  139. IMAGE = "image"
  140. LINK = "link"
  141. BLOB = "blob"
  142. JSON = "json"
  143. IMAGE_LINK = "image_link"
  144. BINARY_LINK = "binary_link"
  145. VARIABLE = "variable"
  146. FILE = "file"
  147. LOG = "log"
  148. type: MessageType = MessageType.TEXT
  149. """
  150. plain text, image url or link url
  151. """
  152. message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
  153. meta: dict[str, Any] | None = None
  154. @field_validator("message", mode="before")
  155. @classmethod
  156. def decode_blob_message(cls, v):
  157. if isinstance(v, dict) and "blob" in v:
  158. try:
  159. v["blob"] = base64.b64decode(v["blob"])
  160. except Exception:
  161. pass
  162. return v
  163. @field_serializer("message")
  164. def serialize_message(self, v):
  165. if isinstance(v, self.BlobMessage):
  166. return {"blob": base64.b64encode(v.blob).decode("utf-8")}
  167. return v
  168. class ToolInvokeMessageBinary(BaseModel):
  169. mimetype: str = Field(..., description="The mimetype of the binary")
  170. url: str = Field(..., description="The url of the binary")
  171. file_var: Optional[dict[str, Any]] = None
  172. class ToolParameter(PluginParameter):
  173. """
  174. Overrides type
  175. """
  176. class ToolParameterType(enum.StrEnum):
  177. """
  178. removes TOOLS_SELECTOR from PluginParameterType
  179. """
  180. STRING = PluginParameterType.STRING.value
  181. NUMBER = PluginParameterType.NUMBER.value
  182. BOOLEAN = PluginParameterType.BOOLEAN.value
  183. SELECT = PluginParameterType.SELECT.value
  184. SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
  185. FILE = PluginParameterType.FILE.value
  186. FILES = PluginParameterType.FILES.value
  187. APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
  188. MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
  189. # deprecated, should not use.
  190. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
  191. def as_normal_type(self):
  192. return as_normal_type(self)
  193. def cast_value(self, value: Any):
  194. return cast_parameter_value(self, value)
  195. class ToolParameterForm(Enum):
  196. SCHEMA = "schema" # should be set while adding tool
  197. FORM = "form" # should be set before invoking tool
  198. LLM = "llm" # will be set by LLM
  199. type: ToolParameterType = Field(..., description="The type of the parameter")
  200. human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
  201. form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
  202. llm_description: Optional[str] = None
  203. @classmethod
  204. def get_simple_instance(
  205. cls,
  206. name: str,
  207. llm_description: str,
  208. typ: ToolParameterType,
  209. required: bool,
  210. options: Optional[list[str]] = None,
  211. ) -> "ToolParameter":
  212. """
  213. get a simple tool parameter
  214. :param name: the name of the parameter
  215. :param llm_description: the description presented to the LLM
  216. :param type: the type of the parameter
  217. :param required: if the parameter is required
  218. :param options: the options of the parameter
  219. """
  220. # convert options to ToolParameterOption
  221. if options:
  222. option_objs = [
  223. PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  224. for option in options
  225. ]
  226. else:
  227. option_objs = []
  228. return cls(
  229. name=name,
  230. label=I18nObject(en_US="", zh_Hans=""),
  231. placeholder=None,
  232. human_description=I18nObject(en_US="", zh_Hans=""),
  233. type=typ,
  234. form=cls.ToolParameterForm.LLM,
  235. llm_description=llm_description,
  236. required=required,
  237. options=option_objs,
  238. )
  239. def init_frontend_parameter(self, value: Any):
  240. return init_frontend_parameter(self, self.type, value)
  241. class ToolProviderIdentity(BaseModel):
  242. author: str = Field(..., description="The author of the tool")
  243. name: str = Field(..., description="The name of the tool")
  244. description: I18nObject = Field(..., description="The description of the tool")
  245. icon: str = Field(..., description="The icon of the tool")
  246. label: I18nObject = Field(..., description="The label of the tool")
  247. tags: Optional[list[ToolLabelEnum]] = Field(
  248. default=[],
  249. description="The tags of the tool",
  250. )
  251. class ToolIdentity(BaseModel):
  252. author: str = Field(..., description="The author of the tool")
  253. name: str = Field(..., description="The name of the tool")
  254. label: I18nObject = Field(..., description="The label of the tool")
  255. provider: str = Field(..., description="The provider of the tool")
  256. icon: Optional[str] = None
  257. class ToolDescription(BaseModel):
  258. human: I18nObject = Field(..., description="The description presented to the user")
  259. llm: str = Field(..., description="The description presented to the LLM")
  260. class ToolEntity(BaseModel):
  261. identity: ToolIdentity
  262. parameters: list[ToolParameter] = Field(default_factory=list)
  263. description: Optional[ToolDescription] = None
  264. output_schema: Optional[dict] = None
  265. has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
  266. # pydantic configs
  267. model_config = ConfigDict(protected_namespaces=())
  268. @field_validator("parameters", mode="before")
  269. @classmethod
  270. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
  271. return v or []
  272. class ToolProviderEntity(BaseModel):
  273. identity: ToolProviderIdentity
  274. plugin_id: Optional[str] = None
  275. credentials_schema: list[ProviderConfig] = Field(default_factory=list)
  276. class ToolProviderEntityWithPlugin(ToolProviderEntity):
  277. tools: list[ToolEntity] = Field(default_factory=list)
  278. class WorkflowToolParameterConfiguration(BaseModel):
  279. """
  280. Workflow tool configuration
  281. """
  282. name: str = Field(..., description="The name of the parameter")
  283. description: str = Field(..., description="The description of the parameter")
  284. form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
  285. class ToolInvokeMeta(BaseModel):
  286. """
  287. Tool invoke meta
  288. """
  289. time_cost: float = Field(..., description="The time cost of the tool invoke")
  290. error: Optional[str] = None
  291. tool_config: Optional[dict] = None
  292. @classmethod
  293. def empty(cls) -> "ToolInvokeMeta":
  294. """
  295. Get an empty instance of ToolInvokeMeta
  296. """
  297. return cls(time_cost=0.0, error=None, tool_config={})
  298. @classmethod
  299. def error_instance(cls, error: str) -> "ToolInvokeMeta":
  300. """
  301. Get an instance of ToolInvokeMeta with error
  302. """
  303. return cls(time_cost=0.0, error=error, tool_config={})
  304. def to_dict(self) -> dict:
  305. return {
  306. "time_cost": self.time_cost,
  307. "error": self.error,
  308. "tool_config": self.tool_config,
  309. }
  310. class ToolLabel(BaseModel):
  311. """
  312. Tool label
  313. """
  314. name: str = Field(..., description="The name of the tool")
  315. label: I18nObject = Field(..., description="The label of the tool")
  316. icon: str = Field(..., description="The icon of the tool")
  317. class ToolInvokeFrom(Enum):
  318. """
  319. Enum class for tool invoke
  320. """
  321. WORKFLOW = "workflow"
  322. AGENT = "agent"
  323. PLUGIN = "plugin"
  324. class ToolSelector(BaseModel):
  325. dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
  326. class Parameter(BaseModel):
  327. name: str = Field(..., description="The name of the parameter")
  328. type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
  329. required: bool = Field(..., description="Whether the parameter is required")
  330. description: str = Field(..., description="The description of the parameter")
  331. default: Optional[Union[int, float, str]] = None
  332. options: Optional[list[PluginParameterOption]] = None
  333. provider_id: str = Field(..., description="The id of the provider")
  334. tool_name: str = Field(..., description="The name of the tool")
  335. tool_description: str = Field(..., description="The description of the tool")
  336. tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
  337. tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
  338. def to_plugin_parameter(self) -> dict[str, Any]:
  339. return self.model_dump()