vanna.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import re
  2. from typing import Any
  3. from urllib.parse import urlparse
  4. from core.tools.errors import ToolProviderCredentialValidationError
  5. from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
  6. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  7. class VannaProvider(BuiltinToolProviderController):
  8. def _get_protocol_and_main_domain(self, url):
  9. parsed_url = urlparse(url)
  10. protocol = parsed_url.scheme
  11. hostname = parsed_url.hostname
  12. port = f":{parsed_url.port}" if parsed_url.port else ""
  13. # Check if the hostname is an IP address
  14. is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None
  15. # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain
  16. main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port
  17. return f"{protocol}://{main_domain}"
  18. def _validate_credentials(self, credentials: dict[str, Any]) -> None:
  19. base_url = credentials.get("base_url")
  20. if not base_url:
  21. base_url = "https://ask.vanna.ai/rpc"
  22. else:
  23. base_url = base_url.removesuffix("/")
  24. credentials["base_url"] = base_url
  25. try:
  26. VannaTool().fork_tool_runtime(
  27. runtime={
  28. "credentials": credentials,
  29. }
  30. ).invoke(
  31. user_id="",
  32. tool_parameters={
  33. "model": "chinook",
  34. "db_type": "SQLite",
  35. "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
  36. "query": "What are the top 10 customers by sales?",
  37. },
  38. )
  39. except Exception as e:
  40. raise ToolProviderCredentialValidationError(str(e))