|
@@ -3,6 +3,7 @@ import logging
|
|
|
|
|
|
from httpx import get
|
|
|
|
|
|
+from core.entities.provider_entities import ProviderConfig
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
|
|
from core.tools.entities.common_entities import I18nObject
|
|
@@ -10,8 +11,6 @@ from core.tools.entities.tool_bundle import ApiToolBundle
|
|
|
from core.tools.entities.tool_entities import (
|
|
|
ApiProviderAuthType,
|
|
|
ApiProviderSchemaType,
|
|
|
- ProviderConfig,
|
|
|
- ToolCredentialsOption,
|
|
|
)
|
|
|
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
|
|
from core.tools.tool_label_manager import ToolLabelManager
|
|
@@ -45,8 +44,8 @@ class ApiToolManageService:
|
|
|
required=True,
|
|
|
default="none",
|
|
|
options=[
|
|
|
- ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
|
|
- ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
|
|
+ ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
|
|
+ ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
|
|
],
|
|
|
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
|
|
),
|
|
@@ -79,15 +78,14 @@ class ApiToolManageService:
|
|
|
raise ValueError(f"invalid schema: {str(e)}")
|
|
|
|
|
|
@staticmethod
|
|
|
- def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
|
|
+ def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
|
|
"""
|
|
|
convert schema to tool bundles
|
|
|
|
|
|
:return: the list of tool bundles, description
|
|
|
"""
|
|
|
try:
|
|
|
- tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
|
|
- return tool_bundles
|
|
|
+ return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
|
|
except Exception as e:
|
|
|
raise ValueError(f"invalid schema: {str(e)}")
|
|
|
|
|
@@ -111,7 +109,7 @@ class ApiToolManageService:
|
|
|
raise ValueError(f"invalid schema type {schema}")
|
|
|
|
|
|
# check if the provider exists
|
|
|
- provider: ApiToolProvider = (
|
|
|
+ provider: ApiToolProvider | None = (
|
|
|
db.session.query(ApiToolProvider)
|
|
|
.filter(
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
@@ -158,7 +156,13 @@ class ApiToolManageService:
|
|
|
provider_controller.load_bundled_tools(tool_bundles)
|
|
|
|
|
|
# encrypt credentials
|
|
|
- tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ config=provider_controller.get_credentials_schema(),
|
|
|
+ provider_type=provider_controller.provider_type.value,
|
|
|
+ provider_identity=provider_controller.identity.name
|
|
|
+ )
|
|
|
+
|
|
|
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
|
|
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
|
|
|
|
@@ -195,21 +199,21 @@ class ApiToolManageService:
|
|
|
return {"schema": schema}
|
|
|
|
|
|
@staticmethod
|
|
|
- def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
|
|
+ def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
|
|
|
"""
|
|
|
list api tool provider tools
|
|
|
"""
|
|
|
- provider: ApiToolProvider = (
|
|
|
+ provider: ApiToolProvider | None = (
|
|
|
db.session.query(ApiToolProvider)
|
|
|
.filter(
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
|
- ApiToolProvider.name == provider,
|
|
|
+ ApiToolProvider.name == provider_name,
|
|
|
)
|
|
|
.first()
|
|
|
)
|
|
|
|
|
|
if provider is None:
|
|
|
- raise ValueError(f"you have not added provider {provider}")
|
|
|
+ raise ValueError(f"you have not added provider {provider_name}")
|
|
|
|
|
|
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
|
|
labels = ToolLabelManager.get_tool_labels(controller)
|
|
@@ -243,7 +247,7 @@ class ApiToolManageService:
|
|
|
raise ValueError(f"invalid schema type {schema}")
|
|
|
|
|
|
# check if the provider exists
|
|
|
- provider: ApiToolProvider = (
|
|
|
+ provider: ApiToolProvider | None = (
|
|
|
db.session.query(ApiToolProvider)
|
|
|
.filter(
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
@@ -282,7 +286,12 @@ class ApiToolManageService:
|
|
|
provider_controller.load_bundled_tools(tool_bundles)
|
|
|
|
|
|
# get original credentials if exists
|
|
|
- tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ config=provider_controller.get_credentials_schema(),
|
|
|
+ provider_type=provider_controller.provider_type.value,
|
|
|
+ provider_identity=provider_controller.identity.name
|
|
|
+ )
|
|
|
|
|
|
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
|
@@ -310,7 +319,7 @@ class ApiToolManageService:
|
|
|
"""
|
|
|
delete tool provider
|
|
|
"""
|
|
|
- provider: ApiToolProvider = (
|
|
|
+ provider: ApiToolProvider | None = (
|
|
|
db.session.query(ApiToolProvider)
|
|
|
.filter(
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
@@ -360,7 +369,7 @@ class ApiToolManageService:
|
|
|
if tool_bundle is None:
|
|
|
raise ValueError(f"invalid tool name {tool_name}")
|
|
|
|
|
|
- db_provider: ApiToolProvider = (
|
|
|
+ db_provider: ApiToolProvider | None = (
|
|
|
db.session.query(ApiToolProvider)
|
|
|
.filter(
|
|
|
ApiToolProvider.tenant_id == tenant_id,
|
|
@@ -396,7 +405,12 @@ class ApiToolManageService:
|
|
|
|
|
|
# decrypt credentials
|
|
|
if db_provider.id:
|
|
|
- tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ config=provider_controller.get_credentials_schema(),
|
|
|
+ provider_type=provider_controller.provider_type.value,
|
|
|
+ provider_identity=provider_controller.identity.name
|
|
|
+ )
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
|
|
# check if the credential has changed, save the original credential
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
|
@@ -444,7 +458,7 @@ class ApiToolManageService:
|
|
|
# add icon
|
|
|
ToolTransformService.repack_provider(user_provider)
|
|
|
|
|
|
- tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
|
|
+ tools = provider_controller.get_tools(tenant_id=tenant_id)
|
|
|
|
|
|
for tool in tools:
|
|
|
user_provider.tools.append(
|