diff --git a/cmd/kaniko-ecr/main.go b/cmd/kaniko-ecr/main.go index 457d5ef..1652369 100644 --- a/cmd/kaniko-ecr/main.go +++ b/cmd/kaniko-ecr/main.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" ecrv1 "github.com/aws/aws-sdk-go/service/ecr" + ecrpublicv1 "github.com/aws/aws-sdk-go/service/ecrpublic" "github.com/aws/smithy-go" "github.com/hashicorp/go-version" "github.com/joho/godotenv" @@ -245,6 +246,8 @@ func run(c *cli.Context) error { registry := c.String("registry") region := c.String("region") noPush := c.Bool("no-push") + assumeRole := c.String("assume-role") + externalId := c.String("external-id") dockerConfig, err := createDockerConfig( c.String("docker-registry"), @@ -253,8 +256,8 @@ func run(c *cli.Context) error { c.String("access-key"), c.String("secret-key"), registry, - c.String("assume-role"), - c.String("external-id"), + assumeRole, + externalId, region, noPush, ) @@ -273,7 +276,7 @@ func run(c *cli.Context) error { // only create repository when pushing and create-repository is true if !noPush && c.Bool("create-repository") { - if err := createRepository(region, repo, registry); err != nil { + if err := createRepository(region, repo, registry, assumeRole, externalId); err != nil { return err } } @@ -283,7 +286,7 @@ func run(c *cli.Context) error { if err != nil { logrus.Fatal(err) } - if err := uploadLifeCyclePolicy(region, repo, string(contents)); err != nil { + if err := uploadLifeCyclePolicy(region, repo, string(contents), assumeRole, externalId); err != nil { logrus.Fatal(fmt.Sprintf("error uploading ECR lifecycle policy: %v", err)) } } @@ -293,7 +296,7 @@ func run(c *cli.Context) error { if err != nil { logrus.Fatal(err) } - if err := uploadRepositoryPolicy(region, repo, registry, string(contents)); err != nil { + if err := uploadRepositoryPolicy(region, repo, registry, string(contents), assumeRole, externalId); err != nil { logrus.Fatal(fmt.Sprintf("error uploading ECR lifecycle policy: %v", err)) } } @@ -383,7 +386,7 @@ func createDockerConfig(dockerRegistry, dockerUsername, dockerPassword, accessKe return dockerConfig, nil } -func createRepository(region, repo, registry string) error { +func createRepository(region, repo, registry, assumeRole, externalId string) error { if registry == "" { return fmt.Errorf("registry must be specified") } @@ -392,22 +395,29 @@ func createRepository(region, repo, registry string) error { return fmt.Errorf("repo must be specified") } - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) - if err != nil { - return errors.Wrap(err, "failed to load aws config") - } - var createErr error - //create public repo - //if registry string starts with public domain (ex: public.ecr.aws/example-registry) - if isRegistryPublic(registry) { - svc := ecrpublic.NewFromConfig(cfg) - _, createErr = svc.CreateRepository(context.TODO(), &ecrpublic.CreateRepositoryInput{RepositoryName: &repo}) - //create private repo + if assumeRole != "" { + if isRegistryPublic(registry) { + _, createErr = getAssumeRoleEcrPublicSvc(region, assumeRole, externalId).CreateRepository(&ecrpublicv1.CreateRepositoryInput{RepositoryName: &repo}) + } else { + _, createErr = getAssumeRoleEcrSvc(region, assumeRole, externalId).CreateRepository(&ecrv1.CreateRepositoryInput{RepositoryName: &repo}) + } } else { - svc := ecr.NewFromConfig(cfg) - _, createErr = svc.CreateRepository(context.TODO(), &ecr.CreateRepositoryInput{RepositoryName: &repo}) + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + return errors.Wrap(err, "failed to load aws config") + } + //create public repo + //if registry string starts with public domain (ex: public.ecr.aws/example-registry) + if isRegistryPublic(registry) { + svc := ecrpublic.NewFromConfig(cfg) + _, createErr = svc.CreateRepository(context.TODO(), &ecrpublic.CreateRepositoryInput{RepositoryName: &repo}) + //create private repo + } else { + svc := ecr.NewFromConfig(cfg) + _, createErr = svc.CreateRepository(context.TODO(), &ecr.CreateRepositoryInput{RepositoryName: &repo}) + } } var apiError smithy.APIError @@ -418,46 +428,67 @@ func createRepository(region, repo, registry string) error { return nil } -func uploadLifeCyclePolicy(region, repo, lifecyclePolicy string) (err error) { - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) - if err != nil { - return errors.Wrap(err, "failed to load aws config") - } +func uploadLifeCyclePolicy(region, repo, lifecyclePolicy, assumeRole, externalId string) (err error) { + if assumeRole != "" { + input := &ecrv1.PutLifecyclePolicyInput{ + LifecyclePolicyText: aws.String(lifecyclePolicy), + RepositoryName: aws.String(repo), + } + _, err = getAssumeRoleEcrSvc(region, assumeRole, externalId).PutLifecyclePolicy(input) + } else { + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + return errors.Wrap(err, "failed to load aws config") + } - svc := ecr.NewFromConfig(cfg) + svc := ecr.NewFromConfig(cfg) - input := &ecr.PutLifecyclePolicyInput{ - LifecyclePolicyText: aws.String(lifecyclePolicy), - RepositoryName: aws.String(repo), + input := &ecr.PutLifecyclePolicyInput{ + LifecyclePolicyText: aws.String(lifecyclePolicy), + RepositoryName: aws.String(repo), + } + _, err = svc.PutLifecyclePolicy(context.TODO(), input) } - _, err = svc.PutLifecyclePolicy(context.TODO(), input) return err } -func uploadRepositoryPolicy(region, repo, registry, repositoryPolicy string) (err error) { - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) - if err != nil { - return errors.Wrap(err, "failed to load aws config") - } - - if isRegistryPublic(registry) { - svc := ecrpublic.NewFromConfig(cfg) - - input := &ecrpublic.SetRepositoryPolicyInput{ - PolicyText: aws.String(repositoryPolicy), - RepositoryName: aws.String(repo), +func uploadRepositoryPolicy(region, repo, registry, repositoryPolicy, assumeRole, externalId string) (err error) { + if assumeRole != "" { + if isRegistryPublic(registry) { + input := &ecrpublicv1.SetRepositoryPolicyInput{ + PolicyText: aws.String(repositoryPolicy), + RepositoryName: aws.String(repo), + } + _, err = getAssumeRoleEcrPublicSvc(region, assumeRole, externalId).SetRepositoryPolicy(input) + } else { + input := &ecrv1.SetRepositoryPolicyInput{ + PolicyText: aws.String(repositoryPolicy), + RepositoryName: aws.String(repo), + } + _, err = getAssumeRoleEcrSvc(region, assumeRole, externalId).SetRepositoryPolicy(input) } - _, err = svc.SetRepositoryPolicy(context.TODO(), input) } else { - - svc := ecr.NewFromConfig(cfg) - - input := &ecr.SetRepositoryPolicyInput{ - PolicyText: aws.String(repositoryPolicy), - RepositoryName: aws.String(repo), + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + return errors.Wrap(err, "failed to load aws config") + } + + if isRegistryPublic(registry) { + svc := ecrpublic.NewFromConfig(cfg) + input := &ecrpublic.SetRepositoryPolicyInput{ + PolicyText: aws.String(repositoryPolicy), + RepositoryName: aws.String(repo), + } + _, err = svc.SetRepositoryPolicy(context.TODO(), input) + } else { + svc := ecr.NewFromConfig(cfg) + input := &ecr.SetRepositoryPolicyInput{ + PolicyText: aws.String(repositoryPolicy), + RepositoryName: aws.String(repo), + } + _, err = svc.SetRepositoryPolicy(context.TODO(), input) } - _, err = svc.SetRepositoryPolicy(context.TODO(), input) } return err @@ -507,6 +538,36 @@ func getAuthInfo(svc *ecrv1.ECR) (username, password, registry string, err error return } +func getAssumeRoleEcrSvc(region, assumeRole, externalId string) *ecrv1.ECR { + sess, err := session.NewSession(&awsv1.Config{Region: ®ion}) + if err != nil { + logrus.Fatal(err, "failed to create aws session") + } + + return ecrv1.New(sess, &awsv1.Config{ + Credentials: stscreds.NewCredentials(sess, assumeRole, func(p *stscreds.AssumeRoleProvider) { + if externalId != "" { + p.ExternalID = &externalId + } + }), + }) +} + +func getAssumeRoleEcrPublicSvc(region, assumeRole, externalId string) *ecrpublicv1.ECRPublic { + sess, err := session.NewSession(&awsv1.Config{Region: ®ion}) + if err != nil { + logrus.Fatal(err, "failed to create aws session") + } + + return ecrpublicv1.New(sess, &awsv1.Config{ + Credentials: stscreds.NewCredentials(sess, assumeRole, func(p *stscreds.AssumeRoleProvider) { + if externalId != "" { + p.ExternalID = &externalId + } + }), + }) +} + func isRegistryPublic(registry string) bool { return strings.HasPrefix(registry, ecrPublicDomain) }