From ff3df7374f4e7d04d0cd9b7a98dd3d8353711bd2 Mon Sep 17 00:00:00 2001 From: Shin Fan Date: Mon, 14 Jun 2021 17:00:17 -0700 Subject: [PATCH] google: support scopes for JWT access token --- google/jwt.go | 33 +++++++++++++++++---- google/jwt_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/google/jwt.go b/google/jwt.go index b0fdb3a..efbd46b 100644 --- a/google/jwt.go +++ b/google/jwt.go @@ -7,6 +7,7 @@ package google import ( "crypto/rsa" "fmt" + "strings" "time" "golang.org/x/oauth2" @@ -24,6 +25,24 @@ import ( // optimization supported by a few Google services. // Unless you know otherwise, you should use JWTConfigFromJSON instead. func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) { + return newJWTSource(jsonKey, audience, nil) +} + +// JWTAccessTokenSourceWithScope uses a Google Developers service account JSON +// key file to read the credentials that authorize and authenticate the +// requests, and returns a TokenSource that does not use any OAuth2 flow but +// instead creates a JWT and sends that as the access token. +// The scopes is typically a list of URLs that specifies the scope of the +// credentials. +// +// Note that this is not a standard OAuth flow, but rather an +// optimization supported by a few Google services. +// Unless you know otherwise, you should use JWTConfigFromJSON instead. +func JWTAccessTokenSourceWithScope(jsonKey []byte, scopes []string) (oauth2.TokenSource, error) { + return newJWTSource(jsonKey, "", scopes) +} + +func newJWTSource(jsonKey []byte, audience string, scopes []string) (oauth2.TokenSource, error) { cfg, err := JWTConfigFromJSON(jsonKey) if err != nil { return nil, fmt.Errorf("google: could not parse JSON key: %v", err) @@ -35,6 +54,7 @@ func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.Token ts := &jwtAccessTokenSource{ email: cfg.Email, audience: audience, + scopes: scopes, pk: pk, pkID: cfg.PrivateKeyID, } @@ -47,6 +67,7 @@ func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.Token type jwtAccessTokenSource struct { email, audience string + scopes []string pk *rsa.PrivateKey pkID string } @@ -54,12 +75,14 @@ type jwtAccessTokenSource struct { func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) { iat := time.Now() exp := iat.Add(time.Hour) + scope := strings.Join(ts.scopes, " ") cs := &jws.ClaimSet{ - Iss: ts.email, - Sub: ts.email, - Aud: ts.audience, - Iat: iat.Unix(), - Exp: exp.Unix(), + Iss: ts.email, + Sub: ts.email, + Aud: ts.audience, + Scope: scope, + Iat: iat.Unix(), + Exp: exp.Unix(), } hdr := &jws.Header{ Algorithm: "RS256", diff --git a/google/jwt_test.go b/google/jwt_test.go index f844436..28be9fa 100644 --- a/google/jwt_test.go +++ b/google/jwt_test.go @@ -89,3 +89,74 @@ func TestJWTAccessTokenSourceFromJSON(t *testing.T) { t.Errorf("Header KeyID = %q, want %q", got, want) } } + +func TestJWTAccessTokenSourceWithScope(t *testing.T) { + // Generate a key we can use in the test data. + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + // Encode the key and substitute into our example JSON. + enc := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + enc, err = json.Marshal(string(enc)) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + jsonKey := bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1) + + ts, err := JWTAccessTokenSourceWithScope(jsonKey, []string{"scope1", "scope2"}) + if err != nil { + t.Fatalf("JWTAccessTokenSourceWithScope: %v\nJSON: %s", err, string(jsonKey)) + } + + tok, err := ts.Token() + if err != nil { + t.Fatalf("Token: %v", err) + } + + if got, want := tok.TokenType, "Bearer"; got != want { + t.Errorf("TokenType = %q, want %q", got, want) + } + if got := tok.Expiry; tok.Expiry.Before(time.Now()) { + t.Errorf("Expiry = %v, should not be expired", got) + } + + err = jws.Verify(tok.AccessToken, &privateKey.PublicKey) + if err != nil { + t.Errorf("jws.Verify on AccessToken: %v", err) + } + + claim, err := jws.Decode(tok.AccessToken) + if err != nil { + t.Fatalf("jws.Decode on AccessToken: %v", err) + } + + if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want { + t.Errorf("Iss = %q, want %q", got, want) + } + if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want { + t.Errorf("Sub = %q, want %q", got, want) + } + if got, want := claim.Scope, "scope1 scope2"; got != want { + t.Errorf("Aud = %q, want %q", got, want) + } + + // Finally, check the header private key. + parts := strings.Split(tok.AccessToken, ".") + hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0]) + } + var hdr jws.Header + if err := json.Unmarshal([]byte(hdrJSON), &hdr); err != nil { + t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON) + } + + if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want { + t.Errorf("Header KeyID = %q, want %q", got, want) + } +}