datasets.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. import flask_restful # type: ignore
  2. from flask import request
  3. from flask_login import current_user # type: ignore # type: ignore
  4. from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
  5. from werkzeug.exceptions import Forbidden, NotFound
  6. import services
  7. from configs import dify_config
  8. from controllers.console import api
  9. from controllers.console.apikey import api_key_fields, api_key_list
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
  12. from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
  13. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  14. from core.indexing_runner import IndexingRunner
  15. from core.model_runtime.entities.model_entities import ModelType
  16. from core.provider_manager import ProviderManager
  17. from core.rag.datasource.vdb.vector_type import VectorType
  18. from core.rag.extractor.entity.extract_setting import ExtractSetting
  19. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  20. from extensions.ext_database import db
  21. from fields.app_fields import related_app_list
  22. from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
  23. from fields.document_fields import document_status_fields
  24. from libs.login import login_required
  25. from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
  26. from models.dataset import DatasetPermissionEnum
  27. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  28. def _validate_name(name):
  29. if not name or len(name) < 1 or len(name) > 40:
  30. raise ValueError("Name must be between 1 to 40 characters.")
  31. return name
  32. def _validate_description_length(description):
  33. if len(description) > 400:
  34. raise ValueError("Description cannot exceed 400 characters.")
  35. return description
  36. class DatasetListApi(Resource):
  37. @setup_required
  38. @login_required
  39. @account_initialization_required
  40. @enterprise_license_required
  41. def get(self):
  42. page = request.args.get("page", default=1, type=int)
  43. limit = request.args.get("limit", default=20, type=int)
  44. ids = request.args.getlist("ids")
  45. # provider = request.args.get("provider", default="vendor")
  46. search = request.args.get("keyword", default=None, type=str)
  47. tag_ids = request.args.getlist("tag_ids")
  48. include_all = request.args.get("include_all", default="false").lower() == "true"
  49. if ids:
  50. datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
  51. else:
  52. datasets, total = DatasetService.get_datasets(
  53. page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
  54. )
  55. # check embedding setting
  56. provider_manager = ProviderManager()
  57. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  58. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  59. model_names = []
  60. for embedding_model in embedding_models:
  61. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  62. data = marshal(datasets, dataset_detail_fields)
  63. for item in data:
  64. if item["indexing_technique"] == "high_quality":
  65. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  66. if item_model in model_names:
  67. item["embedding_available"] = True
  68. else:
  69. item["embedding_available"] = False
  70. else:
  71. item["embedding_available"] = True
  72. if item.get("permission") == "partial_members":
  73. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
  74. item.update({"partial_member_list": part_users_list})
  75. else:
  76. item.update({"partial_member_list": []})
  77. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  78. return response, 200
  79. @setup_required
  80. @login_required
  81. @account_initialization_required
  82. def post(self):
  83. parser = reqparse.RequestParser()
  84. parser.add_argument(
  85. "name",
  86. nullable=False,
  87. required=True,
  88. help="type is required. Name must be between 1 to 40 characters.",
  89. type=_validate_name,
  90. )
  91. parser.add_argument(
  92. "description",
  93. type=str,
  94. nullable=True,
  95. required=False,
  96. default="",
  97. )
  98. parser.add_argument(
  99. "indexing_technique",
  100. type=str,
  101. location="json",
  102. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  103. nullable=True,
  104. help="Invalid indexing technique.",
  105. )
  106. parser.add_argument(
  107. "external_knowledge_api_id",
  108. type=str,
  109. nullable=True,
  110. required=False,
  111. )
  112. parser.add_argument(
  113. "provider",
  114. type=str,
  115. nullable=True,
  116. choices=Dataset.PROVIDER_LIST,
  117. required=False,
  118. default="vendor",
  119. )
  120. parser.add_argument(
  121. "external_knowledge_id",
  122. type=str,
  123. nullable=True,
  124. required=False,
  125. )
  126. args = parser.parse_args()
  127. # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
  128. if not current_user.is_dataset_editor:
  129. raise Forbidden()
  130. try:
  131. dataset = DatasetService.create_empty_dataset(
  132. tenant_id=current_user.current_tenant_id,
  133. name=args["name"],
  134. description=args["description"],
  135. indexing_technique=args["indexing_technique"],
  136. account=current_user,
  137. permission=DatasetPermissionEnum.ONLY_ME,
  138. provider=args["provider"],
  139. external_knowledge_api_id=args["external_knowledge_api_id"],
  140. external_knowledge_id=args["external_knowledge_id"],
  141. )
  142. except services.errors.dataset.DatasetNameDuplicateError:
  143. raise DatasetNameDuplicateError()
  144. return marshal(dataset, dataset_detail_fields), 201
  145. class DatasetApi(Resource):
  146. @setup_required
  147. @login_required
  148. @account_initialization_required
  149. def get(self, dataset_id):
  150. dataset_id_str = str(dataset_id)
  151. dataset = DatasetService.get_dataset(dataset_id_str)
  152. if dataset is None:
  153. raise NotFound("Dataset not found.")
  154. try:
  155. DatasetService.check_dataset_permission(dataset, current_user)
  156. except services.errors.account.NoPermissionError as e:
  157. raise Forbidden(str(e))
  158. data = marshal(dataset, dataset_detail_fields)
  159. if data.get("permission") == "partial_members":
  160. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  161. data.update({"partial_member_list": part_users_list})
  162. # check embedding setting
  163. provider_manager = ProviderManager()
  164. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  165. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  166. model_names = []
  167. for embedding_model in embedding_models:
  168. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  169. if data["indexing_technique"] == "high_quality":
  170. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  171. if item_model in model_names:
  172. data["embedding_available"] = True
  173. else:
  174. data["embedding_available"] = False
  175. else:
  176. data["embedding_available"] = True
  177. if data.get("permission") == "partial_members":
  178. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  179. data.update({"partial_member_list": part_users_list})
  180. return data, 200
  181. @setup_required
  182. @login_required
  183. @account_initialization_required
  184. def patch(self, dataset_id):
  185. dataset_id_str = str(dataset_id)
  186. dataset = DatasetService.get_dataset(dataset_id_str)
  187. if dataset is None:
  188. raise NotFound("Dataset not found.")
  189. parser = reqparse.RequestParser()
  190. parser.add_argument(
  191. "name",
  192. nullable=False,
  193. help="type is required. Name must be between 1 to 40 characters.",
  194. type=_validate_name,
  195. )
  196. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  197. parser.add_argument(
  198. "indexing_technique",
  199. type=str,
  200. location="json",
  201. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  202. nullable=True,
  203. help="Invalid indexing technique.",
  204. )
  205. parser.add_argument(
  206. "permission",
  207. type=str,
  208. location="json",
  209. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  210. help="Invalid permission.",
  211. )
  212. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  213. parser.add_argument(
  214. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  215. )
  216. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  217. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  218. parser.add_argument(
  219. "external_retrieval_model",
  220. type=dict,
  221. required=False,
  222. nullable=True,
  223. location="json",
  224. help="Invalid external retrieval model.",
  225. )
  226. parser.add_argument(
  227. "external_knowledge_id",
  228. type=str,
  229. required=False,
  230. nullable=True,
  231. location="json",
  232. help="Invalid external knowledge id.",
  233. )
  234. parser.add_argument(
  235. "external_knowledge_api_id",
  236. type=str,
  237. required=False,
  238. nullable=True,
  239. location="json",
  240. help="Invalid external knowledge api id.",
  241. )
  242. args = parser.parse_args()
  243. data = request.get_json()
  244. # check embedding model setting
  245. if data.get("indexing_technique") == "high_quality":
  246. DatasetService.check_embedding_model_setting(
  247. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  248. )
  249. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  250. DatasetPermissionService.check_permission(
  251. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  252. )
  253. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  254. if dataset is None:
  255. raise NotFound("Dataset not found.")
  256. result_data = marshal(dataset, dataset_detail_fields)
  257. tenant_id = current_user.current_tenant_id
  258. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  259. DatasetPermissionService.update_partial_member_list(
  260. tenant_id, dataset_id_str, data.get("partial_member_list")
  261. )
  262. # clear partial member list when permission is only_me or all_team_members
  263. elif (
  264. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  265. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  266. ):
  267. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  268. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  269. result_data.update({"partial_member_list": partial_member_list})
  270. return result_data, 200
  271. @setup_required
  272. @login_required
  273. @account_initialization_required
  274. def delete(self, dataset_id):
  275. dataset_id_str = str(dataset_id)
  276. # The role of the current user in the ta table must be admin, owner, or editor
  277. if not current_user.is_editor or current_user.is_dataset_operator:
  278. raise Forbidden()
  279. try:
  280. if DatasetService.delete_dataset(dataset_id_str, current_user):
  281. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  282. return {"result": "success"}, 204
  283. else:
  284. raise NotFound("Dataset not found.")
  285. except services.errors.dataset.DatasetInUseError:
  286. raise DatasetInUseError()
  287. class DatasetUseCheckApi(Resource):
  288. @setup_required
  289. @login_required
  290. @account_initialization_required
  291. def get(self, dataset_id):
  292. dataset_id_str = str(dataset_id)
  293. dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
  294. return {"is_using": dataset_is_using}, 200
  295. class DatasetQueryApi(Resource):
  296. @setup_required
  297. @login_required
  298. @account_initialization_required
  299. def get(self, dataset_id):
  300. dataset_id_str = str(dataset_id)
  301. dataset = DatasetService.get_dataset(dataset_id_str)
  302. if dataset is None:
  303. raise NotFound("Dataset not found.")
  304. try:
  305. DatasetService.check_dataset_permission(dataset, current_user)
  306. except services.errors.account.NoPermissionError as e:
  307. raise Forbidden(str(e))
  308. page = request.args.get("page", default=1, type=int)
  309. limit = request.args.get("limit", default=20, type=int)
  310. dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
  311. response = {
  312. "data": marshal(dataset_queries, dataset_query_detail_fields),
  313. "has_more": len(dataset_queries) == limit,
  314. "limit": limit,
  315. "total": total,
  316. "page": page,
  317. }
  318. return response, 200
  319. class DatasetIndexingEstimateApi(Resource):
  320. @setup_required
  321. @login_required
  322. @account_initialization_required
  323. def post(self):
  324. parser = reqparse.RequestParser()
  325. parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
  326. parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
  327. parser.add_argument(
  328. "indexing_technique",
  329. type=str,
  330. required=True,
  331. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  332. nullable=True,
  333. location="json",
  334. )
  335. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  336. parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
  337. parser.add_argument(
  338. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  339. )
  340. args = parser.parse_args()
  341. # validate args
  342. DocumentService.estimate_args_validate(args)
  343. extract_settings = []
  344. if args["info_list"]["data_source_type"] == "upload_file":
  345. file_ids = args["info_list"]["file_info_list"]["file_ids"]
  346. file_details = (
  347. db.session.query(UploadFile)
  348. .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
  349. .all()
  350. )
  351. if file_details is None:
  352. raise NotFound("File not found.")
  353. if file_details:
  354. for file_detail in file_details:
  355. extract_setting = ExtractSetting(
  356. datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
  357. )
  358. extract_settings.append(extract_setting)
  359. elif args["info_list"]["data_source_type"] == "notion_import":
  360. notion_info_list = args["info_list"]["notion_info_list"]
  361. for notion_info in notion_info_list:
  362. workspace_id = notion_info["workspace_id"]
  363. for page in notion_info["pages"]:
  364. extract_setting = ExtractSetting(
  365. datasource_type="notion_import",
  366. notion_info={
  367. "notion_workspace_id": workspace_id,
  368. "notion_obj_id": page["page_id"],
  369. "notion_page_type": page["type"],
  370. "tenant_id": current_user.current_tenant_id,
  371. },
  372. document_model=args["doc_form"],
  373. )
  374. extract_settings.append(extract_setting)
  375. elif args["info_list"]["data_source_type"] == "website_crawl":
  376. website_info_list = args["info_list"]["website_info_list"]
  377. for url in website_info_list["urls"]:
  378. extract_setting = ExtractSetting(
  379. datasource_type="website_crawl",
  380. website_info={
  381. "provider": website_info_list["provider"],
  382. "job_id": website_info_list["job_id"],
  383. "url": url,
  384. "tenant_id": current_user.current_tenant_id,
  385. "mode": "crawl",
  386. "only_main_content": website_info_list["only_main_content"],
  387. },
  388. document_model=args["doc_form"],
  389. )
  390. extract_settings.append(extract_setting)
  391. else:
  392. raise ValueError("Data source type not support")
  393. indexing_runner = IndexingRunner()
  394. try:
  395. response = indexing_runner.indexing_estimate(
  396. current_user.current_tenant_id,
  397. extract_settings,
  398. args["process_rule"],
  399. args["doc_form"],
  400. args["doc_language"],
  401. args["dataset_id"],
  402. args["indexing_technique"],
  403. )
  404. except LLMBadRequestError:
  405. raise ProviderNotInitializeError(
  406. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  407. )
  408. except ProviderTokenNotInitError as ex:
  409. raise ProviderNotInitializeError(ex.description)
  410. except Exception as e:
  411. raise IndexingEstimateError(str(e))
  412. return response.model_dump(), 200
  413. class DatasetRelatedAppListApi(Resource):
  414. @setup_required
  415. @login_required
  416. @account_initialization_required
  417. @marshal_with(related_app_list)
  418. def get(self, dataset_id):
  419. dataset_id_str = str(dataset_id)
  420. dataset = DatasetService.get_dataset(dataset_id_str)
  421. if dataset is None:
  422. raise NotFound("Dataset not found.")
  423. try:
  424. DatasetService.check_dataset_permission(dataset, current_user)
  425. except services.errors.account.NoPermissionError as e:
  426. raise Forbidden(str(e))
  427. app_dataset_joins = DatasetService.get_related_apps(dataset.id)
  428. related_apps = []
  429. for app_dataset_join in app_dataset_joins:
  430. app_model = app_dataset_join.app
  431. if app_model:
  432. related_apps.append(app_model)
  433. return {"data": related_apps, "total": len(related_apps)}, 200
  434. class DatasetIndexingStatusApi(Resource):
  435. @setup_required
  436. @login_required
  437. @account_initialization_required
  438. def get(self, dataset_id):
  439. dataset_id = str(dataset_id)
  440. documents = (
  441. db.session.query(Document)
  442. .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
  443. .all()
  444. )
  445. documents_status = []
  446. for document in documents:
  447. completed_segments = DocumentSegment.query.filter(
  448. DocumentSegment.completed_at.isnot(None),
  449. DocumentSegment.document_id == str(document.id),
  450. DocumentSegment.status != "re_segment",
  451. ).count()
  452. total_segments = DocumentSegment.query.filter(
  453. DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
  454. ).count()
  455. document.completed_segments = completed_segments
  456. document.total_segments = total_segments
  457. documents_status.append(marshal(document, document_status_fields))
  458. data = {"data": documents_status}
  459. return data
  460. class DatasetApiKeyApi(Resource):
  461. max_keys = 10
  462. token_prefix = "dataset-"
  463. resource_type = "dataset"
  464. @setup_required
  465. @login_required
  466. @account_initialization_required
  467. @marshal_with(api_key_list)
  468. def get(self):
  469. keys = (
  470. db.session.query(ApiToken)
  471. .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
  472. .all()
  473. )
  474. return {"items": keys}
  475. @setup_required
  476. @login_required
  477. @account_initialization_required
  478. @marshal_with(api_key_fields)
  479. def post(self):
  480. # The role of the current user in the ta table must be admin or owner
  481. if not current_user.is_admin_or_owner:
  482. raise Forbidden()
  483. current_key_count = (
  484. db.session.query(ApiToken)
  485. .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
  486. .count()
  487. )
  488. if current_key_count >= self.max_keys:
  489. flask_restful.abort(
  490. 400,
  491. message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
  492. code="max_keys_exceeded",
  493. )
  494. key = ApiToken.generate_api_key(self.token_prefix, 24)
  495. api_token = ApiToken()
  496. api_token.tenant_id = current_user.current_tenant_id
  497. api_token.token = key
  498. api_token.type = self.resource_type
  499. db.session.add(api_token)
  500. db.session.commit()
  501. return api_token, 200
  502. class DatasetApiDeleteApi(Resource):
  503. resource_type = "dataset"
  504. @setup_required
  505. @login_required
  506. @account_initialization_required
  507. def delete(self, api_key_id):
  508. api_key_id = str(api_key_id)
  509. # The role of the current user in the ta table must be admin or owner
  510. if not current_user.is_admin_or_owner:
  511. raise Forbidden()
  512. key = (
  513. db.session.query(ApiToken)
  514. .filter(
  515. ApiToken.tenant_id == current_user.current_tenant_id,
  516. ApiToken.type == self.resource_type,
  517. ApiToken.id == api_key_id,
  518. )
  519. .first()
  520. )
  521. if key is None:
  522. flask_restful.abort(404, message="API key not found")
  523. db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
  524. db.session.commit()
  525. return {"result": "success"}, 204
  526. class DatasetApiBaseUrlApi(Resource):
  527. @setup_required
  528. @login_required
  529. @account_initialization_required
  530. def get(self):
  531. return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
  532. class DatasetRetrievalSettingApi(Resource):
  533. @setup_required
  534. @login_required
  535. @account_initialization_required
  536. def get(self):
  537. vector_type = dify_config.VECTOR_STORE
  538. match vector_type:
  539. case (
  540. VectorType.RELYT
  541. | VectorType.PGVECTOR
  542. | VectorType.TIDB_VECTOR
  543. | VectorType.CHROMA
  544. | VectorType.TENCENT
  545. | VectorType.PGVECTO_RS
  546. | VectorType.BAIDU
  547. | VectorType.VIKINGDB
  548. | VectorType.UPSTASH
  549. | VectorType.OCEANBASE
  550. ):
  551. return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
  552. case (
  553. VectorType.QDRANT
  554. | VectorType.WEAVIATE
  555. | VectorType.OPENSEARCH
  556. | VectorType.ANALYTICDB
  557. | VectorType.MYSCALE
  558. | VectorType.ORACLE
  559. | VectorType.ELASTICSEARCH
  560. | VectorType.ELASTICSEARCH_JA
  561. | VectorType.PGVECTOR
  562. | VectorType.TIDB_ON_QDRANT
  563. | VectorType.LINDORM
  564. | VectorType.COUCHBASE
  565. | VectorType.MILVUS
  566. ):
  567. return {
  568. "retrieval_method": [
  569. RetrievalMethod.SEMANTIC_SEARCH.value,
  570. RetrievalMethod.FULL_TEXT_SEARCH.value,
  571. RetrievalMethod.HYBRID_SEARCH.value,
  572. ]
  573. }
  574. case _:
  575. raise ValueError(f"Unsupported vector db type {vector_type}.")
  576. class DatasetRetrievalSettingMockApi(Resource):
  577. @setup_required
  578. @login_required
  579. @account_initialization_required
  580. def get(self, vector_type):
  581. match vector_type:
  582. case (
  583. VectorType.MILVUS
  584. | VectorType.RELYT
  585. | VectorType.TIDB_VECTOR
  586. | VectorType.CHROMA
  587. | VectorType.TENCENT
  588. | VectorType.PGVECTO_RS
  589. | VectorType.BAIDU
  590. | VectorType.VIKINGDB
  591. | VectorType.UPSTASH
  592. | VectorType.OCEANBASE
  593. ):
  594. return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
  595. case (
  596. VectorType.QDRANT
  597. | VectorType.WEAVIATE
  598. | VectorType.OPENSEARCH
  599. | VectorType.ANALYTICDB
  600. | VectorType.MYSCALE
  601. | VectorType.ORACLE
  602. | VectorType.ELASTICSEARCH
  603. | VectorType.ELASTICSEARCH_JA
  604. | VectorType.COUCHBASE
  605. | VectorType.PGVECTOR
  606. | VectorType.LINDORM
  607. ):
  608. return {
  609. "retrieval_method": [
  610. RetrievalMethod.SEMANTIC_SEARCH.value,
  611. RetrievalMethod.FULL_TEXT_SEARCH.value,
  612. RetrievalMethod.HYBRID_SEARCH.value,
  613. ]
  614. }
  615. case _:
  616. raise ValueError(f"Unsupported vector db type {vector_type}.")
  617. class DatasetErrorDocs(Resource):
  618. @setup_required
  619. @login_required
  620. @account_initialization_required
  621. def get(self, dataset_id):
  622. dataset_id_str = str(dataset_id)
  623. dataset = DatasetService.get_dataset(dataset_id_str)
  624. if dataset is None:
  625. raise NotFound("Dataset not found.")
  626. results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
  627. return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
  628. class DatasetPermissionUserListApi(Resource):
  629. @setup_required
  630. @login_required
  631. @account_initialization_required
  632. def get(self, dataset_id):
  633. dataset_id_str = str(dataset_id)
  634. dataset = DatasetService.get_dataset(dataset_id_str)
  635. if dataset is None:
  636. raise NotFound("Dataset not found.")
  637. try:
  638. DatasetService.check_dataset_permission(dataset, current_user)
  639. except services.errors.account.NoPermissionError as e:
  640. raise Forbidden(str(e))
  641. partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  642. return {
  643. "data": partial_members_list,
  644. }, 200
  645. class DatasetAutoDisableLogApi(Resource):
  646. @setup_required
  647. @login_required
  648. @account_initialization_required
  649. def get(self, dataset_id):
  650. dataset_id_str = str(dataset_id)
  651. dataset = DatasetService.get_dataset(dataset_id_str)
  652. if dataset is None:
  653. raise NotFound("Dataset not found.")
  654. return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
  655. api.add_resource(DatasetListApi, "/datasets")
  656. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
  657. api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
  658. api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
  659. api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
  660. api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
  661. api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
  662. api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
  663. api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
  664. api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
  665. api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
  666. api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
  667. api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
  668. api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
  669. api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")