tool_label_manager.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from core.tools.__base.tool_provider import ToolProviderController
  2. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  3. from core.tools.custom_tool.provider import ApiToolProviderController
  4. from core.tools.entities.values import default_tool_label_name_list
  5. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  6. from extensions.ext_database import db
  7. from models.tools import ToolLabelBinding
  8. class ToolLabelManager:
  9. @classmethod
  10. def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
  11. """
  12. Filter tool labels
  13. """
  14. tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
  15. return list(set(tool_labels))
  16. @classmethod
  17. def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
  18. """
  19. Update tool labels
  20. """
  21. labels = cls.filter_tool_labels(labels)
  22. if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  23. provider_id = controller.provider_id
  24. else:
  25. raise ValueError("Unsupported tool type")
  26. # delete old labels
  27. db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
  28. # insert new labels
  29. for label in labels:
  30. db.session.add(
  31. ToolLabelBinding(
  32. tool_id=provider_id,
  33. tool_type=controller.provider_type.value,
  34. label_name=label,
  35. )
  36. )
  37. db.session.commit()
  38. @classmethod
  39. def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
  40. """
  41. Get tool labels
  42. """
  43. if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  44. provider_id = controller.provider_id
  45. elif isinstance(controller, BuiltinToolProviderController):
  46. return controller.tool_labels
  47. else:
  48. raise ValueError("Unsupported tool type")
  49. labels = (
  50. db.session.query(ToolLabelBinding.label_name)
  51. .filter(
  52. ToolLabelBinding.tool_id == provider_id,
  53. ToolLabelBinding.tool_type == controller.provider_type.value,
  54. )
  55. .all()
  56. )
  57. return [label.label_name for label in labels]
  58. @classmethod
  59. def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
  60. """
  61. Get tools labels
  62. :param tool_providers: list of tool providers
  63. :return: dict of tool labels
  64. :key: tool id
  65. :value: list of tool labels
  66. """
  67. if not tool_providers:
  68. return {}
  69. for controller in tool_providers:
  70. if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  71. raise ValueError("Unsupported tool type")
  72. provider_ids = []
  73. for controller in tool_providers:
  74. assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
  75. provider_ids.append(controller.provider_id)
  76. labels: list[ToolLabelBinding] = (
  77. db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
  78. )
  79. tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
  80. for label in labels:
  81. tool_labels[label.tool_id].append(label.label_name)
  82. return tool_labels