s3.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package persistence
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "io"
  7. "github.com/aws/aws-sdk-go-v2/aws"
  8. "github.com/aws/aws-sdk-go-v2/config"
  9. "github.com/aws/aws-sdk-go-v2/credentials"
  10. "github.com/aws/aws-sdk-go-v2/service/s3"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  12. )
  13. type S3Wrapper struct {
  14. client *s3.Client
  15. bucket string
  16. }
  17. func NewS3Wrapper(region string, access_key string, secret_key string, bucket string) (*S3Wrapper, error) {
  18. c, err := config.LoadDefaultConfig(
  19. context.TODO(),
  20. config.WithRegion(region),
  21. config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
  22. access_key,
  23. secret_key,
  24. "",
  25. )),
  26. )
  27. if err != nil {
  28. log.Panic("Failed to load AWS S3 config: %v", err)
  29. }
  30. s3_client := s3.NewFromConfig(c)
  31. log.Info("AWS S3 config loaded")
  32. // check
  33. _, err = s3_client.HeadBucket(context.TODO(), &s3.HeadBucketInput{
  34. Bucket: aws.String(bucket),
  35. })
  36. if err != nil {
  37. log.Panic("Failed to head bucket: %v", err)
  38. }
  39. return &S3Wrapper{
  40. client: s3_client,
  41. bucket: bucket,
  42. }, nil
  43. }
  44. func (s *S3Wrapper) Save(tenant_id string, plugin_checksum string, key string, data []byte) error {
  45. // save to s3
  46. _, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
  47. Bucket: aws.String(s.bucket),
  48. Key: aws.String(key),
  49. Body: bytes.NewReader(data),
  50. })
  51. if err != nil {
  52. return err
  53. }
  54. return nil
  55. }
  56. func (s *S3Wrapper) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) {
  57. // load from s3
  58. resp, err := s.client.GetObject(context.TODO(), &s3.GetObjectInput{
  59. Bucket: aws.String(s.bucket),
  60. Key: aws.String(key),
  61. })
  62. if err != nil {
  63. return nil, err
  64. }
  65. return io.ReadAll(resp.Body)
  66. }
  67. func (s *S3Wrapper) Delete(tenant_id string, plugin_checksum string, key string) error {
  68. _, err := s.client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
  69. Bucket: aws.String(s.bucket),
  70. Key: aws.String(key),
  71. })
  72. return err
  73. }
  74. func (s *S3Wrapper) StateSize(tenant_id string, plugin_checksum string, key string) (int64, error) {
  75. // get object size
  76. resp, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
  77. Bucket: aws.String(s.bucket),
  78. Key: aws.String(key),
  79. })
  80. if err != nil {
  81. return 0, err
  82. }
  83. if resp.ContentLength == nil {
  84. return 0, fmt.Errorf("content length not found")
  85. }
  86. return *resp.ContentLength, nil
  87. }