| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 | import datetimeimport pytzfrom flask import requestfrom flask_login import current_userfrom flask_restful import Resource, fields, marshal_with, reqparsefrom configs import dify_configfrom constants.languages import supported_languagefrom controllers.console import apifrom controllers.console.workspace.error import (    AccountAlreadyInitedError,    CurrentPasswordIncorrectError,    InvalidInvitationCodeError,    RepeatPasswordNotMatchError,)from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_requiredfrom extensions.ext_database import dbfrom fields.member_fields import account_fieldsfrom libs.helper import TimestampField, timezonefrom libs.login import login_requiredfrom models import AccountIntegrate, InvitationCodefrom services.account_service import AccountServicefrom services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectErrorclass AccountInitApi(Resource):    @setup_required    @login_required    def post(self):        account = current_user        if account.status == "active":            raise AccountAlreadyInitedError()        parser = reqparse.RequestParser()        if dify_config.EDITION == "CLOUD":            parser.add_argument("invitation_code", type=str, location="json")        parser.add_argument("interface_language", type=supported_language, required=True, location="json")        parser.add_argument("timezone", type=timezone, required=True, location="json")        args = parser.parse_args()        if dify_config.EDITION == "CLOUD":            if not args["invitation_code"]:                raise ValueError("invitation_code is required")            # check invitation code            invitation_code = (                db.session.query(InvitationCode)                .filter(                    InvitationCode.code == args["invitation_code"],                    InvitationCode.status == "unused",                )                .first()            )            if not invitation_code:                raise InvalidInvitationCodeError()            invitation_code.status = "used"            invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)            invitation_code.used_by_tenant_id = account.current_tenant_id            invitation_code.used_by_account_id = account.id        account.interface_language = args["interface_language"]        account.timezone = args["timezone"]        account.interface_theme = "light"        account.status = "active"        account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)        db.session.commit()        return {"result": "success"}class AccountProfileApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    @enterprise_license_required    def get(self):        return current_userclass AccountNameApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument("name", type=str, required=True, location="json")        args = parser.parse_args()        # Validate account name length        if len(args["name"]) < 3 or len(args["name"]) > 30:            raise ValueError("Account name must be between 3 and 30 characters.")        updated_account = AccountService.update_account(current_user, name=args["name"])        return updated_accountclass AccountAvatarApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument("avatar", type=str, required=True, location="json")        args = parser.parse_args()        updated_account = AccountService.update_account(current_user, avatar=args["avatar"])        return updated_accountclass AccountInterfaceLanguageApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument("interface_language", type=supported_language, required=True, location="json")        args = parser.parse_args()        updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])        return updated_accountclass AccountInterfaceThemeApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")        args = parser.parse_args()        updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])        return updated_accountclass AccountTimezoneApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument("timezone", type=str, required=True, location="json")        args = parser.parse_args()        # Validate timezone string, e.g. America/New_York, Asia/Shanghai        if args["timezone"] not in pytz.all_timezones:            raise ValueError("Invalid timezone string.")        updated_account = AccountService.update_account(current_user, timezone=args["timezone"])        return updated_accountclass AccountPasswordApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument("password", type=str, required=False, location="json")        parser.add_argument("new_password", type=str, required=True, location="json")        parser.add_argument("repeat_new_password", type=str, required=True, location="json")        args = parser.parse_args()        if args["new_password"] != args["repeat_new_password"]:            raise RepeatPasswordNotMatchError()        try:            AccountService.update_account_password(current_user, args["password"], args["new_password"])        except ServiceCurrentPasswordIncorrectError:            raise CurrentPasswordIncorrectError()        return {"result": "success"}class AccountIntegrateApi(Resource):    integrate_fields = {        "provider": fields.String,        "created_at": TimestampField,        "is_bound": fields.Boolean,        "link": fields.String,    }    integrate_list_fields = {        "data": fields.List(fields.Nested(integrate_fields)),    }    @setup_required    @login_required    @account_initialization_required    @marshal_with(integrate_list_fields)    def get(self):        account = current_user        account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()        base_url = request.url_root.rstrip("/")        oauth_base_path = "/console/api/oauth/login"        providers = ["github", "google"]        integrate_data = []        for provider in providers:            existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)            if existing_integrate:                integrate_data.append(                    {                        "id": existing_integrate.id,                        "provider": provider,                        "created_at": existing_integrate.created_at,                        "is_bound": True,                        "link": None,                    }                )            else:                integrate_data.append(                    {                        "id": None,                        "provider": provider,                        "created_at": None,                        "is_bound": False,                        "link": f"{base_url}{oauth_base_path}/{provider}",                    }                )        return {"data": integrate_data}# Register API resourcesapi.add_resource(AccountInitApi, "/account/init")api.add_resource(AccountProfileApi, "/account/profile")api.add_resource(AccountNameApi, "/account/name")api.add_resource(AccountAvatarApi, "/account/avatar")api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")api.add_resource(AccountTimezoneApi, "/account/timezone")api.add_resource(AccountPasswordApi, "/account/password")api.add_resource(AccountIntegrateApi, "/account/integrates")# api.add_resource(AccountEmailApi, '/account/email')# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
 |