瀏覽代碼

feat(vannaai): add base_url configuration (#10294)

Benjamin 5 月之前
父節點
當前提交
d7b4d0756e

+ 2 - 1
api/core/tools/provider/builtin/vanna/tools/vanna.py

@@ -35,7 +35,8 @@ class VannaTool(BuiltinTool):
         password = tool_parameters.get("password", "")
         port = tool_parameters.get("port", 0)
 
-        vn = VannaDefault(model=model, api_key=api_key)
+        base_url = self.runtime.credentials.get("base_url", None)
+        vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url})
 
         db_type = tool_parameters.get("db_type", "")
         if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:

+ 22 - 1
api/core/tools/provider/builtin/vanna/vanna.py

@@ -1,4 +1,6 @@
+import re
 from typing import Any
+from urllib.parse import urlparse
 
 from core.tools.errors import ToolProviderCredentialValidationError
 from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
@@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
 
 
 class VannaProvider(BuiltinToolProviderController):
+    def _get_protocol_and_main_domain(self, url):
+        parsed_url = urlparse(url)
+        protocol = parsed_url.scheme
+        hostname = parsed_url.hostname
+        port = f":{parsed_url.port}" if parsed_url.port else ""
+
+        # Check if the hostname is an IP address
+        is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None
+
+        # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain
+        main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port
+        return f"{protocol}://{main_domain}"
+
     def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        base_url = credentials.get("base_url")
+        if not base_url:
+            base_url = "https://ask.vanna.ai/rpc"
+        else:
+            base_url = base_url.removesuffix("/")
+        credentials["base_url"] = base_url
         try:
             VannaTool().fork_tool_runtime(
                 runtime={
@@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
                 tool_parameters={
                     "model": "chinook",
                     "db_type": "SQLite",
-                    "url": "https://vanna.ai/Chinook.sqlite",
+                    "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
                     "query": "What are the top 10 customers by sales?",
                 },
             )

+ 7 - 0
api/core/tools/provider/builtin/vanna/vanna.yaml

@@ -26,3 +26,10 @@ credentials_for_provider:
       en_US: Get your API key from Vanna.AI
       zh_Hans: 从 Vanna.AI 获取你的 API key
     url: https://vanna.ai/account/profile
+  base_url:
+    type: text-input
+    required: false
+    label:
+      en_US: Vanna.AI Endpoint Base URL
+    placeholder:
+      en_US: https://ask.vanna.ai/rpc