provider_checkout_service.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import datetime
  2. import logging
  3. import stripe
  4. from flask import current_app
  5. from core.model_providers.model_provider_factory import ModelProviderFactory
  6. from extensions.ext_database import db
  7. from models.account import Account
  8. from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType
  9. class ProviderCheckout:
  10. def __init__(self, stripe_checkout_session):
  11. self.stripe_checkout_session = stripe_checkout_session
  12. def get_checkout_url(self):
  13. return self.stripe_checkout_session.url
  14. class ProviderCheckoutService:
  15. def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout:
  16. # check provider name is valid
  17. model_provider_rules = ModelProviderFactory.get_provider_rules()
  18. if provider_name not in model_provider_rules:
  19. raise ValueError(f'provider name {provider_name} is invalid')
  20. model_provider_rule = model_provider_rules[provider_name]
  21. # check provider name can be paid
  22. self._check_provider_payable(provider_name, model_provider_rule)
  23. # get stripe checkout product id
  24. paid_provider = self._get_paid_provider(tenant_id, provider_name)
  25. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  26. model_provider = model_provider_class(provider=paid_provider)
  27. payment_info = model_provider.get_payment_info()
  28. if not payment_info:
  29. raise ValueError(f'provider name {provider_name} not support payment')
  30. payment_product_id = payment_info['product_id']
  31. payment_min_quantity = payment_info['min_quantity']
  32. payment_max_quantity = payment_info['max_quantity']
  33. # create provider order
  34. provider_order = ProviderOrder(
  35. tenant_id=tenant_id,
  36. provider_name=provider_name,
  37. account_id=account.id,
  38. payment_product_id=payment_product_id,
  39. quantity=1,
  40. payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value
  41. )
  42. db.session.add(provider_order)
  43. db.session.flush()
  44. line_item = {
  45. 'price': f'{payment_product_id}',
  46. 'quantity': payment_min_quantity
  47. }
  48. if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
  49. line_item['adjustable_quantity'] = {
  50. 'enabled': True,
  51. 'minimum': payment_min_quantity,
  52. 'maximum': payment_max_quantity
  53. }
  54. try:
  55. # create stripe checkout session
  56. checkout_session = stripe.checkout.Session.create(
  57. line_items=[
  58. line_item
  59. ],
  60. mode='payment',
  61. success_url=current_app.config.get("CONSOLE_WEB_URL")
  62. + f'?provider_name={provider_name}&payment_result=succeeded',
  63. cancel_url=current_app.config.get("CONSOLE_WEB_URL")
  64. + f'?provider_name={provider_name}&payment_result=cancelled',
  65. automatic_tax={'enabled': True},
  66. )
  67. except Exception as e:
  68. logging.exception(e)
  69. raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later')
  70. provider_order.payment_id = checkout_session.id
  71. db.session.commit()
  72. return ProviderCheckout(checkout_session)
  73. def fulfill_provider_order(self, event, line_items):
  74. provider_order = db.session.query(ProviderOrder) \
  75. .filter(ProviderOrder.payment_id == event['data']['object']['id']) \
  76. .first()
  77. if not provider_order:
  78. raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')
  79. if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
  80. raise ValueError(
  81. f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
  82. provider_order.transaction_id = event['data']['object']['payment_intent']
  83. provider_order.currency = event['data']['object']['currency']
  84. provider_order.total_amount = event['data']['object']['amount_subtotal']
  85. provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value
  86. provider_order.paid_at = datetime.datetime.utcnow()
  87. provider_order.updated_at = provider_order.paid_at
  88. # update provider quota
  89. provider = db.session.query(Provider).filter(
  90. Provider.tenant_id == provider_order.tenant_id,
  91. Provider.provider_name == provider_order.provider_name,
  92. Provider.provider_type == ProviderType.SYSTEM.value,
  93. Provider.quota_type == ProviderQuotaType.PAID.value
  94. ).first()
  95. if not provider:
  96. raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, '
  97. f'provider name: {provider_order.provider_name}')
  98. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name)
  99. model_provider = model_provider_class(provider=provider)
  100. payment_info = model_provider.get_payment_info()
  101. quantity = line_items['data'][0]['quantity']
  102. if not payment_info:
  103. increase_quota = 0
  104. else:
  105. increase_quota = int(payment_info['increase_quota']) * quantity
  106. if increase_quota > 0:
  107. provider.quota_limit += increase_quota
  108. provider.is_valid = True
  109. db.session.commit()
  110. def _check_provider_payable(self, provider_name: str, model_provider_rule: dict):
  111. if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']:
  112. raise ValueError(f'provider name {provider_name} not support payment')
  113. if 'system_config' not in model_provider_rule:
  114. raise ValueError(f'provider name {provider_name} not support payment')
  115. if 'supported_quota_types' not in model_provider_rule['system_config']:
  116. raise ValueError(f'provider name {provider_name} not support payment')
  117. if 'paid' not in model_provider_rule['system_config']['supported_quota_types']:
  118. raise ValueError(f'provider name {provider_name} not support payment')
  119. def _get_paid_provider(self, tenant_id: str, provider_name: str):
  120. paid_provider = db.session.query(Provider) \
  121. .filter(
  122. Provider.tenant_id == tenant_id,
  123. Provider.provider_name == provider_name,
  124. Provider.provider_type == ProviderType.SYSTEM.value,
  125. Provider.quota_type == ProviderQuotaType.PAID.value,
  126. ).first()
  127. if not paid_provider:
  128. paid_provider = Provider(
  129. tenant_id=tenant_id,
  130. provider_name=provider_name,
  131. provider_type=ProviderType.SYSTEM.value,
  132. quota_type=ProviderQuotaType.PAID.value,
  133. quota_limit=0,
  134. quota_used=0,
  135. )
  136. db.session.add(paid_provider)
  137. db.session.commit()
  138. return paid_provider