| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 | import datetimeimport pytzfrom flask import requestfrom flask_login import current_userfrom flask_restful import Resource, fields, marshal_with, reqparsefrom configs import dify_configfrom constants.languages import supported_languagefrom controllers.console import apifrom controllers.console.setup import setup_requiredfrom controllers.console.workspace.error import (    AccountAlreadyInitedError,    CurrentPasswordIncorrectError,    InvalidInvitationCodeError,    RepeatPasswordNotMatchError,)from controllers.console.wraps import account_initialization_requiredfrom extensions.ext_database import dbfrom fields.member_fields import account_fieldsfrom libs.helper import TimestampField, timezonefrom libs.login import login_requiredfrom models.account import AccountIntegrate, InvitationCodefrom services.account_service import AccountServicefrom services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectErrorclass AccountInitApi(Resource):    @setup_required    @login_required    def post(self):        account = current_user        if account.status == 'active':            raise AccountAlreadyInitedError()        parser = reqparse.RequestParser()        if dify_config.EDITION == 'CLOUD':            parser.add_argument('invitation_code', type=str, location='json')        parser.add_argument(            'interface_language', type=supported_language, required=True, location='json')        parser.add_argument('timezone', type=timezone,                            required=True, location='json')        args = parser.parse_args()        if dify_config.EDITION == 'CLOUD':            if not args['invitation_code']:                raise ValueError('invitation_code is required')            # check invitation code            invitation_code = db.session.query(InvitationCode).filter(                InvitationCode.code == args['invitation_code'],                InvitationCode.status == 'unused',            ).first()            if not invitation_code:                raise InvalidInvitationCodeError()            invitation_code.status = 'used'            invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)            invitation_code.used_by_tenant_id = account.current_tenant_id            invitation_code.used_by_account_id = account.id        account.interface_language = args['interface_language']        account.timezone = args['timezone']        account.interface_theme = 'light'        account.status = 'active'        account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)        db.session.commit()        return {'result': 'success'}class AccountProfileApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def get(self):        return current_userclass AccountNameApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument('name', type=str, required=True, location='json')        args = parser.parse_args()        # Validate account name length        if len(args['name']) < 3 or len(args['name']) > 30:            raise ValueError(                "Account name must be between 3 and 30 characters.")        updated_account = AccountService.update_account(current_user, name=args['name'])        return updated_accountclass AccountAvatarApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument('avatar', type=str, required=True, location='json')        args = parser.parse_args()        updated_account = AccountService.update_account(current_user, avatar=args['avatar'])        return updated_accountclass AccountInterfaceLanguageApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument(            'interface_language', type=supported_language, required=True, location='json')        args = parser.parse_args()        updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])        return updated_accountclass AccountInterfaceThemeApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument('interface_theme', type=str, choices=[            'light', 'dark'], required=True, location='json')        args = parser.parse_args()        updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])        return updated_accountclass AccountTimezoneApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument('timezone', type=str,                            required=True, location='json')        args = parser.parse_args()        # Validate timezone string, e.g. America/New_York, Asia/Shanghai        if args['timezone'] not in pytz.all_timezones:            raise ValueError("Invalid timezone string.")        updated_account = AccountService.update_account(current_user, timezone=args['timezone'])        return updated_accountclass AccountPasswordApi(Resource):    @setup_required    @login_required    @account_initialization_required    @marshal_with(account_fields)    def post(self):        parser = reqparse.RequestParser()        parser.add_argument('password', type=str,                            required=False, location='json')        parser.add_argument('new_password', type=str,                            required=True, location='json')        parser.add_argument('repeat_new_password', type=str,                            required=True, location='json')        args = parser.parse_args()        if args['new_password'] != args['repeat_new_password']:            raise RepeatPasswordNotMatchError()        try:            AccountService.update_account_password(                current_user, args['password'], args['new_password'])        except ServiceCurrentPasswordIncorrectError:            raise CurrentPasswordIncorrectError()        return {"result": "success"}class AccountIntegrateApi(Resource):    integrate_fields = {        'provider': fields.String,        'created_at': TimestampField,        'is_bound': fields.Boolean,        'link': fields.String    }    integrate_list_fields = {        'data': fields.List(fields.Nested(integrate_fields)),    }    @setup_required    @login_required    @account_initialization_required    @marshal_with(integrate_list_fields)    def get(self):        account = current_user        account_integrates = db.session.query(AccountIntegrate).filter(            AccountIntegrate.account_id == account.id).all()        base_url = request.url_root.rstrip('/')        oauth_base_path = "/console/api/oauth/login"        providers = ["github", "google"]        integrate_data = []        for provider in providers:            existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)            if existing_integrate:                integrate_data.append({                    'id': existing_integrate.id,                    'provider': provider,                    'created_at': existing_integrate.created_at,                    'is_bound': True,                    'link': None                })            else:                integrate_data.append({                    'id': None,                    'provider': provider,                    'created_at': None,                    'is_bound': False,                    'link': f'{base_url}{oauth_base_path}/{provider}'                })        return {'data': integrate_data}# Register API resourcesapi.add_resource(AccountInitApi, '/account/init')api.add_resource(AccountProfileApi, '/account/profile')api.add_resource(AccountNameApi, '/account/name')api.add_resource(AccountAvatarApi, '/account/avatar')api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')api.add_resource(AccountTimezoneApi, '/account/timezone')api.add_resource(AccountPasswordApi, '/account/password')api.add_resource(AccountIntegrateApi, '/account/integrates')# api.add_resource(AccountEmailApi, '/account/email')# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
 |