tool_provider.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from abc import ABC, abstractmethod
  2. from typing import List, Dict, Any, Optional
  3. from pydantic import BaseModel
  4. from core.tools.entities.tool_entities import ToolProviderType, \
  5. ToolProviderIdentity, ToolParamter, ToolProviderCredentials
  6. from core.tools.tool.tool import Tool
  7. from core.tools.entities.user_entities import UserToolProviderCredentials
  8. from core.tools.errors import ToolNotFoundError, \
  9. ToolParamterValidationError, ToolProviderCredentialValidationError
  10. class ToolProviderController(BaseModel, ABC):
  11. identity: Optional[ToolProviderIdentity] = None
  12. tools: Optional[List[Tool]] = None
  13. credentials_schema: Optional[Dict[str, ToolProviderCredentials]] = None
  14. def get_credentials_schema(self) -> Dict[str, ToolProviderCredentials]:
  15. """
  16. returns the credentials schema of the provider
  17. :return: the credentials schema
  18. """
  19. return self.credentials_schema.copy()
  20. def user_get_credentials_schema(self) -> UserToolProviderCredentials:
  21. """
  22. returns the credentials schema of the provider, this method is used for user
  23. :return: the credentials schema
  24. """
  25. credentials = self.credentials_schema.copy()
  26. return UserToolProviderCredentials(credentials=credentials)
  27. @abstractmethod
  28. def get_tools(self) -> List[Tool]:
  29. """
  30. returns a list of tools that the provider can provide
  31. :return: list of tools
  32. """
  33. pass
  34. @abstractmethod
  35. def get_tool(self, tool_name: str) -> Tool:
  36. """
  37. returns a tool that the provider can provide
  38. :return: tool
  39. """
  40. pass
  41. def get_parameters(self, tool_name: str) -> List[ToolParamter]:
  42. """
  43. returns the parameters of the tool
  44. :param tool_name: the name of the tool, defined in `get_tools`
  45. :return: list of parameters
  46. """
  47. tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
  48. if tool is None:
  49. raise ToolNotFoundError(f'tool {tool_name} not found')
  50. return tool.parameters
  51. @property
  52. def app_type(self) -> ToolProviderType:
  53. """
  54. returns the type of the provider
  55. :return: type of the provider
  56. """
  57. return ToolProviderType.BUILT_IN
  58. def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
  59. """
  60. validate the parameters of the tool and set the default value if needed
  61. :param tool_name: the name of the tool, defined in `get_tools`
  62. :param tool_parameters: the parameters of the tool
  63. """
  64. tool_parameters_schema = self.get_parameters(tool_name)
  65. tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
  66. for parameter in tool_parameters_schema:
  67. tool_parameters_need_to_validate[parameter.name] = parameter
  68. for parameter in tool_parameters:
  69. if parameter not in tool_parameters_need_to_validate:
  70. raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
  71. # check type
  72. parameter_schema = tool_parameters_need_to_validate[parameter]
  73. if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
  74. if not isinstance(tool_parameters[parameter], str):
  75. raise ToolParamterValidationError(f'parameter {parameter} should be string')
  76. elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
  77. if not isinstance(tool_parameters[parameter], (int, float)):
  78. raise ToolParamterValidationError(f'parameter {parameter} should be number')
  79. if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
  80. raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
  81. if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
  82. raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
  83. elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
  84. if not isinstance(tool_parameters[parameter], bool):
  85. raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
  86. elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
  87. if not isinstance(tool_parameters[parameter], str):
  88. raise ToolParamterValidationError(f'parameter {parameter} should be string')
  89. options = parameter_schema.options
  90. if not isinstance(options, list):
  91. raise ToolParamterValidationError(f'parameter {parameter} options should be list')
  92. if tool_parameters[parameter] not in [x.value for x in options]:
  93. raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
  94. tool_parameters_need_to_validate.pop(parameter)
  95. for parameter in tool_parameters_need_to_validate:
  96. parameter_schema = tool_parameters_need_to_validate[parameter]
  97. if parameter_schema.required:
  98. raise ToolParamterValidationError(f'parameter {parameter} is required')
  99. # the parameter is not set currently, set the default value if needed
  100. if parameter_schema.default is not None:
  101. default_value = parameter_schema.default
  102. # parse default value into the correct type
  103. if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
  104. parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
  105. default_value = str(default_value)
  106. elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
  107. default_value = float(default_value)
  108. elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
  109. default_value = bool(default_value)
  110. tool_parameters[parameter] = default_value
  111. def validate_credentials_format(self, credentials: Dict[str, Any]) -> None:
  112. """
  113. validate the format of the credentials of the provider and set the default value if needed
  114. :param credentials: the credentials of the tool
  115. """
  116. credentials_schema = self.credentials_schema
  117. if credentials_schema is None:
  118. return
  119. credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {}
  120. for credential_name in credentials_schema:
  121. credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
  122. for credential_name in credentials:
  123. if credential_name not in credentials_need_to_validate:
  124. raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
  125. # check type
  126. credential_schema = credentials_need_to_validate[credential_name]
  127. if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
  128. credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
  129. if not isinstance(credentials[credential_name], str):
  130. raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
  131. elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
  132. if not isinstance(credentials[credential_name], str):
  133. raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
  134. options = credential_schema.options
  135. if not isinstance(options, list):
  136. raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
  137. if credentials[credential_name] not in [x.value for x in options]:
  138. raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
  139. credentials_need_to_validate.pop(credential_name)
  140. for credential_name in credentials_need_to_validate:
  141. credential_schema = credentials_need_to_validate[credential_name]
  142. if credential_schema.required:
  143. raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
  144. # the credential is not set currently, set the default value if needed
  145. if credential_schema.default is not None:
  146. default_value = credential_schema.default
  147. # parse default value into the correct type
  148. if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
  149. credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
  150. credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
  151. default_value = str(default_value)
  152. credentials[credential_name] = default_value
  153. def validate_credentials(self, credentials: Dict[str, Any]) -> None:
  154. """
  155. validate the credentials of the provider
  156. :param tool_name: the name of the tool, defined in `get_tools`
  157. :param credentials: the credentials of the tool
  158. """
  159. # validate credentials format
  160. self.validate_credentials_format(credentials)
  161. # validate credentials
  162. self._validate_credentials(credentials)
  163. @abstractmethod
  164. def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
  165. """
  166. validate the credentials of the provider
  167. :param tool_name: the name of the tool, defined in `get_tools`
  168. :param credentials: the credentials of the tool
  169. """
  170. pass