19 Commits

Author SHA1 Message Date
Gopher Robot
e067960af8 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Id1413f67816220ef8039fb933088f4b7f50d70e5
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/540817
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
2023-11-08 20:28:19 +00:00
Leo
4c91c17b32 google: adds header to security considerations section
Change-Id: I29b93715876f233ae52687c8223fd8733a2a3b80
GitHub-Last-Rev: f15c4cf1a5
GitHub-Pull-Request: golang/oauth2#677
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/535895
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-17 20:42:42 +00:00
Gopher Robot
3c5dbf08cc go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I39a72a7dbb2205a6638a154892c69948ee2deb0d
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/533241
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2023-10-06 08:33:24 +00:00
Chris Smith
11625ccb95 google: add authorized_user conditional to Credentials.UniverseDomain
Return default universe domain if credentials type is authorized_user.

Change-Id: I20a9b5fafa562fcec84717914a236d081f630591
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/532196
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-02 22:04:54 +00:00
Chris Smith
8d6d45b6cd google: add Credentials.UniverseDomain to support TPC
Read and expose universe_domain from service account JSON files in
CredentialsFromJSONWithParams to support TPC in 1p clients.

Change-Id: I3518a0ec8be5ff7235b946cffd88b26ac8d303cf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531715
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-29 14:33:30 +00:00
Jin Qin
43b6a7ba19 google: adding support for external account authorized user
To support a new type of credential: `ExternalAccountAuthorizedUser`

* Refactor the common dependency STS to a separate package.
* Adding the `externalaccountauthorizeduser` package.

Change-Id: I9b9624f912d216b67a0d31945a50f057f747710b
GitHub-Last-Rev: 6e2aaff345
GitHub-Pull-Request: golang/oauth2#671
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531095
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-09-28 22:24:46 +00:00
M Hickford
14b275c918 oauth2: workaround misspelling of verification_uri
Some servers misspell verification_uri as verification_url, contrary to spec RFC 8628

Example server https://issuetracker.google.com/issues/151238144

Fixes #666

Change-Id: I89e354368bbb0a4e3b979bb547b4cb37bbe1cc02
GitHub-Last-Rev: bbf169b52d
GitHub-Pull-Request: golang/oauth2#667
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/527835
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nikolay Turpitko <nick.turpitko@gmail.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
2023-09-22 21:51:39 +00:00
aeitzman
18352fc433 google/internal/externalaccount: adding BYOID Metrics
Adds framework for sending BYOID metrics via the x-goog-api-client header on outgoing sts requests. Also adds a header file for getting the current version of GoLang

Change-Id: Id5431def96f4cfc03e4ada01d5fb8cac8cfa56a9
GitHub-Last-Rev: c93cd478e5
GitHub-Pull-Request: golang/oauth2#661
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/523595
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-22 20:39:34 +00:00
M Hickford
9095a51613 oauth2: clarify error if endpoint missing DeviceAuthURL
Change-Id: I36eb5eb66099161785160f4f39ea1c7f64ad6e74
GitHub-Last-Rev: 31cfe8150f
GitHub-Pull-Request: golang/oauth2#664
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/526302
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
2023-09-22 16:24:29 +00:00
Jin Qin
2d9e4a2adf oauth2/google: remove meta validations for aws external credentials
Remove the url validations to keep a consistency with other libraries.

Change-Id: Icb1767edc000d9695db3f0c7ca271918fb2083f5
GitHub-Last-Rev: af89ee0c72
GitHub-Pull-Request: golang/oauth2#660
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/522395
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
2023-09-12 16:01:49 +00:00
M Hickford
55cd552a36 oauth2: support PKCE
Fixes #603

Fixes golang/go#59835

Change-Id: Ica0cfef975ba9511e00f097498d33ba27dafca0d
GitHub-Last-Rev: f01f7593a3
GitHub-Pull-Request: golang/oauth2#625
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/463979
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
2023-09-07 17:49:42 +00:00
M Hickford
e3fb0fb3af oauth2: support device flow
Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628

Tested with GitHub

Fixes #418

Fixes golang/go#58126

Co-authored-by: cmP <centimitr@gmail.com>

Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14
GitHub-Last-Rev: 9a126d7b53
GitHub-Pull-Request: golang/oauth2#609
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155
Commit-Queue: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
2023-09-06 16:35:20 +00:00
Gopher Robot
07085280e4 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I2fb95ca59417e20377bc315094221fa7165128c8
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/525675
Reviewed-by: Heschi Kreinick <heschi@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
2023-09-05 16:42:47 +00:00
Brad Fitzpatrick
a835fc4358 oauth2: move global auth style cache to be per-Config
In 80673b4a4 (https://go.dev/cl/157820) I added a never-shrinking
package-global cache to remember which auto-detected auth style (HTTP
headers vs POST) was supported by a certain OAuth2 server, keyed by
its URL.

Unfortunately, some multi-tenant SaaS OIDC servers behave poorly and
have one global OpenID configuration document for all of their
customers which says ("we support all auth styles! you pick!") but
then give each customer control of which style they specifically
accept. This is bogus behavior on their part, but the oauth2 package's
global caching per URL isn't helping. (It's also bad to have a
package-global cache that can never be GC'ed)

So, this change moves the cache to hang off the oauth *Configs
instead. Unfortunately, it does so with some backwards compatiblity
compromises (an atomic.Value hack), lest people are using old versions
of Go still or copying a Config by value, both of which this package
previously accidentally supported, even though they weren't tested.

This change also means that anybody that's repeatedly making ephemeral
oauth.Configs without an explicit auth style will be losing &
reinitializing their cache on any auth style failures + fallbacks to
the other style. I think that should be pretty rare. People seem to
make an oauth2.Config once earlier and stash it away somewhere (often
deep in a token fetcher or HTTP client/transport).

Change-Id: I91f107368ab3c3d77bc425eeef65372a589feb7b
Signed-off-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/515675
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
2023-08-09 17:53:10 +00:00
Gopher Robot
2e4a4e2bfb go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I953aeb97bb9ed634f69dc93cf1f21392261c930c
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/516037
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-08-04 23:51:27 +00:00
Dmitri Shuralyov
ac6658e9cb all: update go version to 1.18
Go versions 1.16 and 1.17 are long since unsupported per Go release
policy (https://go.dev/doc/devel/release#policy).

Updating go.mod's go statement to 1.18 makes it so that 'go mod tidy'
doesn't include checksums needed for the full module graph loaded by
Go 1.16¹ that were recently added in CL 507840.

It also makes go fix remove the now-obsolete // +build lines².

Done using cmd/go at go1.21rc2:

$ go get go@1.18
go: upgraded go 1.17 => 1.18
$ go mod tidy
$ go fix ./...
google/appengine_gen1.go: fixed buildtag
google/appengine_gen2_flex.go: fixed buildtag
internal/client_appengine.go: fixed buildtag

¹ https://go.dev/ref/mod#graph-pruning
² https://go.dev/doc/go1.18#go-build-lines

Change-Id: I6c6295adef1f5c64a196c2e66005763893efe5e7
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/507878
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-07-05 21:55:59 +00:00
Gopher Robot
ec5679f607 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I6b389549fe4bc53a62cb383c5fb10156ccfcffba
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/507840
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
2023-07-05 21:07:49 +00:00
Bryan C. Mills
989acb1bfe all: update dependencies to their latest versions
This change was prepared by running:
	go1.21rc2 get -u -t ./...
	go1.21rc2 mod tidy -compat=1.17

Change-Id: I533c4361aae073b7a5280aad2c2e5eea752df62a
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/506296
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Bryan Mills <bcmills@google.com>
2023-06-26 19:20:11 +00:00
Gopher Robot
2323c81c8d go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I7a693f42e110b957194337a0d355dd1f2a5e14ca
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/502797
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Run-TryBot: Gopher Robot <gobot@golang.org>
2023-06-13 14:28:47 +00:00
41 changed files with 1370 additions and 440 deletions

View File

@@ -47,6 +47,10 @@ type Config struct {
// client ID & client secret sent. The zero value means to
// auto-detect.
AuthStyle oauth2.AuthStyle
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
}
// Token uses client credentials to retrieve a token.
@@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v[k] = p
}
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*oauth2.RetrieveError)(rErr)

View File

@@ -12,8 +12,6 @@ import (
"net/http/httptest"
"net/url"
"testing"
"golang.org/x/oauth2/internal"
)
func newConf(serverURL string) *Config {
@@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
}
func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return

198
deviceauth.go Normal file
View File

@@ -0,0 +1,198 @@
package oauth2
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
)
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry time.Time `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
type Alias DeviceAuthResponse
var expiresIn int64
if !d.Expiry.IsZero() {
expiresIn = int64(time.Until(d.Expiry).Seconds())
}
return json.Marshal(&struct {
ExpiresIn int64 `json:"expires_in,omitempty"`
*Alias
}{
ExpiresIn: expiresIn,
Alias: (*Alias)(&d),
})
}
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
type Alias DeviceAuthResponse
aux := &struct {
ExpiresIn int64 `json:"expires_in"`
// workaround misspelling of verification_uri
VerificationURL string `json:"verification_url"`
*Alias
}{
Alias: (*Alias)(c),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ExpiresIn != 0 {
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
}
if c.VerificationURI == "" {
c.VerificationURI = aux.VerificationURL
}
return nil
}
// DeviceAuth returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
v := url.Values{
"client_id": {c.ClientID},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveDeviceAuth(ctx, c, v)
}
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}
da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}
if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}
return da, nil
}
// DeviceAccessToken polls the server to exchange a device code for a token.
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
if !da.Expiry.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
defer cancel()
}
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
v := url.Values{
"client_id": {c.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {da.DeviceCode},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
// "If no value is provided, clients MUST use 5 as the default."
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
interval := da.Interval
if interval == 0 {
interval = 5
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
tok, err := retrieveToken(ctx, c, v)
if err == nil {
return tok, nil
}
e, ok := err.(*RetrieveError)
if !ok {
return nil, err
}
switch e.ErrorCode {
case errSlowDown:
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
interval += 5
ticker.Reset(time.Duration(interval) * time.Second)
case errAuthorizationPending:
// Do nothing.
case errAccessDenied, errExpiredToken:
fallthrough
default:
return tok, err
}
}
}
}

97
deviceauth_test.go Normal file
View File

@@ -0,0 +1,97 @@
package oauth2
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
tests := []struct {
name string
response DeviceAuthResponse
want string
}{
{
name: "empty",
response: DeviceAuthResponse{},
want: `{"device_code":"","user_code":"","verification_uri":""}`,
},
{
name: "soon",
response: DeviceAuthResponse{
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
},
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
gotBytes, err := json.Marshal(tc.response)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
t.Skip("test ran too slowly to compare `expires_in`")
}
got := string(gotBytes)
if got != tc.want {
t.Errorf("want=%s, got=%s", tc.want, got)
}
})
}
}
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
tests := []struct {
name string
data string
want DeviceAuthResponse
}{
{
name: "empty",
data: `{}`,
want: DeviceAuthResponse{},
},
{
name: "soon",
data: `{"expires_in":100}`,
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
got := DeviceAuthResponse{}
err := json.Unmarshal([]byte(tc.data), &got)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
t.Errorf("want=%#v, got=%#v", tc.want, got)
}
})
}
}
func ExampleConfig_DeviceAuth() {
var config Config
ctx := context.Background()
response, err := config.DeviceAuth(ctx)
if err != nil {
panic(err)
}
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
token, err := config.DeviceAccessToken(ctx, response)
if err != nil {
panic(err)
}
fmt.Println(token)
}

View File

@@ -57,6 +57,7 @@ var Fitbit = oauth2.Endpoint{
var GitHub = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
DeviceAuthURL: "https://github.com/login/device/code",
}
// GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
var Google = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
}
// Heroku is the endpoint for Heroku.

View File

@@ -26,9 +26,13 @@ func ExampleConfig() {
},
}
// use PKCE to protect against CSRF attacks
// https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6
verifier := oauth2.GenerateVerifier()
// Redirect user to consent page to ask for permission
// for the scopes specified above.
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Use the authorization code that is pushed to the redirect
@@ -39,7 +43,7 @@ func ExampleConfig() {
if _, err := fmt.Scan(&code); err != nil {
log.Fatal(err)
}
tok, err := conf.Exchange(ctx, code)
tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
if err != nil {
log.Fatal(err)
}

View File

@@ -6,11 +6,8 @@
package github // import "golang.org/x/oauth2/github"
import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
// Endpoint is Github's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
}
var Endpoint = endpoints.GitHub

13
go.mod
View File

@@ -1,15 +1,16 @@
module golang.org/x/oauth2
go 1.17
go 1.18
require (
cloud.google.com/go/compute/metadata v0.2.0
github.com/google/go-cmp v0.5.8
cloud.google.com/go/compute/metadata v0.2.3
github.com/google/go-cmp v0.5.9
google.golang.org/appengine v1.6.7
)
require (
github.com/golang/protobuf v1.5.2 // indirect
golang.org/x/net v0.10.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
cloud.google.com/go/compute v1.20.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
golang.org/x/net v0.18.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
)

51
go.sum
View File

@@ -1,53 +1,26 @@
cloud.google.com/go/compute/metadata v0.2.0 h1:nBbNSZyDpkNlo3DepaaLKVuO7ClyifSAmNloSCZrHnQ=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZNbg=
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg=
golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

View File

@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build appengine
// +build appengine
// This file applies to App Engine first generation runtimes (<= Go 1.9).

View File

@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !appengine
// +build !appengine
// This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible.

View File

@@ -19,7 +19,10 @@ import (
"golang.org/x/oauth2/authhandler"
)
const adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
const (
adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
universeDomainDefault = "googleapis.com"
)
// Credentials holds Google credentials, including "Application Default Credentials".
// For more details, see:
@@ -37,6 +40,18 @@ type Credentials struct {
// environment and not with a credentials file, e.g. when code is
// running on Google Cloud Platform.
JSON []byte
// universeDomain is the default service domain for a given Cloud universe.
universeDomain string
}
// UniverseDomain returns the default service domain for a given Cloud universe.
// The default value is "googleapis.com".
func (c *Credentials) UniverseDomain() string {
if c.universeDomain == "" {
return universeDomainDefault
}
return c.universeDomain
}
// DefaultCredentials is the old name of Credentials.
@@ -200,6 +215,13 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
if err := json.Unmarshal(jsonData, &f); err != nil {
return nil, err
}
universeDomain := f.UniverseDomain
// Authorized user credentials are only supported in the googleapis.com universe.
if f.Type == userCredentialsKey {
universeDomain = universeDomainDefault
}
ts, err := f.tokenSource(ctx, params)
if err != nil {
return nil, err
@@ -209,6 +231,7 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
ProjectID: f.ProjectID,
TokenSource: ts,
JSON: jsonData,
universeDomain: universeDomain,
}, nil
}

124
google/default_test.go Normal file
View File

@@ -0,0 +1,124 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"context"
"testing"
)
var saJSONJWT = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var saJSONJWTUniverseDomain = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"universe_domain": "example.com",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var userJSON = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2"
}`)
var userJSONUniverseDomain = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2",
"universe_domain": "example.com"
}`)
func TestCredentialsFromJSONWithParams_SA(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWT, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWTUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if want := "example.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_User(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSON, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSONUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}

View File

@@ -101,6 +101,8 @@
// executable-sourced credentials), please check out:
// https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in
//
// # Security considerations
//
// Note that this library does not perform any validation on the token_url, token_info_url,
// or service_account_impersonation_url fields of the credential configuration.
// It is not recommended to use a credential configuration that you did not generate with

View File

@@ -16,6 +16,7 @@ import (
"cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/externalaccount"
"golang.org/x/oauth2/google/internal/externalaccountauthorizeduser"
"golang.org/x/oauth2/jwt"
)
@@ -23,6 +24,7 @@ import (
var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
AuthStyle: oauth2.AuthStyleInParams,
}
@@ -98,6 +100,7 @@ const (
serviceAccountKey = "service_account"
userCredentialsKey = "authorized_user"
externalAccountKey = "external_account"
externalAccountAuthorizedUserKey = "external_account_authorized_user"
impersonatedServiceAccount = "impersonated_service_account"
)
@@ -112,6 +115,7 @@ type credentialsFile struct {
AuthURL string `json:"auth_uri"`
TokenURL string `json:"token_uri"`
ProjectID string `json:"project_id"`
UniverseDomain string `json:"universe_domain"`
// User Credential fields
// (These typically come from gcloud auth.)
@@ -131,6 +135,9 @@ type credentialsFile struct {
QuotaProjectID string `json:"quota_project_id"`
WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
// External Account Authorized User fields
RevokeURL string `json:"revoke_url"`
// Service account impersonation
SourceCredentials *credentialsFile `json:"source_credentials"`
}
@@ -199,6 +206,19 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
WorkforcePoolUserProject: f.WorkforcePoolUserProject,
}
return cfg.TokenSource(ctx)
case externalAccountAuthorizedUserKey:
cfg := &externalaccountauthorizeduser.Config{
Audience: f.Audience,
RefreshToken: f.RefreshToken,
TokenURL: f.TokenURLExternal,
TokenInfoURL: f.TokenInfoURL,
ClientID: f.ClientID,
ClientSecret: f.ClientSecret,
RevokeURL: f.RevokeURL,
QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes,
}
return cfg.TokenSource(ctx)
case impersonatedServiceAccount:
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")

View File

@@ -274,49 +274,6 @@ type awsRequest struct {
Headers []awsRequestHeader `json:"headers"`
}
func (cs awsCredentialSource) validateMetadataServers() error {
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
return err
}
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
return err
}
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
}
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
if metadataUrl == "" {
// Zero value means use default, which is valid.
return true
}
u, err := url.Parse(metadataUrl)
if err != nil {
// Unparseable URL means invalid
return false
}
for _, validHostname := range validHostnames {
if u.Hostname() == validHostname {
// If it's one of the valid hostnames, everything is good
return true
}
}
// hostname not found in our allowlist, so not valid
return false
}
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
if !cs.isValidMetadataServer(metadataUrl) {
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
}
return nil
}
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil)
@@ -339,6 +296,10 @@ func shouldUseMetadataServer() bool {
return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
}
func (cs awsCredentialSource) credentialSourceType() string {
return "aws"
}
func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil {
headers := make(map[string]string)

View File

@@ -585,25 +585,18 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -631,25 +624,18 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
func TestAWSCredential_IMDSv2(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -677,10 +663,6 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
delete(server.Credentials, "Token")
tfc := testFileConfig
@@ -688,15 +670,12 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -724,21 +703,15 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -746,7 +719,6 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -774,21 +746,15 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -796,7 +762,6 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -823,21 +788,15 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -846,7 +805,6 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
"AWS_DEFAULT_REGION": "us-east-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -873,25 +831,18 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.EnvironmentID = "aws3"
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
_, err = tfc.parse(context.Background())
_, err := tfc.parse(context.Background())
if err == nil {
t.Fatalf("parse() should have failed")
}
@@ -903,23 +854,16 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.RegionURL = ""
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -939,23 +883,17 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRegion = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -975,10 +913,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}"))
}
@@ -987,13 +922,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1013,10 +945,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
}
@@ -1025,13 +954,10 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1051,23 +977,16 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.URL = ""
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1087,23 +1006,16 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRolename = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1123,23 +1035,16 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1159,10 +1064,6 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Metadata server should not have been called.")
@@ -1174,11 +1075,9 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -1186,7 +1085,6 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1214,28 +1112,21 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": accessKeyID,
"AWS_SECRET_ACCESS_KEY": secretAccessKey,
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1263,28 +1154,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1312,28 +1196,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1358,87 +1235,19 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testin
}
}
func TestAWSCredential_Validations(t *testing.T) {
var metadataServerValidityTests = []struct {
name string
credSource CredentialSource
errText string
}{
{
name: "No Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "",
URL: "",
IMDSv2SessionTokenURL: "",
},
}, {
name: "IPv4 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
}, {
name: "IPv6 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
},
}, {
name: "Faulty RegionURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url",
}, {
name: "Faulty CredVerificationURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://abc.com/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url",
}, {
name: "Faulty IMDSv2SessionTokenURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://abc.com/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url",
},
}
func TestAwsCredential_CredentialSourceType(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
for _, tt := range metadataServerValidityTests {
t.Run(tt.name, func(t *testing.T) {
tfc := testFileConfig
tfc.CredentialSource = tt.credSource
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(context.Background())
base, err := tfc.parse(context.Background())
if err != nil {
if tt.errText == "" {
t.Errorf("Didn't expect an error, but got %v", err)
} else if tt.errText != err.Error() {
t.Errorf("Expected %v, but got %v", tt.errText, err)
t.Fatalf("parse() failed %v", err)
}
} else {
if tt.errText != "" {
t.Errorf("Expected error %v, but got none", tt.errText)
}
}
})
if got, want := base.credentialSourceType(), "aws"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}

View File

@@ -8,13 +8,12 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing
@@ -63,31 +62,10 @@ type Config struct {
WorkforcePoolUserProject string
}
// Each element consists of a list of patterns. validateURLs checks for matches
// that include all elements in a given list, in that order.
var (
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
)
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
parsed, err := url.Parse(input)
if err != nil {
return false
}
if !strings.EqualFold(parsed.Scheme, scheme) {
return false
}
toTest := parsed.Host
for _, pattern := range patterns {
if pattern.MatchString(toTest) {
return true
}
}
return false
}
func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input)
}
@@ -185,10 +163,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
}
if err := awsCredSource.validateMetadataServers(); err != nil {
return nil, err
}
return awsCredSource, nil
}
} else if c.CredentialSource.File != "" {
@@ -202,6 +176,7 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
}
type baseCredentialSource interface {
credentialSourceType() string
subjectToken() (string, error)
}
@@ -211,6 +186,15 @@ type tokenSource struct {
conf *Config
}
func getMetricsHeaderValue(conf *Config, credSource baseCredentialSource) string {
return fmt.Sprintf("gl-go/%s auth/%s google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t",
goVersion(),
"unknown",
credSource.credentialSourceType(),
conf.ServiceAccountImpersonationURL != "",
conf.ServiceAccountImpersonationLifetimeSeconds != 0)
}
// Token allows tokenSource to conform to the oauth2.TokenSource interface.
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
@@ -224,7 +208,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, err
}
stsRequest := stsTokenExchangeRequest{
stsRequest := stsexchange.TokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience,
Scope: conf.Scopes,
@@ -234,7 +218,8 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
}
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
clientAuth := clientAuthentication{
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
@@ -247,7 +232,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
"userProject": conf.WorkforcePoolUserProject,
}
}
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ package externalaccount
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@@ -51,6 +52,7 @@ type testExchangeTokenServer struct {
url string
authorization string
contentType string
metricsHeader string
body string
response string
}
@@ -68,6 +70,10 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T
if got, want := headerContentType, tets.contentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, tets.metricsHeader; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
@@ -106,6 +112,10 @@ func validateToken(t *testing.T, tok *oauth2.Token) {
}
}
func getExpectedMetricsHeader(source string, saImpersonation bool, configLifetime bool) string {
return fmt.Sprintf("gl-go/%s auth/unknown google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t", goVersion(), source, saImpersonation, configLifetime)
}
func TestToken(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
@@ -120,6 +130,7 @@ func TestToken(t *testing.T) {
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: baseCredsRequestBody,
response: baseCredsResponseBody,
}
@@ -147,6 +158,7 @@ func TestWorkforcePoolTokenWithClientID(t *testing.T) {
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithClientId,
response: baseCredsResponseBody,
}
@@ -173,6 +185,7 @@ func TestWorkforcePoolTokenWithoutClientID(t *testing.T) {
url: "/",
authorization: "",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithoutClientId,
response: baseCredsResponseBody,
}

View File

@@ -233,6 +233,10 @@ func (cs executableCredentialSource) parseSubjectTokenFromSource(response []byte
return "", tokenTypeError(source)
}
func (cs executableCredentialSource) credentialSourceType() string {
return "executable"
}
func (cs executableCredentialSource) subjectToken() (string, error) {
if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil {
return token, err

View File

@@ -150,6 +150,9 @@ func TestCreateExecutableCredential(t *testing.T) {
if ecs.Timeout != tt.expectedTimeout {
t.Errorf("ecs.Timeout got %v but want %v", ecs.Timeout, tt.expectedTimeout)
}
if ecs.credentialSourceType() != "executable" {
t.Errorf("ecs.CredentialSourceType() got %s but want executable", ecs.credentialSourceType())
}
}
})
}

View File

@@ -19,6 +19,10 @@ type fileCredentialSource struct {
Format format
}
func (cs fileCredentialSource) credentialSourceType() string {
return "file"
}
func (cs fileCredentialSource) subjectToken() (string, error) {
tokenFile, err := os.Open(cs.File)
if err != nil {

View File

@@ -68,6 +68,9 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
t.Errorf("got %v but want %v", out, test.want)
}
if got, want := base.credentialSourceType(), "file"; got != want {
t.Errorf("got %v but want %v", got, want)
}
})
}
}

View File

@@ -0,0 +1,64 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"strings"
"unicode"
)
var (
// version is a package internal global variable for testing purposes.
version = runtime.Version
)
// versionUnknown is only used when the runtime version cannot be determined.
const versionUnknown = "UNKNOWN"
// goVersion returns a Go runtime version derived from the runtime environment
// that is modified to be suitable for reporting in a header, meaning it has no
// whitespace. If it is unable to determine the Go runtime version, it returns
// versionUnknown.
func goVersion() string {
const develPrefix = "devel +"
s := version()
if strings.HasPrefix(s, develPrefix) {
s = s[len(develPrefix):]
if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
return s
} else if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
notSemverRune := func(r rune) bool {
return !strings.ContainsRune("0123456789.", r)
}
if strings.HasPrefix(s, "go1") {
s = s[2:]
var prerelease string
if p := strings.IndexFunc(s, notSemverRune); p >= 0 {
s, prerelease = s[:p], s[p:]
}
if strings.HasSuffix(s, ".") {
s += "0"
} else if strings.Count(s, ".") < 2 {
s += ".0"
}
if prerelease != "" {
// Some release candidates already have a dash in them.
if !strings.HasPrefix(prerelease, "-") {
prerelease = "-" + prerelease
}
s += prerelease
}
return s
}
return "UNKNOWN"
}

View File

@@ -0,0 +1,48 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGoVersion(t *testing.T) {
testVersion := func(v string) func() string {
return func() string {
return v
}
}
for _, tst := range []struct {
v func() string
want string
}{
{
testVersion("go1.19"),
"1.19.0",
},
{
testVersion("go1.21-20230317-RC01"),
"1.21.0-20230317-RC01",
},
{
testVersion("devel +abc1234"),
"abc1234",
},
{
testVersion("this should be unknown"),
versionUnknown,
},
} {
version = tst.v
got := goVersion()
if diff := cmp.Diff(got, tst.want); diff != "" {
t.Errorf("got(-),want(+):\n%s", diff)
}
}
version = runtime.Version
}

View File

@@ -42,7 +42,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin
}))
}
func createTargetServer(t *testing.T) *httptest.Server {
func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), "/"; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
@@ -55,6 +55,10 @@ func createTargetServer(t *testing.T) *httptest.Server {
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, metricsHeaderWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
@@ -71,6 +75,7 @@ var impersonationTests = []struct {
name string
config Config
expectedImpersonationBody string
expectedMetricsHeader string
}{
{
name: "Base Impersonation",
@@ -84,6 +89,7 @@ var impersonationTests = []struct {
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
},
expectedImpersonationBody: "{\"lifetime\":\"3600s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, false),
},
{
name: "With TokenLifetime Set",
@@ -98,6 +104,7 @@ var impersonationTests = []struct {
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
expectedImpersonationBody: "{\"lifetime\":\"10000s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, true),
},
}
@@ -109,7 +116,7 @@ func TestImpersonation(t *testing.T) {
defer impersonateServer.Close()
testImpersonateConfig.ServiceAccountImpersonationURL = impersonateServer.URL
targetServer := createTargetServer(t)
targetServer := createTargetServer(tt.expectedMetricsHeader, t)
defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL

View File

@@ -23,6 +23,10 @@ type urlCredentialSource struct {
ctx context.Context
}
func (cs urlCredentialSource) credentialSourceType() string {
return "url"
}
func (cs urlCredentialSource) subjectToken() (string, error) {
client := oauth2.NewClient(cs.ctx, nil)
req, err := http.NewRequest("GET", cs.URL, nil)

View File

@@ -111,3 +111,21 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
t.Errorf("got %v but want %v", out, myURLToken)
}
}
func TestURLCredential_CredentialSourceType(t *testing.T) {
cs := CredentialSource{
URL: "http://example.com",
Format: format{Type: fileTypeText},
}
tfc := testFileConfig
tfc.CredentialSource = cs
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
if got, want := base.credentialSourceType(), "url"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}

View File

@@ -0,0 +1,114 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"errors"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing.
var now = func() time.Time {
return time.Now().UTC()
}
var tokenValid = func(token oauth2.Token) bool {
return token.Valid()
}
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
// the provider identifier in that pool.
Audience string
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
RefreshToken string
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
// None if the token can not be refreshed.
TokenURL string
// TokenInfoURL is the optional STS endpoint URL for token introspection.
TokenInfoURL string
// ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
// access token. When provided, STS will be called with additional basic authentication using client_id as username
// and client_secret as password.
ClientSecret string
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
Token string
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
Expiry time.Time
// RevokeURL is the optional STS endpoint URL for revoking tokens.
RevokeURL string
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
// project used to create the credentials.
QuotaProjectID string
Scopes []string
}
func (c *Config) canRefresh() bool {
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
}
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
var token oauth2.Token
if c.Token != "" && !c.Expiry.IsZero() {
token = oauth2.Token{
AccessToken: c.Token,
Expiry: c.Expiry,
TokenType: "Bearer",
}
}
if !tokenValid(token) && !c.canRefresh() {
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
return oauth2.ReuseTokenSource(&token, ts), nil
}
type tokenSource struct {
ctx context.Context
conf *Config
}
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
if !conf.canRefresh() {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}
if stsResponse.ExpiresIn < 0 {
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
}
if stsResponse.RefreshToken != "" {
conf.RefreshToken = stsResponse.RefreshToken
}
token := &oauth2.Token{
AccessToken: stsResponse.AccessToken,
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
TokenType: "Bearer",
}
return token, nil
}

View File

@@ -0,0 +1,259 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
const expiryDelta = 10 * time.Second
var (
expiry = time.Unix(234852, 0)
testNow = func() time.Time { return expiry }
testValid = func(t oauth2.Token) bool {
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
}
)
type testRefreshTokenServer struct {
URL string
Authorization string
ContentType string
Body string
ResponsePayload *stsexchange.Response
Response string
server *httptest.Server
}
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expiry: now().Add(time.Hour),
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
if config.RefreshToken != "CCCCCCC" {
t.Fatalf("Refresh token not updated")
}
}
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
testCases := []struct {
name string
config Config
}{
{
name: "empty config",
config: Config{},
},
{
name: "missing refresh token",
config: Config{
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing token url",
config: Config{
RefreshToken: "BBBBBBBBB",
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client id",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client secrect",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
_, err := tc.config.TokenSource((context.Background()))
if err == nil {
t.Fatalf("Expected error, but received none")
}
if got := err.Error(); got != expectErrMsg {
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
}
})
}
}
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
t.Helper()
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) close(t *testing.T) error {
t.Helper()
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"encoding/base64"
@@ -12,8 +12,8 @@ import (
"golang.org/x/oauth2"
)
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type clientAuthentication struct {
// ClientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type ClientAuthentication struct {
// AuthStyle can be either basic or request-body
AuthStyle oauth2.AuthStyle
ClientID string
@@ -23,7 +23,7 @@ type clientAuthentication struct {
// InjectAuthentication is used to add authentication to a Secure Token Service exchange
// request. It modifies either the passed url.Values or http.Header depending on the desired
// authentication format.
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
return
}

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"net/http"
@@ -38,7 +38,7 @@ func TestClientAuthentication_InjectHeaderAuthentication(t *testing.T) {
"Content-Type": ContentType,
}
headerAuthentication := clientAuthentication{
headerAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID,
ClientSecret: clientSecret,
@@ -80,7 +80,7 @@ func TestClientAuthentication_ParamsAuthentication(t *testing.T) {
headerP := http.Header{
"Content-Type": ContentType,
}
paramsAuthentication := clientAuthentication{
paramsAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInParams,
ClientID: clientID,
ClientSecret: clientSecret,

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"context"
@@ -18,14 +18,17 @@ import (
"golang.org/x/oauth2"
)
// exchangeToken performs an oauth2 token exchange with the provided endpoint.
func defaultHeader() http.Header {
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
return header
}
// ExchangeToken performs an oauth2 token exchange with the provided endpoint.
// The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server.
func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchangeRequest, authentication clientAuthentication, headers http.Header, options map[string]interface{}) (*stsTokenExchangeResponse, error) {
client := oauth2.NewClient(ctx, nil)
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
data := url.Values{}
data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@@ -41,13 +44,28 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
data.Set("options", string(opts))
}
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
if headers == nil {
headers = defaultHeader()
}
client := oauth2.NewClient(ctx, nil)
authentication.InjectAuthentication(data, headers)
encodedData := data.Encode()
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
}
req = req.WithContext(ctx)
for key, list := range headers {
@@ -71,7 +89,7 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
}
var stsResp stsTokenExchangeResponse
var stsResp Response
err = json.Unmarshal(body, &stsResp)
if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
@@ -81,8 +99,8 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
return &stsResp, nil
}
// stsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type stsTokenExchangeRequest struct {
// TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type TokenExchangeRequest struct {
ActingParty struct {
ActorToken string
ActorTokenType string
@@ -96,8 +114,8 @@ type stsTokenExchangeRequest struct {
SubjectTokenType string
}
// stsTokenExchangeResponse is used to decode the remote server response during an oauth2 token exchange.
type stsTokenExchangeResponse struct {
// Response is used to decode the remote server response during an oauth2 token exchange.
type Response struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"context"
@@ -16,13 +16,13 @@ import (
"golang.org/x/oauth2"
)
var auth = clientAuthentication{
var auth = ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID,
ClientSecret: clientSecret,
}
var tokenRequest = stsTokenExchangeRequest{
var exchangeTokenRequest = TokenExchangeRequest{
ActingParty: struct {
ActorToken string
ActorTokenType string
@@ -36,9 +36,9 @@ var tokenRequest = stsTokenExchangeRequest{
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
}
var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedToken = stsTokenExchangeResponse{
var exchangeRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var exchangeResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedExchangeToken = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
@@ -47,6 +47,18 @@ var expectedToken = stsTokenExchangeResponse{
RefreshToken: "",
}
var refreshToken = "ReFrEsHtOkEn"
var refreshRequestBody = "grant_type=refresh_token&refresh_token=" + refreshToken
var refreshResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform","refresh_token":"REFRESHED_REFRESH"}`
var expectedRefreshResponse = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: "https://www.googleapis.com/auth/cloud-platform",
RefreshToken: "REFRESHED_REFRESH",
}
func TestExchangeToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
@@ -65,26 +77,34 @@ func TestExchangeToken(t *testing.T) {
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), requestbody; got != want {
if got, want := string(body), exchangeRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody))
w.Write([]byte(exchangeResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
resp, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedToken, *resp)
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
resp, err = ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, nil, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
}
func TestExchangeToken_Err(t *testing.T) {
@@ -96,7 +116,7 @@ func TestExchangeToken_Err(t *testing.T) {
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
_, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
@@ -171,7 +191,7 @@ func TestExchangeToken_Opts(t *testing.T) {
// Send a proper reply so that no other errors crop up.
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody))
w.Write([]byte(exchangeResponseBody))
}))
defer ts.Close()
@@ -183,5 +203,69 @@ func TestExchangeToken_Opts(t *testing.T) {
inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption
inputOpts["two"] = secondOption
exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, inputOpts)
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
}
func TestRefreshToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Unexpected request method, %v is found", r.Method)
}
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
if got, want := r.Header.Get("Authorization"), "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("Unexpected authorization header, got %v, want %v", got, want)
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), refreshRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(refreshResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
resp, err = RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
}
func TestRefreshToken_Err(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("what's wrong with this response?"))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
}

View File

@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build appengine
// +build appengine
package internal

View File

@@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
@@ -115,41 +116,60 @@ const (
AuthStyleInHeader AuthStyle = 2
)
// authStyleCache is the set of tokenURLs we've successfully used via
// LazyAuthStyleCache is a backwards compatibility compromise to let Configs
// have a lazily-initialized AuthStyleCache.
//
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
// both would ideally just embed an unexported AuthStyleCache but because both
// were historically allowed to be copied by value we can't retroactively add an
// uncopyable Mutex to them.
//
// We could use an atomic.Pointer, but that was added recently enough (in Go
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
// still pass. By using an atomic.Value, it supports both Go 1.17 and
// copying by value, even if that's not ideal.
type LazyAuthStyleCache struct {
v atomic.Value // of *AuthStyleCache
}
func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
return c
}
c := new(AuthStyleCache)
if !lc.v.CompareAndSwap(nil, c) {
c = lc.v.Load().(*AuthStyleCache)
}
return c
}
// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
// the set of OAuth2 servers a program contacts over time is fixed and
// small.
var authStyleCache struct {
sync.Mutex
type AuthStyleCache struct {
mu sync.Mutex
m map[string]AuthStyle // keyed by tokenURL
}
// ResetAuthCache resets the global authentication style cache used
// for AuthStyleUnknown token requests.
func ResetAuthCache() {
authStyleCache.Lock()
defer authStyleCache.Unlock()
authStyleCache.m = nil
}
// lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so.
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
authStyleCache.Lock()
defer authStyleCache.Unlock()
style, ok = authStyleCache.m[tokenURL]
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
style, ok = c.m[tokenURL]
return
}
// setAuthStyle adds an entry to authStyleCache, documented above.
func setAuthStyle(tokenURL string, v AuthStyle) {
authStyleCache.Lock()
defer authStyleCache.Unlock()
if authStyleCache.m == nil {
authStyleCache.m = make(map[string]AuthStyle)
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
c.mu.Lock()
defer c.mu.Unlock()
if c.m == nil {
c.m = make(map[string]AuthStyle)
}
authStyleCache.m[tokenURL] = v
c.m[tokenURL] = v
}
// newTokenRequest returns a new *http.Request to retrieve a new token
@@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
return v2
}
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe {
if style, ok := lookupAuthStyle(tokenURL); ok {
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
@@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req)
}
if needsAuthStyleProbe && err == nil {
setAuthStyle(tokenURL, authStyle)
styleCache.setAuthStyle(tokenURL, authStyle)
}
// Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request.

View File

@@ -16,7 +16,7 @@ import (
)
func TestRetrieveToken_InParams(t *testing.T) {
ResetAuthCache()
styleCache := new(AuthStyleCache)
const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want {
@@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
}))
defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil {
t.Errorf("RetrieveToken = %v; want no error", err)
}
}
func TestRetrieveTokenWithContexts(t *testing.T) {
ResetAuthCache()
styleCache := new(AuthStyleCache)
const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
}))
defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil {
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
}
@@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
close(retrieved)
if err == nil {
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")

View File

@@ -58,6 +58,10 @@ type Config struct {
// Scope specifies optional requested permissions.
Scopes []string
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
}
// A TokenSource is anything that can return a token.
@@ -72,6 +76,7 @@ type TokenSource interface {
// endpoint URLs.
type Endpoint struct {
AuthURL string
DeviceAuthURL string
TokenURL string
// AuthStyle optionally specifies how the endpoint wants the
@@ -139,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// State is an opaque value used by the client to maintain state between the
// request and callback. The authorization server includes this value when
// redirecting the user agent back to the client.
//
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce.
// It can also be used to pass the PKCE challenge.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
//
// To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
// generate a random state parameter and verify it after exchange.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL)
@@ -162,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
v.Set("scope", strings.Join(c.Scopes, " "))
}
if state != "" {
// TODO(light): Docs say never to omit state; don't allow empty.
v.Set("state", state)
}
for _, opt := range opts {
@@ -207,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
//
// The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state").
// calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks.
//
// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
// If using PKCE to protect against CSRF attacks, opts should include a
// VerifierOption.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{
"grant_type": {"authorization_code"},

View File

@@ -15,8 +15,6 @@ import (
"net/url"
"testing"
"time"
"golang.org/x/oauth2/internal"
)
type mockTransport struct {
@@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
}
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
internal.ResetAuthCache()
tr := &mockTransport{
rt: func(r *http.Request) (w *http.Response, err error) {
headerAuth := r.Header.Get("Authorization")
@@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
}
func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return

68
pkce.go Normal file
View File

@@ -0,0 +1,68 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/url"
)
const (
codeChallengeKey = "code_challenge"
codeChallengeMethodKey = "code_challenge_method"
codeVerifierKey = "code_verifier"
)
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
// This follows recommendations in RFC 7636.
//
// A fresh verifier should be generated for each authorization.
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
// (or Config.DeviceAccessToken).
func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then
// base64url-encoded to produce a 43-octet URL-safe string to use as the
// code verifier."
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
data := make([]byte, 32)
if _, err := rand.Read(data); err != nil {
panic(err)
}
return base64.RawURLEncoding.EncodeToString(data)
}
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier}
}
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
//
// Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:])
}
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
// only.
func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{
challenge_method: "S256",
challenge: S256ChallengeFromVerifier(verifier),
}
}
type challengeOption struct{ challenge_method, challenge string }
func (p challengeOption) setValue(m url.Values) {
m.Set(codeChallengeMethodKey, p.challenge_method)
m.Set(codeChallengeKey, p.challenge)
}

View File

@@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle))
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)