api.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from typing import Optional
  2. from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
  3. from core.external_data_tool.base import ExternalDataTool
  4. from core.helper import encrypter
  5. from extensions.ext_database import db
  6. from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
  7. class ApiExternalDataTool(ExternalDataTool):
  8. """
  9. The api external data tool.
  10. """
  11. name: str = "api"
  12. """the unique name of external data tool"""
  13. @classmethod
  14. def validate_config(cls, tenant_id: str, config: dict) -> None:
  15. """
  16. Validate the incoming form config data.
  17. :param tenant_id: the id of workspace
  18. :param config: the form config data
  19. :return:
  20. """
  21. # own validation logic
  22. api_based_extension_id = config.get("api_based_extension_id")
  23. if not api_based_extension_id:
  24. raise ValueError("api_based_extension_id is required")
  25. # get api_based_extension
  26. api_based_extension = (
  27. db.session.query(APIBasedExtension)
  28. .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
  29. .first()
  30. )
  31. if not api_based_extension:
  32. raise ValueError("api_based_extension_id is invalid")
  33. def query(self, inputs: dict, query: Optional[str] = None) -> str:
  34. """
  35. Query the external data tool.
  36. :param inputs: user inputs
  37. :param query: the query of chat app
  38. :return: the tool query result
  39. """
  40. # get params from config
  41. api_based_extension_id = self.config.get("api_based_extension_id")
  42. # get api_based_extension
  43. api_based_extension = (
  44. db.session.query(APIBasedExtension)
  45. .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
  46. .first()
  47. )
  48. if not api_based_extension:
  49. raise ValueError(
  50. "[External data tool] API query failed, variable: {}, "
  51. "error: api_based_extension_id is invalid".format(self.variable)
  52. )
  53. # decrypt api_key
  54. api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key)
  55. try:
  56. # request api
  57. requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
  58. except Exception as e:
  59. raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e))
  60. response_json = requestor.request(
  61. point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
  62. params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query},
  63. )
  64. if "result" not in response_json:
  65. raise ValueError(
  66. "[External data tool] API query failed, variable: {}, error: result not found in response".format(
  67. self.variable
  68. )
  69. )
  70. if not isinstance(response_json["result"], str):
  71. raise ValueError(
  72. "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable)
  73. )
  74. return response_json["result"]