app_tool_provider.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import logging
  2. from typing import Any, Optional
  3. from core.tools.entities.common_entities import I18nObject
  4. from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
  5. from core.tools.provider.tool_provider import ToolProviderController
  6. from core.tools.tool.api_tool import ApiTool
  7. from core.tools.tool.tool import Tool
  8. from extensions.ext_database import db
  9. from models.model import App, AppModelConfig
  10. from models.tools import PublishedAppTool
  11. logger = logging.getLogger(__name__)
  12. class AppToolProviderEntity(ToolProviderController):
  13. @property
  14. def provider_type(self) -> ToolProviderType:
  15. return ToolProviderType.APP
  16. def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
  17. pass
  18. def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
  19. pass
  20. def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
  21. db_tools: list[PublishedAppTool] = (
  22. db.session.query(PublishedAppTool)
  23. .filter(
  24. PublishedAppTool.user_id == user_id,
  25. )
  26. .all()
  27. )
  28. if not db_tools or len(db_tools) == 0:
  29. return []
  30. tools: list[Tool] = []
  31. for db_tool in db_tools:
  32. tool: dict[str, Any] = {
  33. "identity": {
  34. "author": db_tool.author,
  35. "name": db_tool.tool_name,
  36. "label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name},
  37. "icon": "",
  38. },
  39. "description": {
  40. "human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans},
  41. "llm": db_tool.llm_description,
  42. },
  43. "parameters": [],
  44. }
  45. # get app from db
  46. app: Optional[App] = db_tool.app
  47. if not app:
  48. logger.error(f"app {db_tool.app_id} not found")
  49. continue
  50. app_model_config: AppModelConfig = app.app_model_config
  51. user_input_form_list = app_model_config.user_input_form_list
  52. for input_form in user_input_form_list:
  53. # get type
  54. form_type = list(input_form.keys())[0]
  55. default = input_form[form_type]["default"]
  56. required = input_form[form_type]["required"]
  57. label = input_form[form_type]["label"]
  58. variable_name = input_form[form_type]["variable_name"]
  59. options = input_form[form_type].get("options", [])
  60. if form_type in {"paragraph", "text-input"}:
  61. tool["parameters"].append(
  62. ToolParameter(
  63. name=variable_name,
  64. label=I18nObject(en_US=label, zh_Hans=label),
  65. human_description=I18nObject(en_US=label, zh_Hans=label),
  66. llm_description=label,
  67. form=ToolParameter.ToolParameterForm.FORM,
  68. type=ToolParameter.ToolParameterType.STRING,
  69. required=required,
  70. default=default,
  71. placeholder=I18nObject(en_US="", zh_Hans=""),
  72. )
  73. )
  74. elif form_type == "select":
  75. tool["parameters"].append(
  76. ToolParameter(
  77. name=variable_name,
  78. label=I18nObject(en_US=label, zh_Hans=label),
  79. human_description=I18nObject(en_US=label, zh_Hans=label),
  80. llm_description=label,
  81. form=ToolParameter.ToolParameterForm.FORM,
  82. type=ToolParameter.ToolParameterType.SELECT,
  83. required=required,
  84. default=default,
  85. placeholder=I18nObject(en_US="", zh_Hans=""),
  86. options=[
  87. ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  88. for option in options
  89. ],
  90. )
  91. )
  92. tools.append(ApiTool(**tool))
  93. return tools