s3_storage.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package s3
  2. import (
  3. "bytes"
  4. "context"
  5. "io"
  6. "strings"
  7. "time"
  8. "github.com/aws/aws-sdk-go-v2/aws"
  9. "github.com/aws/aws-sdk-go-v2/config"
  10. "github.com/aws/aws-sdk-go-v2/credentials"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. "github.com/langgenius/dify-plugin-daemon/internal/oss"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. )
  15. type S3Storage struct {
  16. bucket string
  17. client *s3.Client
  18. }
  19. func NewS3Storage(useAws bool, endpoint string, usePathStyle bool, ak string, sk string, bucket string, region string) (oss.OSS, error) {
  20. var cfg aws.Config
  21. var err error
  22. var client *s3.Client
  23. if useAws {
  24. if ak == "" && sk == "" {
  25. cfg, err = config.LoadDefaultConfig(
  26. context.TODO(),
  27. config.WithRegion(region),
  28. )
  29. } else {
  30. cfg, err = config.LoadDefaultConfig(
  31. context.TODO(),
  32. config.WithRegion(region),
  33. config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
  34. ak,
  35. sk,
  36. "",
  37. )),
  38. )
  39. }
  40. if err != nil {
  41. return nil, err
  42. }
  43. client = s3.NewFromConfig(cfg, func(options *s3.Options) {
  44. if endpoint != "" {
  45. options.BaseEndpoint = aws.String(endpoint)
  46. }
  47. })
  48. } else {
  49. client = s3.New(s3.Options{
  50. Credentials: credentials.NewStaticCredentialsProvider(ak, sk, ""),
  51. UsePathStyle: usePathStyle,
  52. Region: region,
  53. EndpointResolver: s3.EndpointResolverFunc(
  54. func(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) {
  55. return aws.Endpoint{
  56. URL: endpoint,
  57. HostnameImmutable: false,
  58. SigningName: "s3",
  59. PartitionID: "aws",
  60. SigningRegion: region,
  61. SigningMethod: "v4",
  62. Source: aws.EndpointSourceCustom,
  63. }, nil
  64. }),
  65. })
  66. }
  67. // check bucket
  68. _, err = client.HeadBucket(context.TODO(), &s3.HeadBucketInput{
  69. Bucket: aws.String(bucket),
  70. })
  71. if err != nil {
  72. _, err = client.CreateBucket(context.TODO(), &s3.CreateBucketInput{
  73. Bucket: aws.String(bucket),
  74. })
  75. if err != nil {
  76. return nil, err
  77. }
  78. }
  79. return &S3Storage{bucket: bucket, client: client}, nil
  80. }
  81. func (s *S3Storage) Save(key string, data []byte) error {
  82. _, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
  83. Bucket: aws.String(s.bucket),
  84. Key: aws.String(key),
  85. Body: bytes.NewReader(data),
  86. })
  87. return err
  88. }
  89. func (s *S3Storage) Load(key string) ([]byte, error) {
  90. resp, err := s.client.GetObject(context.TODO(), &s3.GetObjectInput{
  91. Bucket: aws.String(s.bucket),
  92. Key: aws.String(key),
  93. })
  94. if err != nil {
  95. return nil, err
  96. }
  97. return io.ReadAll(resp.Body)
  98. }
  99. func (s *S3Storage) Exists(key string) (bool, error) {
  100. _, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
  101. Bucket: aws.String(s.bucket),
  102. Key: aws.String(key),
  103. })
  104. return err == nil, nil
  105. }
  106. func (s *S3Storage) Delete(key string) error {
  107. _, err := s.client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
  108. Bucket: aws.String(s.bucket),
  109. Key: aws.String(key),
  110. })
  111. return err
  112. }
  113. func (s *S3Storage) List(prefix string) ([]oss.OSSPath, error) {
  114. // append a slash to the prefix if it doesn't end with one
  115. if !strings.HasSuffix(prefix, "/") {
  116. prefix = prefix + "/"
  117. }
  118. var keys []oss.OSSPath
  119. input := &s3.ListObjectsV2Input{
  120. Bucket: aws.String(s.bucket),
  121. Prefix: aws.String(prefix),
  122. }
  123. paginator := s3.NewListObjectsV2Paginator(s.client, input)
  124. for paginator.HasMorePages() {
  125. page, err := paginator.NextPage(context.TODO())
  126. if err != nil {
  127. return nil, err
  128. }
  129. for _, obj := range page.Contents {
  130. // remove prefix
  131. key := strings.TrimPrefix(*obj.Key, prefix)
  132. // remove leading slash
  133. key = strings.TrimPrefix(key, "/")
  134. keys = append(keys, oss.OSSPath{
  135. Path: key,
  136. IsDir: false,
  137. })
  138. }
  139. }
  140. return keys, nil
  141. }
  142. func (s *S3Storage) State(key string) (oss.OSSState, error) {
  143. resp, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
  144. Bucket: aws.String(s.bucket),
  145. Key: aws.String(key),
  146. })
  147. if err != nil {
  148. return oss.OSSState{}, err
  149. }
  150. if resp.ContentLength == nil {
  151. resp.ContentLength = parser.ToPtr[int64](0)
  152. }
  153. if resp.LastModified == nil {
  154. resp.LastModified = parser.ToPtr(time.Time{})
  155. }
  156. return oss.OSSState{
  157. Size: *resp.ContentLength,
  158. LastModified: *resp.LastModified,
  159. }, nil
  160. }
  161. func (s *S3Storage) Type() string {
  162. return oss.OSS_TYPE_S3
  163. }