Преглед на файлове

suhh-新增部门增删改查接口、部门关联账号增删改查、修改知识库查询接口

‘suhuihui’ преди 4 месеца
родител
ревизия
5bf1cf291b

+ 14 - 5
api/controllers/console/datasets/datasets.py

@@ -68,8 +68,17 @@ class DatasetListApi(Resource):
             datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
         else:
             datasets, total = DatasetService.get_datasets2(
-                page, limit, current_user.current_tenant_id, current_user, search, tag_ids,
-                category_ids, auth_type, creator_dept, creator, include_all
+                page,
+                limit,
+                current_user.current_tenant_id,
+                current_user,
+                search,
+                tag_ids,
+                category_ids,
+                auth_type,
+                creator_dept,
+                creator,
+                include_all,
             )
 
         # check embedding setting
@@ -85,7 +94,7 @@ class DatasetListApi(Resource):
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
             # 返回编辑授权
-            item["has_edit_permission"] = DatasetService.has_edit_permission(current_user.id,item["id"])
+            item["has_edit_permission"] = DatasetService.has_edit_permission(current_user.id, item["id"])
             # convert embedding_model_provider to plugin standard format
             if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
                 item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
@@ -222,8 +231,8 @@ class DatasetApi(Resource):
         if data.get("permission") == "partial_members":
             part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
             data.update({"partial_member_list": part_users_list})
-        edit_user_id_list = DatasetPermissionService.get_dataset_edit_user_ids(dataset_id_str)
-        data.update({"edit_user_ids": edit_user_id_list})
+        data["has_edit_permission"] = DatasetService.has_edit_permission(current_user.id, dataset_id_str)
+
         return data, 200
 
     @setup_required

+ 109 - 2
api/controllers/console/dept/depts.py

@@ -1,4 +1,7 @@
+from flask import jsonify, request
+from flask_login import current_user
 from flask_restful import Resource  # type: ignore
+from werkzeug.exceptions import NotFound
 
 from controllers.console import api
 from controllers.console.wraps import (
@@ -14,12 +17,59 @@ class DeptAccountListApi(Resource):
     @login_required
     @account_initialization_required
     def get(self):
-        dept_account_list=DeptService.get_dept_account_list()
+        dept_account_list = DeptService.get_dept_account_list()
         response = {
+            "result": "success",
             "data": dept_account_list,
         }
         return response, 200
 
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        data = request.get_json()
+        if not data:
+            raise NotFound("Invalid JSON")
+        dept_account_list = data.get("account_ids")
+        dept_id = data.get("dept_id")
+
+        DeptService.save_dept_account_list(dept_id, dept_account_list)
+        response = {
+            "result": "success",
+        }
+        return response, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self):
+        data = request.get_json()
+        if not data:
+            raise NotFound("Invalid JSON")
+        dept_account_list = data.get("account_ids")
+        dept_id = data.get("dept_id")
+
+        DeptService.delete_dept_account_list(dept_id, dept_account_list)
+        response = {
+            "result": "success",
+        }
+        return response, 200
+
+
+class DeptAccountApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, dept_id):
+        dept_account = DeptService.get_dept_account(dept_id)
+        response = {
+            "result": "success",
+            "data": dept_account,
+        }
+        return response, 200
+
+
 class DeptListApi(Resource):
     @setup_required
     @login_required
@@ -27,9 +77,66 @@ class DeptListApi(Resource):
     def get(self):
         dept_list = DeptService.get_dept_list()
         response = {
+            "result": "success",
             "data": dept_list,
         }
         return response, 200
 
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        data = request.get_json()
+        if not data:
+            raise NotFound("Invalid JSON")
+        dept_name = data.get("dept_name")
+        dept_id = data.get("dept_id")
+        dept_by_name = DeptService.get_dept_by_name(dept_name)
+        if dept_id not in {None, ""}:
+            dept_by_id = DeptService.get_dept_by_id(dept_id)
+            if dept_by_id != None:
+                if dept_by_id.dept_name != dept_name:
+                    # 修改
+                    if dept_by_name != None:
+                        return jsonify({"error": "'dept_name' repeat"}), 400
+                    else:
+                        DeptService.update_dept(dept_id, dept_name, current_user)
+            else:
+                raise NotFound("Dept not found.")
+        else:
+            if dept_by_name != None:
+                return jsonify({"error": "'dept_name' repeat"}), 400
+            else:
+                DeptService.save_dept(dept_name, current_user)
+        return {"result": "success"}, 204
+
+
+class DeptApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, dept_id):
+        dept_id = str(dept_id)
+        dept = DeptService.get_dept_by_id(dept_id)
+        response = {
+            "result": "success",
+            "data": dept,
+        }
+        return response, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, dept_id):
+        dept_id = str(dept_id)
+        dept = DeptService.get_dept_by_id(dept_id)
+        if dept is None:
+            raise NotFound("Dept not found.")
+        DeptService.delete_dept(dept)
+        return {"result": "success"}, 204
+
+
 api.add_resource(DeptAccountListApi, "/dept/dept-accounts")
-api.add_resource(DeptListApi, "/dept")
+api.add_resource(DeptAccountApi, "/dept/dept-accounts/<uuid:dept_id>")
+api.add_resource(DeptListApi, "/depts")
+api.add_resource(DeptApi, "/depts/<uuid:dept_id>")

+ 26 - 0
api/controllers/console/workspace/account.py

@@ -292,6 +292,30 @@ class AccountDeleteUpdateFeedbackApi(Resource):
         return {"result": "success"}
 
 
+class AccountDeptApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(account_fields)
+    def post(self):
+        data = request.get_json()
+        account_id = data.get("account_id")
+        dept_id = data.get("dept_id")
+        AccountService.update_account_dept(dept_id, account_id)
+
+        return {"result": "success"}, 204
+
+
+class AccountNoDeptApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        data = AccountService.get_no_dept_accounts()
+
+        return {"result": "success", "data": data}
+
+
 # Register API resources
 api.add_resource(AccountInitApi, "/account/init")
 api.add_resource(AccountProfileApi, "/account/profile")
@@ -305,5 +329,7 @@ api.add_resource(AccountIntegrateApi, "/account/integrates")
 api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
 api.add_resource(AccountDeleteApi, "/account/delete")
 api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
+api.add_resource(AccountDeptApi, "/account/dept")
+api.add_resource(AccountNoDeptApi, "/account/nodept")
 # api.add_resource(AccountEmailApi, '/account/email')
 # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

+ 1 - 0
api/fields/member_fields.py

@@ -17,6 +17,7 @@ account_fields = {
     "last_login_at": TimestampField,
     "last_login_ip": fields.String,
     "created_at": TimestampField,
+    "dept_id": fields.String,
 }
 
 account_with_role_fields = {

+ 7 - 5
api/models/account.py

@@ -39,7 +39,7 @@ class Account(UserMixin, Base):
     initialized_at = db.Column(db.DateTime)
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-    dept_id = db.Column(StringUUID, nullable=True)
+    dept_id = db.Column(db.String(255), nullable=True)
 
     @property
     def is_password_set(self):
@@ -185,10 +185,12 @@ class TenantAccountRole(enum.StrEnum):
     def is_editing_role(role: str) -> bool:
         if not role:
             return False
-        return role in {TenantAccountRole.OWNER,
-                        TenantAccountRole.ADMIN,
-                        TenantAccountRole.EDITOR,
-                        TenantAccountRole.LEADER}
+        return role in {
+            TenantAccountRole.OWNER,
+            TenantAccountRole.ADMIN,
+            TenantAccountRole.EDITOR,
+            TenantAccountRole.LEADER,
+        }
 
     @staticmethod
     def is_dataset_edit_role(role: str) -> bool:

+ 3 - 3
api/models/dept.py

@@ -14,9 +14,7 @@ class DeptStatus(enum.StrEnum):
 
 class Dept(db.Model):
     __tablename__ = "dept"
-    __table_args__ = (
-        db.PrimaryKeyConstraint("dept_id", name="dept_pkey"),
-    )
+    __table_args__ = (db.PrimaryKeyConstraint("dept_id", name="dept_pkey"),)
 
     dept_id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     dept_name = db.Column(db.String(255), nullable=False)
@@ -25,6 +23,8 @@ class Dept(db.Model):
     status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying"))
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+    created_by = db.Column(StringUUID, nullable=False)
+    updated_by = db.Column(StringUUID, nullable=True)
 
     def get_status(self) -> DeptStatus:
         status_str = self.status

+ 28 - 2
api/services/account_service.py

@@ -3,13 +3,12 @@ import json
 import logging
 import random
 import secrets
-import uuid
 from datetime import UTC, datetime, timedelta
 from hashlib import sha256
 from typing import Any, Optional, cast
 
 from pydantic import BaseModel
-from sqlalchemy import func
+from sqlalchemy import func, or_
 from sqlalchemy.orm import Session
 from werkzeug.exceptions import Unauthorized
 
@@ -31,6 +30,7 @@ from models.account import (
     TenantAccountRole,
     TenantStatus,
 )
+from models.dept import Dept
 from models.model import DifySetup
 from services.billing_service import BillingService
 from services.errors.account import (
@@ -342,6 +342,32 @@ class AccountService:
         return account
 
     @staticmethod
+    def update_account_dept(dept_id, account_id) -> Optional[Account]:
+        db.session.query(Account).filter(Account.id == account_id).update(
+            {"dept_id": dept_id, "updated_at": datetime.now()}
+        )
+        db.session.commit()
+
+    def delete_account_dept(self: Dept):
+        db.session.query(Account).filter(Account.dept_id == self.dept_id).update(
+            {"dept_id": "", "updated_at": datetime.now()}
+        )
+        db.session.commit()
+
+    @staticmethod
+    def get_no_dept_accounts():
+        no_dept_accounts = []
+
+        condition = or_(Account.dept_id == None, Account.dept_id == "")
+        accounts = db.session.query(Account).filter(condition).all()
+        for account_row in accounts:
+            no_dept_accounts.append(
+                {"account_id": account_row.id, "email": account_row.email, "name": account_row.name}
+            )
+
+        return no_dept_accounts
+
+    @staticmethod
     def update_login_info(account: Account, *, ip_address: str) -> None:
         """Update last login time and ip"""
         account.last_login_at = datetime.now(UTC).replace(tzinfo=None)

+ 80 - 19
api/services/dept_service.py

@@ -1,41 +1,102 @@
+import datetime
+
 from extensions.ext_database import db
 from models.account import Account
 from models.dept import Dept
+from services.account_service import AccountService
 
 
 class DeptService:
-
     @staticmethod
     def get_dept_account_list():
-        dept_list=[]
-        account_list=[]
+        dept_list = []
+        account_list = []
 
-        dept_results = (
-            db.session.query(Dept.dept_id, Dept.dept_name)
-            .filter(Dept.status == 'active')
-            .all()
-        )
+        dept_results = db.session.query(Dept.dept_id, Dept.dept_name).filter(Dept.status == "active").all()
 
         account_results = (
-            db.session.query(Account.dept_id, Account.id,Account.email)
-            .filter(Account.status == 'active')
-            .all()
+            db.session.query(Account.dept_id, Account.id, Account.email).filter(Account.status == "active").all()
         )
         for dept_row in dept_results:
             for account_row in account_results:
                 if account_row.dept_id == dept_row.dept_id:
-                   account_list.append({"account_id":account_row.id,"email":account_row.email})
-            dept_list.append({"dept_id": dept_row.dept_id, "dept_name": dept_row.dept_name,"accounts":account_list})
+                    account_list.append({"account_id": account_row.id, "email": account_row.email})
+            dept_list.append({"dept_id": dept_row.dept_id, "dept_name": dept_row.dept_name, "accounts": account_list})
 
         return dept_list
+
     @staticmethod
-    def get_dept_list():
-        dept_list = []
-        dept_results = (
-            db.session.query(Dept.dept_id, Dept.dept_name)
-            .filter(Dept.status == 'active')
+    def get_dept_account(dept_id):
+        dept_account = []
+        account_results = (
+            db.session.query(Account.dept_id, Account.id, Account.email)
+            .filter(Account.status == "active", Account.dept_id == str(dept_id))
             .all()
         )
+        for row in account_results:
+            dept_account.append({"account_id": row.id, "email": row.email})
+        return dept_account
+
+    @staticmethod
+    def get_dept_list():
+        dept_list = []
+        dept_results = db.session.query(Dept.dept_id, Dept.dept_name).filter(Dept.status == "active").all()
         for dept_row in dept_results:
             dept_list.append({"dept_id": dept_row.dept_id, "dept_name": dept_row.dept_name})
-        return  dept_list
+        return dept_list
+
+    @staticmethod
+    def get_dept_by_name(dept_name):
+        dept = db.session.query(Dept).filter(Dept.status == "active", Dept.dept_name == dept_name).first()
+        return dept
+
+    @staticmethod
+    def get_dept_by_id_name(dept_id, dept_name):
+        dept = (
+            db.session.query(Dept)
+            .filter(Dept.status == "active", Dept.dept_name == dept_name, Dept.dept_id == dept_id)
+            .first()
+        )
+        return dept
+
+    @staticmethod
+    def get_dept_by_id(dept_id):
+        dept = db.session.query(Dept).filter(Dept.status == "active", Dept.dept_id == dept_id).first()
+        return dept
+
+    @staticmethod
+    def save_dept(dept_name, current_user):
+        dept = Dept(
+            dept_name=dept_name,
+            status="active",
+            created_at=datetime.datetime.now(),
+            created_by=current_user.id,
+        )
+        db.session.add(dept)
+        db.session.flush()
+        db.session.commit()
+
+    @staticmethod
+    def update_dept(dept_id, dept_name, current_user):
+        db.session.query(Dept).filter(Dept.dept_id == dept_id).update(
+            {"updated_by": current_user.id, "updated_at": datetime.datetime.now(), "dept_name": dept_name}
+        )
+        db.session.commit()
+
+    @staticmethod
+    def delete_dept(dept):
+        AccountService.delete_account_dept(dept)
+        db.session.delete(dept)
+        db.session.commit()
+
+    @staticmethod
+    def save_dept_account_list(dept_id, dept_account_list):
+        for dept_account in dept_account_list:
+            account_id = dept_account.get("account_id")
+            AccountService.update_account_dept(dept_id, account_id)
+
+    @staticmethod
+    def delete_dept_account_list(dept_id, dept_account_list):
+        for dept_account in dept_account_list:
+            account_id = dept_account.get("account_id")
+            AccountService.update_account_dept("", account_id)