datasets.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. import flask_restful # type: ignore
  2. from flask import request
  3. from flask_login import current_user # type: ignore # type: ignore
  4. from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
  5. from werkzeug.exceptions import Forbidden, NotFound
  6. import services
  7. from configs import dify_config
  8. from controllers.console import api
  9. from controllers.console.apikey import api_key_fields, api_key_list
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
  12. from controllers.console.wraps import (
  13. account_initialization_required,
  14. cloud_edition_billing_rate_limit_check,
  15. enterprise_license_required,
  16. setup_required,
  17. )
  18. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  19. from core.indexing_runner import IndexingRunner
  20. from core.model_runtime.entities.model_entities import ModelType
  21. from core.plugin.entities.plugin import ModelProviderID
  22. from core.provider_manager import ProviderManager
  23. from core.rag.datasource.vdb.vector_type import VectorType
  24. from core.rag.extractor.entity.extract_setting import ExtractSetting
  25. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  26. from extensions.ext_database import db
  27. from fields.app_fields import related_app_list
  28. from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
  29. from fields.document_fields import document_status_fields
  30. from libs.login import login_required
  31. from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
  32. from models.dataset import DatasetPermissionEnum
  33. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  34. def _validate_name(name):
  35. if not name or len(name) < 1 or len(name) > 40:
  36. raise ValueError("Name must be between 1 to 40 characters.")
  37. return name
  38. def _validate_description_length(description):
  39. if len(description) > 400:
  40. raise ValueError("Description cannot exceed 400 characters.")
  41. return description
  42. class DatasetListApi(Resource):
  43. @setup_required
  44. @login_required
  45. @account_initialization_required
  46. @enterprise_license_required
  47. def get(self):
  48. page = request.args.get("page", default=1, type=int)
  49. limit = request.args.get("limit", default=20, type=int)
  50. ids = request.args.getlist("ids")
  51. # provider = request.args.get("provider", default="vendor")
  52. search = request.args.get("keyword", default=None, type=str)
  53. tag_ids = request.args.getlist("tag_ids")
  54. auth_type = request.args.get("authType", default=None, type=int)
  55. creator_dept = request.args.get("creatorDept")
  56. creator = request.args.get("creator", default=None, type=str)
  57. category_ids = request.args.getlist("category_ids")
  58. include_all = request.args.get("include_all", default="false").lower() == "true"
  59. if ids:
  60. datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
  61. else:
  62. datasets, total = DatasetService.get_datasets2(
  63. page, limit, current_user.current_tenant_id, current_user, search, tag_ids,
  64. category_ids, auth_type, creator_dept, creator, include_all
  65. )
  66. # check embedding setting
  67. provider_manager = ProviderManager()
  68. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  69. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  70. model_names = []
  71. for embedding_model in embedding_models:
  72. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  73. data = marshal(datasets, dataset_detail_fields)
  74. for item in data:
  75. # 返回编辑授权
  76. item["has_edit_permission"] = DatasetService.has_edit_permission(current_user.id,item["id"])
  77. # convert embedding_model_provider to plugin standard format
  78. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  79. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  80. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  81. if item_model in model_names:
  82. item["embedding_available"] = True
  83. else:
  84. item["embedding_available"] = False
  85. else:
  86. item["embedding_available"] = True
  87. if item.get("permission") == "partial_members":
  88. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
  89. item.update({"partial_member_list": part_users_list})
  90. else:
  91. item.update({"partial_member_list": []})
  92. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  93. return response, 200
  94. @setup_required
  95. @login_required
  96. @account_initialization_required
  97. @cloud_edition_billing_rate_limit_check("knowledge")
  98. def post(self):
  99. parser = reqparse.RequestParser()
  100. parser.add_argument(
  101. "name",
  102. nullable=False,
  103. required=True,
  104. help="type is required. Name must be between 1 to 40 characters.",
  105. type=_validate_name,
  106. )
  107. parser.add_argument(
  108. "description",
  109. type=str,
  110. nullable=True,
  111. required=False,
  112. default="",
  113. )
  114. parser.add_argument(
  115. "indexing_technique",
  116. type=str,
  117. location="json",
  118. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  119. nullable=True,
  120. help="Invalid indexing technique.",
  121. )
  122. parser.add_argument(
  123. "external_knowledge_api_id",
  124. type=str,
  125. nullable=True,
  126. required=False,
  127. )
  128. parser.add_argument(
  129. "provider",
  130. type=str,
  131. nullable=True,
  132. choices=Dataset.PROVIDER_LIST,
  133. required=False,
  134. default="vendor",
  135. )
  136. parser.add_argument(
  137. "external_knowledge_id",
  138. type=str,
  139. nullable=True,
  140. required=False,
  141. )
  142. args = parser.parse_args()
  143. # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
  144. if not current_user.is_dataset_editor:
  145. raise Forbidden()
  146. try:
  147. dataset = DatasetService.create_empty_dataset(
  148. tenant_id=current_user.current_tenant_id,
  149. name=args["name"],
  150. description=args["description"],
  151. indexing_technique=args["indexing_technique"],
  152. account=current_user,
  153. permission=DatasetPermissionEnum.ONLY_ME,
  154. provider=args["provider"],
  155. external_knowledge_api_id=args["external_knowledge_api_id"],
  156. external_knowledge_id=args["external_knowledge_id"],
  157. )
  158. except services.errors.dataset.DatasetNameDuplicateError:
  159. raise DatasetNameDuplicateError()
  160. return marshal(dataset, dataset_detail_fields), 201
  161. class DatasetApi(Resource):
  162. @setup_required
  163. @login_required
  164. @account_initialization_required
  165. def get(self, dataset_id):
  166. dataset_id_str = str(dataset_id)
  167. dataset = DatasetService.get_dataset(dataset_id_str)
  168. if dataset is None:
  169. raise NotFound("Dataset not found.")
  170. try:
  171. DatasetService.check_dataset_permission(dataset, current_user)
  172. except services.errors.account.NoPermissionError as e:
  173. raise Forbidden(str(e))
  174. data = marshal(dataset, dataset_detail_fields)
  175. if dataset.indexing_technique == "high_quality":
  176. if dataset.embedding_model_provider:
  177. provider_id = ModelProviderID(dataset.embedding_model_provider)
  178. data["embedding_model_provider"] = str(provider_id)
  179. if data.get("permission") == "partial_members":
  180. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  181. data.update({"partial_member_list": part_users_list})
  182. # check embedding setting
  183. provider_manager = ProviderManager()
  184. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  185. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  186. model_names = []
  187. for embedding_model in embedding_models:
  188. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  189. if data["indexing_technique"] == "high_quality":
  190. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  191. if item_model in model_names:
  192. data["embedding_available"] = True
  193. else:
  194. data["embedding_available"] = False
  195. else:
  196. data["embedding_available"] = True
  197. if data.get("permission") == "partial_members":
  198. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  199. data.update({"partial_member_list": part_users_list})
  200. return data, 200
  201. @setup_required
  202. @login_required
  203. @account_initialization_required
  204. @cloud_edition_billing_rate_limit_check("knowledge")
  205. def patch(self, dataset_id):
  206. dataset_id_str = str(dataset_id)
  207. dataset = DatasetService.get_dataset(dataset_id_str)
  208. if dataset is None:
  209. raise NotFound("Dataset not found.")
  210. parser = reqparse.RequestParser()
  211. parser.add_argument(
  212. "name",
  213. nullable=False,
  214. help="type is required. Name must be between 1 to 40 characters.",
  215. type=_validate_name,
  216. )
  217. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  218. parser.add_argument(
  219. "indexing_technique",
  220. type=str,
  221. location="json",
  222. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  223. nullable=True,
  224. help="Invalid indexing technique.",
  225. )
  226. parser.add_argument(
  227. "permission",
  228. type=str,
  229. location="json",
  230. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  231. help="Invalid permission.",
  232. )
  233. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  234. parser.add_argument(
  235. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  236. )
  237. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  238. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  239. parser.add_argument(
  240. "external_retrieval_model",
  241. type=dict,
  242. required=False,
  243. nullable=True,
  244. location="json",
  245. help="Invalid external retrieval model.",
  246. )
  247. parser.add_argument(
  248. "external_knowledge_id",
  249. type=str,
  250. required=False,
  251. nullable=True,
  252. location="json",
  253. help="Invalid external knowledge id.",
  254. )
  255. parser.add_argument(
  256. "external_knowledge_api_id",
  257. type=str,
  258. required=False,
  259. nullable=True,
  260. location="json",
  261. help="Invalid external knowledge api id.",
  262. )
  263. args = parser.parse_args()
  264. data = request.get_json()
  265. # check embedding model setting
  266. if (
  267. data.get("indexing_technique") == "high_quality"
  268. and data.get("embedding_model_provider") is not None
  269. and data.get("embedding_model") is not None
  270. ):
  271. DatasetService.check_embedding_model_setting(
  272. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  273. )
  274. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  275. DatasetPermissionService.check_permission(
  276. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  277. )
  278. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  279. if dataset is None:
  280. raise NotFound("Dataset not found.")
  281. result_data = marshal(dataset, dataset_detail_fields)
  282. tenant_id = current_user.current_tenant_id
  283. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  284. DatasetPermissionService.update_partial_member_list(
  285. tenant_id, dataset_id_str, data.get("partial_member_list")
  286. )
  287. # clear partial member list when permission is only_me or all_team_members
  288. elif (
  289. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  290. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  291. ):
  292. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  293. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  294. result_data.update({"partial_member_list": partial_member_list})
  295. return result_data, 200
  296. @setup_required
  297. @login_required
  298. @account_initialization_required
  299. @cloud_edition_billing_rate_limit_check("knowledge")
  300. def delete(self, dataset_id):
  301. dataset_id_str = str(dataset_id)
  302. # The role of the current user in the ta table must be admin, owner, or editor
  303. if not current_user.is_editor or current_user.is_dataset_operator:
  304. raise Forbidden()
  305. try:
  306. if DatasetService.delete_dataset(dataset_id_str, current_user):
  307. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  308. return {"result": "success"}, 204
  309. else:
  310. raise NotFound("Dataset not found.")
  311. except services.errors.dataset.DatasetInUseError:
  312. raise DatasetInUseError()
  313. class DatasetUseCheckApi(Resource):
  314. @setup_required
  315. @login_required
  316. @account_initialization_required
  317. def get(self, dataset_id):
  318. dataset_id_str = str(dataset_id)
  319. dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
  320. return {"is_using": dataset_is_using}, 200
  321. class DatasetQueryApi(Resource):
  322. @setup_required
  323. @login_required
  324. @account_initialization_required
  325. def get(self, dataset_id):
  326. dataset_id_str = str(dataset_id)
  327. dataset = DatasetService.get_dataset(dataset_id_str)
  328. if dataset is None:
  329. raise NotFound("Dataset not found.")
  330. try:
  331. DatasetService.check_dataset_permission(dataset, current_user)
  332. except services.errors.account.NoPermissionError as e:
  333. raise Forbidden(str(e))
  334. page = request.args.get("page", default=1, type=int)
  335. limit = request.args.get("limit", default=20, type=int)
  336. dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
  337. response = {
  338. "data": marshal(dataset_queries, dataset_query_detail_fields),
  339. "has_more": len(dataset_queries) == limit,
  340. "limit": limit,
  341. "total": total,
  342. "page": page,
  343. }
  344. return response, 200
  345. class DatasetIndexingEstimateApi(Resource):
  346. @setup_required
  347. @login_required
  348. @account_initialization_required
  349. def post(self):
  350. parser = reqparse.RequestParser()
  351. parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
  352. parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
  353. parser.add_argument(
  354. "indexing_technique",
  355. type=str,
  356. required=True,
  357. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  358. nullable=True,
  359. location="json",
  360. )
  361. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  362. parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
  363. parser.add_argument(
  364. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  365. )
  366. args = parser.parse_args()
  367. # validate args
  368. DocumentService.estimate_args_validate(args)
  369. extract_settings = []
  370. if args["info_list"]["data_source_type"] == "upload_file":
  371. file_ids = args["info_list"]["file_info_list"]["file_ids"]
  372. file_details = (
  373. db.session.query(UploadFile)
  374. .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
  375. .all()
  376. )
  377. if file_details is None:
  378. raise NotFound("File not found.")
  379. if file_details:
  380. for file_detail in file_details:
  381. extract_setting = ExtractSetting(
  382. datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
  383. )
  384. extract_settings.append(extract_setting)
  385. elif args["info_list"]["data_source_type"] == "notion_import":
  386. notion_info_list = args["info_list"]["notion_info_list"]
  387. for notion_info in notion_info_list:
  388. workspace_id = notion_info["workspace_id"]
  389. for page in notion_info["pages"]:
  390. extract_setting = ExtractSetting(
  391. datasource_type="notion_import",
  392. notion_info={
  393. "notion_workspace_id": workspace_id,
  394. "notion_obj_id": page["page_id"],
  395. "notion_page_type": page["type"],
  396. "tenant_id": current_user.current_tenant_id,
  397. },
  398. document_model=args["doc_form"],
  399. )
  400. extract_settings.append(extract_setting)
  401. elif args["info_list"]["data_source_type"] == "website_crawl":
  402. website_info_list = args["info_list"]["website_info_list"]
  403. for url in website_info_list["urls"]:
  404. extract_setting = ExtractSetting(
  405. datasource_type="website_crawl",
  406. website_info={
  407. "provider": website_info_list["provider"],
  408. "job_id": website_info_list["job_id"],
  409. "url": url,
  410. "tenant_id": current_user.current_tenant_id,
  411. "mode": "crawl",
  412. "only_main_content": website_info_list["only_main_content"],
  413. },
  414. document_model=args["doc_form"],
  415. )
  416. extract_settings.append(extract_setting)
  417. else:
  418. raise ValueError("Data source type not support")
  419. indexing_runner = IndexingRunner()
  420. try:
  421. response = indexing_runner.indexing_estimate(
  422. current_user.current_tenant_id,
  423. extract_settings,
  424. args["process_rule"],
  425. args["doc_form"],
  426. args["doc_language"],
  427. args["dataset_id"],
  428. args["indexing_technique"],
  429. )
  430. except LLMBadRequestError:
  431. raise ProviderNotInitializeError(
  432. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  433. )
  434. except ProviderTokenNotInitError as ex:
  435. raise ProviderNotInitializeError(ex.description)
  436. except Exception as e:
  437. raise IndexingEstimateError(str(e))
  438. return response.model_dump(), 200
  439. class DatasetRelatedAppListApi(Resource):
  440. @setup_required
  441. @login_required
  442. @account_initialization_required
  443. @marshal_with(related_app_list)
  444. def get(self, dataset_id):
  445. dataset_id_str = str(dataset_id)
  446. dataset = DatasetService.get_dataset(dataset_id_str)
  447. if dataset is None:
  448. raise NotFound("Dataset not found.")
  449. try:
  450. DatasetService.check_dataset_permission(dataset, current_user)
  451. except services.errors.account.NoPermissionError as e:
  452. raise Forbidden(str(e))
  453. app_dataset_joins = DatasetService.get_related_apps(dataset.id)
  454. related_apps = []
  455. for app_dataset_join in app_dataset_joins:
  456. app_model = app_dataset_join.app
  457. if app_model:
  458. related_apps.append(app_model)
  459. return {"data": related_apps, "total": len(related_apps)}, 200
  460. class DatasetIndexingStatusApi(Resource):
  461. @setup_required
  462. @login_required
  463. @account_initialization_required
  464. def get(self, dataset_id):
  465. dataset_id = str(dataset_id)
  466. documents = (
  467. db.session.query(Document)
  468. .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
  469. .all()
  470. )
  471. documents_status = []
  472. for document in documents:
  473. completed_segments = DocumentSegment.query.filter(
  474. DocumentSegment.completed_at.isnot(None),
  475. DocumentSegment.document_id == str(document.id),
  476. DocumentSegment.status != "re_segment",
  477. ).count()
  478. total_segments = DocumentSegment.query.filter(
  479. DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
  480. ).count()
  481. document.completed_segments = completed_segments
  482. document.total_segments = total_segments
  483. documents_status.append(marshal(document, document_status_fields))
  484. data = {"data": documents_status}
  485. return data
  486. class DatasetApiKeyApi(Resource):
  487. max_keys = 10
  488. token_prefix = "dataset-"
  489. resource_type = "dataset"
  490. @setup_required
  491. @login_required
  492. @account_initialization_required
  493. @marshal_with(api_key_list)
  494. def get(self):
  495. keys = (
  496. db.session.query(ApiToken)
  497. .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
  498. .all()
  499. )
  500. return {"items": keys}
  501. @setup_required
  502. @login_required
  503. @account_initialization_required
  504. @marshal_with(api_key_fields)
  505. def post(self):
  506. # The role of the current user in the ta table must be admin or owner
  507. if not current_user.is_admin_or_owner:
  508. raise Forbidden()
  509. current_key_count = (
  510. db.session.query(ApiToken)
  511. .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
  512. .count()
  513. )
  514. if current_key_count >= self.max_keys:
  515. flask_restful.abort(
  516. 400,
  517. message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
  518. code="max_keys_exceeded",
  519. )
  520. key = ApiToken.generate_api_key(self.token_prefix, 24)
  521. api_token = ApiToken()
  522. api_token.tenant_id = current_user.current_tenant_id
  523. api_token.token = key
  524. api_token.type = self.resource_type
  525. db.session.add(api_token)
  526. db.session.commit()
  527. return api_token, 200
  528. class DatasetApiDeleteApi(Resource):
  529. resource_type = "dataset"
  530. @setup_required
  531. @login_required
  532. @account_initialization_required
  533. def delete(self, api_key_id):
  534. api_key_id = str(api_key_id)
  535. # The role of the current user in the ta table must be admin or owner
  536. if not current_user.is_admin_or_owner:
  537. raise Forbidden()
  538. key = (
  539. db.session.query(ApiToken)
  540. .filter(
  541. ApiToken.tenant_id == current_user.current_tenant_id,
  542. ApiToken.type == self.resource_type,
  543. ApiToken.id == api_key_id,
  544. )
  545. .first()
  546. )
  547. if key is None:
  548. flask_restful.abort(404, message="API key not found")
  549. db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
  550. db.session.commit()
  551. return {"result": "success"}, 204
  552. class DatasetApiBaseUrlApi(Resource):
  553. @setup_required
  554. @login_required
  555. @account_initialization_required
  556. def get(self):
  557. return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
  558. class DatasetRetrievalSettingApi(Resource):
  559. @setup_required
  560. @login_required
  561. @account_initialization_required
  562. def get(self):
  563. vector_type = dify_config.VECTOR_STORE
  564. match vector_type:
  565. case (
  566. VectorType.RELYT
  567. | VectorType.TIDB_VECTOR
  568. | VectorType.CHROMA
  569. | VectorType.TENCENT
  570. | VectorType.PGVECTO_RS
  571. | VectorType.BAIDU
  572. | VectorType.VIKINGDB
  573. | VectorType.UPSTASH
  574. ):
  575. return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
  576. case (
  577. VectorType.QDRANT
  578. | VectorType.WEAVIATE
  579. | VectorType.OPENSEARCH
  580. | VectorType.ANALYTICDB
  581. | VectorType.MYSCALE
  582. | VectorType.ORACLE
  583. | VectorType.ELASTICSEARCH
  584. | VectorType.ELASTICSEARCH_JA
  585. | VectorType.PGVECTOR
  586. | VectorType.TIDB_ON_QDRANT
  587. | VectorType.LINDORM
  588. | VectorType.COUCHBASE
  589. | VectorType.MILVUS
  590. | VectorType.OPENGAUSS
  591. | VectorType.OCEANBASE
  592. ):
  593. return {
  594. "retrieval_method": [
  595. RetrievalMethod.SEMANTIC_SEARCH.value,
  596. RetrievalMethod.FULL_TEXT_SEARCH.value,
  597. RetrievalMethod.HYBRID_SEARCH.value,
  598. ]
  599. }
  600. case _:
  601. raise ValueError(f"Unsupported vector db type {vector_type}.")
  602. class DatasetRetrievalSettingMockApi(Resource):
  603. @setup_required
  604. @login_required
  605. @account_initialization_required
  606. def get(self, vector_type):
  607. match vector_type:
  608. case (
  609. VectorType.MILVUS
  610. | VectorType.RELYT
  611. | VectorType.TIDB_VECTOR
  612. | VectorType.CHROMA
  613. | VectorType.TENCENT
  614. | VectorType.PGVECTO_RS
  615. | VectorType.BAIDU
  616. | VectorType.VIKINGDB
  617. | VectorType.UPSTASH
  618. ):
  619. return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
  620. case (
  621. VectorType.QDRANT
  622. | VectorType.WEAVIATE
  623. | VectorType.OPENSEARCH
  624. | VectorType.ANALYTICDB
  625. | VectorType.MYSCALE
  626. | VectorType.ORACLE
  627. | VectorType.ELASTICSEARCH
  628. | VectorType.ELASTICSEARCH_JA
  629. | VectorType.COUCHBASE
  630. | VectorType.PGVECTOR
  631. | VectorType.LINDORM
  632. | VectorType.OPENGAUSS
  633. | VectorType.OCEANBASE
  634. ):
  635. return {
  636. "retrieval_method": [
  637. RetrievalMethod.SEMANTIC_SEARCH.value,
  638. RetrievalMethod.FULL_TEXT_SEARCH.value,
  639. RetrievalMethod.HYBRID_SEARCH.value,
  640. ]
  641. }
  642. case _:
  643. raise ValueError(f"Unsupported vector db type {vector_type}.")
  644. class DatasetErrorDocs(Resource):
  645. @setup_required
  646. @login_required
  647. @account_initialization_required
  648. def get(self, dataset_id):
  649. dataset_id_str = str(dataset_id)
  650. dataset = DatasetService.get_dataset(dataset_id_str)
  651. if dataset is None:
  652. raise NotFound("Dataset not found.")
  653. results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
  654. return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
  655. class DatasetPermissionUserListApi(Resource):
  656. @setup_required
  657. @login_required
  658. @account_initialization_required
  659. def get(self, dataset_id):
  660. dataset_id_str = str(dataset_id)
  661. dataset = DatasetService.get_dataset(dataset_id_str)
  662. if dataset is None:
  663. raise NotFound("Dataset not found.")
  664. try:
  665. DatasetService.check_dataset_permission(dataset, current_user)
  666. except services.errors.account.NoPermissionError as e:
  667. raise Forbidden(str(e))
  668. partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  669. return {
  670. "data": partial_members_list,
  671. }, 200
  672. class DatasetAutoDisableLogApi(Resource):
  673. @setup_required
  674. @login_required
  675. @account_initialization_required
  676. def get(self, dataset_id):
  677. dataset_id_str = str(dataset_id)
  678. dataset = DatasetService.get_dataset(dataset_id_str)
  679. if dataset is None:
  680. raise NotFound("Dataset not found.")
  681. return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
  682. class DatasetCountApi(Resource):
  683. @setup_required
  684. @login_required
  685. @account_initialization_required
  686. def get(self):
  687. # tenant_id = current_user.current_tenant_id
  688. tenant_id = request.args.get("tenant_id", default=None, type=str)
  689. datasets_count = DatasetService.get_datasets_count(tenant_id, current_user)
  690. tags_count = DatasetService.get_tags_count(tenant_id)
  691. response = {"datasets_count": datasets_count, "tags_count": tags_count, "depts_count": 0}
  692. return {"data": response}, 200
  693. class DatasetUpdateStatsApi(Resource):
  694. @setup_required
  695. @login_required
  696. @account_initialization_required
  697. def get(self):
  698. """Get dataset update statistics."""
  699. tenant_id = request.args.get("tenant_id", default=None, type=str)
  700. stats = DatasetService.get_dataset_update_stats(tenant_id)
  701. # 转换为前端需要的格式
  702. response = {
  703. "data": [
  704. {"period": "半年以上", "count": stats["over_180_days"]},
  705. {"period": "半年以内", "count": stats["within_180_days"]},
  706. {"period": "90天内", "count": stats["within_90_days"]},
  707. {"period": "30天内", "count": stats["within_30_days"]},
  708. {"period": "7天内", "count": stats["within_7_days"]},
  709. {"period": "3天内", "count": stats["within_3_days"]},
  710. ]
  711. }
  712. return response, 200
  713. class DatasetTypeStatsApi(Resource):
  714. @setup_required
  715. @login_required
  716. @account_initialization_required
  717. def get(self):
  718. tenant_id = current_user.current_tenant_id
  719. response = DatasetService.get_dataset_type_stats(tenant_id)
  720. return {"data": response}, 200
  721. api.add_resource(DatasetListApi, "/datasets")
  722. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
  723. api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
  724. api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
  725. api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
  726. api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
  727. api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
  728. api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
  729. api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
  730. api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
  731. api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
  732. api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
  733. api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
  734. api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
  735. api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")
  736. api.add_resource(DatasetCountApi, "/datasets/count")
  737. api.add_resource(DatasetUpdateStatsApi, "/datasets/update-stats")
  738. api.add_resource(DatasetTypeStatsApi, "/datasets/type-stats")