dataset.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from flask import request
  2. from flask_restful import marshal, reqparse # type: ignore
  3. from werkzeug.exceptions import NotFound
  4. import services.dataset_service
  5. from controllers.service_api import api
  6. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
  7. from controllers.service_api.wraps import DatasetApiResource
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.plugin.entities.plugin import ModelProviderID
  10. from core.provider_manager import ProviderManager
  11. from fields.dataset_fields import dataset_detail_fields
  12. from libs.login import current_user
  13. from models.dataset import Dataset, DatasetPermissionEnum
  14. from services.dataset_service import DatasetService
  15. def _validate_name(name):
  16. if not name or len(name) < 1 or len(name) > 40:
  17. raise ValueError("Name must be between 1 to 40 characters.")
  18. return name
  19. class DatasetListApi(DatasetApiResource):
  20. """Resource for datasets."""
  21. def get(self, tenant_id):
  22. """Resource for getting datasets."""
  23. page = request.args.get("page", default=1, type=int)
  24. limit = request.args.get("limit", default=20, type=int)
  25. # provider = request.args.get("provider", default="vendor")
  26. search = request.args.get("keyword", default=None, type=str)
  27. tag_ids = request.args.getlist("tag_ids")
  28. include_all = request.args.get("include_all", default="false").lower() == "true"
  29. datasets, total = DatasetService.get_datasets(
  30. page, limit, tenant_id, current_user, search, tag_ids, include_all
  31. )
  32. # check embedding setting
  33. provider_manager = ProviderManager()
  34. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  35. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  36. model_names = []
  37. for embedding_model in embedding_models:
  38. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  39. data = marshal(datasets, dataset_detail_fields)
  40. for item in data:
  41. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  42. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  43. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  44. if item_model in model_names:
  45. item["embedding_available"] = True
  46. else:
  47. item["embedding_available"] = False
  48. else:
  49. item["embedding_available"] = True
  50. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  51. return response, 200
  52. def post(self, tenant_id):
  53. """Resource for creating datasets."""
  54. parser = reqparse.RequestParser()
  55. parser.add_argument(
  56. "name",
  57. nullable=False,
  58. required=True,
  59. help="type is required. Name must be between 1 to 40 characters.",
  60. type=_validate_name,
  61. )
  62. parser.add_argument(
  63. "description",
  64. type=str,
  65. nullable=True,
  66. required=False,
  67. default="",
  68. )
  69. parser.add_argument(
  70. "indexing_technique",
  71. type=str,
  72. location="json",
  73. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  74. help="Invalid indexing technique.",
  75. )
  76. parser.add_argument(
  77. "permission",
  78. type=str,
  79. location="json",
  80. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  81. help="Invalid permission.",
  82. required=False,
  83. nullable=False,
  84. )
  85. parser.add_argument(
  86. "external_knowledge_api_id",
  87. type=str,
  88. nullable=True,
  89. required=False,
  90. default="_validate_name",
  91. )
  92. parser.add_argument(
  93. "provider",
  94. type=str,
  95. nullable=True,
  96. required=False,
  97. default="vendor",
  98. )
  99. parser.add_argument(
  100. "external_knowledge_id",
  101. type=str,
  102. nullable=True,
  103. required=False,
  104. )
  105. args = parser.parse_args()
  106. try:
  107. dataset = DatasetService.create_empty_dataset(
  108. tenant_id=tenant_id,
  109. name=args["name"],
  110. description=args["description"],
  111. indexing_technique=args["indexing_technique"],
  112. account=current_user,
  113. permission=args["permission"],
  114. provider=args["provider"],
  115. external_knowledge_api_id=args["external_knowledge_api_id"],
  116. external_knowledge_id=args["external_knowledge_id"],
  117. )
  118. except services.errors.dataset.DatasetNameDuplicateError:
  119. raise DatasetNameDuplicateError()
  120. return marshal(dataset, dataset_detail_fields), 200
  121. class DatasetApi(DatasetApiResource):
  122. """Resource for dataset."""
  123. def delete(self, _, dataset_id):
  124. """
  125. Deletes a dataset given its ID.
  126. Args:
  127. dataset_id (UUID): The ID of the dataset to be deleted.
  128. Returns:
  129. dict: A dictionary with a key 'result' and a value 'success'
  130. if the dataset was successfully deleted. Omitted in HTTP response.
  131. int: HTTP status code 204 indicating that the operation was successful.
  132. Raises:
  133. NotFound: If the dataset with the given ID does not exist.
  134. """
  135. dataset_id_str = str(dataset_id)
  136. try:
  137. if DatasetService.delete_dataset(dataset_id_str, current_user):
  138. return {"result": "success"}, 204
  139. else:
  140. raise NotFound("Dataset not found.")
  141. except services.errors.dataset.DatasetInUseError:
  142. raise DatasetInUseError()
  143. api.add_resource(DatasetListApi, "/datasets")
  144. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")