13 Commits

Author SHA1 Message Date
Gopher Robot
6e9ec9323d go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Iad79e50dacd89c4cd0a40d966a1a7ba4cdc3d1a4
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/545176
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
2023-11-27 17:50:56 +00:00
Gopher Robot
e067960af8 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Id1413f67816220ef8039fb933088f4b7f50d70e5
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/540817
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
2023-11-08 20:28:19 +00:00
Leo
4c91c17b32 google: adds header to security considerations section
Change-Id: I29b93715876f233ae52687c8223fd8733a2a3b80
GitHub-Last-Rev: f15c4cf1a5
GitHub-Pull-Request: golang/oauth2#677
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/535895
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-17 20:42:42 +00:00
Gopher Robot
3c5dbf08cc go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I39a72a7dbb2205a6638a154892c69948ee2deb0d
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/533241
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2023-10-06 08:33:24 +00:00
Chris Smith
11625ccb95 google: add authorized_user conditional to Credentials.UniverseDomain
Return default universe domain if credentials type is authorized_user.

Change-Id: I20a9b5fafa562fcec84717914a236d081f630591
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/532196
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-02 22:04:54 +00:00
Chris Smith
8d6d45b6cd google: add Credentials.UniverseDomain to support TPC
Read and expose universe_domain from service account JSON files in
CredentialsFromJSONWithParams to support TPC in 1p clients.

Change-Id: I3518a0ec8be5ff7235b946cffd88b26ac8d303cf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531715
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-29 14:33:30 +00:00
Jin Qin
43b6a7ba19 google: adding support for external account authorized user
To support a new type of credential: `ExternalAccountAuthorizedUser`

* Refactor the common dependency STS to a separate package.
* Adding the `externalaccountauthorizeduser` package.

Change-Id: I9b9624f912d216b67a0d31945a50f057f747710b
GitHub-Last-Rev: 6e2aaff345
GitHub-Pull-Request: golang/oauth2#671
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531095
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-09-28 22:24:46 +00:00
M Hickford
14b275c918 oauth2: workaround misspelling of verification_uri
Some servers misspell verification_uri as verification_url, contrary to spec RFC 8628

Example server https://issuetracker.google.com/issues/151238144

Fixes #666

Change-Id: I89e354368bbb0a4e3b979bb547b4cb37bbe1cc02
GitHub-Last-Rev: bbf169b52d
GitHub-Pull-Request: golang/oauth2#667
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/527835
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nikolay Turpitko <nick.turpitko@gmail.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
2023-09-22 21:51:39 +00:00
aeitzman
18352fc433 google/internal/externalaccount: adding BYOID Metrics
Adds framework for sending BYOID metrics via the x-goog-api-client header on outgoing sts requests. Also adds a header file for getting the current version of GoLang

Change-Id: Id5431def96f4cfc03e4ada01d5fb8cac8cfa56a9
GitHub-Last-Rev: c93cd478e5
GitHub-Pull-Request: golang/oauth2#661
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/523595
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-22 20:39:34 +00:00
M Hickford
9095a51613 oauth2: clarify error if endpoint missing DeviceAuthURL
Change-Id: I36eb5eb66099161785160f4f39ea1c7f64ad6e74
GitHub-Last-Rev: 31cfe8150f
GitHub-Pull-Request: golang/oauth2#664
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/526302
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
2023-09-22 16:24:29 +00:00
Jin Qin
2d9e4a2adf oauth2/google: remove meta validations for aws external credentials
Remove the url validations to keep a consistency with other libraries.

Change-Id: Icb1767edc000d9695db3f0c7ca271918fb2083f5
GitHub-Last-Rev: af89ee0c72
GitHub-Pull-Request: golang/oauth2#660
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/522395
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
2023-09-12 16:01:49 +00:00
M Hickford
55cd552a36 oauth2: support PKCE
Fixes #603

Fixes golang/go#59835

Change-Id: Ica0cfef975ba9511e00f097498d33ba27dafca0d
GitHub-Last-Rev: f01f7593a3
GitHub-Pull-Request: golang/oauth2#625
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/463979
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
2023-09-07 17:49:42 +00:00
M Hickford
e3fb0fb3af oauth2: support device flow
Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628

Tested with GitHub

Fixes #418

Fixes golang/go#58126

Co-authored-by: cmP <centimitr@gmail.com>

Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14
GitHub-Last-Rev: 9a126d7b53
GitHub-Pull-Request: golang/oauth2#609
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155
Commit-Queue: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
2023-09-06 16:35:20 +00:00
32 changed files with 1294 additions and 356 deletions

198
deviceauth.go Normal file
View File

@@ -0,0 +1,198 @@
package oauth2
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
)
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry time.Time `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
type Alias DeviceAuthResponse
var expiresIn int64
if !d.Expiry.IsZero() {
expiresIn = int64(time.Until(d.Expiry).Seconds())
}
return json.Marshal(&struct {
ExpiresIn int64 `json:"expires_in,omitempty"`
*Alias
}{
ExpiresIn: expiresIn,
Alias: (*Alias)(&d),
})
}
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
type Alias DeviceAuthResponse
aux := &struct {
ExpiresIn int64 `json:"expires_in"`
// workaround misspelling of verification_uri
VerificationURL string `json:"verification_url"`
*Alias
}{
Alias: (*Alias)(c),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ExpiresIn != 0 {
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
}
if c.VerificationURI == "" {
c.VerificationURI = aux.VerificationURL
}
return nil
}
// DeviceAuth returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
v := url.Values{
"client_id": {c.ClientID},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveDeviceAuth(ctx, c, v)
}
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}
da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}
if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}
return da, nil
}
// DeviceAccessToken polls the server to exchange a device code for a token.
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
if !da.Expiry.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
defer cancel()
}
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
v := url.Values{
"client_id": {c.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {da.DeviceCode},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
// "If no value is provided, clients MUST use 5 as the default."
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
interval := da.Interval
if interval == 0 {
interval = 5
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
tok, err := retrieveToken(ctx, c, v)
if err == nil {
return tok, nil
}
e, ok := err.(*RetrieveError)
if !ok {
return nil, err
}
switch e.ErrorCode {
case errSlowDown:
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
interval += 5
ticker.Reset(time.Duration(interval) * time.Second)
case errAuthorizationPending:
// Do nothing.
case errAccessDenied, errExpiredToken:
fallthrough
default:
return tok, err
}
}
}
}

97
deviceauth_test.go Normal file
View File

@@ -0,0 +1,97 @@
package oauth2
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
tests := []struct {
name string
response DeviceAuthResponse
want string
}{
{
name: "empty",
response: DeviceAuthResponse{},
want: `{"device_code":"","user_code":"","verification_uri":""}`,
},
{
name: "soon",
response: DeviceAuthResponse{
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
},
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
gotBytes, err := json.Marshal(tc.response)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
t.Skip("test ran too slowly to compare `expires_in`")
}
got := string(gotBytes)
if got != tc.want {
t.Errorf("want=%s, got=%s", tc.want, got)
}
})
}
}
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
tests := []struct {
name string
data string
want DeviceAuthResponse
}{
{
name: "empty",
data: `{}`,
want: DeviceAuthResponse{},
},
{
name: "soon",
data: `{"expires_in":100}`,
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
got := DeviceAuthResponse{}
err := json.Unmarshal([]byte(tc.data), &got)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
t.Errorf("want=%#v, got=%#v", tc.want, got)
}
})
}
}
func ExampleConfig_DeviceAuth() {
var config Config
ctx := context.Background()
response, err := config.DeviceAuth(ctx)
if err != nil {
panic(err)
}
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
token, err := config.DeviceAccessToken(ctx, response)
if err != nil {
panic(err)
}
fmt.Println(token)
}

View File

@@ -57,6 +57,7 @@ var Fitbit = oauth2.Endpoint{
var GitHub = oauth2.Endpoint{ var GitHub = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize", AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token", TokenURL: "https://github.com/login/oauth/access_token",
DeviceAuthURL: "https://github.com/login/device/code",
} }
// GitLab is the endpoint for GitLab. // GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
var Google = oauth2.Endpoint{ var Google = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth", AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token", TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
} }
// Heroku is the endpoint for Heroku. // Heroku is the endpoint for Heroku.

View File

@@ -26,9 +26,13 @@ func ExampleConfig() {
}, },
} }
// use PKCE to protect against CSRF attacks
// https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6
verifier := oauth2.GenerateVerifier()
// Redirect user to consent page to ask for permission // Redirect user to consent page to ask for permission
// for the scopes specified above. // for the scopes specified above.
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline) url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
fmt.Printf("Visit the URL for the auth dialog: %v", url) fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Use the authorization code that is pushed to the redirect // Use the authorization code that is pushed to the redirect
@@ -39,7 +43,7 @@ func ExampleConfig() {
if _, err := fmt.Scan(&code); err != nil { if _, err := fmt.Scan(&code); err != nil {
log.Fatal(err) log.Fatal(err)
} }
tok, err := conf.Exchange(ctx, code) tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@@ -6,11 +6,8 @@
package github // import "golang.org/x/oauth2/github" package github // import "golang.org/x/oauth2/github"
import ( import (
"golang.org/x/oauth2" "golang.org/x/oauth2/endpoints"
) )
// Endpoint is Github's OAuth 2.0 endpoint. // Endpoint is Github's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{ var Endpoint = endpoints.GitHub
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
}

2
go.mod
View File

@@ -11,6 +11,6 @@ require (
require ( require (
cloud.google.com/go/compute v1.20.1 // indirect cloud.google.com/go/compute v1.20.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
golang.org/x/net v0.15.0 // indirect golang.org/x/net v0.19.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
) )

4
go.sum
View File

@@ -11,8 +11,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=

View File

@@ -19,7 +19,10 @@ import (
"golang.org/x/oauth2/authhandler" "golang.org/x/oauth2/authhandler"
) )
const adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc" const (
adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
universeDomainDefault = "googleapis.com"
)
// Credentials holds Google credentials, including "Application Default Credentials". // Credentials holds Google credentials, including "Application Default Credentials".
// For more details, see: // For more details, see:
@@ -37,6 +40,18 @@ type Credentials struct {
// environment and not with a credentials file, e.g. when code is // environment and not with a credentials file, e.g. when code is
// running on Google Cloud Platform. // running on Google Cloud Platform.
JSON []byte JSON []byte
// universeDomain is the default service domain for a given Cloud universe.
universeDomain string
}
// UniverseDomain returns the default service domain for a given Cloud universe.
// The default value is "googleapis.com".
func (c *Credentials) UniverseDomain() string {
if c.universeDomain == "" {
return universeDomainDefault
}
return c.universeDomain
} }
// DefaultCredentials is the old name of Credentials. // DefaultCredentials is the old name of Credentials.
@@ -200,6 +215,13 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
if err := json.Unmarshal(jsonData, &f); err != nil { if err := json.Unmarshal(jsonData, &f); err != nil {
return nil, err return nil, err
} }
universeDomain := f.UniverseDomain
// Authorized user credentials are only supported in the googleapis.com universe.
if f.Type == userCredentialsKey {
universeDomain = universeDomainDefault
}
ts, err := f.tokenSource(ctx, params) ts, err := f.tokenSource(ctx, params)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -209,6 +231,7 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
ProjectID: f.ProjectID, ProjectID: f.ProjectID,
TokenSource: ts, TokenSource: ts,
JSON: jsonData, JSON: jsonData,
universeDomain: universeDomain,
}, nil }, nil
} }

124
google/default_test.go Normal file
View File

@@ -0,0 +1,124 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"context"
"testing"
)
var saJSONJWT = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var saJSONJWTUniverseDomain = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"universe_domain": "example.com",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var userJSON = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2"
}`)
var userJSONUniverseDomain = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2",
"universe_domain": "example.com"
}`)
func TestCredentialsFromJSONWithParams_SA(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWT, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWTUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if want := "example.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_User(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSON, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSONUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}

View File

@@ -101,6 +101,8 @@
// executable-sourced credentials), please check out: // executable-sourced credentials), please check out:
// https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in // https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in
// //
// # Security considerations
//
// Note that this library does not perform any validation on the token_url, token_info_url, // Note that this library does not perform any validation on the token_url, token_info_url,
// or service_account_impersonation_url fields of the credential configuration. // or service_account_impersonation_url fields of the credential configuration.
// It is not recommended to use a credential configuration that you did not generate with // It is not recommended to use a credential configuration that you did not generate with

View File

@@ -16,6 +16,7 @@ import (
"cloud.google.com/go/compute/metadata" "cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/externalaccount" "golang.org/x/oauth2/google/internal/externalaccount"
"golang.org/x/oauth2/google/internal/externalaccountauthorizeduser"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
) )
@@ -23,6 +24,7 @@ import (
var Endpoint = oauth2.Endpoint{ var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth", AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token", TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
} }
@@ -98,6 +100,7 @@ const (
serviceAccountKey = "service_account" serviceAccountKey = "service_account"
userCredentialsKey = "authorized_user" userCredentialsKey = "authorized_user"
externalAccountKey = "external_account" externalAccountKey = "external_account"
externalAccountAuthorizedUserKey = "external_account_authorized_user"
impersonatedServiceAccount = "impersonated_service_account" impersonatedServiceAccount = "impersonated_service_account"
) )
@@ -112,6 +115,7 @@ type credentialsFile struct {
AuthURL string `json:"auth_uri"` AuthURL string `json:"auth_uri"`
TokenURL string `json:"token_uri"` TokenURL string `json:"token_uri"`
ProjectID string `json:"project_id"` ProjectID string `json:"project_id"`
UniverseDomain string `json:"universe_domain"`
// User Credential fields // User Credential fields
// (These typically come from gcloud auth.) // (These typically come from gcloud auth.)
@@ -131,6 +135,9 @@ type credentialsFile struct {
QuotaProjectID string `json:"quota_project_id"` QuotaProjectID string `json:"quota_project_id"`
WorkforcePoolUserProject string `json:"workforce_pool_user_project"` WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
// External Account Authorized User fields
RevokeURL string `json:"revoke_url"`
// Service account impersonation // Service account impersonation
SourceCredentials *credentialsFile `json:"source_credentials"` SourceCredentials *credentialsFile `json:"source_credentials"`
} }
@@ -199,6 +206,19 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
WorkforcePoolUserProject: f.WorkforcePoolUserProject, WorkforcePoolUserProject: f.WorkforcePoolUserProject,
} }
return cfg.TokenSource(ctx) return cfg.TokenSource(ctx)
case externalAccountAuthorizedUserKey:
cfg := &externalaccountauthorizeduser.Config{
Audience: f.Audience,
RefreshToken: f.RefreshToken,
TokenURL: f.TokenURLExternal,
TokenInfoURL: f.TokenInfoURL,
ClientID: f.ClientID,
ClientSecret: f.ClientSecret,
RevokeURL: f.RevokeURL,
QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes,
}
return cfg.TokenSource(ctx)
case impersonatedServiceAccount: case impersonatedServiceAccount:
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil { if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials") return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")

View File

@@ -274,49 +274,6 @@ type awsRequest struct {
Headers []awsRequestHeader `json:"headers"` Headers []awsRequestHeader `json:"headers"`
} }
func (cs awsCredentialSource) validateMetadataServers() error {
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
return err
}
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
return err
}
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
}
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
if metadataUrl == "" {
// Zero value means use default, which is valid.
return true
}
u, err := url.Parse(metadataUrl)
if err != nil {
// Unparseable URL means invalid
return false
}
for _, validHostname := range validHostnames {
if u.Hostname() == validHostname {
// If it's one of the valid hostnames, everything is good
return true
}
}
// hostname not found in our allowlist, so not valid
return false
}
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
if !cs.isValidMetadataServer(metadataUrl) {
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
}
return nil
}
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
if cs.client == nil { if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil) cs.client = oauth2.NewClient(cs.ctx, nil)
@@ -339,6 +296,10 @@ func shouldUseMetadataServer() bool {
return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
} }
func (cs awsCredentialSource) credentialSourceType() string {
return "aws"
}
func (cs awsCredentialSource) subjectToken() (string, error) { func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil { if cs.requestSigner == nil {
headers := make(map[string]string) headers := make(map[string]string)

View File

@@ -585,25 +585,18 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
func TestAWSCredential_BasicRequest(t *testing.T) { func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -631,25 +624,18 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
func TestAWSCredential_IMDSv2(t *testing.T) { func TestAWSCredential_IMDSv2(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -677,10 +663,6 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
delete(server.Credentials, "Token") delete(server.Credentials, "Token")
tfc := testFileConfig tfc := testFileConfig
@@ -688,15 +670,12 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -724,21 +703,15 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -746,7 +719,6 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -774,21 +746,15 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -796,7 +762,6 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -823,21 +788,15 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -846,7 +805,6 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
"AWS_DEFAULT_REGION": "us-east-1", "AWS_DEFAULT_REGION": "us-east-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -873,25 +831,18 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
func TestAWSCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.EnvironmentID = "aws3" tfc.CredentialSource.EnvironmentID = "aws3"
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
_, err = tfc.parse(context.Background()) _, err := tfc.parse(context.Background())
if err == nil { if err == nil {
t.Fatalf("parse() should have failed") t.Fatalf("parse() should have failed")
} }
@@ -903,23 +854,16 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.RegionURL = "" tfc.CredentialSource.RegionURL = ""
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -939,23 +883,17 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRegion = notFound server.WriteRegion = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -975,10 +913,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}")) w.Write([]byte("{}"))
} }
@@ -987,13 +922,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1013,10 +945,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
} }
@@ -1025,13 +954,10 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1051,23 +977,16 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.URL = "" tfc.CredentialSource.URL = ""
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1087,23 +1006,16 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRolename = notFound server.WriteRolename = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1123,23 +1035,16 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = notFound server.WriteSecurityCredentials = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1159,10 +1064,6 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) { func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Metadata server should not have been called.") t.Error("Metadata server should not have been called.")
@@ -1174,11 +1075,9 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -1186,7 +1085,6 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1214,28 +1112,21 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) { func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": accessKeyID, "AWS_ACCESS_KEY_ID": accessKeyID,
"AWS_SECRET_ACCESS_KEY": secretAccessKey, "AWS_SECRET_ACCESS_KEY": secretAccessKey,
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1263,28 +1154,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) { func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1312,28 +1196,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) { func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1358,87 +1235,19 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testin
} }
} }
func TestAWSCredential_Validations(t *testing.T) { func TestAwsCredential_CredentialSourceType(t *testing.T) {
var metadataServerValidityTests = []struct { server := createDefaultAwsTestServer()
name string ts := httptest.NewServer(server)
credSource CredentialSource
errText string
}{
{
name: "No Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "",
URL: "",
IMDSv2SessionTokenURL: "",
},
}, {
name: "IPv4 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
}, {
name: "IPv6 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
},
}, {
name: "Faulty RegionURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url",
}, {
name: "Faulty CredVerificationURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://abc.com/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url",
}, {
name: "Faulty IMDSv2SessionTokenURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://abc.com/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url",
},
}
for _, tt := range metadataServerValidityTests {
t.Run(tt.name, func(t *testing.T) {
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = tt.credSource tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv base, err := tfc.parse(context.Background())
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(context.Background())
if err != nil { if err != nil {
if tt.errText == "" { t.Fatalf("parse() failed %v", err)
t.Errorf("Didn't expect an error, but got %v", err)
} else if tt.errText != err.Error() {
t.Errorf("Expected %v, but got %v", tt.errText, err)
} }
} else {
if tt.errText != "" { if got, want := base.credentialSourceType(), "aws"; got != want {
t.Errorf("Expected error %v, but got none", tt.errText) t.Errorf("got %v but want %v", got, want)
}
}
})
} }
} }

View File

@@ -8,13 +8,12 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
) )
// now aliases time.Now for testing // now aliases time.Now for testing
@@ -63,31 +62,10 @@ type Config struct {
WorkforcePoolUserProject string WorkforcePoolUserProject string
} }
// Each element consists of a list of patterns. validateURLs checks for matches
// that include all elements in a given list, in that order.
var ( var (
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`) validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
) )
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
parsed, err := url.Parse(input)
if err != nil {
return false
}
if !strings.EqualFold(parsed.Scheme, scheme) {
return false
}
toTest := parsed.Host
for _, pattern := range patterns {
if pattern.MatchString(toTest) {
return true
}
}
return false
}
func validateWorkforceAudience(input string) bool { func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input) return validWorkforceAudiencePattern.MatchString(input)
} }
@@ -185,10 +163,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
} }
if err := awsCredSource.validateMetadataServers(); err != nil {
return nil, err
}
return awsCredSource, nil return awsCredSource, nil
} }
} else if c.CredentialSource.File != "" { } else if c.CredentialSource.File != "" {
@@ -202,6 +176,7 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
} }
type baseCredentialSource interface { type baseCredentialSource interface {
credentialSourceType() string
subjectToken() (string, error) subjectToken() (string, error)
} }
@@ -211,6 +186,15 @@ type tokenSource struct {
conf *Config conf *Config
} }
func getMetricsHeaderValue(conf *Config, credSource baseCredentialSource) string {
return fmt.Sprintf("gl-go/%s auth/%s google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t",
goVersion(),
"unknown",
credSource.credentialSourceType(),
conf.ServiceAccountImpersonationURL != "",
conf.ServiceAccountImpersonationLifetimeSeconds != 0)
}
// Token allows tokenSource to conform to the oauth2.TokenSource interface. // Token allows tokenSource to conform to the oauth2.TokenSource interface.
func (ts tokenSource) Token() (*oauth2.Token, error) { func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf conf := ts.conf
@@ -224,7 +208,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
stsRequest := stsTokenExchangeRequest{ stsRequest := stsexchange.TokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange", GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience, Audience: conf.Audience,
Scope: conf.Scopes, Scope: conf.Scopes,
@@ -234,7 +218,8 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
} }
header := make(http.Header) header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded") header.Add("Content-Type", "application/x-www-form-urlencoded")
clientAuth := clientAuthentication{ header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID, ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret, ClientSecret: conf.ClientSecret,
@@ -247,7 +232,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
"userProject": conf.WorkforcePoolUserProject, "userProject": conf.WorkforcePoolUserProject,
} }
} }
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options) stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -6,6 +6,7 @@ package externalaccount
import ( import (
"context" "context"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -51,6 +52,7 @@ type testExchangeTokenServer struct {
url string url string
authorization string authorization string
contentType string contentType string
metricsHeader string
body string body string
response string response string
} }
@@ -68,6 +70,10 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T
if got, want := headerContentType, tets.contentType; got != want { if got, want := headerContentType, tets.contentType; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, tets.metricsHeader; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %s.", err) t.Fatalf("Failed reading request body: %s.", err)
@@ -106,6 +112,10 @@ func validateToken(t *testing.T, tok *oauth2.Token) {
} }
} }
func getExpectedMetricsHeader(source string, saImpersonation bool, configLifetime bool) string {
return fmt.Sprintf("gl-go/%s auth/unknown google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t", goVersion(), source, saImpersonation, configLifetime)
}
func TestToken(t *testing.T) { func TestToken(t *testing.T) {
config := Config{ config := Config{
Audience: "32555940559.apps.googleusercontent.com", Audience: "32555940559.apps.googleusercontent.com",
@@ -120,6 +130,7 @@ func TestToken(t *testing.T) {
url: "/", url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded", contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: baseCredsRequestBody, body: baseCredsRequestBody,
response: baseCredsResponseBody, response: baseCredsResponseBody,
} }
@@ -147,6 +158,7 @@ func TestWorkforcePoolTokenWithClientID(t *testing.T) {
url: "/", url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded", contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithClientId, body: workforcePoolRequestBodyWithClientId,
response: baseCredsResponseBody, response: baseCredsResponseBody,
} }
@@ -173,6 +185,7 @@ func TestWorkforcePoolTokenWithoutClientID(t *testing.T) {
url: "/", url: "/",
authorization: "", authorization: "",
contentType: "application/x-www-form-urlencoded", contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithoutClientId, body: workforcePoolRequestBodyWithoutClientId,
response: baseCredsResponseBody, response: baseCredsResponseBody,
} }

View File

@@ -233,6 +233,10 @@ func (cs executableCredentialSource) parseSubjectTokenFromSource(response []byte
return "", tokenTypeError(source) return "", tokenTypeError(source)
} }
func (cs executableCredentialSource) credentialSourceType() string {
return "executable"
}
func (cs executableCredentialSource) subjectToken() (string, error) { func (cs executableCredentialSource) subjectToken() (string, error) {
if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil { if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil {
return token, err return token, err

View File

@@ -150,6 +150,9 @@ func TestCreateExecutableCredential(t *testing.T) {
if ecs.Timeout != tt.expectedTimeout { if ecs.Timeout != tt.expectedTimeout {
t.Errorf("ecs.Timeout got %v but want %v", ecs.Timeout, tt.expectedTimeout) t.Errorf("ecs.Timeout got %v but want %v", ecs.Timeout, tt.expectedTimeout)
} }
if ecs.credentialSourceType() != "executable" {
t.Errorf("ecs.CredentialSourceType() got %s but want executable", ecs.credentialSourceType())
}
} }
}) })
} }

View File

@@ -19,6 +19,10 @@ type fileCredentialSource struct {
Format format Format format
} }
func (cs fileCredentialSource) credentialSourceType() string {
return "file"
}
func (cs fileCredentialSource) subjectToken() (string, error) { func (cs fileCredentialSource) subjectToken() (string, error) {
tokenFile, err := os.Open(cs.File) tokenFile, err := os.Open(cs.File)
if err != nil { if err != nil {

View File

@@ -68,6 +68,9 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
t.Errorf("got %v but want %v", out, test.want) t.Errorf("got %v but want %v", out, test.want)
} }
if got, want := base.credentialSourceType(), "file"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}) })
} }
} }

View File

@@ -0,0 +1,64 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"strings"
"unicode"
)
var (
// version is a package internal global variable for testing purposes.
version = runtime.Version
)
// versionUnknown is only used when the runtime version cannot be determined.
const versionUnknown = "UNKNOWN"
// goVersion returns a Go runtime version derived from the runtime environment
// that is modified to be suitable for reporting in a header, meaning it has no
// whitespace. If it is unable to determine the Go runtime version, it returns
// versionUnknown.
func goVersion() string {
const develPrefix = "devel +"
s := version()
if strings.HasPrefix(s, develPrefix) {
s = s[len(develPrefix):]
if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
return s
} else if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
notSemverRune := func(r rune) bool {
return !strings.ContainsRune("0123456789.", r)
}
if strings.HasPrefix(s, "go1") {
s = s[2:]
var prerelease string
if p := strings.IndexFunc(s, notSemverRune); p >= 0 {
s, prerelease = s[:p], s[p:]
}
if strings.HasSuffix(s, ".") {
s += "0"
} else if strings.Count(s, ".") < 2 {
s += ".0"
}
if prerelease != "" {
// Some release candidates already have a dash in them.
if !strings.HasPrefix(prerelease, "-") {
prerelease = "-" + prerelease
}
s += prerelease
}
return s
}
return "UNKNOWN"
}

View File

@@ -0,0 +1,48 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGoVersion(t *testing.T) {
testVersion := func(v string) func() string {
return func() string {
return v
}
}
for _, tst := range []struct {
v func() string
want string
}{
{
testVersion("go1.19"),
"1.19.0",
},
{
testVersion("go1.21-20230317-RC01"),
"1.21.0-20230317-RC01",
},
{
testVersion("devel +abc1234"),
"abc1234",
},
{
testVersion("this should be unknown"),
versionUnknown,
},
} {
version = tst.v
got := goVersion()
if diff := cmp.Diff(got, tst.want); diff != "" {
t.Errorf("got(-),want(+):\n%s", diff)
}
}
version = runtime.Version
}

View File

@@ -42,7 +42,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin
})) }))
} }
func createTargetServer(t *testing.T) *httptest.Server { func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), "/"; got != want { if got, want := r.URL.String(), "/"; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want) t.Errorf("URL.String(): got %v but want %v", got, want)
@@ -55,6 +55,10 @@ func createTargetServer(t *testing.T) *httptest.Server {
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, metricsHeaderWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %v.", err) t.Fatalf("Failed reading request body: %v.", err)
@@ -71,6 +75,7 @@ var impersonationTests = []struct {
name string name string
config Config config Config
expectedImpersonationBody string expectedImpersonationBody string
expectedMetricsHeader string
}{ }{
{ {
name: "Base Impersonation", name: "Base Impersonation",
@@ -84,6 +89,7 @@ var impersonationTests = []struct {
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}, },
expectedImpersonationBody: "{\"lifetime\":\"3600s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}", expectedImpersonationBody: "{\"lifetime\":\"3600s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, false),
}, },
{ {
name: "With TokenLifetime Set", name: "With TokenLifetime Set",
@@ -98,6 +104,7 @@ var impersonationTests = []struct {
ServiceAccountImpersonationLifetimeSeconds: 10000, ServiceAccountImpersonationLifetimeSeconds: 10000,
}, },
expectedImpersonationBody: "{\"lifetime\":\"10000s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}", expectedImpersonationBody: "{\"lifetime\":\"10000s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, true),
}, },
} }
@@ -109,7 +116,7 @@ func TestImpersonation(t *testing.T) {
defer impersonateServer.Close() defer impersonateServer.Close()
testImpersonateConfig.ServiceAccountImpersonationURL = impersonateServer.URL testImpersonateConfig.ServiceAccountImpersonationURL = impersonateServer.URL
targetServer := createTargetServer(t) targetServer := createTargetServer(tt.expectedMetricsHeader, t)
defer targetServer.Close() defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL testImpersonateConfig.TokenURL = targetServer.URL

View File

@@ -23,6 +23,10 @@ type urlCredentialSource struct {
ctx context.Context ctx context.Context
} }
func (cs urlCredentialSource) credentialSourceType() string {
return "url"
}
func (cs urlCredentialSource) subjectToken() (string, error) { func (cs urlCredentialSource) subjectToken() (string, error) {
client := oauth2.NewClient(cs.ctx, nil) client := oauth2.NewClient(cs.ctx, nil)
req, err := http.NewRequest("GET", cs.URL, nil) req, err := http.NewRequest("GET", cs.URL, nil)

View File

@@ -111,3 +111,21 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
t.Errorf("got %v but want %v", out, myURLToken) t.Errorf("got %v but want %v", out, myURLToken)
} }
} }
func TestURLCredential_CredentialSourceType(t *testing.T) {
cs := CredentialSource{
URL: "http://example.com",
Format: format{Type: fileTypeText},
}
tfc := testFileConfig
tfc.CredentialSource = cs
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
if got, want := base.credentialSourceType(), "url"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}

View File

@@ -0,0 +1,114 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"errors"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing.
var now = func() time.Time {
return time.Now().UTC()
}
var tokenValid = func(token oauth2.Token) bool {
return token.Valid()
}
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
// the provider identifier in that pool.
Audience string
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
RefreshToken string
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
// None if the token can not be refreshed.
TokenURL string
// TokenInfoURL is the optional STS endpoint URL for token introspection.
TokenInfoURL string
// ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
// access token. When provided, STS will be called with additional basic authentication using client_id as username
// and client_secret as password.
ClientSecret string
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
Token string
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
Expiry time.Time
// RevokeURL is the optional STS endpoint URL for revoking tokens.
RevokeURL string
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
// project used to create the credentials.
QuotaProjectID string
Scopes []string
}
func (c *Config) canRefresh() bool {
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
}
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
var token oauth2.Token
if c.Token != "" && !c.Expiry.IsZero() {
token = oauth2.Token{
AccessToken: c.Token,
Expiry: c.Expiry,
TokenType: "Bearer",
}
}
if !tokenValid(token) && !c.canRefresh() {
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
return oauth2.ReuseTokenSource(&token, ts), nil
}
type tokenSource struct {
ctx context.Context
conf *Config
}
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
if !conf.canRefresh() {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}
if stsResponse.ExpiresIn < 0 {
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
}
if stsResponse.RefreshToken != "" {
conf.RefreshToken = stsResponse.RefreshToken
}
token := &oauth2.Token{
AccessToken: stsResponse.AccessToken,
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
TokenType: "Bearer",
}
return token, nil
}

View File

@@ -0,0 +1,259 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
const expiryDelta = 10 * time.Second
var (
expiry = time.Unix(234852, 0)
testNow = func() time.Time { return expiry }
testValid = func(t oauth2.Token) bool {
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
}
)
type testRefreshTokenServer struct {
URL string
Authorization string
ContentType string
Body string
ResponsePayload *stsexchange.Response
Response string
server *httptest.Server
}
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expiry: now().Add(time.Hour),
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
if config.RefreshToken != "CCCCCCC" {
t.Fatalf("Refresh token not updated")
}
}
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
testCases := []struct {
name string
config Config
}{
{
name: "empty config",
config: Config{},
},
{
name: "missing refresh token",
config: Config{
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing token url",
config: Config{
RefreshToken: "BBBBBBBBB",
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client id",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client secrect",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
_, err := tc.config.TokenSource((context.Background()))
if err == nil {
t.Fatalf("Expected error, but received none")
}
if got := err.Error(); got != expectErrMsg {
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
}
})
}
}
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
t.Helper()
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) close(t *testing.T) error {
t.Helper()
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"encoding/base64" "encoding/base64"
@@ -12,8 +12,8 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. // ClientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type clientAuthentication struct { type ClientAuthentication struct {
// AuthStyle can be either basic or request-body // AuthStyle can be either basic or request-body
AuthStyle oauth2.AuthStyle AuthStyle oauth2.AuthStyle
ClientID string ClientID string
@@ -23,7 +23,7 @@ type clientAuthentication struct {
// InjectAuthentication is used to add authentication to a Secure Token Service exchange // InjectAuthentication is used to add authentication to a Secure Token Service exchange
// request. It modifies either the passed url.Values or http.Header depending on the desired // request. It modifies either the passed url.Values or http.Header depending on the desired
// authentication format. // authentication format.
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) { func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil { if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
return return
} }

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"net/http" "net/http"
@@ -38,7 +38,7 @@ func TestClientAuthentication_InjectHeaderAuthentication(t *testing.T) {
"Content-Type": ContentType, "Content-Type": ContentType,
} }
headerAuthentication := clientAuthentication{ headerAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
@@ -80,7 +80,7 @@ func TestClientAuthentication_ParamsAuthentication(t *testing.T) {
headerP := http.Header{ headerP := http.Header{
"Content-Type": ContentType, "Content-Type": ContentType,
} }
paramsAuthentication := clientAuthentication{ paramsAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"context" "context"
@@ -18,14 +18,17 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// exchangeToken performs an oauth2 token exchange with the provided endpoint. func defaultHeader() http.Header {
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
return header
}
// ExchangeToken performs an oauth2 token exchange with the provided endpoint.
// The first 4 fields are all mandatory. headers can be used to pass additional // The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can // headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server. // be used to pass additional JSON-structured options to the remote server.
func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchangeRequest, authentication clientAuthentication, headers http.Header, options map[string]interface{}) (*stsTokenExchangeResponse, error) { func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
client := oauth2.NewClient(ctx, nil)
data := url.Values{} data := url.Values{}
data.Set("audience", request.Audience) data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@@ -41,13 +44,28 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
data.Set("options", string(opts)) data.Set("options", string(opts))
} }
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
if headers == nil {
headers = defaultHeader()
}
client := oauth2.NewClient(ctx, nil)
authentication.InjectAuthentication(data, headers) authentication.InjectAuthentication(data, headers)
encodedData := data.Encode() encodedData := data.Encode()
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData)) req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err) return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
for key, list := range headers { for key, list := range headers {
@@ -71,7 +89,7 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
} }
var stsResp stsTokenExchangeResponse var stsResp Response
err = json.Unmarshal(body, &stsResp) err = json.Unmarshal(body, &stsResp)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err) return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
@@ -81,8 +99,8 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
return &stsResp, nil return &stsResp, nil
} }
// stsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange. // TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type stsTokenExchangeRequest struct { type TokenExchangeRequest struct {
ActingParty struct { ActingParty struct {
ActorToken string ActorToken string
ActorTokenType string ActorTokenType string
@@ -96,8 +114,8 @@ type stsTokenExchangeRequest struct {
SubjectTokenType string SubjectTokenType string
} }
// stsTokenExchangeResponse is used to decode the remote server response during an oauth2 token exchange. // Response is used to decode the remote server response during an oauth2 token exchange.
type stsTokenExchangeResponse struct { type Response struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"` IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"context" "context"
@@ -16,13 +16,13 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
var auth = clientAuthentication{ var auth = ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
} }
var tokenRequest = stsTokenExchangeRequest{ var exchangeTokenRequest = TokenExchangeRequest{
ActingParty: struct { ActingParty: struct {
ActorToken string ActorToken string
ActorTokenType string ActorTokenType string
@@ -36,9 +36,9 @@ var tokenRequest = stsTokenExchangeRequest{
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
} }
var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" var exchangeRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` var exchangeResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedToken = stsTokenExchangeResponse{ var expectedExchangeToken = Response{
AccessToken: "Sample.Access.Token", AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer", TokenType: "Bearer",
@@ -47,6 +47,18 @@ var expectedToken = stsTokenExchangeResponse{
RefreshToken: "", RefreshToken: "",
} }
var refreshToken = "ReFrEsHtOkEn"
var refreshRequestBody = "grant_type=refresh_token&refresh_token=" + refreshToken
var refreshResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform","refresh_token":"REFRESHED_REFRESH"}`
var expectedRefreshResponse = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: "https://www.googleapis.com/auth/cloud-platform",
RefreshToken: "REFRESHED_REFRESH",
}
func TestExchangeToken(t *testing.T) { func TestExchangeToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
@@ -65,26 +77,34 @@ func TestExchangeToken(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %v.", err) t.Errorf("Failed reading request body: %v.", err)
} }
if got, want := string(body), requestbody; got != want { if got, want := string(body), exchangeRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want) t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody)) w.Write([]byte(exchangeResponseBody))
})) }))
defer ts.Close() defer ts.Close()
headers := http.Header{} headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded") headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) resp, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err != nil { if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err) t.Fatalf("exchangeToken failed with error: %v", err)
} }
if expectedToken != *resp { if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedToken, *resp) t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
} }
resp, err = ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, nil, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
} }
func TestExchangeToken_Err(t *testing.T) { func TestExchangeToken_Err(t *testing.T) {
@@ -96,7 +116,7 @@ func TestExchangeToken_Err(t *testing.T) {
headers := http.Header{} headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded") headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) _, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err == nil { if err == nil {
t.Errorf("Expected handled error; instead got nil.") t.Errorf("Expected handled error; instead got nil.")
} }
@@ -171,7 +191,7 @@ func TestExchangeToken_Opts(t *testing.T) {
// Send a proper reply so that no other errors crop up. // Send a proper reply so that no other errors crop up.
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody)) w.Write([]byte(exchangeResponseBody))
})) }))
defer ts.Close() defer ts.Close()
@@ -183,5 +203,69 @@ func TestExchangeToken_Opts(t *testing.T) {
inputOpts := make(map[string]interface{}) inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption inputOpts["one"] = firstOption
inputOpts["two"] = secondOption inputOpts["two"] = secondOption
exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, inputOpts) ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
}
func TestRefreshToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Unexpected request method, %v is found", r.Method)
}
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
if got, want := r.Header.Get("Authorization"), "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("Unexpected authorization header, got %v, want %v", got, want)
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), refreshRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(refreshResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
resp, err = RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
}
func TestRefreshToken_Err(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("what's wrong with this response?"))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
} }

View File

@@ -76,6 +76,7 @@ type TokenSource interface {
// endpoint URLs. // endpoint URLs.
type Endpoint struct { type Endpoint struct {
AuthURL string AuthURL string
DeviceAuthURL string
TokenURL string TokenURL string
// AuthStyle optionally specifies how the endpoint wants the // AuthStyle optionally specifies how the endpoint wants the
@@ -143,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page // AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly. // that asks for permissions for the required scopes explicitly.
// //
// State is a token to protect the user from CSRF attacks. You must // State is an opaque value used by the client to maintain state between the
// always provide a non-empty string and validate that it matches the // request and callback. The authorization server includes this value when
// state query parameter on your redirect callback. // redirecting the user agent back to the client.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// //
// Opts may include AccessTypeOnline or AccessTypeOffline, as well // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce. // as ApprovalForce.
// It can also be used to pass the PKCE challenge. //
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. // To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
// generate a random state parameter and verify it after exchange.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL) buf.WriteString(c.Endpoint.AuthURL)
@@ -166,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
v.Set("scope", strings.Join(c.Scopes, " ")) v.Set("scope", strings.Join(c.Scopes, " "))
} }
if state != "" { if state != "" {
// TODO(light): Docs say never to omit state; don't allow empty.
v.Set("state", state) v.Set("state", state)
} }
for _, opt := range opts { for _, opt := range opts {
@@ -211,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. // The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
// //
// The code will be in the *http.Request.FormValue("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state"). // calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks.
// //
// Opts may include the PKCE verifier code if previously used in AuthCodeURL. // If using PKCE to protect against CSRF attacks, opts should include a
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. // VerifierOption.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) { func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{ v := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},

68
pkce.go Normal file
View File

@@ -0,0 +1,68 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/url"
)
const (
codeChallengeKey = "code_challenge"
codeChallengeMethodKey = "code_challenge_method"
codeVerifierKey = "code_verifier"
)
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
// This follows recommendations in RFC 7636.
//
// A fresh verifier should be generated for each authorization.
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
// (or Config.DeviceAccessToken).
func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then
// base64url-encoded to produce a 43-octet URL-safe string to use as the
// code verifier."
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
data := make([]byte, 32)
if _, err := rand.Read(data); err != nil {
panic(err)
}
return base64.RawURLEncoding.EncodeToString(data)
}
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier}
}
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
//
// Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:])
}
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
// only.
func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{
challenge_method: "S256",
challenge: S256ChallengeFromVerifier(verifier),
}
}
type challengeOption struct{ challenge_method, challenge string }
func (p challengeOption) setValue(m url.Values) {
m.Set(codeChallengeMethodKey, p.challenge_method)
m.Set(codeChallengeKey, p.challenge)
}