lambda.go 9.8 KB


  1. package aws
  2. // This file contains functions for interacting with AWS Lambda
  3. // it take a docker image and push it to ECR, create a lambda function and deploy it
  4. // also, it will create a function url for the lambda function with auth enabled
  5. import (
  6. "context"
  7. "encoding/base64"
  8. "fmt"
  9. "strings"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/config"
  12. "github.com/aws/aws-sdk-go-v2/credentials"
  13. "github.com/aws/aws-sdk-go-v2/service/ecr"
  14. "github.com/aws/aws-sdk-go-v2/service/lambda"
  15. lambdatypes "github.com/aws/aws-sdk-go-v2/service/lambda/types"
  16. "github.com/aws/aws-sdk-go-v2/service/sts"
  17. "github.com/langgenius/dify-plugin-daemon/internal/types/app"
  18. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  19. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  20. )
  21. var (
  22. aws_lambda_config *aws.Config
  23. lambda_client *lambda.Client
  24. lambda_account_id string
  25. )
  26. // InitLambda initializes the AWS configuration and validates the credentials
  27. // It takes a pointer to the app.Config struct as an argument
  28. func InitLambda(app *app.Config) {
  29. // Check if required AWS Lambda configuration is provided
  30. if app.AWSLambdaRegion == nil || app.AWSLambdaAccessKey == nil || app.AWSLambdaSecretKey == nil {
  31. log.Panic("AWSLambdaRegion, AWSLambdaAccessKey, and AWSLambdaSecretKey must be set")
  32. }
  33. // Load AWS configuration with provided credentials
  34. c, err := config.LoadDefaultConfig(
  35. context.TODO(),
  36. config.WithRegion(*app.AWSLambdaRegion),
  37. config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
  38. *app.AWSLambdaAccessKey,
  39. *app.AWSLambdaSecretKey,
  40. "",
  41. )),
  42. )
  43. // Handle error if AWS config loading fails
  44. if err != nil {
  45. log.Panic("Failed to load AWS Lambda config: %v", err)
  46. }
  47. log.Info("AWS Lambda config loaded")
  48. // Create STS client to validate AWS credentials
  49. stsClient := sts.NewFromConfig(c)
  50. identity, err := stsClient.GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{})
  51. if err != nil {
  52. log.Panic("Failed to validate AWS Lambda credentials: %v", err)
  53. }
  54. // Get the account ID
  55. lambda_account_id = *identity.Account
  56. // Create the Lambda client
  57. lambda_client = lambda.NewFromConfig(c)
  58. log.Info("AWS Lambda credentials validated successfully")
  59. // Store the AWS configuration globally
  60. aws_lambda_config = &c
  61. }
  62. type LambdaFunction struct {
  63. FunctionName string
  64. FunctionARN string
  65. FunctionURL string
  66. }
  67. // PushImageToECR pushes a Docker image to ECR
  68. func PushImageToECR(ctx context.Context, plugin_runtime entities.PluginRuntimeInterface) (string, error) {
  69. ecr_client := ecr.NewFromConfig(*aws_lambda_config)
  70. // Create ECR repository if it doesn't exist
  71. identity, err := plugin_runtime.Identity()
  72. if err != nil {
  73. return "", fmt.Errorf("failed to get plugin identity: %v", err)
  74. }
  75. image_name := fmt.Sprintf("dify-plugin-%s-%s", identity, plugin_runtime.Checksum())
  76. repo_name := fmt.Sprintf("dify-plugin-%s", image_name)
  77. _, err = ecr_client.CreateRepository(ctx, &ecr.CreateRepositoryInput{
  78. RepositoryName: aws.String(repo_name),
  79. })
  80. if err != nil && !strings.Contains(err.Error(), "RepositoryAlreadyExistsException") {
  81. return "", fmt.Errorf("failed to create ECR repository: %v", err)
  82. }
  83. // Get ECR authorization token
  84. auth_output, err := ecr_client.GetAuthorizationToken(ctx, &ecr.GetAuthorizationTokenInput{})
  85. if err != nil {
  86. return "", fmt.Errorf("failed to get ECR authorization token: %v", err)
  87. }
  88. if len(auth_output.AuthorizationData) == 0 || auth_output.AuthorizationData[0].AuthorizationToken == nil {
  89. return "", fmt.Errorf("invalid ECR authorization data")
  90. }
  91. auth_token, err := base64.StdEncoding.DecodeString(*auth_output.AuthorizationData[0].AuthorizationToken)
  92. if err != nil {
  93. return "", fmt.Errorf("failed to decode ECR authorization token: %v", err)
  94. }
  95. // Extract username and password from auth token
  96. credentials := strings.SplitN(string(auth_token), ":", 2)
  97. if len(credentials) != 2 {
  98. return "", fmt.Errorf("invalid ECR credentials format")
  99. }
  100. // TODO: Use the extracted credentials to push the Docker image to ECR
  101. // This step typically involves using a Docker client library or executing Docker CLI commands
  102. if auth_output.AuthorizationData[0].ProxyEndpoint == nil {
  103. return "", fmt.Errorf("invalid ECR proxy endpoint")
  104. }
  105. return fmt.Sprintf("%s/%s:latest", *auth_output.AuthorizationData[0].ProxyEndpoint, repo_name), nil
  106. }
  107. // CreateLambdaFunction creates a Lambda function from an ECR image
  108. func CreateLambdaFunction(ctx context.Context, plugin_runtime entities.PluginRuntimeInterface, image_uri string) (*LambdaFunction, error) {
  109. function_name := fmt.Sprintf("dify-plugin-%s", plugin_runtime.Checksum())
  110. // Get or create the lambda execution role
  111. role_arn, err := getOrCreateLambdaExecutionRole(ctx)
  112. if err != nil {
  113. return nil, fmt.Errorf("failed to get or create Lambda execution role: %v", err)
  114. }
  115. create_output, err := lambda_client.CreateFunction(ctx, &lambda.CreateFunctionInput{
  116. FunctionName: aws.String(function_name),
  117. Role: aws.String(role_arn),
  118. PackageType: lambdatypes.PackageTypeImage,
  119. Code: &lambdatypes.FunctionCode{
  120. ImageUri: aws.String(image_uri),
  121. },
  122. })
  123. if err != nil {
  124. return nil, fmt.Errorf("failed to create Lambda function: %v", err)
  125. }
  126. if create_output.FunctionArn == nil {
  127. return nil, fmt.Errorf("invalid Lambda function creation output")
  128. }
  129. // Create function URL
  130. url_output, err := lambda_client.CreateFunctionUrlConfig(ctx, &lambda.CreateFunctionUrlConfigInput{
  131. FunctionName: aws.String(function_name),
  132. AuthType: lambdatypes.FunctionUrlAuthTypeAwsIam,
  133. })
  134. if err != nil {
  135. return nil, fmt.Errorf("failed to create function URL: %v", err)
  136. }
  137. if url_output.FunctionUrl == nil {
  138. return nil, fmt.Errorf("invalid function URL creation output")
  139. }
  140. return &LambdaFunction{
  141. FunctionName: function_name,
  142. FunctionARN: *create_output.FunctionArn,
  143. FunctionURL: *url_output.FunctionUrl,
  144. }, nil
  145. }
  146. // ListLambdaFunctions lists all Lambda functions with the "dify-plugin-" prefix
  147. func ListLambdaFunctions(ctx context.Context) ([]*LambdaFunction, error) {
  148. var functions []*LambdaFunction
  149. var marker *string
  150. for {
  151. output, err := lambda_client.ListFunctions(ctx, &lambda.ListFunctionsInput{
  152. Marker: marker,
  153. })
  154. if err != nil {
  155. return nil, fmt.Errorf("failed to list Lambda functions: %v", err)
  156. }
  157. for _, f := range output.Functions {
  158. if f.FunctionName == nil || f.FunctionArn == nil {
  159. continue
  160. }
  161. if strings.HasPrefix(*f.FunctionName, "dify-plugin-") {
  162. url_output, err := lambda_client.GetFunctionUrlConfig(ctx, &lambda.GetFunctionUrlConfigInput{
  163. FunctionName: f.FunctionName,
  164. })
  165. if err != nil {
  166. return nil, fmt.Errorf("failed to get function URL for %s: %v", *f.FunctionName, err)
  167. }
  168. if url_output.FunctionUrl == nil {
  169. return nil, fmt.Errorf("invalid function URL output for %s", *f.FunctionName)
  170. }
  171. functions = append(functions, &LambdaFunction{
  172. FunctionName: *f.FunctionName,
  173. FunctionARN: *f.FunctionArn,
  174. FunctionURL: *url_output.FunctionUrl,
  175. })
  176. }
  177. }
  178. if output.NextMarker == nil {
  179. break
  180. }
  181. marker = output.NextMarker
  182. }
  183. return functions, nil
  184. }
  185. // GetLambdaFunction retrieves a specific Lambda function by its checksum
  186. func GetLambdaFunction(ctx context.Context, identity string, checksum string) (*LambdaFunction, error) {
  187. function_name := fmt.Sprintf("dify-plugin-%s-%s", identity, checksum)
  188. output, err := lambda_client.GetFunction(ctx, &lambda.GetFunctionInput{
  189. FunctionName: aws.String(function_name),
  190. })
  191. if err != nil {
  192. return nil, fmt.Errorf("failed to get Lambda function: %v", err)
  193. }
  194. if output.Configuration == nil || output.Configuration.FunctionName == nil || output.Configuration.FunctionArn == nil {
  195. return nil, fmt.Errorf("invalid GetFunction output")
  196. }
  197. url_output, err := lambda_client.GetFunctionUrlConfig(ctx, &lambda.GetFunctionUrlConfigInput{
  198. FunctionName: aws.String(function_name),
  199. })
  200. if err != nil {
  201. return nil, fmt.Errorf("failed to get function URL: %v", err)
  202. }
  203. if url_output.FunctionUrl == nil {
  204. return nil, fmt.Errorf("invalid function URL output")
  205. }
  206. return &LambdaFunction{
  207. FunctionName: *output.Configuration.FunctionName,
  208. FunctionARN: *output.Configuration.FunctionArn,
  209. FunctionURL: *url_output.FunctionUrl,
  210. }, nil
  211. }
  212. // UpdateLambdaFunction updates an existing Lambda function with a new image
  213. func UpdateLambdaFunction(ctx context.Context, plugin_runtime entities.PluginRuntimeInterface, image_uri string) error {
  214. // Get the function name
  215. identity, err := plugin_runtime.Identity()
  216. if err != nil {
  217. return fmt.Errorf("failed to get plugin identity: %v", err)
  218. }
  219. function_name := fmt.Sprintf("dify-plugin-%s-%s", identity, plugin_runtime.Checksum())
  220. _, err = lambda_client.UpdateFunctionCode(ctx, &lambda.UpdateFunctionCodeInput{
  221. FunctionName: aws.String(function_name),
  222. ImageUri: aws.String(image_uri),
  223. })
  224. if err != nil {
  225. return fmt.Errorf("failed to update Lambda function: %v", err)
  226. }
  227. return nil
  228. }
  229. // DeleteLambdaFunction deletes a Lambda function and its associated function URL
  230. func DeleteLambdaFunction(ctx context.Context, plugin_runtime entities.PluginRuntimeInterface) error {
  231. // Get the function name
  232. identity, err := plugin_runtime.Identity()
  233. if err != nil {
  234. return fmt.Errorf("failed to get plugin identity: %v", err)
  235. }
  236. function_name := fmt.Sprintf("dify-plugin-%s-%s", identity, plugin_runtime.Checksum())
  237. // Delete function URL
  238. _, err = lambda_client.DeleteFunctionUrlConfig(ctx, &lambda.DeleteFunctionUrlConfigInput{
  239. FunctionName: aws.String(function_name),
  240. })
  241. if err != nil {
  242. return fmt.Errorf("failed to delete function URL: %v", err)
  243. }
  244. // Delete Lambda function
  245. _, err = lambda_client.DeleteFunction(ctx, &lambda.DeleteFunctionInput{
  246. FunctionName: aws.String(function_name),
  247. })
  248. if err != nil {
  249. return fmt.Errorf("failed to delete Lambda function: %v", err)
  250. }
  251. return nil
  252. }