Restructured service account impersonation flow.

Change-Id: I17c0283f053711f44abaf5620f2642eea08aca62
This commit is contained in:
Patrick Jones
2021-01-25 00:02:46 -08:00
parent 85db953d34
commit 975d0951de
3 changed files with 38 additions and 29 deletions

View File

@@ -31,11 +31,26 @@ type Config struct {
// TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials.
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
if c.ServiceAccountImpersonationURL == "" {
ts := tokenSource{ ts := tokenSource{
ctx: ctx, ctx: ctx,
conf: c, conf: c,
} }
return oauth2.ReuseTokenSource(nil, ts) return oauth2.ReuseTokenSource(nil, ts)
}
imp := impersonateTokenSource{
ctx: ctx,
url: c.ServiceAccountImpersonationURL,
scopes: c.Scopes,
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
ts.conf.ServiceAccountImpersonationURL = ""
ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
imp.ts = oauth2.ReuseTokenSource(nil, ts)
return oauth2.ReuseTokenSource(nil, imp)
} }
// Subject token file types. // Subject token file types.
@@ -89,14 +104,6 @@ type tokenSource struct {
func (ts tokenSource) Token() (*oauth2.Token, error) { func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf conf := ts.conf
if conf.ServiceAccountImpersonationURL != "" {
token, err := ts.impersonate()
if err != nil {
return nil, err
}
return token, err
}
credSource := conf.parse(ts.ctx) credSource := conf.parse(ts.ctx)
if credSource == nil { if credSource == nil {
return nil, fmt.Errorf("oauth2/google: unable to parse credential source") return nil, fmt.Errorf("oauth2/google: unable to parse credential source")

View File

@@ -6,6 +6,7 @@ package externalaccount
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -27,49 +28,53 @@ type impersonateTokenResponse struct {
ExpireTime string `json:"expireTime"` ExpireTime string `json:"expireTime"`
} }
type impersonateTokenSource struct {
ctx context.Context
ts oauth2.TokenSource
url string
scopes []string
}
// impersonate performs the exchange to get a temporary service account // impersonate performs the exchange to get a temporary service account
func (ts tokenSource) impersonate() (*oauth2.Token, error) { func (its impersonateTokenSource) Token() (*oauth2.Token, error) {
reqBody := generateAccessTokenReq{ reqBody := generateAccessTokenReq{
Lifetime: "3600s", Lifetime: "3600s",
Scope: ts.conf.Scopes, Scope: its.scopes,
} }
b, err := json.Marshal(reqBody) b, err := json.Marshal(reqBody)
serviceAccountImpersonationURL := ts.conf.ServiceAccountImpersonationURL client := oauth2.NewClient(its.ctx, its.ts)
ts.conf.ServiceAccountImpersonationURL = ""
ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
client := oauth2.NewClient(ts.ctx, ts)
if err != nil { if err != nil {
return &oauth2.Token{}, fmt.Errorf("google: unable to marshal request: %v", err) return nil, fmt.Errorf("oauth2/google: unable to marshal request: %v", err)
} }
req, err := http.NewRequest("POST", serviceAccountImpersonationURL, bytes.NewReader(b)) req, err := http.NewRequest("POST", its.url, bytes.NewReader(b))
if err != nil { if err != nil {
return nil, fmt.Errorf("impersonate: unable to create request: %v", err) return nil, fmt.Errorf("oauth2/google: unable to create impersonation request: %v", err)
} }
req = req.WithContext(ts.ctx) req = req.WithContext(its.ctx)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("impersonate: unable to generate access token: %v", err) return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %v", err) return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err)
} }
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body) return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
} }
var accessTokenResp impersonateTokenResponse var accessTokenResp impersonateTokenResponse
if err := json.Unmarshal(body, &accessTokenResp); err != nil { if err := json.Unmarshal(body, &accessTokenResp); err != nil {
return nil, fmt.Errorf("impersonate: unable to parse response: %v", err) return nil, fmt.Errorf("oauth2/google: unable to parse response: %v", err)
} }
expiry, err := time.Parse(time.RFC3339, accessTokenResp.ExpireTime) expiry, err := time.Parse(time.RFC3339, accessTokenResp.ExpireTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("impersonate: unable to parse expiry: %v", err) return nil, fmt.Errorf("oauth2/google: unable to parse expiry: %v", err)
} }
return &oauth2.Token{ return &oauth2.Token{
AccessToken: accessTokenResp.AccessToken, AccessToken: accessTokenResp.AccessToken,

View File

@@ -77,10 +77,7 @@ func TestImpersonation(t *testing.T) {
defer targetServer.Close() defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL testImpersonateConfig.TokenURL = targetServer.URL
ourTS := tokenSource{ ourTS := testImpersonateConfig.TokenSource(context.Background())
ctx: context.Background(),
conf: &testImpersonateConfig,
}
oldNow := now oldNow := now
defer func() { now = oldNow }() defer func() { now = oldNow }()