mirror of
https://github.com/drone/drone-kaniko.git
synced 2026-06-04 10:14:55 +08:00
feat: [CI-19349]: Added oidc support for ACR (#154)
* feat: [CI-18693]: Added oidc support for ACR * feat: [CI-19349]: error handling * feat: [CI-19349]: Refactored the code and added cli flags * feat: [CI-19349]: Added test cases * Update cmd/kaniko-acr/main.go * Update cmd/kaniko-acr/main.go * Update cmd/kaniko-acr/main.go * feat: [CI-19349]: changed the variable names * feat: [CI-18693]: Added error handling * feat: [CI-19349]: removed redundant code --------- Co-authored-by: OP (oppenheimer) <21008429+Ompragash@users.noreply.github.com>
This commit is contained in:
+86
-31
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/urfave/cli"
|
||||
|
||||
kaniko "github.com/drone/drone-kaniko"
|
||||
azureutil "github.com/drone/drone-kaniko/internal/azure"
|
||||
"github.com/drone/drone-kaniko/pkg/artifact"
|
||||
"github.com/drone/drone-kaniko/pkg/docker"
|
||||
"github.com/drone/drone-kaniko/pkg/utils"
|
||||
@@ -168,7 +169,7 @@ func main() {
|
||||
cli.StringFlag{
|
||||
Name: "tenant-id",
|
||||
Usage: "Azure Tenant Id",
|
||||
EnvVar: "TENANT_ID",
|
||||
EnvVar: "TENANT_ID,AZURE_TENANT_ID,PLUGIN_TENANT_ID",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "subscription-id",
|
||||
@@ -177,8 +178,18 @@ func main() {
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "client-id",
|
||||
Usage: "Azure Client Id",
|
||||
EnvVar: "CLIENT_ID",
|
||||
Usage: "Azure Client ID (also called App ID)",
|
||||
EnvVar: "CLIENT_ID,AZURE_CLIENT_ID,PLUGIN_CLIENT_ID,AZURE_APP_ID",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "oidc-token-id",
|
||||
Usage: "OIDC ID token to exchange for Azure AD access token (federated credentials)",
|
||||
EnvVar: "PLUGIN_OIDC_TOKEN_ID",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "azure-authority-host",
|
||||
Usage: "Azure authority host base URL (e.g., https://login.microsoftonline.com, https://login.microsoftonline.us)",
|
||||
EnvVar: "AZURE_AUTHORITY_HOST",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "snapshot-mode",
|
||||
@@ -417,9 +428,17 @@ func run(c *cli.Context) error {
|
||||
registry := c.String("registry")
|
||||
noPush := c.Bool("no-push")
|
||||
|
||||
publicUrl, err := setupAuth(
|
||||
c.String("tenant-id"),
|
||||
c.String("client-id"),
|
||||
clientID := c.String("client-id")
|
||||
tenantID := c.String("tenant-id")
|
||||
oidcIdToken := c.String("oidc-token-id")
|
||||
authorityHost := c.String("azure-authority-host")
|
||||
|
||||
var publicUrl string
|
||||
var err error
|
||||
publicUrl, err = setupAuth(
|
||||
tenantID,
|
||||
clientID,
|
||||
oidcIdToken,
|
||||
c.String("client-cert"),
|
||||
c.String("client-secret"),
|
||||
c.String("subscription-id"),
|
||||
@@ -427,6 +446,7 @@ func run(c *cli.Context) error {
|
||||
c.String("base-image-username"),
|
||||
c.String("base-image-password"),
|
||||
c.String("base-image-registry"),
|
||||
authorityHost,
|
||||
noPush,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -516,40 +536,66 @@ func run(c *cli.Context) error {
|
||||
return plugin.Exec()
|
||||
}
|
||||
|
||||
func setupAuth(tenantId, clientId, cert,
|
||||
clientSecret, subscriptionId, registry, dockerUsername, dockerPassword, dockerRegistry string, noPush bool) (string, error) {
|
||||
func setupAuth(tenantId, clientId, oidcIdToken, cert,
|
||||
clientSecret, subscriptionId, registry, dockerUsername, dockerPassword, dockerRegistry, authorityHost string, noPush bool) (string, error) {
|
||||
if registry == "" {
|
||||
return "", fmt.Errorf("registry must be specified")
|
||||
}
|
||||
|
||||
// case of client secret or cert based auth
|
||||
if clientId != "" {
|
||||
// only setup auth when pushing or credentials are defined
|
||||
// Determine auth path: OIDC or Service Principal (secret/cert)
|
||||
if tenantId == "" || clientId == "" {
|
||||
if noPush {
|
||||
logrus.Warnf("NO_PUSH mode: tenantId or clientId not provided")
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("tenantId and clientId must be provided")
|
||||
}
|
||||
|
||||
token, publicUrl, err := getACRToken(subscriptionId, tenantId, clientId, clientSecret, cert, registry)
|
||||
var aadAccessToken string
|
||||
var acrToken string
|
||||
var publicUrl string
|
||||
var err error
|
||||
|
||||
if oidcIdToken != "" {
|
||||
// Exchange OIDC ID token for AAD access token via client_assertion
|
||||
aadAccessToken, err = azureutil.GetAADAccessTokenViaClientAssertion(context.Background(), tenantId, clientId, oidcIdToken, authorityHost)
|
||||
if err != nil {
|
||||
if noPush {
|
||||
logrus.Warnf("NO_PUSH mode: failed to fetch ACR Token: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return "", errors.Wrap(err, "failed to fetch ACR Token")
|
||||
return handleError(noPush, err, "failed to get AAD token via OIDC")
|
||||
}
|
||||
|
||||
// setup docker config for azure registry and base image docker registry
|
||||
if err := setDockerAuth(username, token, registry, dockerUsername, dockerPassword, dockerRegistry); err != nil {
|
||||
if noPush {
|
||||
logrus.Warnf("NO_PUSH mode: failed to create docker config: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return "", errors.Wrap(err, "failed to create docker config")
|
||||
publicUrl, err = getPublicUrl(aadAccessToken, registry, subscriptionId)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to get public url with error: %s\n", err)
|
||||
}
|
||||
// Exchange AAD access token to ACR refresh token
|
||||
acrToken, err = fetchACRToken(tenantId, aadAccessToken, registry)
|
||||
if err != nil {
|
||||
return handleError(noPush, err, "failed to fetch ACR token")
|
||||
}
|
||||
} else if clientSecret != "" || cert != "" {
|
||||
acrToken, publicUrl, err = getACRToken(subscriptionId, tenantId, clientId, clientSecret, cert, registry)
|
||||
if err != nil {
|
||||
return handleError(noPush, err, "failed to fetch ACR Token")
|
||||
}
|
||||
return publicUrl, nil
|
||||
} else {
|
||||
if noPush {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("managed authentication is not supported")
|
||||
}
|
||||
|
||||
if err := setDockerAuth(username, acrToken, registry, dockerUsername, dockerPassword, dockerRegistry); err != nil {
|
||||
return handleError(noPush, err, "failed to create docker config")
|
||||
}
|
||||
return publicUrl, nil
|
||||
}
|
||||
|
||||
// Error handling
|
||||
func handleError(noPush bool, err error, msg string) (string, error) {
|
||||
if noPush {
|
||||
logrus.Warnf("NO_PUSH mode: %s: %v", msg, err)
|
||||
return "", nil
|
||||
}
|
||||
return "", errors.Wrap(err, msg)
|
||||
}
|
||||
|
||||
func getACRToken(subscriptionId, tenantId, clientId, clientSecret, cert, registry string) (string, string, error) {
|
||||
@@ -762,10 +808,18 @@ func handlePushOnly(c *cli.Context) error {
|
||||
return fmt.Errorf("repository and registry must be specified for push-only operation")
|
||||
}
|
||||
|
||||
// Setup ACR authentication
|
||||
publicUrl, err := setupAuth(
|
||||
c.String("tenant-id"),
|
||||
c.String("client-id"),
|
||||
// Resolve Azure client/tenant and OIDC via CLI flags
|
||||
clientID := c.String("client-id")
|
||||
tenantID := c.String("tenant-id")
|
||||
oidcIdToken := c.String("oidc-token-id")
|
||||
authorityHost := c.String("azure-authority-host")
|
||||
|
||||
var publicUrl string
|
||||
var err error
|
||||
publicUrl, err = setupAuth(
|
||||
tenantID,
|
||||
clientID,
|
||||
oidcIdToken,
|
||||
c.String("client-cert"),
|
||||
c.String("client-secret"),
|
||||
c.String("subscription-id"),
|
||||
@@ -773,7 +827,8 @@ func handlePushOnly(c *cli.Context) error {
|
||||
c.String("base-image-username"),
|
||||
c.String("base-image-password"),
|
||||
c.String("base-image-registry"),
|
||||
false, // We want to push in push-only mode
|
||||
authorityHost,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -367,3 +367,23 @@ func TestACRAuthenticationFlow(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuth_RegistryMustBeSpecified(t *testing.T) {
|
||||
pub, err := setupAuth("tenant", "client", "", "", "", "sub", "", "", "", "", "", false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "registry must be specified")
|
||||
assert.Equal(t, "", pub)
|
||||
}
|
||||
|
||||
func TestSetupAuth_MissingTenantOrClient(t *testing.T) {
|
||||
pub, err := setupAuth("tenant", "", "", "", "", "sub", "myregistry.azurecr.io", "", "", "", "", false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "tenantId and clientId must be provided")
|
||||
assert.Equal(t, "", pub)
|
||||
}
|
||||
|
||||
func TestSetupAuth_NoCreds_NoPushTrue(t *testing.T) {
|
||||
pub, err := setupAuth("tenant", "client", "", "", "", "sub", "myregistry.azurecr.io", "", "", "", "", true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", pub)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultResource = "https://management.azure.com/"
|
||||
const defaultAuthorityHost = "https://login.microsoftonline.com"
|
||||
const defaultHTTPTimeout = 30 * time.Second
|
||||
|
||||
// GetAADAccessTokenViaClientAssertion exchanges an external OIDC ID token for an Azure AD access token
|
||||
|
||||
func GetAADAccessTokenViaClientAssertion(ctx context.Context, tenantID, clientID, oidcToken, authorityHost string) (string, error) {
|
||||
resource := DefaultResource
|
||||
|
||||
form := url.Values{
|
||||
"client_id": {clientID},
|
||||
"scope": {resource + ".default"},
|
||||
"grant_type": {"client_credentials"},
|
||||
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
|
||||
"client_assertion": {oidcToken},
|
||||
}
|
||||
base := authorityHost
|
||||
if strings.TrimSpace(base) == "" {
|
||||
base = defaultAuthorityHost
|
||||
}
|
||||
base = strings.TrimRight(base, "/")
|
||||
endpoint := fmt.Sprintf("%s/%s/oauth2/v2.0/token", base, tenantID)
|
||||
client := &http.Client{Timeout: defaultHTTPTimeout}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
var aadErr struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
}
|
||||
limited := io.LimitedReader{R: resp.Body, N: 4096}
|
||||
_ = json.NewDecoder(&limited).Decode(&aadErr)
|
||||
if aadErr.Error != "" {
|
||||
return "", fmt.Errorf("AAD token request failed: status=%d, error=%s", resp.StatusCode, aadErr.Error)
|
||||
}
|
||||
return "", fmt.Errorf("AAD token request failed: status=%d", resp.StatusCode)
|
||||
}
|
||||
var payload struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if payload.AccessToken == "" {
|
||||
return "", fmt.Errorf("AAD token response missing access_token")
|
||||
}
|
||||
return payload.AccessToken, nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetAADAccessTokenViaClientAssertion_Success(t *testing.T) {
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); !strings.Contains(ct, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("expected form content-type, got %s", ct)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("failed parsing form: %v", err)
|
||||
}
|
||||
assertEq(t, r.Form.Get("client_id"), "client")
|
||||
assertEq(t, r.Form.Get("grant_type"), "client_credentials")
|
||||
assertEq(t, r.Form.Get("client_assertion_type"), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
assertEq(t, r.Form.Get("client_assertion"), "idtoken")
|
||||
assertEq(t, r.Form.Get("scope"), DefaultResource+".default")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"access_token":"AT","token_type":"Bearer","expires_in":3600}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
tok, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if tok != "AT" {
|
||||
t.Fatalf("expected access token AT, got %q", tok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAADAccessTokenViaClientAssertion_400WithErrorField(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"bad"}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||
if err == nil || !strings.Contains(err.Error(), "status=400") || !strings.Contains(err.Error(), "invalid_client") {
|
||||
t.Fatalf("expected 400 with invalid_client error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAADAccessTokenViaClientAssertion_400WithoutErrorField(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("{}"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||
if err == nil || !strings.Contains(err.Error(), "status=400") {
|
||||
t.Fatalf("expected 400 error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAADAccessTokenViaClientAssertion_MalformedJSON(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||
if err == nil {
|
||||
t.Fatalf("expected JSON decode error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAADAccessTokenViaClientAssertion_MissingAccessToken(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"token_type":"Bearer","expires_in":3600}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||
if err == nil || !strings.Contains(err.Error(), "missing access_token") {
|
||||
t.Fatalf("expected missing access_token error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEq(t *testing.T, got, want string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Fatalf("mismatch: got=%q want=%q", got, want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user