load_balancing_config.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from flask_restful import Resource, reqparse # type: ignore
  2. from werkzeug.exceptions import Forbidden
  3. from controllers.console import api
  4. from controllers.console.wraps import account_initialization_required, setup_required
  5. from core.model_runtime.entities.model_entities import ModelType
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from libs.login import current_user, login_required
  8. from models.account import TenantAccountRole
  9. from services.model_load_balancing_service import ModelLoadBalancingService
  10. class LoadBalancingCredentialsValidateApi(Resource):
  11. @setup_required
  12. @login_required
  13. @account_initialization_required
  14. def post(self, provider: str):
  15. if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
  16. raise Forbidden()
  17. tenant_id = current_user.current_tenant_id
  18. parser = reqparse.RequestParser()
  19. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  20. parser.add_argument(
  21. "model_type",
  22. type=str,
  23. required=True,
  24. nullable=False,
  25. choices=[mt.value for mt in ModelType],
  26. location="json",
  27. )
  28. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  29. args = parser.parse_args()
  30. # validate model load balancing credentials
  31. model_load_balancing_service = ModelLoadBalancingService()
  32. result = True
  33. error = ""
  34. try:
  35. model_load_balancing_service.validate_load_balancing_credentials(
  36. tenant_id=tenant_id,
  37. provider=provider,
  38. model=args["model"],
  39. model_type=args["model_type"],
  40. credentials=args["credentials"],
  41. )
  42. except CredentialsValidateFailedError as ex:
  43. result = False
  44. error = str(ex)
  45. response = {"result": "success" if result else "error"}
  46. if not result:
  47. response["error"] = error
  48. return response
  49. class LoadBalancingConfigCredentialsValidateApi(Resource):
  50. @setup_required
  51. @login_required
  52. @account_initialization_required
  53. def post(self, provider: str, config_id: str):
  54. if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
  55. raise Forbidden()
  56. tenant_id = current_user.current_tenant_id
  57. parser = reqparse.RequestParser()
  58. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  59. parser.add_argument(
  60. "model_type",
  61. type=str,
  62. required=True,
  63. nullable=False,
  64. choices=[mt.value for mt in ModelType],
  65. location="json",
  66. )
  67. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  68. args = parser.parse_args()
  69. # validate model load balancing config credentials
  70. model_load_balancing_service = ModelLoadBalancingService()
  71. result = True
  72. error = ""
  73. try:
  74. model_load_balancing_service.validate_load_balancing_credentials(
  75. tenant_id=tenant_id,
  76. provider=provider,
  77. model=args["model"],
  78. model_type=args["model_type"],
  79. credentials=args["credentials"],
  80. config_id=config_id,
  81. )
  82. except CredentialsValidateFailedError as ex:
  83. result = False
  84. error = str(ex)
  85. response = {"result": "success" if result else "error"}
  86. if not result:
  87. response["error"] = error
  88. return response
  89. # Load Balancing Config
  90. api.add_resource(
  91. LoadBalancingCredentialsValidateApi,
  92. "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
  93. )
  94. api.add_resource(
  95. LoadBalancingConfigCredentialsValidateApi,
  96. "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
  97. )