diff --git a/cmd/kaniko-acr/main.go b/cmd/kaniko-acr/main.go index 01aa3d4..e862caf 100644 --- a/cmd/kaniko-acr/main.go +++ b/cmd/kaniko-acr/main.go @@ -20,6 +20,7 @@ import ( "github.com/urfave/cli" kaniko "github.com/drone/drone-kaniko" + azureutil "github.com/drone/drone-kaniko/internal/azure" "github.com/drone/drone-kaniko/pkg/artifact" "github.com/drone/drone-kaniko/pkg/docker" "github.com/drone/drone-kaniko/pkg/utils" @@ -168,7 +169,7 @@ func main() { cli.StringFlag{ Name: "tenant-id", Usage: "Azure Tenant Id", - EnvVar: "TENANT_ID", + EnvVar: "TENANT_ID,AZURE_TENANT_ID,PLUGIN_TENANT_ID", }, cli.StringFlag{ Name: "subscription-id", @@ -177,8 +178,18 @@ func main() { }, cli.StringFlag{ Name: "client-id", - Usage: "Azure Client Id", - EnvVar: "CLIENT_ID", + Usage: "Azure Client ID (also called App ID)", + EnvVar: "CLIENT_ID,AZURE_CLIENT_ID,PLUGIN_CLIENT_ID,AZURE_APP_ID", + }, + cli.StringFlag{ + Name: "oidc-token-id", + Usage: "OIDC ID token to exchange for Azure AD access token (federated credentials)", + EnvVar: "PLUGIN_OIDC_TOKEN_ID", + }, + cli.StringFlag{ + Name: "azure-authority-host", + Usage: "Azure authority host base URL (e.g., https://login.microsoftonline.com, https://login.microsoftonline.us)", + EnvVar: "AZURE_AUTHORITY_HOST", }, cli.StringFlag{ Name: "snapshot-mode", @@ -417,9 +428,17 @@ func run(c *cli.Context) error { registry := c.String("registry") noPush := c.Bool("no-push") - publicUrl, err := setupAuth( - c.String("tenant-id"), - c.String("client-id"), + clientID := c.String("client-id") + tenantID := c.String("tenant-id") + oidcIdToken := c.String("oidc-token-id") + authorityHost := c.String("azure-authority-host") + + var publicUrl string + var err error + publicUrl, err = setupAuth( + tenantID, + clientID, + oidcIdToken, c.String("client-cert"), c.String("client-secret"), c.String("subscription-id"), @@ -427,6 +446,7 @@ func run(c *cli.Context) error { c.String("base-image-username"), c.String("base-image-password"), c.String("base-image-registry"), + authorityHost, noPush, ) if err != nil { @@ -516,40 +536,66 @@ func run(c *cli.Context) error { return plugin.Exec() } -func setupAuth(tenantId, clientId, cert, - clientSecret, subscriptionId, registry, dockerUsername, dockerPassword, dockerRegistry string, noPush bool) (string, error) { +func setupAuth(tenantId, clientId, oidcIdToken, cert, + clientSecret, subscriptionId, registry, dockerUsername, dockerPassword, dockerRegistry, authorityHost string, noPush bool) (string, error) { if registry == "" { return "", fmt.Errorf("registry must be specified") } - // case of client secret or cert based auth - if clientId != "" { - // only setup auth when pushing or credentials are defined + // Determine auth path: OIDC or Service Principal (secret/cert) + if tenantId == "" || clientId == "" { + if noPush { + logrus.Warnf("NO_PUSH mode: tenantId or clientId not provided") + return "", nil + } + return "", fmt.Errorf("tenantId and clientId must be provided") + } - token, publicUrl, err := getACRToken(subscriptionId, tenantId, clientId, clientSecret, cert, registry) + var aadAccessToken string + var acrToken string + var publicUrl string + var err error + + if oidcIdToken != "" { + // Exchange OIDC ID token for AAD access token via client_assertion + aadAccessToken, err = azureutil.GetAADAccessTokenViaClientAssertion(context.Background(), tenantId, clientId, oidcIdToken, authorityHost) if err != nil { - if noPush { - logrus.Warnf("NO_PUSH mode: failed to fetch ACR Token: %v", err) - return "", nil - } - return "", errors.Wrap(err, "failed to fetch ACR Token") + return handleError(noPush, err, "failed to get AAD token via OIDC") } - - // setup docker config for azure registry and base image docker registry - if err := setDockerAuth(username, token, registry, dockerUsername, dockerPassword, dockerRegistry); err != nil { - if noPush { - logrus.Warnf("NO_PUSH mode: failed to create docker config: %v", err) - return "", nil - } - return "", errors.Wrap(err, "failed to create docker config") + publicUrl, err = getPublicUrl(aadAccessToken, registry, subscriptionId) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get public url with error: %s\n", err) + } + // Exchange AAD access token to ACR refresh token + acrToken, err = fetchACRToken(tenantId, aadAccessToken, registry) + if err != nil { + return handleError(noPush, err, "failed to fetch ACR token") + } + } else if clientSecret != "" || cert != "" { + acrToken, publicUrl, err = getACRToken(subscriptionId, tenantId, clientId, clientSecret, cert, registry) + if err != nil { + return handleError(noPush, err, "failed to fetch ACR Token") } - return publicUrl, nil } else { if noPush { return "", nil } return "", fmt.Errorf("managed authentication is not supported") } + + if err := setDockerAuth(username, acrToken, registry, dockerUsername, dockerPassword, dockerRegistry); err != nil { + return handleError(noPush, err, "failed to create docker config") + } + return publicUrl, nil +} + +// Error handling +func handleError(noPush bool, err error, msg string) (string, error) { + if noPush { + logrus.Warnf("NO_PUSH mode: %s: %v", msg, err) + return "", nil + } + return "", errors.Wrap(err, msg) } func getACRToken(subscriptionId, tenantId, clientId, clientSecret, cert, registry string) (string, string, error) { @@ -762,10 +808,18 @@ func handlePushOnly(c *cli.Context) error { return fmt.Errorf("repository and registry must be specified for push-only operation") } - // Setup ACR authentication - publicUrl, err := setupAuth( - c.String("tenant-id"), - c.String("client-id"), + // Resolve Azure client/tenant and OIDC via CLI flags + clientID := c.String("client-id") + tenantID := c.String("tenant-id") + oidcIdToken := c.String("oidc-token-id") + authorityHost := c.String("azure-authority-host") + + var publicUrl string + var err error + publicUrl, err = setupAuth( + tenantID, + clientID, + oidcIdToken, c.String("client-cert"), c.String("client-secret"), c.String("subscription-id"), @@ -773,7 +827,8 @@ func handlePushOnly(c *cli.Context) error { c.String("base-image-username"), c.String("base-image-password"), c.String("base-image-registry"), - false, // We want to push in push-only mode + authorityHost, + false, ) if err != nil { return err diff --git a/cmd/kaniko-acr/main_test.go b/cmd/kaniko-acr/main_test.go index 646d5b4..c3727c2 100644 --- a/cmd/kaniko-acr/main_test.go +++ b/cmd/kaniko-acr/main_test.go @@ -367,3 +367,23 @@ func TestACRAuthenticationFlow(t *testing.T) { }) } } + +func TestSetupAuth_RegistryMustBeSpecified(t *testing.T) { + pub, err := setupAuth("tenant", "client", "", "", "", "sub", "", "", "", "", "", false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "registry must be specified") + assert.Equal(t, "", pub) +} + +func TestSetupAuth_MissingTenantOrClient(t *testing.T) { + pub, err := setupAuth("tenant", "", "", "", "", "sub", "myregistry.azurecr.io", "", "", "", "", false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "tenantId and clientId must be provided") + assert.Equal(t, "", pub) +} + +func TestSetupAuth_NoCreds_NoPushTrue(t *testing.T) { + pub, err := setupAuth("tenant", "client", "", "", "", "sub", "myregistry.azurecr.io", "", "", "", "", true) + assert.NoError(t, err) + assert.Equal(t, "", pub) +} diff --git a/internal/azure/tokenutil.go b/internal/azure/tokenutil.go new file mode 100644 index 0000000..a23395c --- /dev/null +++ b/internal/azure/tokenutil.go @@ -0,0 +1,72 @@ +package azure + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const DefaultResource = "https://management.azure.com/" +const defaultAuthorityHost = "https://login.microsoftonline.com" +const defaultHTTPTimeout = 30 * time.Second + +// GetAADAccessTokenViaClientAssertion exchanges an external OIDC ID token for an Azure AD access token + +func GetAADAccessTokenViaClientAssertion(ctx context.Context, tenantID, clientID, oidcToken, authorityHost string) (string, error) { + resource := DefaultResource + + form := url.Values{ + "client_id": {clientID}, + "scope": {resource + ".default"}, + "grant_type": {"client_credentials"}, + "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, + "client_assertion": {oidcToken}, + } + base := authorityHost + if strings.TrimSpace(base) == "" { + base = defaultAuthorityHost + } + base = strings.TrimRight(base, "/") + endpoint := fmt.Sprintf("%s/%s/oauth2/v2.0/token", base, tenantID) + client := &http.Client{Timeout: defaultHTTPTimeout} + req, err := http.NewRequestWithContext(ctx, "POST", endpoint, strings.NewReader(form.Encode())) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var aadErr struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + limited := io.LimitedReader{R: resp.Body, N: 4096} + _ = json.NewDecoder(&limited).Decode(&aadErr) + if aadErr.Error != "" { + return "", fmt.Errorf("AAD token request failed: status=%d, error=%s", resp.StatusCode, aadErr.Error) + } + return "", fmt.Errorf("AAD token request failed: status=%d", resp.StatusCode) + } + var payload struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "", err + } + if payload.AccessToken == "" { + return "", fmt.Errorf("AAD token response missing access_token") + } + return payload.AccessToken, nil +} diff --git a/internal/azure/tokenutil_test.go b/internal/azure/tokenutil_test.go new file mode 100644 index 0000000..667c1cd --- /dev/null +++ b/internal/azure/tokenutil_test.go @@ -0,0 +1,103 @@ +package azure + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGetAADAccessTokenViaClientAssertion_Success(t *testing.T) { + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); !strings.Contains(ct, "application/x-www-form-urlencoded") { + t.Fatalf("expected form content-type, got %s", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("failed parsing form: %v", err) + } + assertEq(t, r.Form.Get("client_id"), "client") + assertEq(t, r.Form.Get("grant_type"), "client_credentials") + assertEq(t, r.Form.Get("client_assertion_type"), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + assertEq(t, r.Form.Get("client_assertion"), "idtoken") + assertEq(t, r.Form.Get("scope"), DefaultResource+".default") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"AT","token_type":"Bearer","expires_in":3600}`)) + })) + defer ts.Close() + + tok, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tok != "AT" { + t.Fatalf("expected access token AT, got %q", tok) + } +} + +func TestGetAADAccessTokenViaClientAssertion_400WithErrorField(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"bad"}`)) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil || !strings.Contains(err.Error(), "status=400") || !strings.Contains(err.Error(), "invalid_client") { + t.Fatalf("expected 400 with invalid_client error, got %v", err) + } +} + +func TestGetAADAccessTokenViaClientAssertion_400WithoutErrorField(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil || !strings.Contains(err.Error(), "status=400") { + t.Fatalf("expected 400 error, got %v", err) + } +} + +func TestGetAADAccessTokenViaClientAssertion_MalformedJSON(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not-json")) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil { + t.Fatalf("expected JSON decode error, got nil") + } +} + +func TestGetAADAccessTokenViaClientAssertion_MissingAccessToken(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"token_type":"Bearer","expires_in":3600}`)) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil || !strings.Contains(err.Error(), "missing access_token") { + t.Fatalf("expected missing access_token error, got %v", err) + } +} + +func assertEq(t *testing.T, got, want string) { + t.Helper() + if got != want { + t.Fatalf("mismatch: got=%q want=%q", got, want) + } +}