import re from typing import Any from urllib.parse import urlparse from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vanna.tools.vanna import VannaTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class VannaProvider(BuiltinToolProviderController): def _get_protocol_and_main_domain(self, url): parsed_url = urlparse(url) protocol = parsed_url.scheme hostname = parsed_url.hostname port = f":{parsed_url.port}" if parsed_url.port else "" # Check if the hostname is an IP address is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port return f"{protocol}://{main_domain}" def _validate_credentials(self, credentials: dict[str, Any]) -> None: base_url = credentials.get("base_url") if not base_url: base_url = "https://ask.vanna.ai/rpc" else: base_url = base_url.removesuffix("/") credentials["base_url"] = base_url try: VannaTool().fork_tool_runtime( runtime={ "credentials": credentials, } ).invoke( user_id="", tool_parameters={ "model": "chinook", "db_type": "SQLite", "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite', "query": "What are the top 10 customers by sales?", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e))