|
- from copy import deepcopy
- from typing import Any
- from core.entities.model_entities import ModelStatus
- from core.errors.error import ProviderTokenNotInitError
- from core.model_manager import ModelInstance
- from core.model_runtime.entities.model_entities import ModelFeature, ModelType
- from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
- from core.tools.entities.common_entities import I18nObject
- from core.tools.entities.tool_entities import (
- ModelToolPropertyKey,
- ToolDescription,
- ToolIdentity,
- ToolParameter,
- ToolProviderCredentials,
- ToolProviderIdentity,
- ToolProviderType,
- )
- from core.tools.errors import ToolNotFoundError
- from core.tools.provider.tool_provider import ToolProviderController
- from core.tools.tool.model_tool import ModelTool
- from core.tools.tool.tool import Tool
- from core.tools.utils.configuration import ModelToolConfigurationManager
- class ModelToolProviderController(ToolProviderController):
- configuration: ProviderConfiguration = None
- is_active: bool = False
- def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
- """
- init the provider
- :param data: the data of the provider
- """
- super().__init__(**kwargs)
- self.configuration = configuration
- @staticmethod
- def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
- """
- init the provider from db
- :param configuration: the configuration of the provider
- """
- # check if all models are active
- if configuration is None:
- return None
- is_active = True
- models = configuration.get_provider_models()
- for model in models:
- if model.status != ModelStatus.ACTIVE:
- is_active = False
- break
- # get the provider configuration
- model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
- if model_tool_configuration is None:
- raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
- # override the configuration
- if model_tool_configuration.label:
- label = deepcopy(model_tool_configuration.label)
- if label.en_US:
- label.en_US = model_tool_configuration.label.en_US
- if label.zh_Hans:
- label.zh_Hans = model_tool_configuration.label.zh_Hans
- else:
- label = I18nObject(
- en_US=configuration.provider.label.en_US,
- zh_Hans=configuration.provider.label.zh_Hans
- )
- return ModelToolProviderController(
- is_active=is_active,
- identity=ToolProviderIdentity(
- author='Dify',
- name=configuration.provider.provider,
- description=I18nObject(
- zh_Hans=f'{label.zh_Hans} 模型能力提供商',
- en_US=f'{label.en_US} model capability provider'
- ),
- label=I18nObject(
- zh_Hans=label.zh_Hans,
- en_US=label.en_US
- ),
- icon=configuration.provider.icon_small.en_US,
- ),
- configuration=configuration,
- credentials_schema={},
- )
-
- @staticmethod
- def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
- """
- check if the configuration has a model can be used as a tool
- """
- models = configuration.get_provider_models()
- for model in models:
- if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
- return True
- return False
- def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
- """
- returns a list of tools that the provider can provide
- :return: list of tools
- """
- tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
- provider_manager = ProviderManager()
- if self.configuration is None:
- configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
- self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
- # get all tools
- tools: list[ModelTool] = []
- # get all models
- if not self.configuration:
- return tools
- configuration = self.configuration
- provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
- if provider_configuration is None:
- raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
- for model in configuration.get_provider_models():
- model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
- if model_configuration is None:
- continue
- if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
- provider_instance = configuration.get_provider_instance()
- model_type_instance = provider_instance.get_model_instance(model.model_type)
- provider_model_bundle = ProviderModelBundle(
- configuration=configuration,
- provider_instance=provider_instance,
- model_type_instance=model_type_instance
- )
- try:
- model_instance = ModelInstance(provider_model_bundle, model.model)
- except ProviderTokenNotInitError:
- model_instance = None
-
- tools.append(ModelTool(
- identity=ToolIdentity(
- author='Dify',
- name=model.model,
- label=model_configuration.label,
- ),
- parameters=[
- ToolParameter(
- name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
- label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
- human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
- type=ToolParameter.ToolParameterType.STRING,
- form=ToolParameter.ToolParameterForm.LLM,
- required=True,
- default=Tool.VARIABLE_KEY.IMAGE.value
- )
- ],
- description=ToolDescription(
- human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
- llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
- ),
- is_team_authorization=model.status == ModelStatus.ACTIVE,
- tool_type=ModelTool.ModelToolType.VISION,
- model_instance=model_instance,
- model=model.model,
- ))
- self.tools = tools
- return tools
-
- def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
- """
- returns the credentials schema of the provider
- :return: the credentials schema
- """
- return {}
- def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
- """
- returns a list of tools that the provider can provide
- :return: list of tools
- """
- return self._get_model_tools(tenant_id=tenant_id)
-
- def get_tool(self, tool_name: str) -> ModelTool:
- """
- get tool by name
- :param tool_name: the name of the tool
- :return: the tool
- """
- if self.tools is None:
- self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
- for tool in self.tools:
- if tool.identity.name == tool_name:
- return tool
- raise ValueError(f'tool {tool_name} not found')
- def get_parameters(self, tool_name: str) -> list[ToolParameter]:
- """
- returns the parameters of the tool
- :param tool_name: the name of the tool, defined in `get_tools`
- :return: list of parameters
- """
- tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
- if tool is None:
- raise ToolNotFoundError(f'tool {tool_name} not found')
- return tool.parameters
- @property
- def app_type(self) -> ToolProviderType:
- """
- returns the type of the provider
- :return: type of the provider
- """
- return ToolProviderType.MODEL
-
- def validate_credentials(self, credentials: dict[str, Any]) -> None:
- """
- validate the credentials of the provider
- :param tool_name: the name of the tool, defined in `get_tools`
- :param credentials: the credentials of the tool
- """
- pass
- def _validate_credentials(self, credentials: dict[str, Any]) -> None:
- """
- validate the credentials of the provider
- :param tool_name: the name of the tool, defined in `get_tools`
- :param credentials: the credentials of the tool
- """
- pass
|