model_tool_provider.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from copy import deepcopy
  2. from typing import Any
  3. from core.entities.model_entities import ModelStatus
  4. from core.errors.error import ProviderTokenNotInitError
  5. from core.model_manager import ModelInstance
  6. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  7. from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
  8. from core.tools.entities.common_entities import I18nObject
  9. from core.tools.entities.tool_entities import (
  10. ModelToolPropertyKey,
  11. ToolDescription,
  12. ToolIdentity,
  13. ToolParameter,
  14. ToolProviderCredentials,
  15. ToolProviderIdentity,
  16. ToolProviderType,
  17. )
  18. from core.tools.errors import ToolNotFoundError
  19. from core.tools.provider.tool_provider import ToolProviderController
  20. from core.tools.tool.model_tool import ModelTool
  21. from core.tools.tool.tool import Tool
  22. from core.tools.utils.configuration import ModelToolConfigurationManager
  23. class ModelToolProviderController(ToolProviderController):
  24. configuration: ProviderConfiguration = None
  25. is_active: bool = False
  26. def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
  27. """
  28. init the provider
  29. :param data: the data of the provider
  30. """
  31. super().__init__(**kwargs)
  32. self.configuration = configuration
  33. @staticmethod
  34. def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
  35. """
  36. init the provider from db
  37. :param configuration: the configuration of the provider
  38. """
  39. # check if all models are active
  40. if configuration is None:
  41. return None
  42. is_active = True
  43. models = configuration.get_provider_models()
  44. for model in models:
  45. if model.status != ModelStatus.ACTIVE:
  46. is_active = False
  47. break
  48. # get the provider configuration
  49. model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
  50. if model_tool_configuration is None:
  51. raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
  52. # override the configuration
  53. if model_tool_configuration.label:
  54. label = deepcopy(model_tool_configuration.label)
  55. if label.en_US:
  56. label.en_US = model_tool_configuration.label.en_US
  57. if label.zh_Hans:
  58. label.zh_Hans = model_tool_configuration.label.zh_Hans
  59. else:
  60. label = I18nObject(
  61. en_US=configuration.provider.label.en_US,
  62. zh_Hans=configuration.provider.label.zh_Hans
  63. )
  64. return ModelToolProviderController(
  65. is_active=is_active,
  66. identity=ToolProviderIdentity(
  67. author='Dify',
  68. name=configuration.provider.provider,
  69. description=I18nObject(
  70. zh_Hans=f'{label.zh_Hans} 模型能力提供商',
  71. en_US=f'{label.en_US} model capability provider'
  72. ),
  73. label=I18nObject(
  74. zh_Hans=label.zh_Hans,
  75. en_US=label.en_US
  76. ),
  77. icon=configuration.provider.icon_small.en_US,
  78. ),
  79. configuration=configuration,
  80. credentials_schema={},
  81. )
  82. @staticmethod
  83. def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
  84. """
  85. check if the configuration has a model can be used as a tool
  86. """
  87. models = configuration.get_provider_models()
  88. for model in models:
  89. if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
  90. return True
  91. return False
  92. def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
  93. """
  94. returns a list of tools that the provider can provide
  95. :return: list of tools
  96. """
  97. tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
  98. provider_manager = ProviderManager()
  99. if self.configuration is None:
  100. configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
  101. self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
  102. # get all tools
  103. tools: list[ModelTool] = []
  104. # get all models
  105. if not self.configuration:
  106. return tools
  107. configuration = self.configuration
  108. provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
  109. if provider_configuration is None:
  110. raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
  111. for model in configuration.get_provider_models():
  112. model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
  113. if model_configuration is None:
  114. continue
  115. if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
  116. provider_instance = configuration.get_provider_instance()
  117. model_type_instance = provider_instance.get_model_instance(model.model_type)
  118. provider_model_bundle = ProviderModelBundle(
  119. configuration=configuration,
  120. provider_instance=provider_instance,
  121. model_type_instance=model_type_instance
  122. )
  123. try:
  124. model_instance = ModelInstance(provider_model_bundle, model.model)
  125. except ProviderTokenNotInitError:
  126. model_instance = None
  127. tools.append(ModelTool(
  128. identity=ToolIdentity(
  129. author='Dify',
  130. name=model.model,
  131. label=model_configuration.label,
  132. ),
  133. parameters=[
  134. ToolParameter(
  135. name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
  136. label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
  137. human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
  138. type=ToolParameter.ToolParameterType.STRING,
  139. form=ToolParameter.ToolParameterForm.LLM,
  140. required=True,
  141. default=Tool.VARIABLE_KEY.IMAGE.value
  142. )
  143. ],
  144. description=ToolDescription(
  145. human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
  146. llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
  147. ),
  148. is_team_authorization=model.status == ModelStatus.ACTIVE,
  149. tool_type=ModelTool.ModelToolType.VISION,
  150. model_instance=model_instance,
  151. model=model.model,
  152. ))
  153. self.tools = tools
  154. return tools
  155. def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
  156. """
  157. returns the credentials schema of the provider
  158. :return: the credentials schema
  159. """
  160. return {}
  161. def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
  162. """
  163. returns a list of tools that the provider can provide
  164. :return: list of tools
  165. """
  166. return self._get_model_tools(tenant_id=tenant_id)
  167. def get_tool(self, tool_name: str) -> ModelTool:
  168. """
  169. get tool by name
  170. :param tool_name: the name of the tool
  171. :return: the tool
  172. """
  173. if self.tools is None:
  174. self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
  175. for tool in self.tools:
  176. if tool.identity.name == tool_name:
  177. return tool
  178. raise ValueError(f'tool {tool_name} not found')
  179. def get_parameters(self, tool_name: str) -> list[ToolParameter]:
  180. """
  181. returns the parameters of the tool
  182. :param tool_name: the name of the tool, defined in `get_tools`
  183. :return: list of parameters
  184. """
  185. tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
  186. if tool is None:
  187. raise ToolNotFoundError(f'tool {tool_name} not found')
  188. return tool.parameters
  189. @property
  190. def app_type(self) -> ToolProviderType:
  191. """
  192. returns the type of the provider
  193. :return: type of the provider
  194. """
  195. return ToolProviderType.MODEL
  196. def validate_credentials(self, credentials: dict[str, Any]) -> None:
  197. """
  198. validate the credentials of the provider
  199. :param tool_name: the name of the tool, defined in `get_tools`
  200. :param credentials: the credentials of the tool
  201. """
  202. pass
  203. def _validate_credentials(self, credentials: dict[str, Any]) -> None:
  204. """
  205. validate the credentials of the provider
  206. :param tool_name: the name of the tool, defined in `get_tools`
  207. :param credentials: the credentials of the tool
  208. """
  209. pass