45 Commits

Author SHA1 Message Date
4ad99b3ed5 remove usage of appengine to get rid of unsafe imports 2024-04-23 18:30:17 +02:00
Chris Smith
d0e617c58c google: add Credentials.UniverseDomainProvider
* move MDS universe retrieval within Compute credentials

Change-Id: I847d2075ca11bde998a06220307626e902230c23
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/575936
Reviewed-by: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2024-04-03 20:36:14 +00:00
Jin Qin
3c9c1f6d00 oauth2/google: fix the logic of sts 0 value of expires_in
The sts response contains an optional field of `expires_in` and the value can be any integer.

https://github.com/golang/oauth2/blob/master/google/internal/externalaccount/basecredentials.go#L246-L248

In the case of less than `0`, we are going to throw an error. But in the case of equals to `0` practically it means "never expire" instead of "instantly expire" which doesn't make sense.

So we need to not set the expiration value for Token object. The current else if greater or equal is wrong.

It's never triggered only because we are sending positive `3600` in sts response.

Change-Id: Id227ca71130855235572b65ab178681e80d0da3a
GitHub-Last-Rev: a95c923d6a5d256fa92629a1fcb908495d7b1338
GitHub-Pull-Request: golang/oauth2#687
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/545895
Reviewed-by: Shin Fan <shinfan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
2024-03-12 20:05:50 +00:00
Jin Qin
5a05c654f9 oauth2/google: fix remove content-type header from idms get requests
This is a fix on the https://github.com/googleapis/google-cloud-go/pull/9508.
The aws provider in that library is a ported dependency from here.

Change-Id: I28e1efa4fdb8292210b695a164a55060c83dae88
GitHub-Last-Rev: c425f2d3b12082bdd477100648a9e46cab026da0
GitHub-Pull-Request: golang/oauth2#711
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/570875
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Chris Smith <chrisdsmith@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2024-03-12 14:54:40 +00:00
Jordan Liggitt
3a6776ada7 appengine: drop obsolete code for AppEngine envs <=Go 1.11
This library no longer builds on Go versions prior to Go 1.17,
so no longer needs to support compilation specific to AppEngine
environments on Go versions prior to Go 1.11

Related to #615

Change-Id: Ia9579ea2091cb86ee96065affb920370c4ba33ea
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/570595
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2024-03-11 19:47:38 +00:00
Gopher Robot
85231f99d6 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I993c77edbea8426f558ab84c4ba769e0bdf6406d
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/568935
Reviewed-by: Than McIntosh <thanm@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
2024-03-04 22:41:57 +00:00
Chris Smith
34a7afaa85 google/externalaccount: add Config.UniverseDomain
Change-Id: Ia1caee246da68c01addd06e1367ed1e43645826b
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/568216
Reviewed-by: Alex Eitzman <eitzman@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2024-03-04 19:42:12 +00:00
aeitzman
95bec95381 google/externalaccount: moves externalaccount package out of internal and exports it
go/programmable-auth-design for context. Adds support for user defined
 supplier methods to return subject tokens and AWS security credentials.

Change-Id: I7bc41f8c5202ae933fce516632f5049bbeb3d378
GitHub-Last-Rev: ac519b242f8315df572f1b205b0670f139bfc6c3
GitHub-Pull-Request: golang/oauth2#690
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/550835
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Chris Smith <chrisdsmith@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2024-02-27 21:55:11 +00:00
Gopher Robot
ebe81ad837 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I8228a126b322fb14250bbb5933199ce45e8584d3
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/562496
Reviewed-by: Than McIntosh <thanm@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
2024-02-08 13:19:31 +00:00
Chris Smith
adffd94437 google/internal/externalaccount: update serviceAccountImpersonationRE to support universe domain
Change-Id: Iafe35c293209bd88997c876341ebde7ac9ecda93
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/557195
TryBot-Bypass: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
2024-01-19 20:50:34 +00:00
Chris Smith
deefa7e836 google/downscope: add DownscopingConfig.UniverseDomain to support TPC
Change-Id: I3669352b382414ea640ca176afa4071995fc5ff1
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/557135
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Bypass: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
2024-01-19 18:57:04 +00:00
Gopher Robot
39adbb7807 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Icf68cb33585a13df206afacdb79832ea76f82346
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/554676
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
2024-01-08 18:34:15 +00:00
Chris Smith
4ce7bbb2ff google: add Credentials.GetUniverseDomain with GCE MDS support
* Deprecate Credentials.UniverseDomain

Change-Id: I1cbc842fbfce35540c8dff99fec09e036b9e2cdf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/554215
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Viacheslav Rostovtsev <virost@google.com>
2024-01-05 14:38:43 +00:00
Chris Smith
1e6999b1be google: add UniverseDomain to CredentialsParams
Change-Id: I7925b8341e1f047d0115acd7a01a34679a489ee0
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/552716
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Viacheslav Rostovtsev <virost@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2024-01-04 15:11:51 +00:00
Gopher Robot
6e9ec9323d go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Iad79e50dacd89c4cd0a40d966a1a7ba4cdc3d1a4
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/545176
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
2023-11-27 17:50:56 +00:00
Gopher Robot
e067960af8 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Id1413f67816220ef8039fb933088f4b7f50d70e5
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/540817
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
2023-11-08 20:28:19 +00:00
Leo
4c91c17b32 google: adds header to security considerations section
Change-Id: I29b93715876f233ae52687c8223fd8733a2a3b80
GitHub-Last-Rev: f15c4cf1a5
GitHub-Pull-Request: golang/oauth2#677
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/535895
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-17 20:42:42 +00:00
Gopher Robot
3c5dbf08cc go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I39a72a7dbb2205a6638a154892c69948ee2deb0d
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/533241
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2023-10-06 08:33:24 +00:00
Chris Smith
11625ccb95 google: add authorized_user conditional to Credentials.UniverseDomain
Return default universe domain if credentials type is authorized_user.

Change-Id: I20a9b5fafa562fcec84717914a236d081f630591
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/532196
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-02 22:04:54 +00:00
Chris Smith
8d6d45b6cd google: add Credentials.UniverseDomain to support TPC
Read and expose universe_domain from service account JSON files in
CredentialsFromJSONWithParams to support TPC in 1p clients.

Change-Id: I3518a0ec8be5ff7235b946cffd88b26ac8d303cf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531715
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-29 14:33:30 +00:00
Jin Qin
43b6a7ba19 google: adding support for external account authorized user
To support a new type of credential: `ExternalAccountAuthorizedUser`

* Refactor the common dependency STS to a separate package.
* Adding the `externalaccountauthorizeduser` package.

Change-Id: I9b9624f912d216b67a0d31945a50f057f747710b
GitHub-Last-Rev: 6e2aaff345
GitHub-Pull-Request: golang/oauth2#671
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531095
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-09-28 22:24:46 +00:00
M Hickford
14b275c918 oauth2: workaround misspelling of verification_uri
Some servers misspell verification_uri as verification_url, contrary to spec RFC 8628

Example server https://issuetracker.google.com/issues/151238144

Fixes #666

Change-Id: I89e354368bbb0a4e3b979bb547b4cb37bbe1cc02
GitHub-Last-Rev: bbf169b52d
GitHub-Pull-Request: golang/oauth2#667
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/527835
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nikolay Turpitko <nick.turpitko@gmail.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
2023-09-22 21:51:39 +00:00
aeitzman
18352fc433 google/internal/externalaccount: adding BYOID Metrics
Adds framework for sending BYOID metrics via the x-goog-api-client header on outgoing sts requests. Also adds a header file for getting the current version of GoLang

Change-Id: Id5431def96f4cfc03e4ada01d5fb8cac8cfa56a9
GitHub-Last-Rev: c93cd478e5
GitHub-Pull-Request: golang/oauth2#661
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/523595
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-22 20:39:34 +00:00
M Hickford
9095a51613 oauth2: clarify error if endpoint missing DeviceAuthURL
Change-Id: I36eb5eb66099161785160f4f39ea1c7f64ad6e74
GitHub-Last-Rev: 31cfe8150f
GitHub-Pull-Request: golang/oauth2#664
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/526302
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
2023-09-22 16:24:29 +00:00
Jin Qin
2d9e4a2adf oauth2/google: remove meta validations for aws external credentials
Remove the url validations to keep a consistency with other libraries.

Change-Id: Icb1767edc000d9695db3f0c7ca271918fb2083f5
GitHub-Last-Rev: af89ee0c72
GitHub-Pull-Request: golang/oauth2#660
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/522395
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
2023-09-12 16:01:49 +00:00
M Hickford
55cd552a36 oauth2: support PKCE
Fixes #603

Fixes golang/go#59835

Change-Id: Ica0cfef975ba9511e00f097498d33ba27dafca0d
GitHub-Last-Rev: f01f7593a3
GitHub-Pull-Request: golang/oauth2#625
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/463979
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
2023-09-07 17:49:42 +00:00
M Hickford
e3fb0fb3af oauth2: support device flow
Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628

Tested with GitHub

Fixes #418

Fixes golang/go#58126

Co-authored-by: cmP <centimitr@gmail.com>

Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14
GitHub-Last-Rev: 9a126d7b53
GitHub-Pull-Request: golang/oauth2#609
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155
Commit-Queue: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
2023-09-06 16:35:20 +00:00
Gopher Robot
07085280e4 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I2fb95ca59417e20377bc315094221fa7165128c8
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/525675
Reviewed-by: Heschi Kreinick <heschi@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
2023-09-05 16:42:47 +00:00
Brad Fitzpatrick
a835fc4358 oauth2: move global auth style cache to be per-Config
In 80673b4a4 (https://go.dev/cl/157820) I added a never-shrinking
package-global cache to remember which auto-detected auth style (HTTP
headers vs POST) was supported by a certain OAuth2 server, keyed by
its URL.

Unfortunately, some multi-tenant SaaS OIDC servers behave poorly and
have one global OpenID configuration document for all of their
customers which says ("we support all auth styles! you pick!") but
then give each customer control of which style they specifically
accept. This is bogus behavior on their part, but the oauth2 package's
global caching per URL isn't helping. (It's also bad to have a
package-global cache that can never be GC'ed)

So, this change moves the cache to hang off the oauth *Configs
instead. Unfortunately, it does so with some backwards compatiblity
compromises (an atomic.Value hack), lest people are using old versions
of Go still or copying a Config by value, both of which this package
previously accidentally supported, even though they weren't tested.

This change also means that anybody that's repeatedly making ephemeral
oauth.Configs without an explicit auth style will be losing &
reinitializing their cache on any auth style failures + fallbacks to
the other style. I think that should be pretty rare. People seem to
make an oauth2.Config once earlier and stash it away somewhere (often
deep in a token fetcher or HTTP client/transport).

Change-Id: I91f107368ab3c3d77bc425eeef65372a589feb7b
Signed-off-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/515675
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
2023-08-09 17:53:10 +00:00
Gopher Robot
2e4a4e2bfb go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I953aeb97bb9ed634f69dc93cf1f21392261c930c
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/516037
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-08-04 23:51:27 +00:00
Dmitri Shuralyov
ac6658e9cb all: update go version to 1.18
Go versions 1.16 and 1.17 are long since unsupported per Go release
policy (https://go.dev/doc/devel/release#policy).

Updating go.mod's go statement to 1.18 makes it so that 'go mod tidy'
doesn't include checksums needed for the full module graph loaded by
Go 1.16¹ that were recently added in CL 507840.

It also makes go fix remove the now-obsolete // +build lines².

Done using cmd/go at go1.21rc2:

$ go get go@1.18
go: upgraded go 1.17 => 1.18
$ go mod tidy
$ go fix ./...
google/appengine_gen1.go: fixed buildtag
google/appengine_gen2_flex.go: fixed buildtag
internal/client_appengine.go: fixed buildtag

¹ https://go.dev/ref/mod#graph-pruning
² https://go.dev/doc/go1.18#go-build-lines

Change-Id: I6c6295adef1f5c64a196c2e66005763893efe5e7
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/507878
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-07-05 21:55:59 +00:00
Gopher Robot
ec5679f607 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I6b389549fe4bc53a62cb383c5fb10156ccfcffba
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/507840
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
2023-07-05 21:07:49 +00:00
Bryan C. Mills
989acb1bfe all: update dependencies to their latest versions
This change was prepared by running:
	go1.21rc2 get -u -t ./...
	go1.21rc2 mod tidy -compat=1.17

Change-Id: I533c4361aae073b7a5280aad2c2e5eea752df62a
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/506296
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Bryan Mills <bcmills@google.com>
2023-06-26 19:20:11 +00:00
Gopher Robot
2323c81c8d go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I7a693f42e110b957194337a0d355dd1f2a5e14ca
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/502797
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Run-TryBot: Gopher Robot <gobot@golang.org>
2023-06-13 14:28:47 +00:00
Cody Oss
839de2255f google: don't check for IsNotExist for well-known file
There are cases when reading this file that a ENOTDIR is returned.
Because of this it is safer to just fall-back when any error
happens from reading the gcloud file.

Change-Id: Ie8e45ad508643e900adb5c9787907aaa50cceb5d
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/493695
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Russ Cox <rsc@golang.org>
Auto-Submit: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-05-08 21:24:50 +00:00
Gopher Robot
0690208dba go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: I97dfa241b763dfba4fc0c02da2f241255e2f53d1
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/493576
Reviewed-by: Heschi Kreinick <heschi@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Gopher Robot <gobot@golang.org>
2023-05-08 17:08:26 +00:00
cui fliter
451d5d662f internal: remove repeated definite articles
Change-Id: I0ce35bd2b7b870de9c0ffd898f245b49edbe55f7
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/489715
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: shuang cui <imcusg@gmail.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
2023-05-04 16:27:46 +00:00
M Hickford
cfe200d5bb oauth2: parse RFC 6749 error response
Parse error response described in https://datatracker.ietf.org/doc/html/rfc6749#section-5.2

Handle unorthodox servers responding 200 in error case.

Implements API changes in accepted proposal https://github.com/golang/go/issues/58125

Fixes #441
Fixes #274
Updates #173

Change-Id: If9399c3f952ac0501edbeefeb3a71ed057ca8d37
GitHub-Last-Rev: 0030e27422
GitHub-Pull-Request: golang/oauth2#610
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/451076
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-04-11 16:15:57 +00:00
Gopher Robot
36075149c5 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.
Once this CL is submitted, and post-submit testing succeeds on all
first-class ports across all supported Go versions, this repository
will be tagged with its next minor version.

Change-Id: If1689e1b37e36e8e8dd1cfc37fe9cb94bd49c807
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/482856
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Run-TryBot: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
2023-04-06 17:54:20 +00:00
Cody Oss
4abfd87339 google: add CredentialsParams.EarlyTokenRefresh
This option is a followup to to cl/479676 where an option was added
to configure the preemptive token refresh. Currently the option
in this package is only being used by compute credentials. In the
future we can support more/all auth flows but that would require
a lot of new surfaces to be added. Compute credentials are currently
the only case where we are expirencing the need to configure this
setting.

Change-Id: Ib78ca4beec44d0fe030ae81e84c8fcc4924793ba
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/479956
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
2023-03-29 20:00:17 +00:00
Roland Shoemaker
1e7f329364 oauth2: add ReuseTokenSourceWithExpiry
Add a constructor which allows for the configuration of the expiryDelta
buffer. Due to the construction of reuseTokenSource and Token we need
to store the new delta in both places, so the behavior of Valid is
consistent regardless of where it is called from.

Fixes #623

Change-Id: I89f9c206a9cc16bb473b8c619605c8410a82fff0
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/479676
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-03-28 18:36:12 +00:00
thomas-goncalves
86850e0723 oauth2: fix typo
Change-Id: I515f8897cc79c58a8a49df84ccddc5acd9536d87
GitHub-Last-Rev: 5acbebb81b
GitHub-Pull-Request: golang/oauth2#616
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/459695
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
2023-03-24 18:42:48 +00:00
aeitzman
a6e37e7441 google: Updating 3pi documentation
Fixing dead links in workload docs, adds workforce documentation

Change-Id: Ifad86e1937997f96ef577f5469d1e6fe496197b5
GitHub-Last-Rev: af288081ce
GitHub-Pull-Request: golang/oauth2#638
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/478555
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-03-22 21:13:33 +00:00
Cody Oss
54b70c833f google: update missing auth help URL
Update the URL to a newer page that better describes how to set
up credentials in different environments.

Change-Id: Ic0726fe298c543265d333cda60d62c235e4e2293
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/473735
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Auto-Submit: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-03-17 18:50:21 +00:00
M Hickford
2fc4ef5a6f README: encourage issues and proposals before changes
Text verbatim from https://go.dev/doc/contribute

Change-Id: Iefdcf9e9f771b9e55601bf9c9b59e20593b4573a
GitHub-Last-Rev: ba45caadaf
GitHub-Pull-Request: golang/oauth2#632
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/471281
Run-TryBot: Matthew Hickford <hickford@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Reviewed-by: Matthew Hickford <hickford@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-03-10 21:26:16 +00:00
58 changed files with 3450 additions and 1483 deletions

View File

@@ -19,7 +19,7 @@ See pkg.go.dev for further documentation and examples.
* [pkg.go.dev/golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) * [pkg.go.dev/golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2)
* [pkg.go.dev/golang.org/x/oauth2/google](https://pkg.go.dev/golang.org/x/oauth2/google) * [pkg.go.dev/golang.org/x/oauth2/google](https://pkg.go.dev/golang.org/x/oauth2/google)
## Policy for new packages ## Policy for new endpoints
We no longer accept new provider-specific packages in this repo if all We no longer accept new provider-specific packages in this repo if all
they do is add a single endpoint variable. If you just want to add a they do is add a single endpoint variable. If you just want to add a
@@ -29,8 +29,12 @@ package.
## Report Issues / Send Patches ## Report Issues / Send Patches
This repository uses Gerrit for code changes. To learn how to submit changes to
this repository, see https://golang.org/doc/contribute.html.
The main issue tracker for the oauth2 repository is located at The main issue tracker for the oauth2 repository is located at
https://github.com/golang/oauth2/issues. https://github.com/golang/oauth2/issues.
This repository uses Gerrit for code changes. To learn how to submit changes to
this repository, see https://golang.org/doc/contribute.html. In particular:
* Excluding trivial changes, all contributions should be connected to an existing issue.
* API changes must go through the [change proposal process](https://go.dev/s/proposal-process) before they can be accepted.
* The code owners are listed at [dev.golang.org/owners](https://dev.golang.org/owners#:~:text=x/oauth2).

View File

@@ -47,6 +47,10 @@ type Config struct {
// client ID & client secret sent. The zero value means to // client ID & client secret sent. The zero value means to
// auto-detect. // auto-detect.
AuthStyle oauth2.AuthStyle AuthStyle oauth2.AuthStyle
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
} }
// Token uses client credentials to retrieve a token. // Token uses client credentials to retrieve a token.
@@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v[k] = p v[k] = p
} }
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
if err != nil { if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok { if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*oauth2.RetrieveError)(rErr) return nil, (*oauth2.RetrieveError)(rErr)

View File

@@ -12,8 +12,6 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"golang.org/x/oauth2/internal"
) )
func newConf(serverURL string) *Config { func newConf(serverURL string) *Config {
@@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
} }
func TestTokenRefreshRequest(t *testing.T) { func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" { if r.URL.String() == "/somethingelse" {
return return

198
deviceauth.go Normal file
View File

@@ -0,0 +1,198 @@
package oauth2
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
)
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry time.Time `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
type Alias DeviceAuthResponse
var expiresIn int64
if !d.Expiry.IsZero() {
expiresIn = int64(time.Until(d.Expiry).Seconds())
}
return json.Marshal(&struct {
ExpiresIn int64 `json:"expires_in,omitempty"`
*Alias
}{
ExpiresIn: expiresIn,
Alias: (*Alias)(&d),
})
}
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
type Alias DeviceAuthResponse
aux := &struct {
ExpiresIn int64 `json:"expires_in"`
// workaround misspelling of verification_uri
VerificationURL string `json:"verification_url"`
*Alias
}{
Alias: (*Alias)(c),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ExpiresIn != 0 {
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
}
if c.VerificationURI == "" {
c.VerificationURI = aux.VerificationURL
}
return nil
}
// DeviceAuth returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
v := url.Values{
"client_id": {c.ClientID},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveDeviceAuth(ctx, c, v)
}
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}
da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}
if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}
return da, nil
}
// DeviceAccessToken polls the server to exchange a device code for a token.
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
if !da.Expiry.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
defer cancel()
}
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
v := url.Values{
"client_id": {c.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {da.DeviceCode},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
// "If no value is provided, clients MUST use 5 as the default."
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
interval := da.Interval
if interval == 0 {
interval = 5
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
tok, err := retrieveToken(ctx, c, v)
if err == nil {
return tok, nil
}
e, ok := err.(*RetrieveError)
if !ok {
return nil, err
}
switch e.ErrorCode {
case errSlowDown:
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
interval += 5
ticker.Reset(time.Duration(interval) * time.Second)
case errAuthorizationPending:
// Do nothing.
case errAccessDenied, errExpiredToken:
fallthrough
default:
return tok, err
}
}
}
}

97
deviceauth_test.go Normal file
View File

@@ -0,0 +1,97 @@
package oauth2
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
tests := []struct {
name string
response DeviceAuthResponse
want string
}{
{
name: "empty",
response: DeviceAuthResponse{},
want: `{"device_code":"","user_code":"","verification_uri":""}`,
},
{
name: "soon",
response: DeviceAuthResponse{
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
},
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
gotBytes, err := json.Marshal(tc.response)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
t.Skip("test ran too slowly to compare `expires_in`")
}
got := string(gotBytes)
if got != tc.want {
t.Errorf("want=%s, got=%s", tc.want, got)
}
})
}
}
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
tests := []struct {
name string
data string
want DeviceAuthResponse
}{
{
name: "empty",
data: `{}`,
want: DeviceAuthResponse{},
},
{
name: "soon",
data: `{"expires_in":100}`,
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
got := DeviceAuthResponse{}
err := json.Unmarshal([]byte(tc.data), &got)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
t.Errorf("want=%#v, got=%#v", tc.want, got)
}
})
}
}
func ExampleConfig_DeviceAuth() {
var config Config
ctx := context.Background()
response, err := config.DeviceAuth(ctx)
if err != nil {
panic(err)
}
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
token, err := config.DeviceAccessToken(ctx, response)
if err != nil {
panic(err)
}
fmt.Println(token)
}

View File

@@ -57,6 +57,7 @@ var Fitbit = oauth2.Endpoint{
var GitHub = oauth2.Endpoint{ var GitHub = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize", AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token", TokenURL: "https://github.com/login/oauth/access_token",
DeviceAuthURL: "https://github.com/login/device/code",
} }
// GitLab is the endpoint for GitLab. // GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
var Google = oauth2.Endpoint{ var Google = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth", AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token", TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
} }
// Heroku is the endpoint for Heroku. // Heroku is the endpoint for Heroku.

View File

@@ -26,9 +26,13 @@ func ExampleConfig() {
}, },
} }
// use PKCE to protect against CSRF attacks
// https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6
verifier := oauth2.GenerateVerifier()
// Redirect user to consent page to ask for permission // Redirect user to consent page to ask for permission
// for the scopes specified above. // for the scopes specified above.
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline) url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
fmt.Printf("Visit the URL for the auth dialog: %v", url) fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Use the authorization code that is pushed to the redirect // Use the authorization code that is pushed to the redirect
@@ -39,7 +43,7 @@ func ExampleConfig() {
if _, err := fmt.Scan(&code); err != nil { if _, err := fmt.Scan(&code); err != nil {
log.Fatal(err) log.Fatal(err)
} }
tok, err := conf.Exchange(ctx, code) tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@@ -6,11 +6,8 @@
package github // import "golang.org/x/oauth2/github" package github // import "golang.org/x/oauth2/github"
import ( import (
"golang.org/x/oauth2" "golang.org/x/oauth2/endpoints"
) )
// Endpoint is Github's OAuth 2.0 endpoint. // Endpoint is Github's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{ var Endpoint = endpoints.GitHub
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
}

13
go.mod
View File

@@ -1,15 +1,10 @@
module golang.org/x/oauth2 module golang.org/x/oauth2
go 1.17 go 1.18
require ( require (
cloud.google.com/go/compute/metadata v0.2.0 cloud.google.com/go/compute/metadata v0.2.3
github.com/google/go-cmp v0.5.8 github.com/google/go-cmp v0.5.9
google.golang.org/appengine v1.6.7
) )
require ( require cloud.google.com/go/compute v1.20.1 // indirect
github.com/golang/protobuf v1.5.2 // indirect
golang.org/x/net v0.8.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
)

59
go.sum
View File

@@ -1,53 +1,6 @@
cloud.google.com/go/compute/metadata v0.2.0 h1:nBbNSZyDpkNlo3DepaaLKVuO7ClyifSAmNloSCZrHnQ= cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZNbg=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

View File

@@ -1,38 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"context"
"time"
"golang.org/x/oauth2"
)
// Set at init time by appengine_gen1.go. If nil, we're not on App Engine standard first generation (<= Go 1.9) or App Engine flexible.
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error)
// Set at init time by appengine_gen1.go. If nil, we're not on App Engine standard first generation (<= Go 1.9) or App Engine flexible.
var appengineAppIDFunc func(c context.Context) string
// AppEngineTokenSource returns a token source that fetches tokens from either
// the current application's service account or from the metadata server,
// depending on the App Engine environment. See below for environment-specific
// details. If you are implementing a 3-legged OAuth 2.0 flow on App Engine that
// involves user accounts, see oauth2.Config instead.
//
// First generation App Engine runtimes (<= Go 1.9):
// AppEngineTokenSource returns a token source that fetches tokens issued to the
// current App Engine application's service account. The provided context must have
// come from appengine.NewContext.
//
// Second generation App Engine runtimes (>= Go 1.11) and App Engine flexible:
// AppEngineTokenSource is DEPRECATED on second generation runtimes and on the
// flexible environment. It delegates to ComputeTokenSource, and the provided
// context and scopes are not used. Please use DefaultTokenSource (or ComputeTokenSource,
// which DefaultTokenSource will use in this case) instead.
func AppEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
return appEngineTokenSource(ctx, scope...)
}

View File

@@ -1,78 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build appengine
// +build appengine
// This file applies to App Engine first generation runtimes (<= Go 1.9).
package google
import (
"context"
"sort"
"strings"
"sync"
"golang.org/x/oauth2"
"google.golang.org/appengine"
)
func init() {
appengineTokenFunc = appengine.AccessToken
appengineAppIDFunc = appengine.AppID
}
// See comment on AppEngineTokenSource in appengine.go.
func appEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
scopes := append([]string{}, scope...)
sort.Strings(scopes)
return &gaeTokenSource{
ctx: ctx,
scopes: scopes,
key: strings.Join(scopes, " "),
}
}
// aeTokens helps the fetched tokens to be reused until their expiration.
var (
aeTokensMu sync.Mutex
aeTokens = make(map[string]*tokenLock) // key is space-separated scopes
)
type tokenLock struct {
mu sync.Mutex // guards t; held while fetching or updating t
t *oauth2.Token
}
type gaeTokenSource struct {
ctx context.Context
scopes []string
key string // to aeTokens map; space-separated scopes
}
func (ts *gaeTokenSource) Token() (*oauth2.Token, error) {
aeTokensMu.Lock()
tok, ok := aeTokens[ts.key]
if !ok {
tok = &tokenLock{}
aeTokens[ts.key] = tok
}
aeTokensMu.Unlock()
tok.mu.Lock()
defer tok.mu.Unlock()
if tok.t.Valid() {
return tok.t, nil
}
access, exp, err := appengineTokenFunc(ts.ctx, ts.scopes...)
if err != nil {
return nil, err
}
tok.t = &oauth2.Token{
AccessToken: access,
Expiry: exp,
}
return tok.t, nil
}

View File

@@ -1,28 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !appengine
// +build !appengine
// This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible.
package google
import (
"context"
"log"
"sync"
"golang.org/x/oauth2"
)
var logOnce sync.Once // only spam about deprecation once
// See comment on AppEngineTokenSource in appengine.go.
func appEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
logOnce.Do(func() {
log.Print("google: AppEngineTokenSource is deprecated on App Engine standard second generation runtimes (>= Go 1.11) and App Engine flexible. Please use DefaultTokenSource or ComputeTokenSource.")
})
return ComputeTokenSource("")
}

View File

@@ -8,17 +8,23 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sync"
"time"
"cloud.google.com/go/compute/metadata" "cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/authhandler" "golang.org/x/oauth2/authhandler"
) )
const (
adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
defaultUniverseDomain = "googleapis.com"
)
// Credentials holds Google credentials, including "Application Default Credentials". // Credentials holds Google credentials, including "Application Default Credentials".
// For more details, see: // For more details, see:
// https://developers.google.com/accounts/docs/application-default-credentials // https://developers.google.com/accounts/docs/application-default-credentials
@@ -35,6 +41,64 @@ type Credentials struct {
// environment and not with a credentials file, e.g. when code is // environment and not with a credentials file, e.g. when code is
// running on Google Cloud Platform. // running on Google Cloud Platform.
JSON []byte JSON []byte
// UniverseDomainProvider returns the default service domain for a given
// Cloud universe. Optional.
//
// On GCE, UniverseDomainProvider should return the universe domain value
// from Google Compute Engine (GCE)'s metadata server. See also [The attached service
// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
// If the GCE metadata server returns a 404 error, the default universe
// domain value should be returned. If the GCE metadata server returns an
// error other than 404, the error should be returned.
UniverseDomainProvider func() (string, error)
udMu sync.Mutex // guards universeDomain
// universeDomain is the default service domain for a given Cloud universe.
universeDomain string
}
// UniverseDomain returns the default service domain for a given Cloud universe.
//
// The default value is "googleapis.com".
//
// Deprecated: Use instead (*Credentials).GetUniverseDomain(), which supports
// obtaining the universe domain when authenticating via the GCE metadata server.
// Unlike GetUniverseDomain, this method, UniverseDomain, will always return the
// default value when authenticating via the GCE metadata server.
// See also [The attached service account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
func (c *Credentials) UniverseDomain() string {
if c.universeDomain == "" {
return defaultUniverseDomain
}
return c.universeDomain
}
// GetUniverseDomain returns the default service domain for a given Cloud
// universe. If present, UniverseDomainProvider will be invoked and its return
// value will be cached.
//
// The default value is "googleapis.com".
func (c *Credentials) GetUniverseDomain() (string, error) {
c.udMu.Lock()
defer c.udMu.Unlock()
if c.universeDomain == "" && c.UniverseDomainProvider != nil {
// On Google Compute Engine, an App Engine standard second generation
// runtime, or App Engine flexible, use an externally provided function
// to request the universe domain from the metadata server.
ud, err := c.UniverseDomainProvider()
if err != nil {
return "", err
}
c.universeDomain = ud
}
// If no UniverseDomainProvider (meaning not on Google Compute Engine), or
// in case of any (non-error) empty return value from
// UniverseDomainProvider, set the default universe domain.
if c.universeDomain == "" {
c.universeDomain = defaultUniverseDomain
}
return c.universeDomain, nil
} }
// DefaultCredentials is the old name of Credentials. // DefaultCredentials is the old name of Credentials.
@@ -66,6 +130,20 @@ type CredentialsParams struct {
// The OAuth2 TokenURL default override. This value overrides the default TokenURL, // The OAuth2 TokenURL default override. This value overrides the default TokenURL,
// unless explicitly specified by the credentials config file. Optional. // unless explicitly specified by the credentials config file. Optional.
TokenURL string TokenURL string
// EarlyTokenRefresh is the amount of time before a token expires that a new
// token will be preemptively fetched. If unset the default value is 10
// seconds.
//
// Note: This option is currently only respected when using credentials
// fetched from the GCE metadata server.
EarlyTokenRefresh time.Duration
// UniverseDomain is the default service domain for a given Cloud universe.
// Only supported in authentication flows that support universe domains.
// This value takes precedence over a universe domain explicitly specified
// in a credentials config file or by the GCE metadata server. Optional.
UniverseDomain string
} }
func (params CredentialsParams) deepCopy() CredentialsParams { func (params CredentialsParams) deepCopy() CredentialsParams {
@@ -110,9 +188,7 @@ func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSourc
// 2. A JSON file in a location known to the gcloud command-line tool. // 2. A JSON file in a location known to the gcloud command-line tool.
// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json. // On Windows, this is %APPDATA%/gcloud/application_default_credentials.json.
// On other systems, $HOME/.config/gcloud/application_default_credentials.json. // On other systems, $HOME/.config/gcloud/application_default_credentials.json.
// 3. On Google App Engine standard first generation runtimes (<= Go 1.9) it uses // 3. On Google Compute Engine, Google App Engine standard second generation runtimes
// the appengine.AccessToken function.
// 4. On Google Compute Engine, Google App Engine standard second generation runtimes
// (>= Go 1.11), and Google App Engine flexible environment, it fetches // (>= Go 1.11), and Google App Engine flexible environment, it fetches
// credentials from the metadata server. // credentials from the metadata server.
func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsParams) (*Credentials, error) { func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsParams) (*Credentials, error) {
@@ -131,35 +207,36 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar
// Second, try a well-known file. // Second, try a well-known file.
filename := wellKnownFile() filename := wellKnownFile()
if creds, err := readCredentialsFile(ctx, filename, params); err == nil { if b, err := os.ReadFile(filename); err == nil {
return creds, nil return CredentialsFromJSONWithParams(ctx, b, params)
} else if !os.IsNotExist(err) {
return nil, fmt.Errorf("google: error getting credentials using well-known file (%v): %v", filename, err)
} }
// Third, if we're on a Google App Engine standard first generation runtime (<= Go 1.9) // Third, if we're on Google Compute Engine, an App Engine standard second generation runtime,
// use those credentials. App Engine standard second generation runtimes (>= Go 1.11)
// and App Engine flexible use ComputeTokenSource and the metadata server.
if appengineTokenFunc != nil {
return &Credentials{
ProjectID: appengineAppIDFunc(ctx),
TokenSource: AppEngineTokenSource(ctx, params.Scopes...),
}, nil
}
// Fourth, if we're on Google Compute Engine, an App Engine standard second generation runtime,
// or App Engine flexible, use the metadata server. // or App Engine flexible, use the metadata server.
if metadata.OnGCE() { if metadata.OnGCE() {
id, _ := metadata.ProjectID() id, _ := metadata.ProjectID()
universeDomainProvider := func() (string, error) {
universeDomain, err := metadata.Get("universe/universe_domain")
if err != nil {
if _, ok := err.(metadata.NotDefinedError); ok {
// http.StatusNotFound (404)
return defaultUniverseDomain, nil
} else {
return "", err
}
}
return universeDomain, nil
}
return &Credentials{ return &Credentials{
ProjectID: id, ProjectID: id,
TokenSource: ComputeTokenSource("", params.Scopes...), TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
UniverseDomainProvider: universeDomainProvider,
universeDomain: params.UniverseDomain,
}, nil }, nil
} }
// None are found; return helpful error. // None are found; return helpful error.
const url = "https://developers.google.com/accounts/docs/application-default-credentials" return nil, fmt.Errorf("google: could not find default credentials. See %v for more information", adcSetupURL)
return nil, fmt.Errorf("google: could not find default credentials. See %v for more information.", url)
} }
// FindDefaultCredentials invokes FindDefaultCredentialsWithParams with the specified scopes. // FindDefaultCredentials invokes FindDefaultCredentialsWithParams with the specified scopes.
@@ -193,6 +270,16 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
if err := json.Unmarshal(jsonData, &f); err != nil { if err := json.Unmarshal(jsonData, &f); err != nil {
return nil, err return nil, err
} }
universeDomain := f.UniverseDomain
if params.UniverseDomain != "" {
universeDomain = params.UniverseDomain
}
// Authorized user credentials are only supported in the googleapis.com universe.
if f.Type == userCredentialsKey {
universeDomain = defaultUniverseDomain
}
ts, err := f.tokenSource(ctx, params) ts, err := f.tokenSource(ctx, params)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -202,6 +289,7 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
ProjectID: f.ProjectID, ProjectID: f.ProjectID,
TokenSource: ts, TokenSource: ts,
JSON: jsonData, JSON: jsonData,
universeDomain: universeDomain,
}, nil }, nil
} }
@@ -221,7 +309,7 @@ func wellKnownFile() string {
} }
func readCredentialsFile(ctx context.Context, filename string, params CredentialsParams) (*Credentials, error) { func readCredentialsFile(ctx context.Context, filename string, params CredentialsParams) (*Credentials, error) {
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }

312
google/default_test.go Normal file
View File

@@ -0,0 +1,312 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"cloud.google.com/go/compute/metadata"
)
var saJSONJWT = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var saJSONJWTUniverseDomain = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"universe_domain": "example.com",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var userJSON = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2"
}`)
var userJSONUniverseDomain = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2",
"universe_domain": "example.com"
}`)
var universeDomain = "example.com"
var universeDomain2 = "apis-tpclp.goog"
func TestCredentialsFromJSONWithParams_SA(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWT, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWT, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
}
func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWTUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if creds.UniverseDomain() != universeDomain {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if got != universeDomain {
t.Fatalf("got %q, want %q", got, universeDomain)
}
}
func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWTUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if got != universeDomain2 {
t.Fatalf("got %q, want %q", got, universeDomain2)
}
}
func TestCredentialsFromJSONWithParams_User(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSON, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSON, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSONUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSONUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestComputeUniverseDomain(t *testing.T) {
universeDomainPath := "/computeMetadata/v1/universe/universe_domain"
universeDomainResponseBody := "example.com"
var requests int
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests++
if r.URL.Path != universeDomainPath {
t.Errorf("bad path, got %s, want %s", r.URL.Path, universeDomainPath)
}
if requests > 1 {
t.Errorf("too many requests, got %d, want 1", requests)
}
w.Write([]byte(universeDomainResponseBody))
}))
defer s.Close()
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
universeDomainProvider := func() (string, error) {
universeDomain, err := metadata.Get("universe/universe_domain")
if err != nil {
return "", err
}
return universeDomain, nil
}
// Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block
creds := &Credentials{
ProjectID: "fake_project",
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
UniverseDomainProvider: universeDomainProvider,
universeDomain: params.UniverseDomain, // empty
}
c := make(chan bool)
go func() {
got, err := creds.GetUniverseDomain() // First conflicting access.
if err != nil {
t.Error(err)
}
if want := universeDomainResponseBody; got != want {
t.Errorf("got %q, want %q", got, want)
}
c <- true
}()
got, err := creds.GetUniverseDomain() // Second conflicting (and potentially uncached) access.
<-c
if err != nil {
t.Error(err)
}
if want := universeDomainResponseBody; got != want {
t.Errorf("got %q, want %q", got, want)
}
}

View File

@@ -22,45 +22,9 @@
// the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or // the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or
// create an http.Client. // create an http.Client.
// //
// # Workload Identity Federation // # Workload and Workforce Identity Federation
// //
// Using workload identity federation, your application can access Google Cloud // For information on how to use Workload and Workforce Identity Federation, see [golang.org/x/oauth2/google/externalaccount].
// resources from Amazon Web Services (AWS), Microsoft Azure or any identity
// provider that supports OpenID Connect (OIDC).
// Traditionally, applications running outside Google Cloud have used service
// account keys to access Google Cloud resources. Using identity federation,
// you can allow your workload to impersonate a service account.
// This lets you access Google Cloud resources directly, eliminating the
// maintenance and security burden associated with service account keys.
//
// Follow the detailed instructions on how to configure Workload Identity Federation
// in various platforms:
//
// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws
// Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure
// OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc
//
// For OIDC and SAML providers, the library can retrieve tokens in three ways:
// from a local file location (file-sourced credentials), from a server
// (URL-sourced credentials), or from a local executable (executable-sourced
// credentials).
// For file-sourced credentials, a background process needs to be continuously
// refreshing the file location with a new OIDC token prior to expiration.
// For tokens with one hour lifetimes, the token needs to be updated in the file
// every hour. The token can be stored directly as plain text or in JSON format.
// For URL-sourced credentials, a local server needs to host a GET endpoint to
// return the OIDC token. The response can be in plain text or JSON.
// Additional required request headers can also be specified.
// For executable-sourced credentials, an application needs to be available to
// output the OIDC token and other information in a JSON format.
// For more information on how these work (and how to implement
// executable-sourced credentials), please check out:
// https://cloud.google.com/iam/docs/using-workload-identity-federation#oidc
//
// Note that this library does not perform any validation on the token_url, token_info_url,
// or service_account_impersonation_url fields of the credential configuration.
// It is not recommended to use a credential configuration that you did not generate with
// the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain.
// //
// # Credentials // # Credentials
// //
@@ -86,5 +50,4 @@
// same as the one obtained from the oauth2.Config returned from ConfigFromJSON or // same as the one obtained from the oauth2.Config returned from ConfigFromJSON or
// JWTConfigFromJSON, but the Credentials may contain additional information // JWTConfigFromJSON, but the Credentials may contain additional information
// that is useful is some circumstances. // that is useful is some circumstances.
//
package google // import "golang.org/x/oauth2/google" package google // import "golang.org/x/oauth2/google"

View File

@@ -42,13 +42,16 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
var ( const (
identityBindingEndpoint = "https://sts.googleapis.com/v1/token" universeDomainPlaceholder = "UNIVERSE_DOMAIN"
identityBindingEndpointTemplate = "https://sts.UNIVERSE_DOMAIN/v1/token"
defaultUniverseDomain = "googleapis.com"
) )
type accessBoundary struct { type accessBoundary struct {
@@ -105,6 +108,18 @@ type DownscopingConfig struct {
// access (or set of accesses) that the new token has to a given resource. // access (or set of accesses) that the new token has to a given resource.
// There can be a maximum of 10 AccessBoundaryRules. // There can be a maximum of 10 AccessBoundaryRules.
Rules []AccessBoundaryRule Rules []AccessBoundaryRule
// UniverseDomain is the default service domain for a given Cloud universe.
// The default value is "googleapis.com". Optional.
UniverseDomain string
}
// identityBindingEndpoint returns the identity binding endpoint with the
// configured universe domain.
func (dc *DownscopingConfig) identityBindingEndpoint() string {
if dc.UniverseDomain == "" {
return strings.Replace(identityBindingEndpointTemplate, universeDomainPlaceholder, defaultUniverseDomain, 1)
}
return strings.Replace(identityBindingEndpointTemplate, universeDomainPlaceholder, dc.UniverseDomain, 1)
} }
// A downscopingTokenSource is used to retrieve a downscoped token with restricted // A downscopingTokenSource is used to retrieve a downscoped token with restricted
@@ -114,6 +129,9 @@ type downscopingTokenSource struct {
ctx context.Context ctx context.Context
// config holds the information necessary to generate a downscoped Token. // config holds the information necessary to generate a downscoped Token.
config DownscopingConfig config DownscopingConfig
// identityBindingEndpoint is the identity binding endpoint with the
// configured universe domain.
identityBindingEndpoint string
} }
// NewTokenSource returns a configured downscopingTokenSource. // NewTokenSource returns a configured downscopingTokenSource.
@@ -135,7 +153,11 @@ func NewTokenSource(ctx context.Context, conf DownscopingConfig) (oauth2.TokenSo
return nil, fmt.Errorf("downscope: all rules must provide at least one permission: %+v", val) return nil, fmt.Errorf("downscope: all rules must provide at least one permission: %+v", val)
} }
} }
return downscopingTokenSource{ctx: ctx, config: conf}, nil return downscopingTokenSource{
ctx: ctx,
config: conf,
identityBindingEndpoint: conf.identityBindingEndpoint(),
}, nil
} }
// Token() uses a downscopingTokenSource to generate an oauth2 Token. // Token() uses a downscopingTokenSource to generate an oauth2 Token.
@@ -171,7 +193,7 @@ func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
form.Add("options", string(b)) form.Add("options", string(b))
myClient := oauth2.NewClient(dts.ctx, nil) myClient := oauth2.NewClient(dts.ctx, nil)
resp, err := myClient.PostForm(identityBindingEndpoint, form) resp, err := myClient.PostForm(dts.identityBindingEndpoint, form)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to generate POST Request %v", err) return nil, fmt.Errorf("unable to generate POST Request %v", err)
} }

View File

@@ -38,18 +38,43 @@ func Test_DownscopedTokenSource(t *testing.T) {
w.Write([]byte(standardRespBody)) w.Write([]byte(standardRespBody))
})) }))
new := []AccessBoundaryRule{ myTok := oauth2.Token{AccessToken: "Mellon"}
tmpSrc := oauth2.StaticTokenSource(&myTok)
rules := []AccessBoundaryRule{
{ {
AvailableResource: "test1", AvailableResource: "test1",
AvailablePermissions: []string{"Perm1", "Perm2"}, AvailablePermissions: []string{"Perm1", "Perm2"},
}, },
} }
myTok := oauth2.Token{AccessToken: "Mellon"} dts := downscopingTokenSource{
tmpSrc := oauth2.StaticTokenSource(&myTok) ctx: context.Background(),
dts := downscopingTokenSource{context.Background(), DownscopingConfig{tmpSrc, new}} config: DownscopingConfig{
identityBindingEndpoint = ts.URL RootSource: tmpSrc,
Rules: rules,
},
identityBindingEndpoint: ts.URL,
}
_, err := dts.Token() _, err := dts.Token()
if err != nil { if err != nil {
t.Fatalf("NewDownscopedTokenSource failed with error: %v", err) t.Fatalf("NewDownscopedTokenSource failed with error: %v", err)
} }
} }
func Test_DownscopingConfig(t *testing.T) {
tests := []struct {
universeDomain string
want string
}{
{"", "https://sts.googleapis.com/v1/token"},
{"googleapis.com", "https://sts.googleapis.com/v1/token"},
{"example.com", "https://sts.example.com/v1/token"},
}
for _, tt := range tests {
c := DownscopingConfig{
UniverseDomain: tt.universeDomain,
}
if got := c.identityBindingEndpoint(); got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
}
}

View File

@@ -26,22 +26,28 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type awsSecurityCredentials struct { // AwsSecurityCredentials models AWS security credentials.
type AwsSecurityCredentials struct {
// AccessKeyId is the AWS Access Key ID - Required.
AccessKeyID string `json:"AccessKeyID"` AccessKeyID string `json:"AccessKeyID"`
// SecretAccessKey is the AWS Secret Access Key - Required.
SecretAccessKey string `json:"SecretAccessKey"` SecretAccessKey string `json:"SecretAccessKey"`
SecurityToken string `json:"Token"` // SessionToken is the AWS Session token. This should be provided for temporary AWS security credentials - Optional.
SessionToken string `json:"Token"`
} }
// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature. // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct { type awsRequestSigner struct {
RegionName string RegionName string
AwsSecurityCredentials awsSecurityCredentials AwsSecurityCredentials *AwsSecurityCredentials
} }
// getenv aliases os.Getenv for testing // getenv aliases os.Getenv for testing
var getenv = os.Getenv var getenv = os.Getenv
const ( const (
defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
// AWS Signature Version 4 signing algorithm identifier. // AWS Signature Version 4 signing algorithm identifier.
awsAlgorithm = "AWS4-HMAC-SHA256" awsAlgorithm = "AWS4-HMAC-SHA256"
@@ -197,8 +203,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
signedRequest.Header.Add("host", requestHost(req)) signedRequest.Header.Add("host", requestHost(req))
if rs.AwsSecurityCredentials.SecurityToken != "" { if rs.AwsSecurityCredentials.SessionToken != "" {
signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken) signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
} }
if signedRequest.Header.Get("date") == "" { if signedRequest.Header.Get("date") == "" {
@@ -251,16 +257,18 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
} }
type awsCredentialSource struct { type awsCredentialSource struct {
EnvironmentID string environmentID string
RegionURL string regionURL string
RegionalCredVerificationURL string regionalCredVerificationURL string
CredVerificationURL string credVerificationURL string
IMDSv2SessionTokenURL string imdsv2SessionTokenURL string
TargetResource string targetResource string
requestSigner *awsRequestSigner requestSigner *awsRequestSigner
region string region string
ctx context.Context ctx context.Context
client *http.Client client *http.Client
awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
supplierOptions SupplierOptions
} }
type awsRequestHeader struct { type awsRequestHeader struct {
@@ -274,49 +282,6 @@ type awsRequest struct {
Headers []awsRequestHeader `json:"headers"` Headers []awsRequestHeader `json:"headers"`
} }
func (cs awsCredentialSource) validateMetadataServers() error {
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
return err
}
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
return err
}
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
}
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
if metadataUrl == "" {
// Zero value means use default, which is valid.
return true
}
u, err := url.Parse(metadataUrl)
if err != nil {
// Unparseable URL means invalid
return false
}
for _, validHostname := range validHostnames {
if u.Hostname() == validHostname {
// If it's one of the valid hostnames, everything is good
return true
}
}
// hostname not found in our allowlist, so not valid
return false
}
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
if !cs.isValidMetadataServer(metadataUrl) {
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
}
return nil
}
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
if cs.client == nil { if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil) cs.client = oauth2.NewClient(cs.ctx, nil)
@@ -335,14 +300,25 @@ func canRetrieveSecurityCredentialFromEnvironment() bool {
return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
} }
func shouldUseMetadataServer() bool { func (cs awsCredentialSource) shouldUseMetadataServer() bool {
return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
}
func (cs awsCredentialSource) credentialSourceType() string {
if cs.awsSecurityCredentialsSupplier != nil {
return "programmatic"
}
return "aws"
} }
func (cs awsCredentialSource) subjectToken() (string, error) { func (cs awsCredentialSource) subjectToken() (string, error) {
// Set Defaults
if cs.regionalCredVerificationURL == "" {
cs.regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
}
if cs.requestSigner == nil { if cs.requestSigner == nil {
headers := make(map[string]string) headers := make(map[string]string)
if shouldUseMetadataServer() { if cs.shouldUseMetadataServer() {
awsSessionToken, err := cs.getAWSSessionToken() awsSessionToken, err := cs.getAWSSessionToken()
if err != nil { if err != nil {
return "", err return "", err
@@ -357,8 +333,8 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
cs.region, err = cs.getRegion(headers)
if cs.region, err = cs.getRegion(headers); err != nil { if err != nil {
return "", err return "", err
} }
@@ -370,7 +346,7 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
// Generate the signed request to AWS STS GetCallerIdentity API. // Generate the signed request to AWS STS GetCallerIdentity API.
// Use the required regional endpoint. Otherwise, the request will fail. // Use the required regional endpoint. Otherwise, the request will fail.
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil) req, err := http.NewRequest("POST", strings.Replace(cs.regionalCredVerificationURL, "{region}", cs.region, 1), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -378,8 +354,8 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
// provider, with or without the HTTPS prefix. // provider, with or without the HTTPS prefix.
// Including this header as part of the signature is recommended to // Including this header as part of the signature is recommended to
// ensure data integrity. // ensure data integrity.
if cs.TargetResource != "" { if cs.targetResource != "" {
req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource) req.Header.Add("x-goog-cloud-target-resource", cs.targetResource)
} }
cs.requestSigner.SignRequest(req) cs.requestSigner.SignRequest(req)
@@ -426,11 +402,11 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
} }
func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
if cs.IMDSv2SessionTokenURL == "" { if cs.imdsv2SessionTokenURL == "" {
return "", nil return "", nil
} }
req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil) req, err := http.NewRequest("PUT", cs.imdsv2SessionTokenURL, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -449,25 +425,29 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody)) return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS session token - %s", string(respBody))
} }
return string(respBody), nil return string(respBody), nil
} }
func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
if cs.awsSecurityCredentialsSupplier != nil {
return cs.awsSecurityCredentialsSupplier.AwsRegion(cs.ctx, cs.supplierOptions)
}
if canRetrieveRegionFromEnvironment() { if canRetrieveRegionFromEnvironment() {
if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
cs.region = envAwsRegion
return envAwsRegion, nil return envAwsRegion, nil
} }
return getenv("AWS_DEFAULT_REGION"), nil return getenv("AWS_DEFAULT_REGION"), nil
} }
if cs.RegionURL == "" { if cs.regionURL == "" {
return "", errors.New("oauth2/google: unable to determine AWS region") return "", errors.New("oauth2/google/externalaccount: unable to determine AWS region")
} }
req, err := http.NewRequest("GET", cs.RegionURL, nil) req, err := http.NewRequest("GET", cs.regionURL, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -488,7 +468,7 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody)) return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS region - %s", string(respBody))
} }
// This endpoint will return the region in format: us-east-2b. // This endpoint will return the region in format: us-east-2b.
@@ -500,12 +480,15 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
return string(respBody[:respBodyEnd]), nil return string(respBody[:respBodyEnd]), nil
} }
func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result *AwsSecurityCredentials, err error) {
if cs.awsSecurityCredentialsSupplier != nil {
return cs.awsSecurityCredentialsSupplier.AwsSecurityCredentials(cs.ctx, cs.supplierOptions)
}
if canRetrieveSecurityCredentialFromEnvironment() { if canRetrieveSecurityCredentialFromEnvironment() {
return awsSecurityCredentials{ return &AwsSecurityCredentials{
AccessKeyID: getenv(awsAccessKeyId), AccessKeyID: getenv(awsAccessKeyId),
SecretAccessKey: getenv(awsSecretAccessKey), SecretAccessKey: getenv(awsSecretAccessKey),
SecurityToken: getenv(awsSessionToken), SessionToken: getenv(awsSessionToken),
}, nil }, nil
} }
@@ -520,24 +503,23 @@ func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string)
} }
if credentials.AccessKeyID == "" { if credentials.AccessKeyID == "" {
return result, errors.New("oauth2/google: missing AccessKeyId credential") return result, errors.New("oauth2/google/externalaccount: missing AccessKeyId credential")
} }
if credentials.SecretAccessKey == "" { if credentials.SecretAccessKey == "" {
return result, errors.New("oauth2/google: missing SecretAccessKey credential") return result, errors.New("oauth2/google/externalaccount: missing SecretAccessKey credential")
} }
return credentials, nil return &credentials, nil
} }
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) { func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) {
var result awsSecurityCredentials var result AwsSecurityCredentials
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.credVerificationURL, roleName), nil)
if err != nil { if err != nil {
return result, err return result, err
} }
req.Header.Add("Content-Type", "application/json")
for name, value := range headers { for name, value := range headers {
req.Header.Add(name, value) req.Header.Add(name, value)
@@ -555,7 +537,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody)) return result, fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS security credentials - %s", string(respBody))
} }
err = json.Unmarshal(respBody, &result) err = json.Unmarshal(respBody, &result)
@@ -563,11 +545,11 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h
} }
func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) { func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
if cs.CredVerificationURL == "" { if cs.credVerificationURL == "" {
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint") return "", errors.New("oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint")
} }
req, err := http.NewRequest("GET", cs.CredVerificationURL, nil) req, err := http.NewRequest("GET", cs.credVerificationURL, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -588,7 +570,7 @@ func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (s
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody)) return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS role name - %s", string(respBody))
} }
return string(respBody), nil return string(respBody), nil

View File

@@ -7,6 +7,7 @@ package externalaccount
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -36,7 +37,7 @@ func setEnvironment(env map[string]string) func(string) string {
var defaultRequestSigner = &awsRequestSigner{ var defaultRequestSigner = &awsRequestSigner{
RegionName: "us-east-1", RegionName: "us-east-1",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: &AwsSecurityCredentials{
AccessKeyID: "AKIDEXAMPLE", AccessKeyID: "AKIDEXAMPLE",
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
}, },
@@ -50,10 +51,10 @@ const (
var requestSignerWithToken = &awsRequestSigner{ var requestSignerWithToken = &awsRequestSigner{
RegionName: "us-east-2", RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: &AwsSecurityCredentials{
AccessKeyID: accessKeyID, AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
SecurityToken: securityToken, SessionToken: securityToken,
}, },
} }
@@ -388,7 +389,7 @@ func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test
func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
var requestSigner = &awsRequestSigner{ var requestSigner = &awsRequestSigner{
RegionName: "us-east-2", RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: &AwsSecurityCredentials{
AccessKeyID: accessKeyID, AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
}, },
@@ -526,8 +527,8 @@ func notFound(w http.ResponseWriter, r *http.Request) {
func noHeaderValidation(r *http.Request) {} func noHeaderValidation(r *http.Request) {}
func (server *testAwsServer) getCredentialSource(url string) CredentialSource { func (server *testAwsServer) getCredentialSource(url string) *CredentialSource {
return CredentialSource{ return &CredentialSource{
EnvironmentID: "aws1", EnvironmentID: "aws1",
URL: url + server.url, URL: url + server.url,
RegionURL: url + server.regionURL, RegionURL: url + server.regionURL,
@@ -541,10 +542,10 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
req.Header.Add("x-goog-cloud-target-resource", testFileConfig.Audience) req.Header.Add("x-goog-cloud-target-resource", testFileConfig.Audience)
signer := &awsRequestSigner{ signer := &awsRequestSigner{
RegionName: region, RegionName: region,
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: &AwsSecurityCredentials{
AccessKeyID: accessKeyID, AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
SecurityToken: securityToken, SessionToken: securityToken,
}, },
} }
signer.SignRequest(req) signer.SignRequest(req)
@@ -585,25 +586,17 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
func TestAWSCredential_BasicRequest(t *testing.T) { func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -631,25 +624,18 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
func TestAWSCredential_IMDSv2(t *testing.T) { func TestAWSCredential_IMDSv2(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -677,10 +663,6 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
delete(server.Credentials, "Token") delete(server.Credentials, "Token")
tfc := testFileConfig tfc := testFileConfig
@@ -688,15 +670,12 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -724,21 +703,15 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -746,7 +719,6 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -774,21 +746,15 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -796,7 +762,6 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -823,21 +788,15 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -846,7 +805,6 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
"AWS_DEFAULT_REGION": "us-east-1", "AWS_DEFAULT_REGION": "us-east-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -873,29 +831,22 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
func TestAWSCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.EnvironmentID = "aws3" tfc.CredentialSource.EnvironmentID = "aws3"
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
_, err = tfc.parse(context.Background()) _, err := tfc.parse(context.Background())
if err == nil { if err == nil {
t.Fatalf("parse() should have failed") t.Fatalf("parse() should have failed")
} }
if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: aws version '3' is not supported in the current build"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -903,23 +854,16 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.RegionURL = "" tfc.CredentialSource.RegionURL = ""
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -931,7 +875,7 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: unable to determine AWS region"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: unable to determine AWS region"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -939,23 +883,17 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRegion = notFound server.WriteRegion = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -967,7 +905,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: unable to retrieve AWS region - Not Found"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: unable to retrieve AWS region - Not Found"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -975,10 +913,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}")) w.Write([]byte("{}"))
} }
@@ -987,13 +922,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1005,7 +937,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: missing AccessKeyId credential"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -1013,10 +945,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
} }
@@ -1025,13 +954,10 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1043,7 +969,7 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: missing SecretAccessKey credential"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -1051,23 +977,16 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.URL = "" tfc.CredentialSource.URL = ""
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1079,7 +998,7 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: unable to determine the AWS metadata server security credentials endpoint"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -1087,23 +1006,16 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRolename = notFound server.WriteRolename = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1115,7 +1027,7 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: unable to retrieve AWS role name - Not Found"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: unable to retrieve AWS role name - Not Found"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -1123,23 +1035,16 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = notFound server.WriteSecurityCredentials = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1151,7 +1056,7 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: unable to retrieve AWS security credentials - Not Found"; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google/externalaccount: unable to retrieve AWS security credentials - Not Found"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@@ -1159,10 +1064,6 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) { func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Metadata server should not have been called.") t.Error("Metadata server should not have been called.")
@@ -1174,11 +1075,9 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -1186,7 +1085,6 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1214,28 +1112,21 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) { func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": accessKeyID, "AWS_ACCESS_KEY_ID": accessKeyID,
"AWS_SECRET_ACCESS_KEY": secretAccessKey, "AWS_SECRET_ACCESS_KEY": secretAccessKey,
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1263,28 +1154,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) { func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1312,28 +1196,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) { func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t) server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
oldNow := now oldNow := now
oldValidHostnames := validHostnames
defer func() { defer func() {
getenv = oldGetenv getenv = oldGetenv
now = oldNow now = oldNow
validHostnames = oldValidHostnames
}() }()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1358,87 +1235,254 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testin
} }
} }
func TestAWSCredential_Validations(t *testing.T) { func TestAWSCredential_ProgrammaticAuth(t *testing.T) {
var metadataServerValidityTests = []struct {
name string
credSource CredentialSource
errText string
}{
{
name: "No Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "",
URL: "",
IMDSv2SessionTokenURL: "",
},
}, {
name: "IPv4 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
}, {
name: "IPv6 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
},
}, {
name: "Faulty RegionURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url",
}, {
name: "Faulty CredVerificationURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://abc.com/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url",
}, {
name: "Faulty IMDSv2SessionTokenURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://abc.com/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url",
},
}
for _, tt := range metadataServerValidityTests {
t.Run(tt.name, func(t *testing.T) {
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = tt.credSource securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: securityToken,
}
oldGetenv := getenv tfc.AwsSecurityCredentialsSupplier = testAwsSupplier{
defer func() { getenv = oldGetenv }() awsRegion: "us-east-2",
getenv = setEnvironment(map[string]string{}) err: nil,
credentials: &securityCredentials,
}
_, err := tfc.parse(context.Background()) oldNow := now
defer func() {
now = oldNow
}()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
if tt.errText == "" { t.Fatalf("parse() failed %v", err)
t.Errorf("Didn't expect an error, but got %v", err)
} else if tt.errText != err.Error() {
t.Errorf("Expected %v, but got %v", tt.errText, err)
} }
} else {
if tt.errText != "" { out, err := base.subjectToken()
t.Errorf("Expected error %v, but got none", tt.errText) if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
securityToken,
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
} }
} }
})
func TestAWSCredential_ProgrammaticAuthNoSessionToken(t *testing.T) {
tfc := testFileConfig
securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
}
tfc.AwsSecurityCredentialsSupplier = testAwsSupplier{
awsRegion: "us-east-2",
err: nil,
credentials: &securityCredentials,
}
oldNow := now
defer func() {
now = oldNow
}()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
} }
} }
func TestAWSCredential_ProgrammaticAuthError(t *testing.T) {
tfc := testFileConfig
testErr := errors.New("test error")
tfc.AwsSecurityCredentialsSupplier = testAwsSupplier{
awsRegion: "us-east-2",
err: testErr,
credentials: nil,
}
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil {
t.Fatalf("subjectToken() should have failed")
}
if err != testErr {
t.Errorf("error = %e, want %e", err, testErr)
}
}
func TestAWSCredential_ProgrammaticAuthRegionError(t *testing.T) {
tfc := testFileConfig
securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
}
testErr := errors.New("test")
tfc.AwsSecurityCredentialsSupplier = testAwsSupplier{
awsRegion: "",
regionErr: testErr,
credentials: &securityCredentials,
}
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil {
t.Fatalf("subjectToken() should have failed")
}
if err != testErr {
t.Errorf("error = %e, want %e", err, testErr)
}
}
func TestAWSCredential_ProgrammaticAuthOptions(t *testing.T) {
tfc := testFileConfig
securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
}
expectedOptions := SupplierOptions{Audience: tfc.Audience, SubjectTokenType: tfc.SubjectTokenType}
tfc.AwsSecurityCredentialsSupplier = testAwsSupplier{
awsRegion: "us-east-2",
credentials: &securityCredentials,
expectedOptions: &expectedOptions,
}
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err != nil {
t.Fatalf("subjectToken() failed %v", err)
}
}
func TestAWSCredential_ProgrammaticAuthContext(t *testing.T) {
tfc := testFileConfig
securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
}
ctx := context.Background()
tfc.AwsSecurityCredentialsSupplier = testAwsSupplier{
awsRegion: "us-east-2",
credentials: &securityCredentials,
expectedContext: ctx,
}
base, err := tfc.parse(ctx)
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err != nil {
t.Fatalf("subjectToken() failed %v", err)
}
}
func TestAwsCredential_CredentialSourceType(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
if got, want := base.credentialSourceType(), "aws"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}
type testAwsSupplier struct {
err error
regionErr error
awsRegion string
credentials *AwsSecurityCredentials
expectedOptions *SupplierOptions
expectedContext context.Context
}
func (supp testAwsSupplier) AwsRegion(ctx context.Context, options SupplierOptions) (string, error) {
if supp.regionErr != nil {
return "", supp.regionErr
}
if supp.expectedOptions != nil {
if supp.expectedOptions.Audience != options.Audience {
return "", errors.New("Audience does not match")
}
if supp.expectedOptions.SubjectTokenType != options.SubjectTokenType {
return "", errors.New("Audience does not match")
}
}
if supp.expectedContext != nil {
if supp.expectedContext != ctx {
return "", errors.New("Context does not match")
}
}
return supp.awsRegion, nil
}
func (supp testAwsSupplier) AwsSecurityCredentials(ctx context.Context, options SupplierOptions) (*AwsSecurityCredentials, error) {
if supp.err != nil {
return nil, supp.err
}
if supp.expectedOptions != nil {
if supp.expectedOptions.Audience != options.Audience {
return nil, errors.New("Audience does not match")
}
if supp.expectedOptions.SubjectTokenType != options.SubjectTokenType {
return nil, errors.New("Audience does not match")
}
}
if supp.expectedContext != nil {
if supp.expectedContext != ctx {
return nil, errors.New("Context does not match")
}
}
return supp.credentials, nil
}

View File

@@ -0,0 +1,485 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package externalaccount provides support for creating workload identity
federation and workforce identity federation token sources that can be
used to access Google Cloud resources from external identity providers.
# Workload Identity Federation
Using workload identity federation, your application can access Google Cloud
resources from Amazon Web Services (AWS), Microsoft Azure or any identity
provider that supports OpenID Connect (OIDC) or SAML 2.0.
Traditionally, applications running outside Google Cloud have used service
account keys to access Google Cloud resources. Using identity federation,
you can allow your workload to impersonate a service account.
This lets you access Google Cloud resources directly, eliminating the
maintenance and security burden associated with service account keys.
Follow the detailed instructions on how to configure Workload Identity Federation
in various platforms:
Amazon Web Services (AWS): https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#aws
Microsoft Azure: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#azure
OIDC identity provider: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#oidc
SAML 2.0 identity provider: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#saml
For OIDC and SAML providers, the library can retrieve tokens in fours ways:
from a local file location (file-sourced credentials), from a server
(URL-sourced credentials), from a local executable (executable-sourced
credentials), or from a user defined function that returns an OIDC or SAML token.
For file-sourced credentials, a background process needs to be continuously
refreshing the file location with a new OIDC/SAML token prior to expiration.
For tokens with one hour lifetimes, the token needs to be updated in the file
every hour. The token can be stored directly as plain text or in JSON format.
For URL-sourced credentials, a local server needs to host a GET endpoint to
return the OIDC/SAML token. The response can be in plain text or JSON.
Additional required request headers can also be specified.
For executable-sourced credentials, an application needs to be available to
output the OIDC/SAML token and other information in a JSON format.
For more information on how these work (and how to implement
executable-sourced credentials), please check out:
https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#create_a_credential_configuration
To use a custom function to supply the token, define a struct that implements the [SubjectTokenSupplier] interface for OIDC/SAML providers,
or one that implements [AwsSecurityCredentialsSupplier] for AWS providers. This can then be used when building a [Config].
The [golang.org/x/oauth2.TokenSource] created from the config using [NewTokenSource] can then be used to access Google
Cloud resources. For instance, you can create a new client from the
[cloud.google.com/go/storage] package and pass in option.WithTokenSource(yourTokenSource))
Note that this library does not perform any validation on the token_url, token_info_url,
or service_account_impersonation_url fields of the credential configuration.
It is not recommended to use a credential configuration that you did not generate with
the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain.
# Workforce Identity Federation
Workforce identity federation lets you use an external identity provider (IdP) to
authenticate and authorize a workforce—a group of users, such as employees, partners,
and contractors—using IAM, so that the users can access Google Cloud services.
Workforce identity federation extends Google Cloud's identity capabilities to support
syncless, attribute-based single sign on.
With workforce identity federation, your workforce can access Google Cloud resources
using an external identity provider (IdP) that supports OpenID Connect (OIDC) or
SAML 2.0 such as Azure Active Directory (Azure AD), Active Directory Federation
Services (AD FS), Okta, and others.
Follow the detailed instructions on how to configure Workload Identity Federation
in various platforms:
Azure AD: https://cloud.google.com/iam/docs/workforce-sign-in-azure-ad
Okta: https://cloud.google.com/iam/docs/workforce-sign-in-okta
OIDC identity provider: https://cloud.google.com/iam/docs/configuring-workforce-identity-federation#oidc
SAML 2.0 identity provider: https://cloud.google.com/iam/docs/configuring-workforce-identity-federation#saml
For workforce identity federation, the library can retrieve tokens in four ways:
from a local file location (file-sourced credentials), from a server
(URL-sourced credentials), from a local executable (executable-sourced
credentials), or from a user supplied function that returns an OIDC or SAML token.
For file-sourced credentials, a background process needs to be continuously
refreshing the file location with a new OIDC/SAML token prior to expiration.
For tokens with one hour lifetimes, the token needs to be updated in the file
every hour. The token can be stored directly as plain text or in JSON format.
For URL-sourced credentials, a local server needs to host a GET endpoint to
return the OIDC/SAML token. The response can be in plain text or JSON.
Additional required request headers can also be specified.
For executable-sourced credentials, an application needs to be available to
output the OIDC/SAML token and other information in a JSON format.
For more information on how these work (and how to implement
executable-sourced credentials), please check out:
https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in
To use a custom function to supply the token, define a struct that implements the [SubjectTokenSupplier] interface for OIDC/SAML providers.
This can then be used when building a [Config].
The [golang.org/x/oauth2.TokenSource] created from the config using [NewTokenSource] can then be used access Google
Cloud resources. For instance, you can create a new client from the
[cloud.google.com/go/storage] package and pass in option.WithTokenSource(yourTokenSource))
# Security considerations
Note that this library does not perform any validation on the token_url, token_info_url,
or service_account_impersonation_url fields of the credential configuration.
It is not recommended to use a credential configuration that you did not generate with
the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain.
*/
package externalaccount
import (
"context"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/impersonate"
"golang.org/x/oauth2/google/internal/stsexchange"
)
const (
universeDomainPlaceholder = "UNIVERSE_DOMAIN"
defaultTokenURL = "https://sts.UNIVERSE_DOMAIN/v1/token"
defaultUniverseDomain = "googleapis.com"
)
// now aliases time.Now for testing
var now = func() time.Time {
return time.Now().UTC()
}
// Config stores the configuration for fetching tokens with external credentials.
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workload
// identity pool or the workforce pool and the provider identifier in that pool. Required.
Audience string
// SubjectTokenType is the STS token type based on the Oauth2.0 token exchange spec.
// Expected values include:
// “urn:ietf:params:oauth:token-type:jwt”
// “urn:ietf:params:oauth:token-type:id-token”
// “urn:ietf:params:oauth:token-type:saml2”
// “urn:ietf:params:aws:token-type:aws4_request”
// Required.
SubjectTokenType string
// TokenURL is the STS token exchange endpoint. If not provided, will default to
// https://sts.UNIVERSE_DOMAIN/v1/token, with UNIVERSE_DOMAIN set to the
// default service domain googleapis.com unless UniverseDomain is set.
// Optional.
TokenURL string
// TokenInfoURL is the token_info endpoint used to retrieve the account related information (
// user attributes like account identifier, eg. email, username, uid, etc). This is
// needed for gCloud session account identification. Optional.
TokenInfoURL string
// ServiceAccountImpersonationURL is the URL for the service account impersonation request. This is only
// required for workload identity pools when APIs to be accessed have not integrated with UberMint. Optional.
ServiceAccountImpersonationURL string
// ServiceAccountImpersonationLifetimeSeconds is the number of seconds the service account impersonation
// token will be valid for. If not provided, it will default to 3600. Optional.
ServiceAccountImpersonationLifetimeSeconds int
// ClientSecret is currently only required if token_info endpoint also
// needs to be called with the generated GCP access token. When provided, STS will be
// called with additional basic authentication using ClientId as username and ClientSecret as password. Optional.
ClientSecret string
// ClientID is only required in conjunction with ClientSecret, as described above. Optional.
ClientID string
// CredentialSource contains the necessary information to retrieve the token itself, as well
// as some environmental information. One of SubjectTokenSupplier, AWSSecurityCredentialSupplier or
// CredentialSource must be provided. Optional.
CredentialSource *CredentialSource
// QuotaProjectID is injected by gCloud. If the value is non-empty, the Auth libraries
// will set the x-goog-user-project header which overrides the project associated with the credentials. Optional.
QuotaProjectID string
// Scopes contains the desired scopes for the returned access token. Optional.
Scopes []string
// WorkforcePoolUserProject is the workforce pool user project number when the credential
// corresponds to a workforce pool and not a workload identity pool.
// The underlying principal must still have serviceusage.services.use IAM
// permission to use the project for billing/quota. Optional.
WorkforcePoolUserProject string
// SubjectTokenSupplier is an optional token supplier for OIDC/SAML credentials.
// One of SubjectTokenSupplier, AWSSecurityCredentialSupplier or CredentialSource must be provided. Optional.
SubjectTokenSupplier SubjectTokenSupplier
// AwsSecurityCredentialsSupplier is an AWS Security Credential supplier for AWS credentials.
// One of SubjectTokenSupplier, AWSSecurityCredentialSupplier or CredentialSource must be provided. Optional.
AwsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
// UniverseDomain is the default service domain for a given Cloud universe.
// This value will be used in the default STS token URL. The default value
// is "googleapis.com". It will not be used if TokenURL is set. Optional.
UniverseDomain string
}
var (
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
)
func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input)
}
// NewTokenSource Returns an external account TokenSource using the provided external account config.
func NewTokenSource(ctx context.Context, conf Config) (oauth2.TokenSource, error) {
if conf.Audience == "" {
return nil, fmt.Errorf("oauth2/google/externalaccount: Audience must be set")
}
if conf.SubjectTokenType == "" {
return nil, fmt.Errorf("oauth2/google/externalaccount: Subject token type must be set")
}
if conf.WorkforcePoolUserProject != "" {
valid := validateWorkforceAudience(conf.Audience)
if !valid {
return nil, fmt.Errorf("oauth2/google/externalaccount: Workforce pool user project should not be set for non-workforce pool credentials")
}
}
count := 0
if conf.CredentialSource != nil {
count++
}
if conf.SubjectTokenSupplier != nil {
count++
}
if conf.AwsSecurityCredentialsSupplier != nil {
count++
}
if count == 0 {
return nil, fmt.Errorf("oauth2/google/externalaccount: One of CredentialSource, SubjectTokenSupplier, or AwsSecurityCredentialsSupplier must be set")
}
if count > 1 {
return nil, fmt.Errorf("oauth2/google/externalaccount: Only one of CredentialSource, SubjectTokenSupplier, or AwsSecurityCredentialsSupplier must be set")
}
return conf.tokenSource(ctx, "https")
}
// tokenSource is a private function that's directly called by some of the tests,
// because the unit test URLs are mocked, and would otherwise fail the
// validity check.
func (c *Config) tokenSource(ctx context.Context, scheme string) (oauth2.TokenSource, error) {
ts := tokenSource{
ctx: ctx,
conf: c,
}
if c.ServiceAccountImpersonationURL == "" {
return oauth2.ReuseTokenSource(nil, ts), nil
}
scopes := c.Scopes
ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
imp := impersonate.ImpersonateTokenSource{
Ctx: ctx,
URL: c.ServiceAccountImpersonationURL,
Scopes: scopes,
Ts: oauth2.ReuseTokenSource(nil, ts),
TokenLifetimeSeconds: c.ServiceAccountImpersonationLifetimeSeconds,
}
return oauth2.ReuseTokenSource(nil, imp), nil
}
// Subject token file types.
const (
fileTypeText = "text"
fileTypeJSON = "json"
)
// Format contains information needed to retireve a subject token for URL or File sourced credentials.
type Format struct {
// Type should be either "text" or "json". This determines whether the file or URL sourced credentials
// expect a simple text subject token or if the subject token will be contained in a JSON object.
// When not provided "text" type is assumed.
Type string `json:"type"`
// SubjectTokenFieldName is only required for JSON format. This is the field name that the credentials will check
// for the subject token in the file or URL response. This would be "access_token" for azure.
SubjectTokenFieldName string `json:"subject_token_field_name"`
}
// CredentialSource stores the information necessary to retrieve the credentials for the STS exchange.
type CredentialSource struct {
// File is the location for file sourced credentials.
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
File string `json:"file"`
// Url is the URL to call for URL sourced credentials.
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
URL string `json:"url"`
// Headers are the headers to attach to the request for URL sourced credentials.
Headers map[string]string `json:"headers"`
// Executable is the configuration object for executable sourced credentials.
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
Executable *ExecutableConfig `json:"executable"`
// EnvironmentID is the EnvironmentID used for AWS sourced credentials. This should start with "AWS".
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
EnvironmentID string `json:"environment_id"`
// RegionURL is the metadata URL to retrieve the region from for EC2 AWS credentials.
RegionURL string `json:"region_url"`
// RegionalCredVerificationURL is the AWS regional credential verification URL, will default to
// "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" if not provided."
RegionalCredVerificationURL string `json:"regional_cred_verification_url"`
// IMDSv2SessionTokenURL is the URL to retrieve the session token when using IMDSv2 in AWS.
IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"`
// Format is the format type for the subject token. Used for File and URL sourced credentials. Expected values are "text" or "json".
Format Format `json:"format"`
}
// ExecutableConfig contains information needed for executable sourced credentials.
type ExecutableConfig struct {
// Command is the the full command to run to retrieve the subject token.
// This can include arguments. Must be an absolute path for the program. Required.
Command string `json:"command"`
// TimeoutMillis is the timeout duration, in milliseconds. Defaults to 30000 milliseconds when not provided. Optional.
TimeoutMillis *int `json:"timeout_millis"`
// OutputFile is the absolute path to the output file where the executable will cache the response.
// If specified the auth libraries will first check this location before running the executable. Optional.
OutputFile string `json:"output_file"`
}
// SubjectTokenSupplier can be used to supply a subject token to exchange for a GCP access token.
type SubjectTokenSupplier interface {
// SubjectToken should return a valid subject token or an error.
// The external account token source does not cache the returned subject token, so caching
// logic should be implemented in the supplier to prevent multiple requests for the same subject token.
SubjectToken(ctx context.Context, options SupplierOptions) (string, error)
}
// AWSSecurityCredentialsSupplier can be used to supply AwsSecurityCredentials and an AWS Region to
// exchange for a GCP access token.
type AwsSecurityCredentialsSupplier interface {
// AwsRegion should return the AWS region or an error.
AwsRegion(ctx context.Context, options SupplierOptions) (string, error)
// GetAwsSecurityCredentials should return a valid set of AwsSecurityCredentials or an error.
// The external account token source does not cache the returned security credentials, so caching
// logic should be implemented in the supplier to prevent multiple requests for the same security credentials.
AwsSecurityCredentials(ctx context.Context, options SupplierOptions) (*AwsSecurityCredentials, error)
}
// SupplierOptions contains information about the requested subject token or AWS security credentials from the
// Google external account credential.
type SupplierOptions struct {
// Audience is the requested audience for the external account credential.
Audience string
// Subject token type is the requested subject token type for the external account credential. Expected values include:
// “urn:ietf:params:oauth:token-type:jwt”
// “urn:ietf:params:oauth:token-type:id-token”
// “urn:ietf:params:oauth:token-type:saml2”
// “urn:ietf:params:aws:token-type:aws4_request”
SubjectTokenType string
}
// tokenURL returns the default STS token endpoint with the configured universe
// domain.
func (c *Config) tokenURL() string {
if c.UniverseDomain == "" {
return strings.Replace(defaultTokenURL, universeDomainPlaceholder, defaultUniverseDomain, 1)
}
return strings.Replace(defaultTokenURL, universeDomainPlaceholder, c.UniverseDomain, 1)
}
// parse determines the type of CredentialSource needed.
func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
//set Defaults
if c.TokenURL == "" {
c.TokenURL = c.tokenURL()
}
supplierOptions := SupplierOptions{Audience: c.Audience, SubjectTokenType: c.SubjectTokenType}
if c.AwsSecurityCredentialsSupplier != nil {
awsCredSource := awsCredentialSource{
awsSecurityCredentialsSupplier: c.AwsSecurityCredentialsSupplier,
targetResource: c.Audience,
supplierOptions: supplierOptions,
ctx: ctx,
}
return awsCredSource, nil
} else if c.SubjectTokenSupplier != nil {
return programmaticRefreshCredentialSource{subjectTokenSupplier: c.SubjectTokenSupplier, supplierOptions: supplierOptions, ctx: ctx}, nil
} else if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
if awsVersion != 1 {
return nil, fmt.Errorf("oauth2/google/externalaccount: aws version '%d' is not supported in the current build", awsVersion)
}
awsCredSource := awsCredentialSource{
environmentID: c.CredentialSource.EnvironmentID,
regionURL: c.CredentialSource.RegionURL,
regionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
credVerificationURL: c.CredentialSource.URL,
targetResource: c.Audience,
ctx: ctx,
}
if c.CredentialSource.IMDSv2SessionTokenURL != "" {
awsCredSource.imdsv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
}
return awsCredSource, nil
}
} else if c.CredentialSource.File != "" {
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
} else if c.CredentialSource.URL != "" {
return urlCredentialSource{URL: c.CredentialSource.URL, Headers: c.CredentialSource.Headers, Format: c.CredentialSource.Format, ctx: ctx}, nil
} else if c.CredentialSource.Executable != nil {
return createExecutableCredential(ctx, c.CredentialSource.Executable, c)
}
return nil, fmt.Errorf("oauth2/google/externalaccount: unable to parse credential source")
}
type baseCredentialSource interface {
credentialSourceType() string
subjectToken() (string, error)
}
// tokenSource is the source that handles external credentials. It is used to retrieve Tokens.
type tokenSource struct {
ctx context.Context
conf *Config
}
func getMetricsHeaderValue(conf *Config, credSource baseCredentialSource) string {
return fmt.Sprintf("gl-go/%s auth/%s google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t",
goVersion(),
"unknown",
credSource.credentialSourceType(),
conf.ServiceAccountImpersonationURL != "",
conf.ServiceAccountImpersonationLifetimeSeconds != 0)
}
// Token allows tokenSource to conform to the oauth2.TokenSource interface.
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
credSource, err := conf.parse(ts.ctx)
if err != nil {
return nil, err
}
subjectToken, err := credSource.subjectToken()
if err != nil {
return nil, err
}
stsRequest := stsexchange.TokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience,
Scope: conf.Scopes,
RequestedTokenType: "urn:ietf:params:oauth:token-type:access_token",
SubjectToken: subjectToken,
SubjectTokenType: conf.SubjectTokenType,
}
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
var options map[string]interface{}
// Do not pass workforce_pool_user_project when client authentication is used.
// The client ID is sufficient for determining the user project.
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
options = map[string]interface{}{
"userProject": conf.WorkforcePoolUserProject,
}
}
stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil {
return nil, err
}
accessToken := &oauth2.Token{
AccessToken: stsResp.AccessToken,
TokenType: stsResp.TokenType,
}
// The RFC8693 doesn't define the explicit 0 of "expires_in" field behavior.
if stsResp.ExpiresIn <= 0 {
return nil, fmt.Errorf("oauth2/google/externalaccount: got invalid expiry from security token service")
}
accessToken.Expiry = now().Add(time.Duration(stsResp.ExpiresIn) * time.Second)
if stsResp.RefreshToken != "" {
accessToken.RefreshToken = stsResp.RefreshToken
}
return accessToken, nil
}

View File

@@ -0,0 +1,574 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
)
const (
textBaseCredPath = "testdata/3pi_cred.txt"
jsonBaseCredPath = "testdata/3pi_cred.json"
baseImpersonateCredsReqBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
baseImpersonateCredsRespBody = `{"accessToken":"Second.Access.Token","expireTime":"2020-12-28T15:01:23Z"}`
)
var testBaseCredSource = CredentialSource{
File: textBaseCredPath,
Format: Format{Type: fileTypeText},
}
var testConfig = Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}
var (
baseCredsRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
baseCredsResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
workforcePoolRequestBodyWithClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
workforcePoolRequestBodyWithoutClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=%7B%22userProject%22%3A%22myProject%22%7D&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
correctAT = "Sample.Access.Token"
expiry int64 = 234852
)
var (
testNow = func() time.Time { return time.Unix(expiry, 0) }
)
type testExchangeTokenServer struct {
url string
authorization string
contentType string
metricsHeader string
body string
response string
}
func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.Token, error) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), tets.url; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, tets.authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, tets.contentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, tets.metricsHeader; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), tets.body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(tets.response))
}))
defer server.Close()
config.TokenURL = server.URL
oldNow := now
defer func() { now = oldNow }()
now = testNow
ts := tokenSource{
ctx: context.Background(),
conf: config,
}
return ts.Token()
}
func validateToken(t *testing.T, tok *oauth2.Token, expectToken *oauth2.Token) {
if expectToken == nil {
return
}
if got, want := tok.AccessToken, expectToken.AccessToken; got != want {
t.Errorf("Unexpected access token: got %v, but wanted %v", got, want)
}
if got, want := tok.TokenType, expectToken.TokenType; got != want {
t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want)
}
if got, want := tok.Expiry, expectToken.Expiry; got != want {
t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want)
}
}
func createImpersonationServer(urlWanted, authWanted, bodyWanted, response string, t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), urlWanted; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, authWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, "application/json"; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
if got, want := string(body), bodyWanted; got != want {
t.Errorf("Unexpected impersonation payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
}
func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), "/"; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, metricsHeaderWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
if got, want := string(body), baseImpersonateCredsReqBody; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(baseCredsResponseBody))
}))
}
func getExpectedMetricsHeader(source string, saImpersonation bool, configLifetime bool) string {
return fmt.Sprintf("gl-go/%s auth/unknown google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t", goVersion(), source, saImpersonation, configLifetime)
}
func TestToken(t *testing.T) {
type MockSTSResponse struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
ExpiresIn int32 `json:"expires_in,omitempty"`
Scope string `json:"scopre,omitenpty"`
}
testCases := []struct {
name string
responseBody MockSTSResponse
expectToken *oauth2.Token
expectErrorMsg string
}{
{
name: "happy case",
responseBody: MockSTSResponse{
AccessToken: correctAT,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: "https://www.googleapis.com/auth/cloud-platform",
},
expectToken: &oauth2.Token{
AccessToken: correctAT,
TokenType: "Bearer",
Expiry: testNow().Add(time.Duration(3600) * time.Second),
},
},
{
name: "no expiry time on token",
responseBody: MockSTSResponse{
AccessToken: correctAT,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
Scope: "https://www.googleapis.com/auth/cloud-platform",
},
expectToken: nil,
expectErrorMsg: "oauth2/google/externalaccount: got invalid expiry from security token service",
},
{
name: "negative expiry time",
responseBody: MockSTSResponse{
AccessToken: correctAT,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: -1,
Scope: "https://www.googleapis.com/auth/cloud-platform",
},
expectToken: nil,
expectErrorMsg: "oauth2/google/externalaccount: got invalid expiry from security token service",
},
}
for _, testCase := range testCases {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}
responseBody, err := json.Marshal(testCase.responseBody)
if err != nil {
t.Errorf("Invalid response received.")
}
server := testExchangeTokenServer{
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: baseCredsRequestBody,
response: string(responseBody),
}
tok, err := run(t, &config, &server)
if err != nil && err.Error() != testCase.expectErrorMsg {
t.Errorf("Error not as expected: got = %v, and want = %v", err, testCase.expectErrorMsg)
}
validateToken(t, tok, testCase.expectToken)
}
}
func TestWorkforcePoolTokenWithClientID(t *testing.T) {
config := Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
server := testExchangeTokenServer{
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithClientId,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
expectToken := oauth2.Token{
AccessToken: correctAT,
TokenType: "Bearer",
Expiry: testNow().Add(time.Duration(3600) * time.Second),
}
validateToken(t, tok, &expectToken)
}
func TestWorkforcePoolTokenWithoutClientID(t *testing.T) {
config := Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
server := testExchangeTokenServer{
url: "/",
authorization: "",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithoutClientId,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
expectToken := oauth2.Token{
AccessToken: correctAT,
TokenType: "Bearer",
Expiry: testNow().Add(time.Duration(3600) * time.Second),
}
validateToken(t, tok, &expectToken)
}
func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
TokenURL: "https://sts.googleapis.com",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
_, err := NewTokenSource(context.Background(), config)
if err == nil {
t.Fatalf("Expected error but found none")
}
if got, want := err.Error(), "oauth2/google/externalaccount: Workforce pool user project should not be set for non-workforce pool credentials"; got != want {
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
}
}
func TestWorkforcePoolCreation(t *testing.T) {
var audienceValidatyTests = []struct {
audience string
expectSuccess bool
}{
{"//iam.googleapis.com/locations/global/workforcePools/pool-id/providers/provider-id", true},
{"//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", true},
{"//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", true},
{"identitynamespace:1f12345:my_provider", false},
{"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/projects/123456/locations/eu/workloadIdentityPools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/workforcePools/providers/provider-id", false},
{"//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", false},
}
ctx := context.Background()
for _, tt := range audienceValidatyTests {
t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
config.ServiceAccountImpersonationURL = "https://iamcredentials.googleapis.com"
config.Audience = tt.audience
config.WorkforcePoolUserProject = "myProject"
_, err := NewTokenSource(ctx, config)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
}
var impersonationTests = []struct {
name string
config Config
expectedImpersonationBody string
expectedMetricsHeader string
}{
{
name: "Base Impersonation",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
},
expectedImpersonationBody: "{\"lifetime\":\"3600s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, false),
},
{
name: "With TokenLifetime Set",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
expectedImpersonationBody: "{\"lifetime\":\"10000s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, true),
},
}
func TestImpersonation(t *testing.T) {
for _, tt := range impersonationTests {
t.Run(tt.name, func(t *testing.T) {
testImpersonateConfig := tt.config
impersonateServer := createImpersonationServer("/", "Bearer Sample.Access.Token", tt.expectedImpersonationBody, baseImpersonateCredsRespBody, t)
defer impersonateServer.Close()
testImpersonateConfig.ServiceAccountImpersonationURL = impersonateServer.URL
targetServer := createTargetServer(tt.expectedMetricsHeader, t)
defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL
ourTS, err := testImpersonateConfig.tokenSource(context.Background(), "http")
if err != nil {
t.Fatalf("Failed to create TokenSource: %v", err)
}
oldNow := now
defer func() { now = oldNow }()
now = testNow
tok, err := ourTS.Token()
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
if got, want := tok.AccessToken, "Second.Access.Token"; got != want {
t.Errorf("Unexpected access token: got %v, but wanted %v", got, want)
}
if got, want := tok.TokenType, "Bearer"; got != want {
t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want)
}
})
}
}
var newTokenTests = []struct {
name string
config Config
}{
{
name: "Missing Audience",
config: Config{
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
},
{
name: "Missing Subject Token Type",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
},
{
name: "No Cred Source",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
},
{
name: "Cred Source and Supplier",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
CredentialSource: &testBaseCredSource,
AwsSecurityCredentialsSupplier: testAwsSupplier{},
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
},
}
func TestNewToken(t *testing.T) {
for _, tt := range newTokenTests {
t.Run(tt.name, func(t *testing.T) {
testConfig := tt.config
_, err := NewTokenSource(context.Background(), testConfig)
if err == nil {
t.Fatalf("expected error when calling NewToken()")
}
})
}
}
func TestConfig_TokenURL(t *testing.T) {
tests := []struct {
tokenURL string
universeDomain string
want string
}{
{
tokenURL: "https://sts.googleapis.com/v1/token",
universeDomain: "",
want: "https://sts.googleapis.com/v1/token",
},
{
tokenURL: "",
universeDomain: "",
want: "https://sts.googleapis.com/v1/token",
},
{
tokenURL: "",
universeDomain: "googleapis.com",
want: "https://sts.googleapis.com/v1/token",
},
{
tokenURL: "",
universeDomain: "example.com",
want: "https://sts.example.com/v1/token",
},
}
for _, tt := range tests {
config := &Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
CredentialSource: &testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}
config.TokenURL = tt.tokenURL
config.UniverseDomain = tt.universeDomain
config.parse(context.Background())
if got := config.TokenURL; got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
}
}

View File

@@ -19,7 +19,7 @@ import (
"time" "time"
) )
var serviceAccountImpersonationRE = regexp.MustCompile("https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/(.*@.*):generateAccessToken") var serviceAccountImpersonationRE = regexp.MustCompile("https://iamcredentials\\..+/v1/projects/-/serviceAccounts/(.*@.*):generateAccessToken")
const ( const (
executableSupportedMaxVersion = 1 executableSupportedMaxVersion = 1
@@ -39,51 +39,51 @@ func (nce nonCacheableError) Error() string {
} }
func missingFieldError(source, field string) error { func missingFieldError(source, field string) error {
return fmt.Errorf("oauth2/google: %v missing `%q` field", source, field) return fmt.Errorf("oauth2/google/externalaccount: %v missing `%q` field", source, field)
} }
func jsonParsingError(source, data string) error { func jsonParsingError(source, data string) error {
return fmt.Errorf("oauth2/google: unable to parse %v\nResponse: %v", source, data) return fmt.Errorf("oauth2/google/externalaccount: unable to parse %v\nResponse: %v", source, data)
} }
func malformedFailureError() error { func malformedFailureError() error {
return nonCacheableError{"oauth2/google: response must include `error` and `message` fields when unsuccessful"} return nonCacheableError{"oauth2/google/externalaccount: response must include `error` and `message` fields when unsuccessful"}
} }
func userDefinedError(code, message string) error { func userDefinedError(code, message string) error {
return nonCacheableError{fmt.Sprintf("oauth2/google: response contains unsuccessful response: (%v) %v", code, message)} return nonCacheableError{fmt.Sprintf("oauth2/google/externalaccount: response contains unsuccessful response: (%v) %v", code, message)}
} }
func unsupportedVersionError(source string, version int) error { func unsupportedVersionError(source string, version int) error {
return fmt.Errorf("oauth2/google: %v contains unsupported version: %v", source, version) return fmt.Errorf("oauth2/google/externalaccount: %v contains unsupported version: %v", source, version)
} }
func tokenExpiredError() error { func tokenExpiredError() error {
return nonCacheableError{"oauth2/google: the token returned by the executable is expired"} return nonCacheableError{"oauth2/google/externalaccount: the token returned by the executable is expired"}
} }
func tokenTypeError(source string) error { func tokenTypeError(source string) error {
return fmt.Errorf("oauth2/google: %v contains unsupported token type", source) return fmt.Errorf("oauth2/google/externalaccount: %v contains unsupported token type", source)
} }
func exitCodeError(exitCode int) error { func exitCodeError(exitCode int) error {
return fmt.Errorf("oauth2/google: executable command failed with exit code %v", exitCode) return fmt.Errorf("oauth2/google/externalaccount: executable command failed with exit code %v", exitCode)
} }
func executableError(err error) error { func executableError(err error) error {
return fmt.Errorf("oauth2/google: executable command failed: %v", err) return fmt.Errorf("oauth2/google/externalaccount: executable command failed: %v", err)
} }
func executablesDisallowedError() error { func executablesDisallowedError() error {
return errors.New("oauth2/google: executables need to be explicitly allowed (set GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES to '1') to run") return errors.New("oauth2/google/externalaccount: executables need to be explicitly allowed (set GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES to '1') to run")
} }
func timeoutRangeError() error { func timeoutRangeError() error {
return errors.New("oauth2/google: invalid `timeout_millis` field — executable timeout must be between 5 and 120 seconds") return errors.New("oauth2/google/externalaccount: invalid `timeout_millis` field — executable timeout must be between 5 and 120 seconds")
} }
func commandMissingError() error { func commandMissingError() error {
return errors.New("oauth2/google: missing `command` field — executable command must be provided") return errors.New("oauth2/google/externalaccount: missing `command` field — executable command must be provided")
} }
type environment interface { type environment interface {
@@ -146,7 +146,7 @@ type executableCredentialSource struct {
// CreateExecutableCredential creates an executableCredentialSource given an ExecutableConfig. // CreateExecutableCredential creates an executableCredentialSource given an ExecutableConfig.
// It also performs defaulting and type conversions. // It also performs defaulting and type conversions.
func CreateExecutableCredential(ctx context.Context, ec *ExecutableConfig, config *Config) (executableCredentialSource, error) { func createExecutableCredential(ctx context.Context, ec *ExecutableConfig, config *Config) (executableCredentialSource, error) {
if ec.Command == "" { if ec.Command == "" {
return executableCredentialSource{}, commandMissingError() return executableCredentialSource{}, commandMissingError()
} }
@@ -233,6 +233,10 @@ func (cs executableCredentialSource) parseSubjectTokenFromSource(response []byte
return "", tokenTypeError(source) return "", tokenTypeError(source)
} }
func (cs executableCredentialSource) credentialSourceType() string {
return "executable"
}
func (cs executableCredentialSource) subjectToken() (string, error) { func (cs executableCredentialSource) subjectToken() (string, error) {
if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil { if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil {
return token, err return token, err

View File

@@ -128,7 +128,7 @@ var creationTests = []struct {
func TestCreateExecutableCredential(t *testing.T) { func TestCreateExecutableCredential(t *testing.T) {
for _, tt := range creationTests { for _, tt := range creationTests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ecs, err := CreateExecutableCredential(context.Background(), &tt.executableConfig, nil) ecs, err := createExecutableCredential(context.Background(), &tt.executableConfig, nil)
if tt.expectedErr != nil { if tt.expectedErr != nil {
if err == nil { if err == nil {
t.Fatalf("Expected error but found none") t.Fatalf("Expected error but found none")
@@ -150,6 +150,9 @@ func TestCreateExecutableCredential(t *testing.T) {
if ecs.Timeout != tt.expectedTimeout { if ecs.Timeout != tt.expectedTimeout {
t.Errorf("ecs.Timeout got %v but want %v", ecs.Timeout, tt.expectedTimeout) t.Errorf("ecs.Timeout got %v but want %v", ecs.Timeout, tt.expectedTimeout)
} }
if ecs.credentialSourceType() != "executable" {
t.Errorf("ecs.CredentialSourceType() got %s but want executable", ecs.credentialSourceType())
}
} }
}) })
} }
@@ -166,7 +169,7 @@ var getEnvironmentTests = []struct {
config: Config{ config: Config{
Audience: "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/oidc", Audience: "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/oidc",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
CredentialSource: CredentialSource{ CredentialSource: &CredentialSource{
Executable: &ExecutableConfig{ Executable: &ExecutableConfig{
Command: "blarg", Command: "blarg",
}, },
@@ -190,7 +193,7 @@ var getEnvironmentTests = []struct {
Audience: "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/oidc", Audience: "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/oidc",
ServiceAccountImpersonationURL: "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test@project.iam.gserviceaccount.com:generateAccessToken", ServiceAccountImpersonationURL: "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test@project.iam.gserviceaccount.com:generateAccessToken",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
CredentialSource: CredentialSource{ CredentialSource: &CredentialSource{
Executable: &ExecutableConfig{ Executable: &ExecutableConfig{
Command: "blarg", Command: "blarg",
OutputFile: "/path/to/generated/cached/credentials", OutputFile: "/path/to/generated/cached/credentials",
@@ -217,7 +220,7 @@ var getEnvironmentTests = []struct {
Audience: "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/oidc", Audience: "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/oidc",
ServiceAccountImpersonationURL: "test@project.iam.gserviceaccount.com", ServiceAccountImpersonationURL: "test@project.iam.gserviceaccount.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
CredentialSource: CredentialSource{ CredentialSource: &CredentialSource{
Executable: &ExecutableConfig{ Executable: &ExecutableConfig{
Command: "blarg", Command: "blarg",
OutputFile: "/path/to/generated/cached/credentials", OutputFile: "/path/to/generated/cached/credentials",
@@ -244,7 +247,7 @@ func TestExecutableCredentialGetEnvironment(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
config := tt.config config := tt.config
ecs, err := CreateExecutableCredential(context.Background(), config.CredentialSource.Executable, &config) ecs, err := createExecutableCredential(context.Background(), config.CredentialSource.Executable, &config)
if err != nil { if err != nil {
t.Fatalf("creation failed %v", err) t.Fatalf("creation failed %v", err)
} }
@@ -468,7 +471,7 @@ func TestRetrieveExecutableSubjectTokenExecutableErrors(t *testing.T) {
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -575,7 +578,7 @@ func TestRetrieveExecutableSubjectTokenSuccesses(t *testing.T) {
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -626,7 +629,7 @@ func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -775,7 +778,7 @@ func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -878,7 +881,7 @@ func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -983,7 +986,7 @@ func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -1018,3 +1021,37 @@ func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
}) })
} }
} }
func TestServiceAccountImpersonationRE(t *testing.T) {
tests := []struct {
name string
serviceAccountImpersonationURL string
want string
}{
{
name: "universe domain Google Default Universe (GDU) googleapis.com",
serviceAccountImpersonationURL: "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test@project.iam.gserviceaccount.com:generateAccessToken",
want: "test@project.iam.gserviceaccount.com",
},
{
name: "email does not match",
serviceAccountImpersonationURL: "test@project.iam.gserviceaccount.com",
want: "",
},
{
name: "universe domain non-GDU",
serviceAccountImpersonationURL: "https://iamcredentials.apis-tpclp.goog/v1/projects/-/serviceAccounts/test@project.iam.gserviceaccount.com:generateAccessToken",
want: "test@project.iam.gserviceaccount.com",
},
}
for _, tt := range tests {
matches := serviceAccountImpersonationRE.FindStringSubmatch(tt.serviceAccountImpersonationURL)
if matches == nil {
if tt.want != "" {
t.Errorf("%q: got nil, want %q", tt.name, tt.want)
}
} else if matches[1] != tt.want {
t.Errorf("%q: got %q, want %q", tt.name, matches[1], tt.want)
}
}
}

View File

@@ -16,18 +16,22 @@ import (
type fileCredentialSource struct { type fileCredentialSource struct {
File string File string
Format format Format Format
}
func (cs fileCredentialSource) credentialSourceType() string {
return "file"
} }
func (cs fileCredentialSource) subjectToken() (string, error) { func (cs fileCredentialSource) subjectToken() (string, error) {
tokenFile, err := os.Open(cs.File) tokenFile, err := os.Open(cs.File)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: failed to open credential file %q", cs.File) return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File)
} }
defer tokenFile.Close() defer tokenFile.Close()
tokenBytes, err := ioutil.ReadAll(io.LimitReader(tokenFile, 1<<20)) tokenBytes, err := ioutil.ReadAll(io.LimitReader(tokenFile, 1<<20))
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: failed to read credential file: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err)
} }
tokenBytes = bytes.TrimSpace(tokenBytes) tokenBytes = bytes.TrimSpace(tokenBytes)
switch cs.Format.Type { switch cs.Format.Type {
@@ -35,15 +39,15 @@ func (cs fileCredentialSource) subjectToken() (string, error) {
jsonData := make(map[string]interface{}) jsonData := make(map[string]interface{})
err = json.Unmarshal(tokenBytes, &jsonData) err = json.Unmarshal(tokenBytes, &jsonData)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: failed to unmarshal subject token file: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)
} }
val, ok := jsonData[cs.Format.SubjectTokenFieldName] val, ok := jsonData[cs.Format.SubjectTokenFieldName]
if !ok { if !ok {
return "", errors.New("oauth2/google: provided subject_token_field_name not found in credentials") return "", errors.New("oauth2/google/externalaccount: provided subject_token_field_name not found in credentials")
} }
token, ok := val.(string) token, ok := val.(string)
if !ok { if !ok {
return "", errors.New("oauth2/google: improperly formatted subject token") return "", errors.New("oauth2/google/externalaccount: improperly formatted subject token")
} }
return token, nil return token, nil
case "text": case "text":
@@ -51,7 +55,7 @@ func (cs fileCredentialSource) subjectToken() (string, error) {
case "": case "":
return string(tokenBytes), nil return string(tokenBytes), nil
default: default:
return "", errors.New("oauth2/google: invalid credential_source file format type") return "", errors.New("oauth2/google/externalaccount: invalid credential_source file format type")
} }
} }

View File

@@ -36,7 +36,7 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
name: "TextFileSource", name: "TextFileSource",
cs: CredentialSource{ cs: CredentialSource{
File: textBaseCredPath, File: textBaseCredPath,
Format: format{Type: fileTypeText}, Format: Format{Type: fileTypeText},
}, },
want: "street123", want: "street123",
}, },
@@ -44,7 +44,7 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
name: "JSONFileSource", name: "JSONFileSource",
cs: CredentialSource{ cs: CredentialSource{
File: jsonBaseCredPath, File: jsonBaseCredPath,
Format: format{Type: fileTypeJSON, SubjectTokenFieldName: "SubjToken"}, Format: Format{Type: fileTypeJSON, SubjectTokenFieldName: "SubjToken"},
}, },
want: "321road", want: "321road",
}, },
@@ -53,7 +53,7 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
for _, test := range fileSourceTests { for _, test := range fileSourceTests {
test := test test := test
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = test.cs tfc.CredentialSource = &test.cs
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
@@ -68,6 +68,9 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
t.Errorf("got %v but want %v", out, test.want) t.Errorf("got %v but want %v", out, test.want)
} }
if got, want := base.credentialSourceType(), "file"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}) })
} }
} }

View File

@@ -0,0 +1,64 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"strings"
"unicode"
)
var (
// version is a package internal global variable for testing purposes.
version = runtime.Version
)
// versionUnknown is only used when the runtime version cannot be determined.
const versionUnknown = "UNKNOWN"
// goVersion returns a Go runtime version derived from the runtime environment
// that is modified to be suitable for reporting in a header, meaning it has no
// whitespace. If it is unable to determine the Go runtime version, it returns
// versionUnknown.
func goVersion() string {
const develPrefix = "devel +"
s := version()
if strings.HasPrefix(s, develPrefix) {
s = s[len(develPrefix):]
if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
return s
} else if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
notSemverRune := func(r rune) bool {
return !strings.ContainsRune("0123456789.", r)
}
if strings.HasPrefix(s, "go1") {
s = s[2:]
var prerelease string
if p := strings.IndexFunc(s, notSemverRune); p >= 0 {
s, prerelease = s[:p], s[p:]
}
if strings.HasSuffix(s, ".") {
s += "0"
} else if strings.Count(s, ".") < 2 {
s += ".0"
}
if prerelease != "" {
// Some release candidates already have a dash in them.
if !strings.HasPrefix(prerelease, "-") {
prerelease = "-" + prerelease
}
s += prerelease
}
return s
}
return "UNKNOWN"
}

View File

@@ -0,0 +1,48 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGoVersion(t *testing.T) {
testVersion := func(v string) func() string {
return func() string {
return v
}
}
for _, tst := range []struct {
v func() string
want string
}{
{
testVersion("go1.19"),
"1.19.0",
},
{
testVersion("go1.21-20230317-RC01"),
"1.21.0-20230317-RC01",
},
{
testVersion("devel +abc1234"),
"abc1234",
},
{
testVersion("this should be unknown"),
versionUnknown,
},
} {
version = tst.v
got := goVersion()
if diff := cmp.Diff(got, tst.want); diff != "" {
t.Errorf("got(-),want(+):\n%s", diff)
}
}
version = runtime.Version
}

View File

@@ -0,0 +1,21 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import "context"
type programmaticRefreshCredentialSource struct {
supplierOptions SupplierOptions
subjectTokenSupplier SubjectTokenSupplier
ctx context.Context
}
func (cs programmaticRefreshCredentialSource) credentialSourceType() string {
return "programmatic"
}
func (cs programmaticRefreshCredentialSource) subjectToken() (string, error) {
return cs.subjectTokenSupplier.SubjectToken(cs.ctx, cs.supplierOptions)
}

View File

@@ -0,0 +1,122 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"context"
"errors"
"testing"
)
func TestRetrieveSubjectToken_ProgrammaticAuth(t *testing.T) {
tfc := testConfig
tfc.SubjectTokenSupplier = testSubjectTokenSupplier{
subjectToken: "subjectToken",
}
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
if out != "subjectToken" {
t.Errorf("subjectToken = \n%q\n want \nSubjectToken", out)
}
}
func TestRetrieveSubjectToken_ProgrammaticAuthFails(t *testing.T) {
tfc := testConfig
testError := errors.New("test error")
tfc.SubjectTokenSupplier = testSubjectTokenSupplier{
err: testError,
}
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil {
t.Fatalf("subjectToken() should have failed")
}
if testError != err {
t.Errorf("subjectToken = %e, want %e", err, testError)
}
}
func TestRetrieveSubjectToken_ProgrammaticAuthOptions(t *testing.T) {
tfc := testConfig
expectedOptions := SupplierOptions{Audience: tfc.Audience, SubjectTokenType: tfc.SubjectTokenType}
tfc.SubjectTokenSupplier = testSubjectTokenSupplier{
subjectToken: "subjectToken",
expectedOptions: &expectedOptions,
}
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
}
func TestRetrieveSubjectToken_ProgrammaticAuthContext(t *testing.T) {
tfc := testConfig
ctx := context.Background()
tfc.SubjectTokenSupplier = testSubjectTokenSupplier{
subjectToken: "subjectToken",
expectedContext: ctx,
}
base, err := tfc.parse(ctx)
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
}
type testSubjectTokenSupplier struct {
err error
subjectToken string
expectedOptions *SupplierOptions
expectedContext context.Context
}
func (supp testSubjectTokenSupplier) SubjectToken(ctx context.Context, options SupplierOptions) (string, error) {
if supp.err != nil {
return "", supp.err
}
if supp.expectedOptions != nil {
if supp.expectedOptions.Audience != options.Audience {
return "", errors.New("Audience does not match")
}
if supp.expectedOptions.SubjectTokenType != options.SubjectTokenType {
return "", errors.New("Audience does not match")
}
}
if supp.expectedContext != nil {
if supp.expectedContext != ctx {
return "", errors.New("Context does not match")
}
}
return supp.subjectToken, nil
}

View File

@@ -19,15 +19,19 @@ import (
type urlCredentialSource struct { type urlCredentialSource struct {
URL string URL string
Headers map[string]string Headers map[string]string
Format format Format Format
ctx context.Context ctx context.Context
} }
func (cs urlCredentialSource) credentialSourceType() string {
return "url"
}
func (cs urlCredentialSource) subjectToken() (string, error) { func (cs urlCredentialSource) subjectToken() (string, error) {
client := oauth2.NewClient(cs.ctx, nil) client := oauth2.NewClient(cs.ctx, nil)
req, err := http.NewRequest("GET", cs.URL, nil) req, err := http.NewRequest("GET", cs.URL, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: HTTP request for URL-sourced credential failed: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: HTTP request for URL-sourced credential failed: %v", err)
} }
req = req.WithContext(cs.ctx) req = req.WithContext(cs.ctx)
@@ -36,16 +40,16 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: invalid response when retrieving subject token: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: invalid response when retrieving subject token: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: invalid body in subject token URL query: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err)
} }
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return "", fmt.Errorf("oauth2/google: status code %d: %s", c, respBody) return "", fmt.Errorf("oauth2/google/externalaccount: status code %d: %s", c, respBody)
} }
switch cs.Format.Type { switch cs.Format.Type {
@@ -53,15 +57,15 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
jsonData := make(map[string]interface{}) jsonData := make(map[string]interface{})
err = json.Unmarshal(respBody, &jsonData) err = json.Unmarshal(respBody, &jsonData)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google: failed to unmarshal subject token file: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)
} }
val, ok := jsonData[cs.Format.SubjectTokenFieldName] val, ok := jsonData[cs.Format.SubjectTokenFieldName]
if !ok { if !ok {
return "", errors.New("oauth2/google: provided subject_token_field_name not found in credentials") return "", errors.New("oauth2/google/externalaccount: provided subject_token_field_name not found in credentials")
} }
token, ok := val.(string) token, ok := val.(string)
if !ok { if !ok {
return "", errors.New("oauth2/google: improperly formatted subject token") return "", errors.New("oauth2/google/externalaccount: improperly formatted subject token")
} }
return token, nil return token, nil
case "text": case "text":
@@ -69,7 +73,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
case "": case "":
return string(respBody), nil return string(respBody), nil
default: default:
return "", errors.New("oauth2/google: invalid credential_source file format type") return "", errors.New("oauth2/google/externalaccount: invalid credential_source file format type")
} }
} }

View File

@@ -28,11 +28,11 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) {
heads["Metadata"] = "True" heads["Metadata"] = "True"
cs := CredentialSource{ cs := CredentialSource{
URL: ts.URL, URL: ts.URL,
Format: format{Type: fileTypeText}, Format: Format{Type: fileTypeText},
Headers: heads, Headers: heads,
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -60,7 +60,7 @@ func TestRetrieveURLSubjectToken_Untyped(t *testing.T) {
URL: ts.URL, URL: ts.URL,
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -93,10 +93,10 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
})) }))
cs := CredentialSource{ cs := CredentialSource{
URL: ts.URL, URL: ts.URL,
Format: format{Type: fileTypeJSON, SubjectTokenFieldName: "SubjToken"}, Format: Format{Type: fileTypeJSON, SubjectTokenFieldName: "SubjToken"},
} }
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@@ -111,3 +111,21 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
t.Errorf("got %v but want %v", out, myURLToken) t.Errorf("got %v but want %v", out, myURLToken)
} }
} }
func TestURLCredential_CredentialSourceType(t *testing.T) {
cs := CredentialSource{
URL: "http://example.com",
Format: Format{Type: fileTypeText},
}
tfc := testFileConfig
tfc.CredentialSource = &cs
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
if got, want := base.credentialSourceType(), "url"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}

View File

@@ -15,7 +15,9 @@ import (
"cloud.google.com/go/compute/metadata" "cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/externalaccount" "golang.org/x/oauth2/google/externalaccount"
"golang.org/x/oauth2/google/internal/externalaccountauthorizeduser"
"golang.org/x/oauth2/google/internal/impersonate"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
) )
@@ -23,6 +25,7 @@ import (
var Endpoint = oauth2.Endpoint{ var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth", AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token", TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
} }
@@ -98,6 +101,7 @@ const (
serviceAccountKey = "service_account" serviceAccountKey = "service_account"
userCredentialsKey = "authorized_user" userCredentialsKey = "authorized_user"
externalAccountKey = "external_account" externalAccountKey = "external_account"
externalAccountAuthorizedUserKey = "external_account_authorized_user"
impersonatedServiceAccount = "impersonated_service_account" impersonatedServiceAccount = "impersonated_service_account"
) )
@@ -112,6 +116,7 @@ type credentialsFile struct {
AuthURL string `json:"auth_uri"` AuthURL string `json:"auth_uri"`
TokenURL string `json:"token_uri"` TokenURL string `json:"token_uri"`
ProjectID string `json:"project_id"` ProjectID string `json:"project_id"`
UniverseDomain string `json:"universe_domain"`
// User Credential fields // User Credential fields
// (These typically come from gcloud auth.) // (These typically come from gcloud auth.)
@@ -131,6 +136,9 @@ type credentialsFile struct {
QuotaProjectID string `json:"quota_project_id"` QuotaProjectID string `json:"quota_project_id"`
WorkforcePoolUserProject string `json:"workforce_pool_user_project"` WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
// External Account Authorized User fields
RevokeURL string `json:"revoke_url"`
// Service account impersonation // Service account impersonation
SourceCredentials *credentialsFile `json:"source_credentials"` SourceCredentials *credentialsFile `json:"source_credentials"`
} }
@@ -193,11 +201,24 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
ServiceAccountImpersonationLifetimeSeconds: f.ServiceAccountImpersonation.TokenLifetimeSeconds, ServiceAccountImpersonationLifetimeSeconds: f.ServiceAccountImpersonation.TokenLifetimeSeconds,
ClientSecret: f.ClientSecret, ClientSecret: f.ClientSecret,
ClientID: f.ClientID, ClientID: f.ClientID,
CredentialSource: f.CredentialSource, CredentialSource: &f.CredentialSource,
QuotaProjectID: f.QuotaProjectID, QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes, Scopes: params.Scopes,
WorkforcePoolUserProject: f.WorkforcePoolUserProject, WorkforcePoolUserProject: f.WorkforcePoolUserProject,
} }
return externalaccount.NewTokenSource(ctx, *cfg)
case externalAccountAuthorizedUserKey:
cfg := &externalaccountauthorizeduser.Config{
Audience: f.Audience,
RefreshToken: f.RefreshToken,
TokenURL: f.TokenURLExternal,
TokenInfoURL: f.TokenInfoURL,
ClientID: f.ClientID,
ClientSecret: f.ClientSecret,
RevokeURL: f.RevokeURL,
QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes,
}
return cfg.TokenSource(ctx) return cfg.TokenSource(ctx)
case impersonatedServiceAccount: case impersonatedServiceAccount:
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil { if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
@@ -208,7 +229,7 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
if err != nil { if err != nil {
return nil, err return nil, err
} }
imp := externalaccount.ImpersonateTokenSource{ imp := impersonate.ImpersonateTokenSource{
Ctx: ctx, Ctx: ctx,
URL: f.ServiceAccountImpersonationURL, URL: f.ServiceAccountImpersonationURL,
Scopes: params.Scopes, Scopes: params.Scopes,
@@ -231,7 +252,11 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
// Further information about retrieving access tokens from the GCE metadata // Further information about retrieving access tokens from the GCE metadata
// server can be found at https://cloud.google.com/compute/docs/authentication. // server can be found at https://cloud.google.com/compute/docs/authentication.
func ComputeTokenSource(account string, scope ...string) oauth2.TokenSource { func ComputeTokenSource(account string, scope ...string) oauth2.TokenSource {
return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope}) return computeTokenSource(account, 0, scope...)
}
func computeTokenSource(account string, earlyExpiry time.Duration, scope ...string) oauth2.TokenSource {
return oauth2.ReuseTokenSourceWithExpiry(nil, computeSource{account: account, scopes: scope}, earlyExpiry)
} }
type computeSource struct { type computeSource struct {

View File

@@ -5,6 +5,8 @@
package google package google
import ( import (
"net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
) )
@@ -137,3 +139,21 @@ func TestJWTConfigFromJSONNoAudience(t *testing.T) {
t.Errorf("Audience = %q; want %q", got, want) t.Errorf("Audience = %q; want %q", got, want)
} }
} }
func TestComputeTokenSource(t *testing.T) {
tokenPath := "/computeMetadata/v1/instance/service-accounts/default/token"
tokenResponseBody := `{"access_token":"Sample.Access.Token","token_type":"Bearer","expires_in":3600}`
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != tokenPath {
t.Errorf("got %s, want %s", r.URL.Path, tokenPath)
}
w.Write([]byte(tokenResponseBody))
}))
defer s.Close()
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
ts := ComputeTokenSource("")
_, err := ts.Token()
if err != nil {
t.Errorf("ts.Token() = %v", err)
}
}

View File

@@ -1,269 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"context"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
)
// now aliases time.Now for testing
var now = func() time.Time {
return time.Now().UTC()
}
// Config stores the configuration for fetching tokens with external credentials.
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workload
// identity pool or the workforce pool and the provider identifier in that pool.
Audience string
// SubjectTokenType is the STS token type based on the Oauth2.0 token exchange spec
// e.g. `urn:ietf:params:oauth:token-type:jwt`.
SubjectTokenType string
// TokenURL is the STS token exchange endpoint.
TokenURL string
// TokenInfoURL is the token_info endpoint used to retrieve the account related information (
// user attributes like account identifier, eg. email, username, uid, etc). This is
// needed for gCloud session account identification.
TokenInfoURL string
// ServiceAccountImpersonationURL is the URL for the service account impersonation request. This is only
// required for workload identity pools when APIs to be accessed have not integrated with UberMint.
ServiceAccountImpersonationURL string
// ServiceAccountImpersonationLifetimeSeconds is the number of seconds the service account impersonation
// token will be valid for.
ServiceAccountImpersonationLifetimeSeconds int
// ClientSecret is currently only required if token_info endpoint also
// needs to be called with the generated GCP access token. When provided, STS will be
// called with additional basic authentication using client_id as username and client_secret as password.
ClientSecret string
// ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// CredentialSource contains the necessary information to retrieve the token itself, as well
// as some environmental information.
CredentialSource CredentialSource
// QuotaProjectID is injected by gCloud. If the value is non-empty, the Auth libraries
// will set the x-goog-user-project which overrides the project associated with the credentials.
QuotaProjectID string
// Scopes contains the desired scopes for the returned access token.
Scopes []string
// The optional workforce pool user project number when the credential
// corresponds to a workforce pool and not a workload identity pool.
// The underlying principal must still have serviceusage.services.use IAM
// permission to use the project for billing/quota.
WorkforcePoolUserProject string
}
// Each element consists of a list of patterns. validateURLs checks for matches
// that include all elements in a given list, in that order.
var (
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
)
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
parsed, err := url.Parse(input)
if err != nil {
return false
}
if !strings.EqualFold(parsed.Scheme, scheme) {
return false
}
toTest := parsed.Host
for _, pattern := range patterns {
if pattern.MatchString(toTest) {
return true
}
}
return false
}
func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input)
}
// TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials.
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return c.tokenSource(ctx, "https")
}
// tokenSource is a private function that's directly called by some of the tests,
// because the unit test URLs are mocked, and would otherwise fail the
// validity check.
func (c *Config) tokenSource(ctx context.Context, scheme string) (oauth2.TokenSource, error) {
if c.WorkforcePoolUserProject != "" {
valid := validateWorkforceAudience(c.Audience)
if !valid {
return nil, fmt.Errorf("oauth2/google: workforce_pool_user_project should not be set for non-workforce pool credentials")
}
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
if c.ServiceAccountImpersonationURL == "" {
return oauth2.ReuseTokenSource(nil, ts), nil
}
scopes := c.Scopes
ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
imp := ImpersonateTokenSource{
Ctx: ctx,
URL: c.ServiceAccountImpersonationURL,
Scopes: scopes,
Ts: oauth2.ReuseTokenSource(nil, ts),
TokenLifetimeSeconds: c.ServiceAccountImpersonationLifetimeSeconds,
}
return oauth2.ReuseTokenSource(nil, imp), nil
}
// Subject token file types.
const (
fileTypeText = "text"
fileTypeJSON = "json"
)
type format struct {
// Type is either "text" or "json". When not provided "text" type is assumed.
Type string `json:"type"`
// SubjectTokenFieldName is only required for JSON format. This would be "access_token" for azure.
SubjectTokenFieldName string `json:"subject_token_field_name"`
}
// CredentialSource stores the information necessary to retrieve the credentials for the STS exchange.
// One field amongst File, URL, and Executable should be filled, depending on the kind of credential in question.
// The EnvironmentID should start with AWS if being used for an AWS credential.
type CredentialSource struct {
File string `json:"file"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Executable *ExecutableConfig `json:"executable"`
EnvironmentID string `json:"environment_id"`
RegionURL string `json:"region_url"`
RegionalCredVerificationURL string `json:"regional_cred_verification_url"`
CredVerificationURL string `json:"cred_verification_url"`
IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"`
Format format `json:"format"`
}
type ExecutableConfig struct {
Command string `json:"command"`
TimeoutMillis *int `json:"timeout_millis"`
OutputFile string `json:"output_file"`
}
// parse determines the type of CredentialSource needed.
func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
if awsVersion != 1 {
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
}
awsCredSource := awsCredentialSource{
EnvironmentID: c.CredentialSource.EnvironmentID,
RegionURL: c.CredentialSource.RegionURL,
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
CredVerificationURL: c.CredentialSource.URL,
TargetResource: c.Audience,
ctx: ctx,
}
if c.CredentialSource.IMDSv2SessionTokenURL != "" {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
}
if err := awsCredSource.validateMetadataServers(); err != nil {
return nil, err
}
return awsCredSource, nil
}
} else if c.CredentialSource.File != "" {
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
} else if c.CredentialSource.URL != "" {
return urlCredentialSource{URL: c.CredentialSource.URL, Headers: c.CredentialSource.Headers, Format: c.CredentialSource.Format, ctx: ctx}, nil
} else if c.CredentialSource.Executable != nil {
return CreateExecutableCredential(ctx, c.CredentialSource.Executable, c)
}
return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
}
type baseCredentialSource interface {
subjectToken() (string, error)
}
// tokenSource is the source that handles external credentials. It is used to retrieve Tokens.
type tokenSource struct {
ctx context.Context
conf *Config
}
// Token allows tokenSource to conform to the oauth2.TokenSource interface.
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
credSource, err := conf.parse(ts.ctx)
if err != nil {
return nil, err
}
subjectToken, err := credSource.subjectToken()
if err != nil {
return nil, err
}
stsRequest := stsTokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience,
Scope: conf.Scopes,
RequestedTokenType: "urn:ietf:params:oauth:token-type:access_token",
SubjectToken: subjectToken,
SubjectTokenType: conf.SubjectTokenType,
}
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
clientAuth := clientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
var options map[string]interface{}
// Do not pass workforce_pool_user_project when client authentication is used.
// The client ID is sufficient for determining the user project.
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
options = map[string]interface{}{
"userProject": conf.WorkforcePoolUserProject,
}
}
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil {
return nil, err
}
accessToken := &oauth2.Token{
AccessToken: stsResp.AccessToken,
TokenType: stsResp.TokenType,
}
if stsResp.ExpiresIn < 0 {
return nil, fmt.Errorf("oauth2/google: got invalid expiry from security token service")
} else if stsResp.ExpiresIn >= 0 {
accessToken.Expiry = now().Add(time.Duration(stsResp.ExpiresIn) * time.Second)
}
if stsResp.RefreshToken != "" {
accessToken.RefreshToken = stsResp.RefreshToken
}
return accessToken, nil
}

View File

@@ -1,246 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
)
const (
textBaseCredPath = "testdata/3pi_cred.txt"
jsonBaseCredPath = "testdata/3pi_cred.json"
)
var testBaseCredSource = CredentialSource{
File: textBaseCredPath,
Format: format{Type: fileTypeText},
}
var testConfig = Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}
var (
baseCredsRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
baseCredsResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
workforcePoolRequestBodyWithClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
workforcePoolRequestBodyWithoutClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=%7B%22userProject%22%3A%22myProject%22%7D&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
correctAT = "Sample.Access.Token"
expiry int64 = 234852
)
var (
testNow = func() time.Time { return time.Unix(expiry, 0) }
)
type testExchangeTokenServer struct {
url string
authorization string
contentType string
body string
response string
}
func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.Token, error) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), tets.url; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, tets.authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, tets.contentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), tets.body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(tets.response))
}))
defer server.Close()
config.TokenURL = server.URL
oldNow := now
defer func() { now = oldNow }()
now = testNow
ts := tokenSource{
ctx: context.Background(),
conf: config,
}
return ts.Token()
}
func validateToken(t *testing.T, tok *oauth2.Token) {
if got, want := tok.AccessToken, correctAT; got != want {
t.Errorf("Unexpected access token: got %v, but wanted %v", got, want)
}
if got, want := tok.TokenType, "Bearer"; got != want {
t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want)
}
if got, want := tok.Expiry, testNow().Add(time.Duration(3600)*time.Second); got != want {
t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want)
}
}
func TestToken(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}
server := testExchangeTokenServer{
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
body: baseCredsRequestBody,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
validateToken(t, tok)
}
func TestWorkforcePoolTokenWithClientID(t *testing.T) {
config := Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
server := testExchangeTokenServer{
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
body: workforcePoolRequestBodyWithClientId,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
validateToken(t, tok)
}
func TestWorkforcePoolTokenWithoutClientID(t *testing.T) {
config := Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
server := testExchangeTokenServer{
url: "/",
authorization: "",
contentType: "application/x-www-form-urlencoded",
body: workforcePoolRequestBodyWithoutClientId,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
validateToken(t, tok)
}
func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
TokenURL: "https://sts.googleapis.com",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
_, err := config.TokenSource(context.Background())
if err == nil {
t.Fatalf("Expected error but found none")
}
if got, want := err.Error(), "oauth2/google: workforce_pool_user_project should not be set for non-workforce pool credentials"; got != want {
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
}
}
func TestWorkforcePoolCreation(t *testing.T) {
var audienceValidatyTests = []struct {
audience string
expectSuccess bool
}{
{"//iam.googleapis.com/locations/global/workforcePools/pool-id/providers/provider-id", true},
{"//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", true},
{"//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", true},
{"identitynamespace:1f12345:my_provider", false},
{"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/projects/123456/locations/eu/workloadIdentityPools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/workforcePools/providers/provider-id", false},
{"//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", false},
}
ctx := context.Background()
for _, tt := range audienceValidatyTests {
t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
config.ServiceAccountImpersonationURL = "https://iamcredentials.googleapis.com"
config.Audience = tt.audience
config.WorkforcePoolUserProject = "myProject"
_, err := config.TokenSource(ctx)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
}

View File

@@ -1,18 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import "fmt"
// Error for handling OAuth related error responses as stated in rfc6749#5.2.
type Error struct {
Code string
URI string
Description string
}
func (err *Error) Error() string {
return fmt.Sprintf("got error code %s from %s: %s", err.Code, err.URI, err.Description)
}

View File

@@ -1,19 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import "testing"
func TestError(t *testing.T) {
e := Error{
"42",
"http:thisIsAPlaceholder",
"The Answer!",
}
want := "got error code 42 from http:thisIsAPlaceholder: The Answer!"
if got := e.Error(); got != want {
t.Errorf("Got error message %q; want %q", got, want)
}
}

View File

@@ -1,137 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
var (
baseImpersonateCredsReqBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
baseImpersonateCredsRespBody = `{"accessToken":"Second.Access.Token","expireTime":"2020-12-28T15:01:23Z"}`
)
func createImpersonationServer(urlWanted, authWanted, bodyWanted, response string, t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), urlWanted; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, authWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, "application/json"; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
if got, want := string(body), bodyWanted; got != want {
t.Errorf("Unexpected impersonation payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
}
func createTargetServer(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), "/"; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
if got, want := string(body), baseImpersonateCredsReqBody; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(baseCredsResponseBody))
}))
}
var impersonationTests = []struct {
name string
config Config
expectedImpersonationBody string
}{
{
name: "Base Impersonation",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
},
expectedImpersonationBody: "{\"lifetime\":\"3600s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
},
{
name: "With TokenLifetime Set",
config: Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
TokenInfoURL: "http://localhost:8080/v1/tokeninfo",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
expectedImpersonationBody: "{\"lifetime\":\"10000s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
},
}
func TestImpersonation(t *testing.T) {
for _, tt := range impersonationTests {
t.Run(tt.name, func(t *testing.T) {
testImpersonateConfig := tt.config
impersonateServer := createImpersonationServer("/", "Bearer Sample.Access.Token", tt.expectedImpersonationBody, baseImpersonateCredsRespBody, t)
defer impersonateServer.Close()
testImpersonateConfig.ServiceAccountImpersonationURL = impersonateServer.URL
targetServer := createTargetServer(t)
defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL
ourTS, err := testImpersonateConfig.tokenSource(context.Background(), "http")
if err != nil {
t.Fatalf("Failed to create TokenSource: %v", err)
}
oldNow := now
defer func() { now = oldNow }()
now = testNow
tok, err := ourTS.Token()
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
if got, want := tok.AccessToken, "Second.Access.Token"; got != want {
t.Errorf("Unexpected access token: got %v, but wanted %v", got, want)
}
if got, want := tok.TokenType, "Bearer"; got != want {
t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want)
}
})
}
}

View File

@@ -0,0 +1,114 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"errors"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing.
var now = func() time.Time {
return time.Now().UTC()
}
var tokenValid = func(token oauth2.Token) bool {
return token.Valid()
}
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
// the provider identifier in that pool.
Audience string
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
RefreshToken string
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
// None if the token can not be refreshed.
TokenURL string
// TokenInfoURL is the optional STS endpoint URL for token introspection.
TokenInfoURL string
// ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
// access token. When provided, STS will be called with additional basic authentication using client_id as username
// and client_secret as password.
ClientSecret string
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
Token string
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
Expiry time.Time
// RevokeURL is the optional STS endpoint URL for revoking tokens.
RevokeURL string
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
// project used to create the credentials.
QuotaProjectID string
Scopes []string
}
func (c *Config) canRefresh() bool {
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
}
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
var token oauth2.Token
if c.Token != "" && !c.Expiry.IsZero() {
token = oauth2.Token{
AccessToken: c.Token,
Expiry: c.Expiry,
TokenType: "Bearer",
}
}
if !tokenValid(token) && !c.canRefresh() {
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
return oauth2.ReuseTokenSource(&token, ts), nil
}
type tokenSource struct {
ctx context.Context
conf *Config
}
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
if !conf.canRefresh() {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}
if stsResponse.ExpiresIn < 0 {
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
}
if stsResponse.RefreshToken != "" {
conf.RefreshToken = stsResponse.RefreshToken
}
token := &oauth2.Token{
AccessToken: stsResponse.AccessToken,
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
TokenType: "Bearer",
}
return token, nil
}

View File

@@ -0,0 +1,259 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
const expiryDelta = 10 * time.Second
var (
expiry = time.Unix(234852, 0)
testNow = func() time.Time { return expiry }
testValid = func(t oauth2.Token) bool {
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
}
)
type testRefreshTokenServer struct {
URL string
Authorization string
ContentType string
Body string
ResponsePayload *stsexchange.Response
Response string
server *httptest.Server
}
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expiry: now().Add(time.Hour),
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
if config.RefreshToken != "CCCCCCC" {
t.Fatalf("Refresh token not updated")
}
}
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
testCases := []struct {
name string
config Config
}{
{
name: "empty config",
config: Config{},
},
{
name: "missing refresh token",
config: Config{
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing token url",
config: Config{
RefreshToken: "BBBBBBBBB",
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client id",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client secrect",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
_, err := tc.config.TokenSource((context.Background()))
if err == nil {
t.Fatalf("Expected error, but received none")
}
if got := err.Error(); got != expectErrMsg {
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
}
})
}
}
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
t.Helper()
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) close(t *testing.T) error {
t.Helper()
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package impersonate
import ( import (
"bytes" "bytes"

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"encoding/base64" "encoding/base64"
@@ -12,8 +12,8 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. // ClientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type clientAuthentication struct { type ClientAuthentication struct {
// AuthStyle can be either basic or request-body // AuthStyle can be either basic or request-body
AuthStyle oauth2.AuthStyle AuthStyle oauth2.AuthStyle
ClientID string ClientID string
@@ -23,7 +23,7 @@ type clientAuthentication struct {
// InjectAuthentication is used to add authentication to a Secure Token Service exchange // InjectAuthentication is used to add authentication to a Secure Token Service exchange
// request. It modifies either the passed url.Values or http.Header depending on the desired // request. It modifies either the passed url.Values or http.Header depending on the desired
// authentication format. // authentication format.
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) { func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil { if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
return return
} }

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"net/http" "net/http"
@@ -38,7 +38,7 @@ func TestClientAuthentication_InjectHeaderAuthentication(t *testing.T) {
"Content-Type": ContentType, "Content-Type": ContentType,
} }
headerAuthentication := clientAuthentication{ headerAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
@@ -80,7 +80,7 @@ func TestClientAuthentication_ParamsAuthentication(t *testing.T) {
headerP := http.Header{ headerP := http.Header{
"Content-Type": ContentType, "Content-Type": ContentType,
} }
paramsAuthentication := clientAuthentication{ paramsAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"context" "context"
@@ -18,14 +18,17 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// exchangeToken performs an oauth2 token exchange with the provided endpoint. func defaultHeader() http.Header {
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
return header
}
// ExchangeToken performs an oauth2 token exchange with the provided endpoint.
// The first 4 fields are all mandatory. headers can be used to pass additional // The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can // headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server. // be used to pass additional JSON-structured options to the remote server.
func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchangeRequest, authentication clientAuthentication, headers http.Header, options map[string]interface{}) (*stsTokenExchangeResponse, error) { func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
client := oauth2.NewClient(ctx, nil)
data := url.Values{} data := url.Values{}
data.Set("audience", request.Audience) data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@@ -41,13 +44,28 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
data.Set("options", string(opts)) data.Set("options", string(opts))
} }
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
if headers == nil {
headers = defaultHeader()
}
client := oauth2.NewClient(ctx, nil)
authentication.InjectAuthentication(data, headers) authentication.InjectAuthentication(data, headers)
encodedData := data.Encode() encodedData := data.Encode()
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData)) req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err) return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
for key, list := range headers { for key, list := range headers {
@@ -71,7 +89,7 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
} }
var stsResp stsTokenExchangeResponse var stsResp Response
err = json.Unmarshal(body, &stsResp) err = json.Unmarshal(body, &stsResp)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err) return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
@@ -81,8 +99,8 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
return &stsResp, nil return &stsResp, nil
} }
// stsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange. // TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type stsTokenExchangeRequest struct { type TokenExchangeRequest struct {
ActingParty struct { ActingParty struct {
ActorToken string ActorToken string
ActorTokenType string ActorTokenType string
@@ -96,8 +114,8 @@ type stsTokenExchangeRequest struct {
SubjectTokenType string SubjectTokenType string
} }
// stsTokenExchangeResponse is used to decode the remote server response during an oauth2 token exchange. // Response is used to decode the remote server response during an oauth2 token exchange.
type stsTokenExchangeResponse struct { type Response struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"` IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package stsexchange
import ( import (
"context" "context"
@@ -16,13 +16,13 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
var auth = clientAuthentication{ var auth = ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
} }
var tokenRequest = stsTokenExchangeRequest{ var exchangeTokenRequest = TokenExchangeRequest{
ActingParty: struct { ActingParty: struct {
ActorToken string ActorToken string
ActorTokenType string ActorTokenType string
@@ -36,9 +36,9 @@ var tokenRequest = stsTokenExchangeRequest{
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
} }
var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" var exchangeRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` var exchangeResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedToken = stsTokenExchangeResponse{ var expectedExchangeToken = Response{
AccessToken: "Sample.Access.Token", AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer", TokenType: "Bearer",
@@ -47,6 +47,18 @@ var expectedToken = stsTokenExchangeResponse{
RefreshToken: "", RefreshToken: "",
} }
var refreshToken = "ReFrEsHtOkEn"
var refreshRequestBody = "grant_type=refresh_token&refresh_token=" + refreshToken
var refreshResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform","refresh_token":"REFRESHED_REFRESH"}`
var expectedRefreshResponse = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: "https://www.googleapis.com/auth/cloud-platform",
RefreshToken: "REFRESHED_REFRESH",
}
func TestExchangeToken(t *testing.T) { func TestExchangeToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
@@ -65,26 +77,34 @@ func TestExchangeToken(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %v.", err) t.Errorf("Failed reading request body: %v.", err)
} }
if got, want := string(body), requestbody; got != want { if got, want := string(body), exchangeRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want) t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody)) w.Write([]byte(exchangeResponseBody))
})) }))
defer ts.Close() defer ts.Close()
headers := http.Header{} headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded") headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) resp, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err != nil { if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err) t.Fatalf("exchangeToken failed with error: %v", err)
} }
if expectedToken != *resp { if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedToken, *resp) t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
} }
resp, err = ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, nil, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
} }
func TestExchangeToken_Err(t *testing.T) { func TestExchangeToken_Err(t *testing.T) {
@@ -96,7 +116,7 @@ func TestExchangeToken_Err(t *testing.T) {
headers := http.Header{} headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded") headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) _, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err == nil { if err == nil {
t.Errorf("Expected handled error; instead got nil.") t.Errorf("Expected handled error; instead got nil.")
} }
@@ -171,7 +191,7 @@ func TestExchangeToken_Opts(t *testing.T) {
// Send a proper reply so that no other errors crop up. // Send a proper reply so that no other errors crop up.
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody)) w.Write([]byte(exchangeResponseBody))
})) }))
defer ts.Close() defer ts.Close()
@@ -183,5 +203,69 @@ func TestExchangeToken_Opts(t *testing.T) {
inputOpts := make(map[string]interface{}) inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption inputOpts["one"] = firstOption
inputOpts["two"] = secondOption inputOpts["two"] = secondOption
exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, inputOpts) ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
}
func TestRefreshToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Unexpected request method, %v is found", r.Method)
}
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
if got, want := r.Header.Get("Authorization"), "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("Unexpected authorization header, got %v, want %v", got, want)
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), refreshRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(refreshResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
resp, err = RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
}
func TestRefreshToken_Err(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("what's wrong with this response?"))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
} }

View File

@@ -1,14 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build appengine
// +build appengine
package internal
import "google.golang.org/appengine/urlfetch"
func init() {
appengineClientHook = urlfetch.Client
}

View File

@@ -14,7 +14,7 @@ import (
// ParseKey converts the binary contents of a private key file // ParseKey converts the binary contents of a private key file
// to an *rsa.PrivateKey. It detects whether the private key is in a // to an *rsa.PrivateKey. It detects whether the private key is in a
// PEM container or not. If so, it extracts the the private key // PEM container or not. If so, it extracts the private key
// from PEM container before conversion. It only supports PEM // from PEM container before conversion. It only supports PEM
// containers with no passphrase. // containers with no passphrase.
func ParseKey(key []byte) (*rsa.PrivateKey, error) { func ParseKey(key []byte) (*rsa.PrivateKey, error) {

View File

@@ -18,6 +18,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@@ -55,12 +56,18 @@ type Token struct {
} }
// tokenJSON is the struct representing the HTTP response from OAuth2 // tokenJSON is the struct representing the HTTP response from OAuth2
// providers returning a token in JSON form. // providers returning a token or error in JSON form.
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
type tokenJSON struct { type tokenJSON struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
// error fields
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
} }
func (e *tokenJSON) expiry() (t time.Time) { func (e *tokenJSON) expiry() (t time.Time) {
@@ -109,41 +116,60 @@ const (
AuthStyleInHeader AuthStyle = 2 AuthStyleInHeader AuthStyle = 2
) )
// authStyleCache is the set of tokenURLs we've successfully used via // LazyAuthStyleCache is a backwards compatibility compromise to let Configs
// have a lazily-initialized AuthStyleCache.
//
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
// both would ideally just embed an unexported AuthStyleCache but because both
// were historically allowed to be copied by value we can't retroactively add an
// uncopyable Mutex to them.
//
// We could use an atomic.Pointer, but that was added recently enough (in Go
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
// still pass. By using an atomic.Value, it supports both Go 1.17 and
// copying by value, even if that's not ideal.
type LazyAuthStyleCache struct {
v atomic.Value // of *AuthStyleCache
}
func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
return c
}
c := new(AuthStyleCache)
if !lc.v.CompareAndSwap(nil, c) {
c = lc.v.Load().(*AuthStyleCache)
}
return c
}
// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using. // RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that // It's called a cache, but it doesn't (yet?) shrink. It's expected that
// the set of OAuth2 servers a program contacts over time is fixed and // the set of OAuth2 servers a program contacts over time is fixed and
// small. // small.
var authStyleCache struct { type AuthStyleCache struct {
sync.Mutex mu sync.Mutex
m map[string]AuthStyle // keyed by tokenURL m map[string]AuthStyle // keyed by tokenURL
} }
// ResetAuthCache resets the global authentication style cache used
// for AuthStyleUnknown token requests.
func ResetAuthCache() {
authStyleCache.Lock()
defer authStyleCache.Unlock()
authStyleCache.m = nil
}
// lookupAuthStyle reports which auth style we last used with tokenURL // lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so. // when calling RetrieveToken and whether we have ever done so.
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
authStyleCache.Lock() c.mu.Lock()
defer authStyleCache.Unlock() defer c.mu.Unlock()
style, ok = authStyleCache.m[tokenURL] style, ok = c.m[tokenURL]
return return
} }
// setAuthStyle adds an entry to authStyleCache, documented above. // setAuthStyle adds an entry to authStyleCache, documented above.
func setAuthStyle(tokenURL string, v AuthStyle) { func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
authStyleCache.Lock() c.mu.Lock()
defer authStyleCache.Unlock() defer c.mu.Unlock()
if authStyleCache.m == nil { if c.m == nil {
authStyleCache.m = make(map[string]AuthStyle) c.m = make(map[string]AuthStyle)
} }
authStyleCache.m[tokenURL] = v c.m[tokenURL] = v
} }
// newTokenRequest returns a new *http.Request to retrieve a new token // newTokenRequest returns a new *http.Request to retrieve a new token
@@ -183,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
return v2 return v2
} }
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == 0 needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe { if needsAuthStyleProbe {
if style, ok := lookupAuthStyle(tokenURL); ok { if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style authStyle = style
needsAuthStyleProbe = false needsAuthStyleProbe = false
} else { } else {
@@ -216,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req) token, err = doTokenRoundTrip(ctx, req)
} }
if needsAuthStyleProbe && err == nil { if needsAuthStyleProbe && err == nil {
setAuthStyle(tokenURL, authStyle) styleCache.setAuthStyle(tokenURL, authStyle)
} }
// Don't overwrite `RefreshToken` with an empty value // Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request. // if this was a token refreshing request.
@@ -236,21 +262,29 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{ failureStatus := r.StatusCode < 200 || r.StatusCode > 299
retrieveError := &RetrieveError{
Response: r, Response: r,
Body: body, Body: body,
} // attempt to populate error detail below
} }
var token *Token var token *Token
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content { switch content {
case "application/x-www-form-urlencoded", "text/plain": case "application/x-www-form-urlencoded", "text/plain":
// some endpoints return a query string
vals, err := url.ParseQuery(string(body)) vals, err := url.ParseQuery(string(body))
if err != nil { if err != nil {
return nil, err if failureStatus {
return nil, retrieveError
} }
return nil, fmt.Errorf("oauth2: cannot parse response: %v", err)
}
retrieveError.ErrorCode = vals.Get("error")
retrieveError.ErrorDescription = vals.Get("error_description")
retrieveError.ErrorURI = vals.Get("error_uri")
token = &Token{ token = &Token{
AccessToken: vals.Get("access_token"), AccessToken: vals.Get("access_token"),
TokenType: vals.Get("token_type"), TokenType: vals.Get("token_type"),
@@ -265,8 +299,14 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
default: default:
var tj tokenJSON var tj tokenJSON
if err = json.Unmarshal(body, &tj); err != nil { if err = json.Unmarshal(body, &tj); err != nil {
return nil, err if failureStatus {
return nil, retrieveError
} }
return nil, fmt.Errorf("oauth2: cannot parse json: %v", err)
}
retrieveError.ErrorCode = tj.ErrorCode
retrieveError.ErrorDescription = tj.ErrorDescription
retrieveError.ErrorURI = tj.ErrorURI
token = &Token{ token = &Token{
AccessToken: tj.AccessToken, AccessToken: tj.AccessToken,
TokenType: tj.TokenType, TokenType: tj.TokenType,
@@ -276,17 +316,37 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
} }
json.Unmarshal(body, &token.Raw) // no error checks for optional fields json.Unmarshal(body, &token.Raw) // no error checks for optional fields
} }
// according to spec, servers should respond status 400 in error case
// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
// but some unorthodox servers respond 200 in error case
if failureStatus || retrieveError.ErrorCode != "" {
return nil, retrieveError
}
if token.AccessToken == "" { if token.AccessToken == "" {
return nil, errors.New("oauth2: server response missing access_token") return nil, errors.New("oauth2: server response missing access_token")
} }
return token, nil return token, nil
} }
// mirrors oauth2.RetrieveError
type RetrieveError struct { type RetrieveError struct {
Response *http.Response Response *http.Response
Body []byte Body []byte
ErrorCode string
ErrorDescription string
ErrorURI string
} }
func (r *RetrieveError) Error() string { func (r *RetrieveError) Error() string {
if r.ErrorCode != "" {
s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
if r.ErrorDescription != "" {
s += fmt.Sprintf(" %q", r.ErrorDescription)
}
if r.ErrorURI != "" {
s += fmt.Sprintf(" %q", r.ErrorURI)
}
return s
}
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
} }

View File

@@ -16,7 +16,7 @@ import (
) )
func TestRetrieveToken_InParams(t *testing.T) { func TestRetrieveToken_InParams(t *testing.T) {
ResetAuthCache() styleCache := new(AuthStyleCache)
const clientID = "client-id" const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want { if got, want := r.FormValue("client_id"), clientID; got != want {
@@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
})) }))
defer ts.Close() defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams) _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil { if err != nil {
t.Errorf("RetrieveToken = %v; want no error", err) t.Errorf("RetrieveToken = %v; want no error", err)
} }
} }
func TestRetrieveTokenWithContexts(t *testing.T) { func TestRetrieveTokenWithContexts(t *testing.T) {
ResetAuthCache() styleCache := new(AuthStyleCache)
const clientID = "client-id" const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown) _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil { if err != nil {
t.Errorf("RetrieveToken (with background context) = %v; want no error", err) t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
} }
@@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown) _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
close(retrieved) close(retrieved)
if err == nil { if err == nil {
t.Errorf("RetrieveToken (with cancelled context) = nil; want error") t.Errorf("RetrieveToken (with cancelled context) = nil; want error")

View File

@@ -18,16 +18,11 @@ var HTTPClient ContextKey
// because nobody else can create a ContextKey, being unexported. // because nobody else can create a ContextKey, being unexported.
type ContextKey struct{} type ContextKey struct{}
var appengineClientHook func(context.Context) *http.Client
func ContextClient(ctx context.Context) *http.Client { func ContextClient(ctx context.Context) *http.Client {
if ctx != nil { if ctx != nil {
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc return hc
} }
} }
if appengineClientHook != nil {
return appengineClientHook(ctx)
}
return http.DefaultClient return http.DefaultClient
} }

View File

@@ -16,6 +16,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
"golang.org/x/oauth2/internal" "golang.org/x/oauth2/internal"
) )
@@ -57,6 +58,10 @@ type Config struct {
// Scope specifies optional requested permissions. // Scope specifies optional requested permissions.
Scopes []string Scopes []string
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
} }
// A TokenSource is anything that can return a token. // A TokenSource is anything that can return a token.
@@ -71,6 +76,7 @@ type TokenSource interface {
// endpoint URLs. // endpoint URLs.
type Endpoint struct { type Endpoint struct {
AuthURL string AuthURL string
DeviceAuthURL string
TokenURL string TokenURL string
// AuthStyle optionally specifies how the endpoint wants the // AuthStyle optionally specifies how the endpoint wants the
@@ -138,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page // AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly. // that asks for permissions for the required scopes explicitly.
// //
// State is a token to protect the user from CSRF attacks. You must // State is an opaque value used by the client to maintain state between the
// always provide a non-empty string and validate that it matches the // request and callback. The authorization server includes this value when
// the state query parameter on your redirect callback. // redirecting the user agent back to the client.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// //
// Opts may include AccessTypeOnline or AccessTypeOffline, as well // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce. // as ApprovalForce.
// It can also be used to pass the PKCE challenge. //
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. // To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
// generate a random state parameter and verify it after exchange.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL) buf.WriteString(c.Endpoint.AuthURL)
@@ -161,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
v.Set("scope", strings.Join(c.Scopes, " ")) v.Set("scope", strings.Join(c.Scopes, " "))
} }
if state != "" { if state != "" {
// TODO(light): Docs say never to omit state; don't allow empty.
v.Set("state", state) v.Set("state", state)
} }
for _, opt := range opts { for _, opt := range opts {
@@ -206,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. // The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
// //
// The code will be in the *http.Request.FormValue("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state"). // calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks.
// //
// Opts may include the PKCE verifier code if previously used in AuthCodeURL. // If using PKCE to protect against CSRF attacks, opts should include a
// See https://www.oauth.com/oauth2-servers/pkce/ for more info. // VerifierOption.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) { func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{ v := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
@@ -290,6 +300,8 @@ type reuseTokenSource struct {
mu sync.Mutex // guards t mu sync.Mutex // guards t
t *Token t *Token
expiryDelta time.Duration
} }
// Token returns the current token if it's still valid, else will // Token returns the current token if it's still valid, else will
@@ -305,6 +317,7 @@ func (s *reuseTokenSource) Token() (*Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.expiryDelta = s.expiryDelta
s.t = t s.t = t
return t, nil return t, nil
} }
@@ -379,3 +392,30 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
new: src, new: src,
} }
} }
// ReuseTokenSource returns a TokenSource that acts in the same manner as the
// TokenSource returned by ReuseTokenSource, except the expiry buffer is
// configurable. The expiration time of a token is calculated as
// t.Expiry.Add(-earlyExpiry).
func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource {
// Don't wrap a reuseTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations.
// Just build the equivalent one.
if rt, ok := src.(*reuseTokenSource); ok {
if t == nil {
// Just use it directly, but set the expiryDelta to earlyExpiry,
// so the behavior matches what the user expects.
rt.expiryDelta = earlyExpiry
return rt
}
src = rt.new
}
if t != nil {
t.expiryDelta = earlyExpiry
}
return &reuseTokenSource{
t: t,
new: src,
expiryDelta: earlyExpiry,
}
}

View File

@@ -15,8 +15,6 @@ import (
"net/url" "net/url"
"testing" "testing"
"time" "time"
"golang.org/x/oauth2/internal"
) )
type mockTransport struct { type mockTransport struct {
@@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
} }
func TestExchangeRequest_NonBasicAuth(t *testing.T) { func TestExchangeRequest_NonBasicAuth(t *testing.T) {
internal.ResetAuthCache()
tr := &mockTransport{ tr := &mockTransport{
rt: func(r *http.Request) (w *http.Response, err error) { rt: func(r *http.Request) (w *http.Response, err error) {
headerAuth := r.Header.Get("Authorization") headerAuth := r.Header.Get("Authorization")
@@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
} }
func TestTokenRefreshRequest(t *testing.T) { func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" { if r.URL.String() == "/somethingelse" {
return return
@@ -484,6 +480,7 @@ func TestTokenRetrieveError(t *testing.T) {
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
} }
w.Header().Set("Content-type", "application/json") w.Header().Set("Content-type", "application/json")
// "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error": "invalid_grant"}`)) w.Write([]byte(`{"error": "invalid_grant"}`))
})) }))
@@ -493,15 +490,47 @@ func TestTokenRetrieveError(t *testing.T) {
if err == nil { if err == nil {
t.Fatalf("got no error, expected one") t.Fatalf("got no error, expected one")
} }
_, ok := err.(*RetrieveError) re, ok := err.(*RetrieveError)
if !ok { if !ok {
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
} }
// Test error string for backwards compatibility expected := `oauth2: "invalid_grant"`
expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
if errStr := err.Error(); errStr != expected { if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected) t.Fatalf("got %#v, expected %#v", errStr, expected)
} }
expected = "invalid_grant"
if re.ErrorCode != expected {
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
}
}
// TestTokenRetrieveError200 tests handling of unorthodox server that returns 200 in error case
func TestTokenRetrieveError200(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
}
w.Header().Set("Content-type", "application/json")
w.Write([]byte(`{"error": "invalid_grant"}`))
}))
defer ts.Close()
conf := newConf(ts.URL)
_, err := conf.Exchange(context.Background(), "exchange-code")
if err == nil {
t.Fatalf("got no error, expected one")
}
re, ok := err.(*RetrieveError)
if !ok {
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
}
expected := `oauth2: "invalid_grant"`
if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected)
}
expected = "invalid_grant"
if re.ErrorCode != expected {
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
}
} }
func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { func TestRefreshToken_RefreshTokenReplacement(t *testing.T) {

68
pkce.go Normal file
View File

@@ -0,0 +1,68 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/url"
)
const (
codeChallengeKey = "code_challenge"
codeChallengeMethodKey = "code_challenge_method"
codeVerifierKey = "code_verifier"
)
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
// This follows recommendations in RFC 7636.
//
// A fresh verifier should be generated for each authorization.
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
// (or Config.DeviceAccessToken).
func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then
// base64url-encoded to produce a 43-octet URL-safe string to use as the
// code verifier."
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
data := make([]byte, 32)
if _, err := rand.Read(data); err != nil {
panic(err)
}
return base64.RawURLEncoding.EncodeToString(data)
}
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier}
}
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
//
// Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:])
}
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
// only.
func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{
challenge_method: "S256",
challenge: S256ChallengeFromVerifier(verifier),
}
}
type challengeOption struct{ challenge_method, challenge string }
func (p challengeOption) setValue(m url.Values) {
m.Set(codeChallengeMethodKey, p.challenge_method)
m.Set(codeChallengeKey, p.challenge)
}

View File

@@ -16,10 +16,10 @@ import (
"golang.org/x/oauth2/internal" "golang.org/x/oauth2/internal"
) )
// expiryDelta determines how earlier a token should be considered // defaultExpiryDelta determines how earlier a token should be considered
// expired than its actual expiration time. It is used to avoid late // expired than its actual expiration time. It is used to avoid late
// expirations due to client-server time mismatches. // expirations due to client-server time mismatches.
const expiryDelta = 10 * time.Second const defaultExpiryDelta = 10 * time.Second
// Token represents the credentials used to authorize // Token represents the credentials used to authorize
// the requests to access protected resources on the OAuth 2.0 // the requests to access protected resources on the OAuth 2.0
@@ -52,6 +52,11 @@ type Token struct {
// raw optionally contains extra metadata from the server // raw optionally contains extra metadata from the server
// when updating a token. // when updating a token.
raw interface{} raw interface{}
// expiryDelta is used to calculate when a token is considered
// expired, by subtracting from Expiry. If zero, defaultExpiryDelta
// is used.
expiryDelta time.Duration
} }
// Type returns t.TokenType if non-empty, else "Bearer". // Type returns t.TokenType if non-empty, else "Bearer".
@@ -127,6 +132,11 @@ func (t *Token) expired() bool {
if t.Expiry.IsZero() { if t.Expiry.IsZero() {
return false return false
} }
expiryDelta := defaultExpiryDelta
if t.expiryDelta != 0 {
expiryDelta = t.expiryDelta
}
return t.Expiry.Round(0).Add(-expiryDelta).Before(timeNow()) return t.Expiry.Round(0).Add(-expiryDelta).Before(timeNow())
} }
@@ -154,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error.. // with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle)) tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil { if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok { if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr) return nil, (*RetrieveError)(rErr)
@@ -165,14 +175,31 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error)
} }
// RetrieveError is the error returned when the token endpoint returns a // RetrieveError is the error returned when the token endpoint returns a
// non-2XX HTTP status code. // non-2XX HTTP status code or populates RFC 6749's 'error' parameter.
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
type RetrieveError struct { type RetrieveError struct {
Response *http.Response Response *http.Response
// Body is the body that was consumed by reading Response.Body. // Body is the body that was consumed by reading Response.Body.
// It may be truncated. // It may be truncated.
Body []byte Body []byte
// ErrorCode is RFC 6749's 'error' parameter.
ErrorCode string
// ErrorDescription is RFC 6749's 'error_description' parameter.
ErrorDescription string
// ErrorURI is RFC 6749's 'error_uri' parameter.
ErrorURI string
} }
func (r *RetrieveError) Error() string { func (r *RetrieveError) Error() string {
if r.ErrorCode != "" {
s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
if r.ErrorDescription != "" {
s += fmt.Sprintf(" %q", r.ErrorDescription)
}
if r.ErrorURI != "" {
s += fmt.Sprintf(" %q", r.ErrorURI)
}
return s
}
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
} }

View File

@@ -43,9 +43,13 @@ func TestTokenExpiry(t *testing.T) {
want bool want bool
}{ }{
{name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false}, {name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false},
{name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: false}, {name: "10 seconds", tok: &Token{Expiry: now.Add(defaultExpiryDelta)}, want: false},
{name: "10 seconds-1ns", tok: &Token{Expiry: now.Add(expiryDelta - 1*time.Nanosecond)}, want: true}, {name: "10 seconds-1ns", tok: &Token{Expiry: now.Add(defaultExpiryDelta - 1*time.Nanosecond)}, want: true},
{name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true}, {name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true},
{name: "12 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(12 * time.Second), expiryDelta: time.Second * 5}, want: false},
{name: "5 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second * 5), expiryDelta: time.Second * 5}, want: false},
{name: "5 seconds-1ns, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second*5 - 1*time.Nanosecond), expiryDelta: time.Second * 5}, want: true},
{name: "-1 hour, custom expiryDelta", tok: &Token{Expiry: now.Add(-1 * time.Hour), expiryDelta: time.Second * 5}, want: true},
} }
for _, tc := range cases { for _, tc := range cases {
if got, want := tc.tok.expired(), tc.want; got != want { if got, want := tc.tok.expired(), tc.want; got != want {