models.py 14 KB


  1. import logging
  2. from flask_login import current_user # type: ignore
  3. from flask_restful import Resource, reqparse # type: ignore
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.wraps import account_initialization_required, setup_required
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from libs.login import login_required
  11. from services.model_load_balancing_service import ModelLoadBalancingService
  12. from services.model_provider_service import ModelProviderService
  13. class DefaultModelApi(Resource):
  14. @setup_required
  15. @login_required
  16. @account_initialization_required
  17. def get(self):
  18. parser = reqparse.RequestParser()
  19. parser.add_argument(
  20. "model_type",
  21. type=str,
  22. required=True,
  23. nullable=False,
  24. choices=[mt.value for mt in ModelType],
  25. location="args",
  26. )
  27. args = parser.parse_args()
  28. tenant_id = current_user.current_tenant_id
  29. model_provider_service = ModelProviderService()
  30. default_model_entity = model_provider_service.get_default_model_of_model_type(
  31. tenant_id=tenant_id, model_type=args["model_type"]
  32. )
  33. return jsonable_encoder({"data": default_model_entity})
  34. @setup_required
  35. @login_required
  36. @account_initialization_required
  37. def post(self):
  38. if not current_user.is_admin_or_owner:
  39. raise Forbidden()
  40. parser = reqparse.RequestParser()
  41. parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
  42. args = parser.parse_args()
  43. tenant_id = current_user.current_tenant_id
  44. model_provider_service = ModelProviderService()
  45. model_settings = args["model_settings"]
  46. for model_setting in model_settings:
  47. if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
  48. raise ValueError("invalid model type")
  49. if "provider" not in model_setting:
  50. continue
  51. if "model" not in model_setting:
  52. raise ValueError("invalid model")
  53. try:
  54. model_provider_service.update_default_model_of_model_type(
  55. tenant_id=tenant_id,
  56. model_type=model_setting["model_type"],
  57. provider=model_setting["provider"],
  58. model=model_setting["model"],
  59. )
  60. except Exception as ex:
  61. logging.exception(
  62. f"Failed to update default model, model type: {model_setting['model_type']},"
  63. f" model:{model_setting.get('model')}"
  64. )
  65. raise ex
  66. return {"result": "success"}
  67. class ModelProviderModelApi(Resource):
  68. @setup_required
  69. @login_required
  70. @account_initialization_required
  71. def get(self, provider):
  72. tenant_id = current_user.current_tenant_id
  73. model_provider_service = ModelProviderService()
  74. models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
  75. return jsonable_encoder({"data": models})
  76. @setup_required
  77. @login_required
  78. @account_initialization_required
  79. def post(self, provider: str):
  80. if not current_user.is_admin_or_owner:
  81. raise Forbidden()
  82. tenant_id = current_user.current_tenant_id
  83. parser = reqparse.RequestParser()
  84. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  85. parser.add_argument(
  86. "model_type",
  87. type=str,
  88. required=True,
  89. nullable=False,
  90. choices=[mt.value for mt in ModelType],
  91. location="json",
  92. )
  93. parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
  94. parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
  95. parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
  96. args = parser.parse_args()
  97. model_load_balancing_service = ModelLoadBalancingService()
  98. if (
  99. "load_balancing" in args
  100. and args["load_balancing"]
  101. and "enabled" in args["load_balancing"]
  102. and args["load_balancing"]["enabled"]
  103. ):
  104. if "configs" not in args["load_balancing"]:
  105. raise ValueError("invalid load balancing configs")
  106. # save load balancing configs
  107. model_load_balancing_service.update_load_balancing_configs(
  108. tenant_id=tenant_id,
  109. provider=provider,
  110. model=args["model"],
  111. model_type=args["model_type"],
  112. configs=args["load_balancing"]["configs"],
  113. )
  114. # enable load balancing
  115. model_load_balancing_service.enable_model_load_balancing(
  116. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  117. )
  118. else:
  119. # disable load balancing
  120. model_load_balancing_service.disable_model_load_balancing(
  121. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  122. )
  123. if args.get("config_from", "") != "predefined-model":
  124. model_provider_service = ModelProviderService()
  125. try:
  126. model_provider_service.save_model_credentials(
  127. tenant_id=tenant_id,
  128. provider=provider,
  129. model=args["model"],
  130. model_type=args["model_type"],
  131. credentials=args["credentials"],
  132. )
  133. except CredentialsValidateFailedError as ex:
  134. logging.exception(
  135. f"Failed to save model credentials, tenant_id: {tenant_id},"
  136. f" model: {args.get('model')}, model_type: {args.get('model_type')}"
  137. )
  138. raise ValueError(str(ex))
  139. return {"result": "success"}, 200
  140. @setup_required
  141. @login_required
  142. @account_initialization_required
  143. def delete(self, provider: str):
  144. if not current_user.is_admin_or_owner:
  145. raise Forbidden()
  146. tenant_id = current_user.current_tenant_id
  147. parser = reqparse.RequestParser()
  148. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  149. parser.add_argument(
  150. "model_type",
  151. type=str,
  152. required=True,
  153. nullable=False,
  154. choices=[mt.value for mt in ModelType],
  155. location="json",
  156. )
  157. args = parser.parse_args()
  158. model_provider_service = ModelProviderService()
  159. model_provider_service.remove_model_credentials(
  160. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  161. )
  162. return {"result": "success"}, 204
  163. class ModelProviderModelCredentialApi(Resource):
  164. @setup_required
  165. @login_required
  166. @account_initialization_required
  167. def get(self, provider: str):
  168. tenant_id = current_user.current_tenant_id
  169. parser = reqparse.RequestParser()
  170. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  171. parser.add_argument(
  172. "model_type",
  173. type=str,
  174. required=True,
  175. nullable=False,
  176. choices=[mt.value for mt in ModelType],
  177. location="args",
  178. )
  179. args = parser.parse_args()
  180. model_provider_service = ModelProviderService()
  181. credentials = model_provider_service.get_model_credentials(
  182. tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
  183. )
  184. model_load_balancing_service = ModelLoadBalancingService()
  185. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  186. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  187. )
  188. return {
  189. "credentials": credentials,
  190. "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
  191. }
  192. class ModelProviderModelEnableApi(Resource):
  193. @setup_required
  194. @login_required
  195. @account_initialization_required
  196. def patch(self, provider: str):
  197. tenant_id = current_user.current_tenant_id
  198. parser = reqparse.RequestParser()
  199. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  200. parser.add_argument(
  201. "model_type",
  202. type=str,
  203. required=True,
  204. nullable=False,
  205. choices=[mt.value for mt in ModelType],
  206. location="json",
  207. )
  208. args = parser.parse_args()
  209. model_provider_service = ModelProviderService()
  210. model_provider_service.enable_model(
  211. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  212. )
  213. return {"result": "success"}
  214. class ModelProviderModelDisableApi(Resource):
  215. @setup_required
  216. @login_required
  217. @account_initialization_required
  218. def patch(self, provider: str):
  219. tenant_id = current_user.current_tenant_id
  220. parser = reqparse.RequestParser()
  221. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  222. parser.add_argument(
  223. "model_type",
  224. type=str,
  225. required=True,
  226. nullable=False,
  227. choices=[mt.value for mt in ModelType],
  228. location="json",
  229. )
  230. args = parser.parse_args()
  231. model_provider_service = ModelProviderService()
  232. model_provider_service.disable_model(
  233. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  234. )
  235. return {"result": "success"}
  236. class ModelProviderModelValidateApi(Resource):
  237. @setup_required
  238. @login_required
  239. @account_initialization_required
  240. def post(self, provider: str):
  241. tenant_id = current_user.current_tenant_id
  242. parser = reqparse.RequestParser()
  243. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  244. parser.add_argument(
  245. "model_type",
  246. type=str,
  247. required=True,
  248. nullable=False,
  249. choices=[mt.value for mt in ModelType],
  250. location="json",
  251. )
  252. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  253. args = parser.parse_args()
  254. model_provider_service = ModelProviderService()
  255. result = True
  256. error = ""
  257. try:
  258. model_provider_service.model_credentials_validate(
  259. tenant_id=tenant_id,
  260. provider=provider,
  261. model=args["model"],
  262. model_type=args["model_type"],
  263. credentials=args["credentials"],
  264. )
  265. except CredentialsValidateFailedError as ex:
  266. result = False
  267. error = str(ex)
  268. response = {"result": "success" if result else "error"}
  269. if not result:
  270. response["error"] = error or ""
  271. return response
  272. class ModelProviderModelParameterRuleApi(Resource):
  273. @setup_required
  274. @login_required
  275. @account_initialization_required
  276. def get(self, provider: str):
  277. parser = reqparse.RequestParser()
  278. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  279. args = parser.parse_args()
  280. tenant_id = current_user.current_tenant_id
  281. model_provider_service = ModelProviderService()
  282. parameter_rules = model_provider_service.get_model_parameter_rules(
  283. tenant_id=tenant_id, provider=provider, model=args["model"]
  284. )
  285. return jsonable_encoder({"data": parameter_rules})
  286. class ModelProviderAvailableModelApi(Resource):
  287. @setup_required
  288. @login_required
  289. @account_initialization_required
  290. def get(self, model_type):
  291. tenant_id = current_user.current_tenant_id
  292. model_provider_service = ModelProviderService()
  293. models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
  294. return jsonable_encoder({"data": models})
  295. api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
  296. api.add_resource(
  297. ModelProviderModelEnableApi,
  298. "/workspaces/current/model-providers/<path:provider>/models/enable",
  299. endpoint="model-provider-model-enable",
  300. )
  301. api.add_resource(
  302. ModelProviderModelDisableApi,
  303. "/workspaces/current/model-providers/<path:provider>/models/disable",
  304. endpoint="model-provider-model-disable",
  305. )
  306. api.add_resource(
  307. ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
  308. )
  309. api.add_resource(
  310. ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
  311. )
  312. api.add_resource(
  313. ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
  314. )
  315. api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
  316. api.add_resource(DefaultModelApi, "/workspaces/current/default-model")