| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 | import iofrom flask import send_filefrom flask_login import current_user  # type: ignorefrom flask_restful import Resource, reqparse  # type: ignorefrom werkzeug.exceptions import Forbiddenfrom controllers.console import apifrom controllers.console.wraps import account_initialization_required, setup_requiredfrom core.model_runtime.entities.model_entities import ModelTypefrom core.model_runtime.errors.validate import CredentialsValidateFailedErrorfrom core.model_runtime.utils.encoders import jsonable_encoderfrom libs.login import login_requiredfrom services.billing_service import BillingServicefrom services.model_provider_service import ModelProviderServiceclass ModelProviderListApi(Resource):    @setup_required    @login_required    @account_initialization_required    def get(self):        tenant_id = current_user.current_tenant_id        parser = reqparse.RequestParser()        parser.add_argument(            "model_type",            type=str,            required=False,            nullable=True,            choices=[mt.value for mt in ModelType],            location="args",        )        args = parser.parse_args()        model_provider_service = ModelProviderService()        provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))        return jsonable_encoder({"data": provider_list})class ModelProviderCredentialApi(Resource):    @setup_required    @login_required    @account_initialization_required    def get(self, provider: str):        tenant_id = current_user.current_tenant_id        model_provider_service = ModelProviderService()        credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)        return {"credentials": credentials}class ModelProviderValidateApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, provider: str):        parser = reqparse.RequestParser()        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")        args = parser.parse_args()        tenant_id = current_user.current_tenant_id        model_provider_service = ModelProviderService()        result = True        error = ""        try:            model_provider_service.provider_credentials_validate(                tenant_id=tenant_id, provider=provider, credentials=args["credentials"]            )        except CredentialsValidateFailedError as ex:            result = False            error = str(ex)        response = {"result": "success" if result else "error"}        if not result:            response["error"] = error or "Unknown error"        return responseclass ModelProviderApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, provider: str):        if not current_user.is_admin_or_owner:            raise Forbidden()        parser = reqparse.RequestParser()        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")        args = parser.parse_args()        model_provider_service = ModelProviderService()        try:            model_provider_service.save_provider_credentials(                tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]            )        except CredentialsValidateFailedError as ex:            raise ValueError(str(ex))        return {"result": "success"}, 201    @setup_required    @login_required    @account_initialization_required    def delete(self, provider: str):        if not current_user.is_admin_or_owner:            raise Forbidden()        model_provider_service = ModelProviderService()        model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)        return {"result": "success"}, 204class ModelProviderIconApi(Resource):    """    Get model provider icon    """    def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):        model_provider_service = ModelProviderService()        icon, mimetype = model_provider_service.get_model_provider_icon(            tenant_id=tenant_id,            provider=provider,            icon_type=icon_type,            lang=lang,        )        if icon is None:            raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}")        return send_file(io.BytesIO(icon), mimetype=mimetype)class PreferredProviderTypeUpdateApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, provider: str):        if not current_user.is_admin_or_owner:            raise Forbidden()        tenant_id = current_user.current_tenant_id        parser = reqparse.RequestParser()        parser.add_argument(            "preferred_provider_type",            type=str,            required=True,            nullable=False,            choices=["system", "custom"],            location="json",        )        args = parser.parse_args()        model_provider_service = ModelProviderService()        model_provider_service.switch_preferred_provider(            tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]        )        return {"result": "success"}class ModelProviderPaymentCheckoutUrlApi(Resource):    @setup_required    @login_required    @account_initialization_required    def get(self, provider: str):        if provider != "anthropic":            raise ValueError(f"provider name {provider} is invalid")        BillingService.is_tenant_owner_or_admin(current_user)        data = BillingService.get_model_provider_payment_link(            provider_name=provider,            tenant_id=current_user.current_tenant_id,            account_id=current_user.id,            prefilled_email=current_user.email,        )        return dataapi.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")api.add_resource(    PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type")api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")api.add_resource(    ModelProviderIconApi,    "/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",)
 |