supabase_storage.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import io
  2. from collections.abc import Generator
  3. from pathlib import Path
  4. from supabase import Client
  5. from extensions.storage.base_storage import BaseStorage
  6. class SupabaseStorage(BaseStorage):
  7. """Implementation for supabase obs storage."""
  8. def __init__(self):
  9. super().__init__()
  10. app_config = self.app.config
  11. self.bucket_name = app_config.get("SUPABASE_BUCKET_NAME")
  12. self.client = Client(
  13. supabase_url=app_config.get("SUPABASE_URL"), supabase_key=app_config.get("SUPABASE_API_KEY")
  14. )
  15. self.create_bucket(
  16. id=app_config.get("SUPABASE_BUCKET_NAME"), bucket_name=app_config.get("SUPABASE_BUCKET_NAME")
  17. )
  18. def create_bucket(self, id, bucket_name):
  19. if not self.bucket_exists():
  20. self.client.storage.create_bucket(id=id, name=bucket_name)
  21. def save(self, filename, data):
  22. self.client.storage.from_(self.bucket_name).upload(filename, data)
  23. def load_once(self, filename: str) -> bytes:
  24. content = self.client.storage.from_(self.bucket_name).download(filename)
  25. return content
  26. def load_stream(self, filename: str) -> Generator:
  27. def generate(filename: str = filename) -> Generator:
  28. result = self.client.storage.from_(self.bucket_name).download(filename)
  29. byte_stream = io.BytesIO(result)
  30. while chunk := byte_stream.read(4096): # Read in chunks of 4KB
  31. yield chunk
  32. return generate()
  33. def download(self, filename, target_filepath):
  34. result = self.client.storage.from_(self.bucket_name).download(filename)
  35. Path(result).write_bytes(result)
  36. def exists(self, filename):
  37. result = self.client.storage.from_(self.bucket_name).list(filename)
  38. if result.count() > 0:
  39. return True
  40. return False
  41. def delete(self, filename):
  42. self.client.storage.from_(self.bucket_name).remove(filename)
  43. def bucket_exists(self):
  44. buckets = self.client.storage.list_buckets()
  45. return any(bucket.name == self.bucket_name for bucket in buckets)