factory.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from collections.abc import Mapping
  2. from typing import Any, Optional, cast
  3. from core.extension.extensible import ExtensionModule
  4. from extensions.ext_code_based_extension import code_based_extension
  5. class ExternalDataToolFactory:
  6. def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
  7. extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
  8. self.__extension_instance = extension_class(
  9. tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
  10. )
  11. @classmethod
  12. def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
  13. """
  14. Validate the incoming form config data.
  15. :param name: the name of external data tool
  16. :param tenant_id: the id of workspace
  17. :param config: the form config data
  18. :return:
  19. """
  20. code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
  21. extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
  22. # FIXME mypy issue here, figure out how to fix it
  23. extension_class.validate_config(tenant_id, config) # type: ignore
  24. def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str:
  25. """
  26. Query the external data tool.
  27. :param inputs: user inputs
  28. :param query: the query of chat app
  29. :return: the tool query result
  30. """
  31. return cast(str, self.__extension_instance.query(inputs, query))