s3_storage.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package s3
  2. import (
  3. "bytes"
  4. "context"
  5. "io"
  6. "strings"
  7. "time"
  8. "github.com/aws/aws-sdk-go-v2/aws"
  9. "github.com/aws/aws-sdk-go-v2/config"
  10. "github.com/aws/aws-sdk-go-v2/credentials"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. "github.com/langgenius/dify-plugin-daemon/internal/oss"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. )
  15. type S3Storage struct {
  16. bucket string
  17. client *s3.Client
  18. }
  19. func NewS3Storage(useAws bool, endpoint string, ak string, sk string, bucket string, region string) (oss.OSS, error) {
  20. var cfg aws.Config
  21. var err error
  22. var client *s3.Client
  23. if useAws {
  24. if ak == "" && sk == "" {
  25. cfg, err = config.LoadDefaultConfig(
  26. context.TODO(),
  27. config.WithRegion(region),
  28. )
  29. } else {
  30. cfg, err = config.LoadDefaultConfig(
  31. context.TODO(),
  32. config.WithRegion(region),
  33. config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
  34. ak,
  35. sk,
  36. "",
  37. )),
  38. )
  39. }
  40. if err != nil {
  41. return nil, err
  42. }
  43. client = s3.NewFromConfig(cfg, func(options *s3.Options) {
  44. options.BaseEndpoint = aws.String(endpoint)
  45. })
  46. } else {
  47. client = s3.New(s3.Options{
  48. Credentials: credentials.NewStaticCredentialsProvider(ak, sk, ""),
  49. UsePathStyle: true,
  50. Region: region,
  51. EndpointResolver: s3.EndpointResolverFunc(
  52. func(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) {
  53. return aws.Endpoint{
  54. URL: endpoint,
  55. HostnameImmutable: false,
  56. SigningName: "s3",
  57. PartitionID: "aws",
  58. SigningRegion: region,
  59. SigningMethod: "v4",
  60. Source: aws.EndpointSourceCustom,
  61. }, nil
  62. }),
  63. })
  64. }
  65. // check bucket
  66. _, err = client.HeadBucket(context.TODO(), &s3.HeadBucketInput{
  67. Bucket: aws.String(bucket),
  68. })
  69. if err != nil {
  70. _, err = client.CreateBucket(context.TODO(), &s3.CreateBucketInput{
  71. Bucket: aws.String(bucket),
  72. })
  73. if err != nil {
  74. return nil, err
  75. }
  76. }
  77. return &S3Storage{bucket: bucket, client: client}, nil
  78. }
  79. func (s *S3Storage) Save(key string, data []byte) error {
  80. _, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
  81. Bucket: aws.String(s.bucket),
  82. Key: aws.String(key),
  83. Body: bytes.NewReader(data),
  84. })
  85. return err
  86. }
  87. func (s *S3Storage) Load(key string) ([]byte, error) {
  88. resp, err := s.client.GetObject(context.TODO(), &s3.GetObjectInput{
  89. Bucket: aws.String(s.bucket),
  90. Key: aws.String(key),
  91. })
  92. if err != nil {
  93. return nil, err
  94. }
  95. return io.ReadAll(resp.Body)
  96. }
  97. func (s *S3Storage) Exists(key string) (bool, error) {
  98. _, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
  99. Bucket: aws.String(s.bucket),
  100. Key: aws.String(key),
  101. })
  102. return err == nil, nil
  103. }
  104. func (s *S3Storage) Delete(key string) error {
  105. _, err := s.client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
  106. Bucket: aws.String(s.bucket),
  107. Key: aws.String(key),
  108. })
  109. return err
  110. }
  111. func (s *S3Storage) List(prefix string) ([]oss.OSSPath, error) {
  112. // append a slash to the prefix if it doesn't end with one
  113. if !strings.HasSuffix(prefix, "/") {
  114. prefix = prefix + "/"
  115. }
  116. var keys []oss.OSSPath
  117. input := &s3.ListObjectsV2Input{
  118. Bucket: aws.String(s.bucket),
  119. Prefix: aws.String(prefix),
  120. }
  121. paginator := s3.NewListObjectsV2Paginator(s.client, input)
  122. for paginator.HasMorePages() {
  123. page, err := paginator.NextPage(context.TODO())
  124. if err != nil {
  125. return nil, err
  126. }
  127. for _, obj := range page.Contents {
  128. // remove prefix
  129. key := strings.TrimPrefix(*obj.Key, prefix)
  130. // remove leading slash
  131. key = strings.TrimPrefix(key, "/")
  132. keys = append(keys, oss.OSSPath{
  133. Path: key,
  134. IsDir: false,
  135. })
  136. }
  137. }
  138. return keys, nil
  139. }
  140. func (s *S3Storage) State(key string) (oss.OSSState, error) {
  141. resp, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
  142. Bucket: aws.String(s.bucket),
  143. Key: aws.String(key),
  144. })
  145. if err != nil {
  146. return oss.OSSState{}, err
  147. }
  148. if resp.ContentLength == nil {
  149. resp.ContentLength = parser.ToPtr[int64](0)
  150. }
  151. if resp.LastModified == nil {
  152. resp.LastModified = parser.ToPtr(time.Time{})
  153. }
  154. return oss.OSSState{
  155. Size: *resp.ContentLength,
  156. LastModified: *resp.LastModified,
  157. }, nil
  158. }