|
@@ -1,4 +1,6 @@
|
|
|
+import re
|
|
|
from typing import Any
|
|
|
+from urllib.parse import urlparse
|
|
|
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
|
|
@@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
|
|
|
|
|
|
|
|
class VannaProvider(BuiltinToolProviderController):
|
|
|
+ def _get_protocol_and_main_domain(self, url):
|
|
|
+ parsed_url = urlparse(url)
|
|
|
+ protocol = parsed_url.scheme
|
|
|
+ hostname = parsed_url.hostname
|
|
|
+ port = f":{parsed_url.port}" if parsed_url.port else ""
|
|
|
+
|
|
|
+ # Check if the hostname is an IP address
|
|
|
+ is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None
|
|
|
+
|
|
|
+ # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain
|
|
|
+ main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port
|
|
|
+ return f"{protocol}://{main_domain}"
|
|
|
+
|
|
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
|
|
+ base_url = credentials.get("base_url")
|
|
|
+ if not base_url:
|
|
|
+ base_url = "https://ask.vanna.ai/rpc"
|
|
|
+ else:
|
|
|
+ base_url = base_url.removesuffix("/")
|
|
|
+ credentials["base_url"] = base_url
|
|
|
try:
|
|
|
VannaTool().fork_tool_runtime(
|
|
|
runtime={
|
|
@@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
|
|
|
tool_parameters={
|
|
|
"model": "chinook",
|
|
|
"db_type": "SQLite",
|
|
|
- "url": "https://vanna.ai/Chinook.sqlite",
|
|
|
+ "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
|
|
|
"query": "What are the top 10 customers by sales?",
|
|
|
},
|
|
|
)
|