mirror of
https://github.com/drone-plugins/drone-docker.git
synced 2026-06-04 18:24:24 +08:00
feat: [CI-19349]: Added oidc support for azure connector (#496)
* feat: [CI-19349]: Added oidc support for azure connector * feat: [CI-19349]: Added env variables * feat: [CI-19349]: Added tests * Update cmd/drone-acr/main.go * Update cmd/drone-acr/main.go * feat: [CI-19349]: Added Debug statements --------- Co-authored-by: OP (oppenheimer) <21008429+Ompragash@users.noreply.github.com>
This commit is contained in:
+32
-9
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
docker "github.com/drone-plugins/drone-docker"
|
docker "github.com/drone-plugins/drone-docker"
|
||||||
|
azureutil "github.com/drone-plugins/drone-docker/internal/azure"
|
||||||
)
|
)
|
||||||
|
|
||||||
type subscriptionUrlResponse struct {
|
type subscriptionUrlResponse struct {
|
||||||
@@ -62,12 +63,14 @@ func main() {
|
|||||||
password = getenv("SERVICE_PRINCIPAL_CLIENT_SECRET")
|
password = getenv("SERVICE_PRINCIPAL_CLIENT_SECRET")
|
||||||
|
|
||||||
// Service principal credentials
|
// Service principal credentials
|
||||||
clientId = getenv("CLIENT_ID")
|
clientId = getenv("CLIENT_ID", "AZURE_CLIENT_ID", "AZURE_APP_ID", "PLUGIN_CLIENT_ID")
|
||||||
clientSecret = getenv("CLIENT_SECRET")
|
clientSecret = getenv("CLIENT_SECRET", "PLUGIN_CLIENT_SECRET")
|
||||||
clientCert = getenv("CLIENT_CERTIFICATE")
|
clientCert = getenv("CLIENT_CERTIFICATE", "PLUGIN_CLIENT_CERTIFICATE")
|
||||||
tenantId = getenv("TENANT_ID")
|
tenantId = getenv("TENANT_ID", "AZURE_TENANT_ID", "PLUGIN_TENANT_ID")
|
||||||
subscriptionId = getenv("SUBSCRIPTION_ID")
|
subscriptionId = getenv("SUBSCRIPTION_ID", "PLUGIN_SUBSCRIPTION_ID")
|
||||||
publicUrl = getenv("DAEMON_REGISTRY")
|
publicUrl = getenv("DAEMON_REGISTRY", "PLUGIN_DAEMON_REGISTRY")
|
||||||
|
authorityHost = getenv("AZURE_AUTHORITY_HOST", "PLUGIN_AZURE_AUTHORITY_HOST")
|
||||||
|
idToken = getenv("PLUGIN_OIDC_TOKEN_ID")
|
||||||
)
|
)
|
||||||
|
|
||||||
// default registry value
|
// default registry value
|
||||||
@@ -80,9 +83,29 @@ func main() {
|
|||||||
// docker login credentials are not provided
|
// docker login credentials are not provided
|
||||||
var err error
|
var err error
|
||||||
username = defaultUsername
|
username = defaultUsername
|
||||||
password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry)
|
if idToken != "" && clientId != "" && tenantId != "" {
|
||||||
if err != nil {
|
logrus.Debug("Using OIDC authentication flow")
|
||||||
logrus.Fatal(err)
|
var aadToken string
|
||||||
|
aadToken, err = azureutil.GetAADAccessTokenViaClientAssertion(context.Background(), tenantId, clientId, idToken, authorityHost)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
var p string
|
||||||
|
p, err = getPublicUrl(aadToken, registry, subscriptionId)
|
||||||
|
if err == nil {
|
||||||
|
publicUrl = p
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "failed to get public url with error: %s\n", err)
|
||||||
|
}
|
||||||
|
password, err = fetchACRToken(tenantId, aadToken, registry)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAuthInputValidation(t *testing.T) {
|
||||||
|
// missing tenant
|
||||||
|
if _, _, err := getAuth("client", "secret", "", "", "sub", "registry.azurecr.io"); err == nil {
|
||||||
|
t.Fatalf("expected error for missing tenantId")
|
||||||
|
}
|
||||||
|
// missing clientId
|
||||||
|
if _, _, err := getAuth("", "secret", "", "tenant", "sub", "registry.azurecr.io"); err == nil {
|
||||||
|
t.Fatalf("expected error for missing clientId")
|
||||||
|
}
|
||||||
|
// missing both secret and cert
|
||||||
|
if _, _, err := getAuth("client", "", "", "tenant", "sub", "registry.azurecr.io"); err == nil {
|
||||||
|
t.Fatalf("expected error for missing credentials")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetenvAuthorityHost(t *testing.T) {
|
||||||
|
os.Setenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.us")
|
||||||
|
defer os.Unsetenv("AZURE_AUTHORITY_HOST")
|
||||||
|
|
||||||
|
got := getenv("AZURE_AUTHORITY_HOST")
|
||||||
|
if got != "https://login.microsoftonline.us" {
|
||||||
|
t.Fatalf("expected AZURE_AUTHORITY_HOST to be returned, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package azure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultResource = "https://management.azure.com/"
|
||||||
|
const defaultAuthorityHost = "https://login.microsoftonline.com"
|
||||||
|
const defaultHTTPTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// GetAADAccessTokenViaClientAssertion exchanges an external OIDC ID token for an Azure AD access token
|
||||||
|
|
||||||
|
func GetAADAccessTokenViaClientAssertion(ctx context.Context, tenantID, clientID, oidcToken, authorityHost string) (string, error) {
|
||||||
|
resource := DefaultResource
|
||||||
|
|
||||||
|
form := url.Values{
|
||||||
|
"client_id": {clientID},
|
||||||
|
"scope": {resource + ".default"},
|
||||||
|
"grant_type": {"client_credentials"},
|
||||||
|
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
|
||||||
|
"client_assertion": {oidcToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
base := authorityHost
|
||||||
|
if strings.TrimSpace(base) == "" {
|
||||||
|
base = defaultAuthorityHost
|
||||||
|
}
|
||||||
|
base = strings.TrimRight(base, "/")
|
||||||
|
endpoint := fmt.Sprintf("%s/%s/oauth2/v2.0/token", base, tenantID)
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: defaultHTTPTimeout}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
var aadErr struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
}
|
||||||
|
limited := io.LimitedReader{R: resp.Body, N: 4096}
|
||||||
|
_ = json.NewDecoder(&limited).Decode(&aadErr)
|
||||||
|
if aadErr.Error != "" {
|
||||||
|
return "", fmt.Errorf("AAD token request failed: status=%d, error=%s", resp.StatusCode, aadErr.Error)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("AAD token request failed: status=%d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if payload.AccessToken == "" {
|
||||||
|
return "", fmt.Errorf("AAD token response missing access_token")
|
||||||
|
}
|
||||||
|
return payload.AccessToken, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package azure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAADAccessTokenViaClientAssertion_Success(t *testing.T) {
|
||||||
|
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if ct := r.Header.Get("Content-Type"); !strings.Contains(ct, "application/x-www-form-urlencoded") {
|
||||||
|
t.Fatalf("expected form content-type, got %s", ct)
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("failed parsing form: %v", err)
|
||||||
|
}
|
||||||
|
assertEq(t, r.Form.Get("client_id"), "client")
|
||||||
|
assertEq(t, r.Form.Get("grant_type"), "client_credentials")
|
||||||
|
assertEq(t, r.Form.Get("client_assertion_type"), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||||
|
assertEq(t, r.Form.Get("client_assertion"), "idtoken")
|
||||||
|
assertEq(t, r.Form.Get("scope"), DefaultResource+".default")
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(`{"access_token":"AT","token_type":"Bearer","expires_in":3600}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
tok, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if tok != "AT" {
|
||||||
|
t.Fatalf("expected access token AT, got %q", tok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAADAccessTokenViaClientAssertion_400WithErrorField(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"bad"}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "status=400") || !strings.Contains(err.Error(), "invalid_client") {
|
||||||
|
t.Fatalf("expected 400 with invalid_client error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAADAccessTokenViaClientAssertion_400WithoutErrorField(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte("{}"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "status=400") {
|
||||||
|
t.Fatalf("expected 400 error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAADAccessTokenViaClientAssertion_MalformedJSON(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("not-json"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected JSON decode error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAADAccessTokenViaClientAssertion_MissingAccessToken(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(`{"token_type":"Bearer","expires_in":3600}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "missing access_token") {
|
||||||
|
t.Fatalf("expected missing access_token error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEq(t *testing.T, got, want string) {
|
||||||
|
t.Helper()
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("mismatch: got=%q want=%q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user