ext_storage.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. import shutil
  3. from contextlib import closing
  4. from typing import Union, Generator
  5. import boto3
  6. from botocore.exceptions import ClientError
  7. from flask import Flask
  8. class Storage:
  9. def __init__(self):
  10. self.storage_type = None
  11. self.bucket_name = None
  12. self.client = None
  13. self.folder = None
  14. def init_app(self, app: Flask):
  15. self.storage_type = app.config.get('STORAGE_TYPE')
  16. if self.storage_type == 's3':
  17. self.bucket_name = app.config.get('S3_BUCKET_NAME')
  18. self.client = boto3.client(
  19. 's3',
  20. aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
  21. aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
  22. endpoint_url=app.config.get('S3_ENDPOINT'),
  23. region_name=app.config.get('S3_REGION')
  24. )
  25. else:
  26. self.folder = app.config.get('STORAGE_LOCAL_PATH')
  27. if not os.path.isabs(self.folder):
  28. self.folder = os.path.join(app.root_path, self.folder)
  29. def save(self, filename, data):
  30. if self.storage_type == 's3':
  31. self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
  32. else:
  33. if not self.folder or self.folder.endswith('/'):
  34. filename = self.folder + filename
  35. else:
  36. filename = self.folder + '/' + filename
  37. folder = os.path.dirname(filename)
  38. os.makedirs(folder, exist_ok=True)
  39. with open(os.path.join(os.getcwd(), filename), "wb") as f:
  40. f.write(data)
  41. def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
  42. if stream:
  43. return self.load_stream(filename)
  44. else:
  45. return self.load_once(filename)
  46. def load_once(self, filename: str) -> bytes:
  47. if self.storage_type == 's3':
  48. try:
  49. with closing(self.client) as client:
  50. data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
  51. except ClientError as ex:
  52. if ex.response['Error']['Code'] == 'NoSuchKey':
  53. raise FileNotFoundError("File not found")
  54. else:
  55. raise
  56. else:
  57. if not self.folder or self.folder.endswith('/'):
  58. filename = self.folder + filename
  59. else:
  60. filename = self.folder + '/' + filename
  61. if not os.path.exists(filename):
  62. raise FileNotFoundError("File not found")
  63. with open(filename, "rb") as f:
  64. data = f.read()
  65. return data
  66. def load_stream(self, filename: str) -> Generator:
  67. def generate(filename: str = filename) -> Generator:
  68. if self.storage_type == 's3':
  69. try:
  70. with closing(self.client) as client:
  71. response = client.get_object(Bucket=self.bucket_name, Key=filename)
  72. for chunk in response['Body'].iter_chunks():
  73. yield chunk
  74. except ClientError as ex:
  75. if ex.response['Error']['Code'] == 'NoSuchKey':
  76. raise FileNotFoundError("File not found")
  77. else:
  78. raise
  79. else:
  80. if not self.folder or self.folder.endswith('/'):
  81. filename = self.folder + filename
  82. else:
  83. filename = self.folder + '/' + filename
  84. if not os.path.exists(filename):
  85. raise FileNotFoundError("File not found")
  86. with open(filename, "rb") as f:
  87. while chunk := f.read(4096): # Read in chunks of 4KB
  88. yield chunk
  89. return generate()
  90. def download(self, filename, target_filepath):
  91. if self.storage_type == 's3':
  92. with closing(self.client) as client:
  93. client.download_file(self.bucket_name, filename, target_filepath)
  94. else:
  95. if not self.folder or self.folder.endswith('/'):
  96. filename = self.folder + filename
  97. else:
  98. filename = self.folder + '/' + filename
  99. if not os.path.exists(filename):
  100. raise FileNotFoundError("File not found")
  101. shutil.copyfile(filename, target_filepath)
  102. def exists(self, filename):
  103. if self.storage_type == 's3':
  104. with closing(self.client) as client:
  105. try:
  106. client.head_object(Bucket=self.bucket_name, Key=filename)
  107. return True
  108. except:
  109. return False
  110. else:
  111. if not self.folder or self.folder.endswith('/'):
  112. filename = self.folder + filename
  113. else:
  114. filename = self.folder + '/' + filename
  115. return os.path.exists(filename)
  116. storage = Storage()
  117. def init_app(app: Flask):
  118. storage.init_app(app)