ext_storage.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. import shutil
  3. from collections.abc import Generator
  4. from contextlib import closing
  5. from typing import Union
  6. import boto3
  7. from botocore.exceptions import ClientError
  8. from flask import Flask
  9. class Storage:
  10. def __init__(self):
  11. self.storage_type = None
  12. self.bucket_name = None
  13. self.client = None
  14. self.folder = None
  15. def init_app(self, app: Flask):
  16. self.storage_type = app.config.get('STORAGE_TYPE')
  17. if self.storage_type == 's3':
  18. self.bucket_name = app.config.get('S3_BUCKET_NAME')
  19. self.client = boto3.client(
  20. 's3',
  21. aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
  22. aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
  23. endpoint_url=app.config.get('S3_ENDPOINT'),
  24. region_name=app.config.get('S3_REGION')
  25. )
  26. else:
  27. self.folder = app.config.get('STORAGE_LOCAL_PATH')
  28. if not os.path.isabs(self.folder):
  29. self.folder = os.path.join(app.root_path, self.folder)
  30. def save(self, filename, data):
  31. if self.storage_type == 's3':
  32. self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
  33. else:
  34. if not self.folder or self.folder.endswith('/'):
  35. filename = self.folder + filename
  36. else:
  37. filename = self.folder + '/' + filename
  38. folder = os.path.dirname(filename)
  39. os.makedirs(folder, exist_ok=True)
  40. with open(os.path.join(os.getcwd(), filename), "wb") as f:
  41. f.write(data)
  42. def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
  43. if stream:
  44. return self.load_stream(filename)
  45. else:
  46. return self.load_once(filename)
  47. def load_once(self, filename: str) -> bytes:
  48. if self.storage_type == 's3':
  49. try:
  50. with closing(self.client) as client:
  51. data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
  52. except ClientError as ex:
  53. if ex.response['Error']['Code'] == 'NoSuchKey':
  54. raise FileNotFoundError("File not found")
  55. else:
  56. raise
  57. else:
  58. if not self.folder or self.folder.endswith('/'):
  59. filename = self.folder + filename
  60. else:
  61. filename = self.folder + '/' + filename
  62. if not os.path.exists(filename):
  63. raise FileNotFoundError("File not found")
  64. with open(filename, "rb") as f:
  65. data = f.read()
  66. return data
  67. def load_stream(self, filename: str) -> Generator:
  68. def generate(filename: str = filename) -> Generator:
  69. if self.storage_type == 's3':
  70. try:
  71. with closing(self.client) as client:
  72. response = client.get_object(Bucket=self.bucket_name, Key=filename)
  73. for chunk in response['Body'].iter_chunks():
  74. yield chunk
  75. except ClientError as ex:
  76. if ex.response['Error']['Code'] == 'NoSuchKey':
  77. raise FileNotFoundError("File not found")
  78. else:
  79. raise
  80. else:
  81. if not self.folder or self.folder.endswith('/'):
  82. filename = self.folder + filename
  83. else:
  84. filename = self.folder + '/' + filename
  85. if not os.path.exists(filename):
  86. raise FileNotFoundError("File not found")
  87. with open(filename, "rb") as f:
  88. while chunk := f.read(4096): # Read in chunks of 4KB
  89. yield chunk
  90. return generate()
  91. def download(self, filename, target_filepath):
  92. if self.storage_type == 's3':
  93. with closing(self.client) as client:
  94. client.download_file(self.bucket_name, filename, target_filepath)
  95. else:
  96. if not self.folder or self.folder.endswith('/'):
  97. filename = self.folder + filename
  98. else:
  99. filename = self.folder + '/' + filename
  100. if not os.path.exists(filename):
  101. raise FileNotFoundError("File not found")
  102. shutil.copyfile(filename, target_filepath)
  103. def exists(self, filename):
  104. if self.storage_type == 's3':
  105. with closing(self.client) as client:
  106. try:
  107. client.head_object(Bucket=self.bucket_name, Key=filename)
  108. return True
  109. except:
  110. return False
  111. else:
  112. if not self.folder or self.folder.endswith('/'):
  113. filename = self.folder + filename
  114. else:
  115. filename = self.folder + '/' + filename
  116. return os.path.exists(filename)
  117. storage = Storage()
  118. def init_app(app: Flask):
  119. storage.init_app(app)