api.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from typing import Optional
  2. from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
  3. from core.external_data_tool.base import ExternalDataTool
  4. from core.helper import encrypter
  5. from extensions.ext_database import db
  6. from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
  7. class ApiExternalDataTool(ExternalDataTool):
  8. """
  9. The api external data tool.
  10. """
  11. name: str = "api"
  12. """the unique name of external data tool"""
  13. @classmethod
  14. def validate_config(cls, tenant_id: str, config: dict) -> None:
  15. """
  16. Validate the incoming form config data.
  17. :param tenant_id: the id of workspace
  18. :param config: the form config data
  19. :return:
  20. """
  21. # own validation logic
  22. api_based_extension_id = config.get("api_based_extension_id")
  23. if not api_based_extension_id:
  24. raise ValueError("api_based_extension_id is required")
  25. # get api_based_extension
  26. api_based_extension = db.session.query(APIBasedExtension).filter(
  27. APIBasedExtension.tenant_id == tenant_id,
  28. APIBasedExtension.id == api_based_extension_id
  29. ).first()
  30. if not api_based_extension:
  31. raise ValueError("api_based_extension_id is invalid")
  32. def query(self, inputs: dict, query: Optional[str] = None) -> str:
  33. """
  34. Query the external data tool.
  35. :param inputs: user inputs
  36. :param query: the query of chat app
  37. :return: the tool query result
  38. """
  39. # get params from config
  40. api_based_extension_id = self.config.get("api_based_extension_id")
  41. # get api_based_extension
  42. api_based_extension = db.session.query(APIBasedExtension).filter(
  43. APIBasedExtension.tenant_id == self.tenant_id,
  44. APIBasedExtension.id == api_based_extension_id
  45. ).first()
  46. if not api_based_extension:
  47. raise ValueError("[External data tool] API query failed, variable: {}, "
  48. "error: api_based_extension_id is invalid"
  49. .format(self.config.get('variable')))
  50. # decrypt api_key
  51. api_key = encrypter.decrypt_token(
  52. tenant_id=self.tenant_id,
  53. token=api_based_extension.api_key
  54. )
  55. try:
  56. # request api
  57. requestor = APIBasedExtensionRequestor(
  58. api_endpoint=api_based_extension.api_endpoint,
  59. api_key=api_key
  60. )
  61. except Exception as e:
  62. raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
  63. self.config.get('variable'),
  64. e
  65. ))
  66. response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
  67. 'app_id': self.app_id,
  68. 'tool_variable': self.variable,
  69. 'inputs': inputs,
  70. 'query': query
  71. })
  72. if 'result' not in response_json:
  73. raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
  74. .format(self.config.get('variable')))
  75. return response_json['result']