generator.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. from flask_login import current_user
  3. from flask_restful import Resource, reqparse
  4. from controllers.console import api
  5. from controllers.console.app.error import (
  6. CompletionRequestError,
  7. ProviderModelCurrentlyNotSupportError,
  8. ProviderNotInitializeError,
  9. ProviderQuotaExceededError,
  10. )
  11. from controllers.console.setup import setup_required
  12. from controllers.console.wraps import account_initialization_required
  13. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  14. from core.llm_generator.llm_generator import LLMGenerator
  15. from core.model_runtime.errors.invoke import InvokeError
  16. from libs.login import login_required
  17. class RuleGenerateApi(Resource):
  18. @setup_required
  19. @login_required
  20. @account_initialization_required
  21. def post(self):
  22. parser = reqparse.RequestParser()
  23. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  24. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  25. parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
  26. args = parser.parse_args()
  27. account = current_user
  28. PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
  29. try:
  30. rules = LLMGenerator.generate_rule_config(
  31. tenant_id=account.current_tenant_id,
  32. instruction=args["instruction"],
  33. model_config=args["model_config"],
  34. no_variable=args["no_variable"],
  35. rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
  36. )
  37. except ProviderTokenNotInitError as ex:
  38. raise ProviderNotInitializeError(ex.description)
  39. except QuotaExceededError:
  40. raise ProviderQuotaExceededError()
  41. except ModelCurrentlyNotSupportError:
  42. raise ProviderModelCurrentlyNotSupportError()
  43. except InvokeError as e:
  44. raise CompletionRequestError(e.description)
  45. return rules
  46. class RuleCodeGenerateApi(Resource):
  47. @setup_required
  48. @login_required
  49. @account_initialization_required
  50. def post(self):
  51. parser = reqparse.RequestParser()
  52. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  53. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  54. parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
  55. parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
  56. args = parser.parse_args()
  57. account = current_user
  58. CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
  59. try:
  60. code_result = LLMGenerator.generate_code(
  61. tenant_id=account.current_tenant_id,
  62. instruction=args["instruction"],
  63. model_config=args["model_config"],
  64. code_language=args["code_language"],
  65. max_tokens=CODE_GENERATION_MAX_TOKENS,
  66. )
  67. except ProviderTokenNotInitError as ex:
  68. raise ProviderNotInitializeError(ex.description)
  69. except QuotaExceededError:
  70. raise ProviderQuotaExceededError()
  71. except ModelCurrentlyNotSupportError:
  72. raise ProviderModelCurrentlyNotSupportError()
  73. except InvokeError as e:
  74. raise CompletionRequestError(e.description)
  75. return code_result
  76. api.add_resource(RuleGenerateApi, "/rule-generate")
  77. api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")