17 Commits

Author SHA1 Message Date
35c2a7f188 remove usage of appengine to get rid of unsafe imports 2024-01-17 13:45:16 +01:00
Gopher Robot
39adbb7807 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Icf68cb33585a13df206afacdb79832ea76f82346
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/554676
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
2024-01-08 18:34:15 +00:00
Chris Smith
4ce7bbb2ff google: add Credentials.GetUniverseDomain with GCE MDS support
* Deprecate Credentials.UniverseDomain

Change-Id: I1cbc842fbfce35540c8dff99fec09e036b9e2cdf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/554215
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Viacheslav Rostovtsev <virost@google.com>
2024-01-05 14:38:43 +00:00
Chris Smith
1e6999b1be google: add UniverseDomain to CredentialsParams
Change-Id: I7925b8341e1f047d0115acd7a01a34679a489ee0
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/552716
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Viacheslav Rostovtsev <virost@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2024-01-04 15:11:51 +00:00
Gopher Robot
6e9ec9323d go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Iad79e50dacd89c4cd0a40d966a1a7ba4cdc3d1a4
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/545176
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Gopher Robot <gobot@golang.org>
2023-11-27 17:50:56 +00:00
Gopher Robot
e067960af8 go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: Id1413f67816220ef8039fb933088f4b7f50d70e5
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/540817
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
2023-11-08 20:28:19 +00:00
Leo
4c91c17b32 google: adds header to security considerations section
Change-Id: I29b93715876f233ae52687c8223fd8733a2a3b80
GitHub-Last-Rev: f15c4cf1a5
GitHub-Pull-Request: golang/oauth2#677
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/535895
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-17 20:42:42 +00:00
Gopher Robot
3c5dbf08cc go.mod: update golang.org/x dependencies
Update golang.org/x dependencies to their latest tagged versions.

Change-Id: I39a72a7dbb2205a6638a154892c69948ee2deb0d
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/533241
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
2023-10-06 08:33:24 +00:00
Chris Smith
11625ccb95 google: add authorized_user conditional to Credentials.UniverseDomain
Return default universe domain if credentials type is authorized_user.

Change-Id: I20a9b5fafa562fcec84717914a236d081f630591
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/532196
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-10-02 22:04:54 +00:00
Chris Smith
8d6d45b6cd google: add Credentials.UniverseDomain to support TPC
Read and expose universe_domain from service account JSON files in
CredentialsFromJSONWithParams to support TPC in 1p clients.

Change-Id: I3518a0ec8be5ff7235b946cffd88b26ac8d303cf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531715
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-29 14:33:30 +00:00
Jin Qin
43b6a7ba19 google: adding support for external account authorized user
To support a new type of credential: `ExternalAccountAuthorizedUser`

* Refactor the common dependency STS to a separate package.
* Adding the `externalaccountauthorizeduser` package.

Change-Id: I9b9624f912d216b67a0d31945a50f057f747710b
GitHub-Last-Rev: 6e2aaff345
GitHub-Pull-Request: golang/oauth2#671
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531095
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
2023-09-28 22:24:46 +00:00
M Hickford
14b275c918 oauth2: workaround misspelling of verification_uri
Some servers misspell verification_uri as verification_url, contrary to spec RFC 8628

Example server https://issuetracker.google.com/issues/151238144

Fixes #666

Change-Id: I89e354368bbb0a4e3b979bb547b4cb37bbe1cc02
GitHub-Last-Rev: bbf169b52d
GitHub-Pull-Request: golang/oauth2#667
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/527835
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nikolay Turpitko <nick.turpitko@gmail.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
2023-09-22 21:51:39 +00:00
aeitzman
18352fc433 google/internal/externalaccount: adding BYOID Metrics
Adds framework for sending BYOID metrics via the x-goog-api-client header on outgoing sts requests. Also adds a header file for getting the current version of GoLang

Change-Id: Id5431def96f4cfc03e4ada01d5fb8cac8cfa56a9
GitHub-Last-Rev: c93cd478e5
GitHub-Pull-Request: golang/oauth2#661
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/523595
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
2023-09-22 20:39:34 +00:00
M Hickford
9095a51613 oauth2: clarify error if endpoint missing DeviceAuthURL
Change-Id: I36eb5eb66099161785160f4f39ea1c7f64ad6e74
GitHub-Last-Rev: 31cfe8150f
GitHub-Pull-Request: golang/oauth2#664
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/526302
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
2023-09-22 16:24:29 +00:00
Jin Qin
2d9e4a2adf oauth2/google: remove meta validations for aws external credentials
Remove the url validations to keep a consistency with other libraries.

Change-Id: Icb1767edc000d9695db3f0c7ca271918fb2083f5
GitHub-Last-Rev: af89ee0c72
GitHub-Pull-Request: golang/oauth2#660
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/522395
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
2023-09-12 16:01:49 +00:00
M Hickford
55cd552a36 oauth2: support PKCE
Fixes #603

Fixes golang/go#59835

Change-Id: Ica0cfef975ba9511e00f097498d33ba27dafca0d
GitHub-Last-Rev: f01f7593a3
GitHub-Pull-Request: golang/oauth2#625
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/463979
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
2023-09-07 17:49:42 +00:00
M Hickford
e3fb0fb3af oauth2: support device flow
Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628

Tested with GitHub

Fixes #418

Fixes golang/go#58126

Co-authored-by: cmP <centimitr@gmail.com>

Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14
GitHub-Last-Rev: 9a126d7b53
GitHub-Pull-Request: golang/oauth2#609
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155
Commit-Queue: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Than McIntosh <thanm@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
Run-TryBot: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
2023-09-06 16:35:20 +00:00
38 changed files with 1555 additions and 552 deletions

198
deviceauth.go Normal file
View File

@@ -0,0 +1,198 @@
package oauth2
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
)
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string `json:"device_code"`
// UserCode is the code the user should enter at the verification uri
UserCode string `json:"user_code"`
// VerificationURI is where user should enter the user code
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
// Expiry is when the device code and user code expire
Expiry time.Time `json:"expires_in,omitempty"`
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 `json:"interval,omitempty"`
}
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
type Alias DeviceAuthResponse
var expiresIn int64
if !d.Expiry.IsZero() {
expiresIn = int64(time.Until(d.Expiry).Seconds())
}
return json.Marshal(&struct {
ExpiresIn int64 `json:"expires_in,omitempty"`
*Alias
}{
ExpiresIn: expiresIn,
Alias: (*Alias)(&d),
})
}
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
type Alias DeviceAuthResponse
aux := &struct {
ExpiresIn int64 `json:"expires_in"`
// workaround misspelling of verification_uri
VerificationURL string `json:"verification_url"`
*Alias
}{
Alias: (*Alias)(c),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ExpiresIn != 0 {
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
}
if c.VerificationURI == "" {
c.VerificationURI = aux.VerificationURL
}
return nil
}
// DeviceAuth returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
v := url.Values{
"client_id": {c.ClientID},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveDeviceAuth(ctx, c, v)
}
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
if c.Endpoint.DeviceAuthURL == "" {
return nil, errors.New("endpoint missing DeviceAuthURL")
}
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
t := time.Now()
r, err := internal.ContextClient(ctx).Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}
da := &DeviceAuthResponse{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}
if !da.Expiry.IsZero() {
// Make a small adjustment to account for time taken by the request
da.Expiry = da.Expiry.Add(-time.Since(t))
}
return da, nil
}
// DeviceAccessToken polls the server to exchange a device code for a token.
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
if !da.Expiry.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
defer cancel()
}
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
v := url.Values{
"client_id": {c.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {da.DeviceCode},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
// "If no value is provided, clients MUST use 5 as the default."
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
interval := da.Interval
if interval == 0 {
interval = 5
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
tok, err := retrieveToken(ctx, c, v)
if err == nil {
return tok, nil
}
e, ok := err.(*RetrieveError)
if !ok {
return nil, err
}
switch e.ErrorCode {
case errSlowDown:
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
interval += 5
ticker.Reset(time.Duration(interval) * time.Second)
case errAuthorizationPending:
// Do nothing.
case errAccessDenied, errExpiredToken:
fallthrough
default:
return tok, err
}
}
}
}

97
deviceauth_test.go Normal file
View File

@@ -0,0 +1,97 @@
package oauth2
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
tests := []struct {
name string
response DeviceAuthResponse
want string
}{
{
name: "empty",
response: DeviceAuthResponse{},
want: `{"device_code":"","user_code":"","verification_uri":""}`,
},
{
name: "soon",
response: DeviceAuthResponse{
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
},
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
gotBytes, err := json.Marshal(tc.response)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
t.Skip("test ran too slowly to compare `expires_in`")
}
got := string(gotBytes)
if got != tc.want {
t.Errorf("want=%s, got=%s", tc.want, got)
}
})
}
}
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
tests := []struct {
name string
data string
want DeviceAuthResponse
}{
{
name: "empty",
data: `{}`,
want: DeviceAuthResponse{},
},
{
name: "soon",
data: `{"expires_in":100}`,
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
begin := time.Now()
got := DeviceAuthResponse{}
err := json.Unmarshal([]byte(tc.data), &got)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
t.Errorf("want=%#v, got=%#v", tc.want, got)
}
})
}
}
func ExampleConfig_DeviceAuth() {
var config Config
ctx := context.Background()
response, err := config.DeviceAuth(ctx)
if err != nil {
panic(err)
}
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
token, err := config.DeviceAccessToken(ctx, response)
if err != nil {
panic(err)
}
fmt.Println(token)
}

View File

@@ -57,6 +57,7 @@ var Fitbit = oauth2.Endpoint{
var GitHub = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
DeviceAuthURL: "https://github.com/login/device/code",
}
// GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
var Google = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
}
// Heroku is the endpoint for Heroku.

View File

@@ -26,9 +26,13 @@ func ExampleConfig() {
},
}
// use PKCE to protect against CSRF attacks
// https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6
verifier := oauth2.GenerateVerifier()
// Redirect user to consent page to ask for permission
// for the scopes specified above.
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Use the authorization code that is pushed to the redirect
@@ -39,7 +43,7 @@ func ExampleConfig() {
if _, err := fmt.Scan(&code); err != nil {
log.Fatal(err)
}
tok, err := conf.Exchange(ctx, code)
tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
if err != nil {
log.Fatal(err)
}

View File

@@ -6,11 +6,8 @@
package github // import "golang.org/x/oauth2/github"
import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
// Endpoint is Github's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
}
var Endpoint = endpoints.GitHub

8
go.mod
View File

@@ -5,12 +5,6 @@ go 1.18
require (
cloud.google.com/go/compute/metadata v0.2.3
github.com/google/go-cmp v0.5.9
google.golang.org/appengine v1.6.7
)
require (
cloud.google.com/go/compute v1.20.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
golang.org/x/net v0.15.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
)
require cloud.google.com/go/compute v1.20.1 // indirect

20
go.sum
View File

@@ -2,25 +2,5 @@ cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZN
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

View File

@@ -1,38 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"context"
"time"
"golang.org/x/oauth2"
)
// Set at init time by appengine_gen1.go. If nil, we're not on App Engine standard first generation (<= Go 1.9) or App Engine flexible.
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error)
// Set at init time by appengine_gen1.go. If nil, we're not on App Engine standard first generation (<= Go 1.9) or App Engine flexible.
var appengineAppIDFunc func(c context.Context) string
// AppEngineTokenSource returns a token source that fetches tokens from either
// the current application's service account or from the metadata server,
// depending on the App Engine environment. See below for environment-specific
// details. If you are implementing a 3-legged OAuth 2.0 flow on App Engine that
// involves user accounts, see oauth2.Config instead.
//
// First generation App Engine runtimes (<= Go 1.9):
// AppEngineTokenSource returns a token source that fetches tokens issued to the
// current App Engine application's service account. The provided context must have
// come from appengine.NewContext.
//
// Second generation App Engine runtimes (>= Go 1.11) and App Engine flexible:
// AppEngineTokenSource is DEPRECATED on second generation runtimes and on the
// flexible environment. It delegates to ComputeTokenSource, and the provided
// context and scopes are not used. Please use DefaultTokenSource (or ComputeTokenSource,
// which DefaultTokenSource will use in this case) instead.
func AppEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
return appEngineTokenSource(ctx, scope...)
}

View File

@@ -1,77 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build appengine
// This file applies to App Engine first generation runtimes (<= Go 1.9).
package google
import (
"context"
"sort"
"strings"
"sync"
"golang.org/x/oauth2"
"google.golang.org/appengine"
)
func init() {
appengineTokenFunc = appengine.AccessToken
appengineAppIDFunc = appengine.AppID
}
// See comment on AppEngineTokenSource in appengine.go.
func appEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
scopes := append([]string{}, scope...)
sort.Strings(scopes)
return &gaeTokenSource{
ctx: ctx,
scopes: scopes,
key: strings.Join(scopes, " "),
}
}
// aeTokens helps the fetched tokens to be reused until their expiration.
var (
aeTokensMu sync.Mutex
aeTokens = make(map[string]*tokenLock) // key is space-separated scopes
)
type tokenLock struct {
mu sync.Mutex // guards t; held while fetching or updating t
t *oauth2.Token
}
type gaeTokenSource struct {
ctx context.Context
scopes []string
key string // to aeTokens map; space-separated scopes
}
func (ts *gaeTokenSource) Token() (*oauth2.Token, error) {
aeTokensMu.Lock()
tok, ok := aeTokens[ts.key]
if !ok {
tok = &tokenLock{}
aeTokens[ts.key] = tok
}
aeTokensMu.Unlock()
tok.mu.Lock()
defer tok.mu.Unlock()
if tok.t.Valid() {
return tok.t, nil
}
access, exp, err := appengineTokenFunc(ts.ctx, ts.scopes...)
if err != nil {
return nil, err
}
tok.t = &oauth2.Token{
AccessToken: access,
Expiry: exp,
}
return tok.t, nil
}

View File

@@ -1,27 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !appengine
// This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible.
package google
import (
"context"
"log"
"sync"
"golang.org/x/oauth2"
)
var logOnce sync.Once // only spam about deprecation once
// See comment on AppEngineTokenSource in appengine.go.
func appEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
logOnce.Do(func() {
log.Print("google: AppEngineTokenSource is deprecated on App Engine standard second generation runtimes (>= Go 1.11) and App Engine flexible. Please use DefaultTokenSource or ComputeTokenSource.")
})
return ComputeTokenSource("")
}

View File

@@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"runtime"
"sync"
"time"
"cloud.google.com/go/compute/metadata"
@@ -19,7 +20,10 @@ import (
"golang.org/x/oauth2/authhandler"
)
const adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
const (
adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
universeDomainDefault = "googleapis.com"
)
// Credentials holds Google credentials, including "Application Default Credentials".
// For more details, see:
@@ -37,6 +41,75 @@ type Credentials struct {
// environment and not with a credentials file, e.g. when code is
// running on Google Cloud Platform.
JSON []byte
udMu sync.Mutex // guards universeDomain
// universeDomain is the default service domain for a given Cloud universe.
universeDomain string
}
// UniverseDomain returns the default service domain for a given Cloud universe.
//
// The default value is "googleapis.com".
//
// Deprecated: Use instead (*Credentials).GetUniverseDomain(), which supports
// obtaining the universe domain when authenticating via the GCE metadata server.
// Unlike GetUniverseDomain, this method, UniverseDomain, will always return the
// default value when authenticating via the GCE metadata server.
// See also [The attached service account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
func (c *Credentials) UniverseDomain() string {
if c.universeDomain == "" {
return universeDomainDefault
}
return c.universeDomain
}
// GetUniverseDomain returns the default service domain for a given Cloud
// universe.
//
// The default value is "googleapis.com".
//
// It obtains the universe domain from the attached service account on GCE when
// authenticating via the GCE metadata server. See also [The attached service
// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
// If the GCE metadata server returns a 404 error, the default value is
// returned. If the GCE metadata server returns an error other than 404, the
// error is returned.
func (c *Credentials) GetUniverseDomain() (string, error) {
c.udMu.Lock()
defer c.udMu.Unlock()
if c.universeDomain == "" && metadata.OnGCE() {
// If we're on Google Compute Engine, an App Engine standard second
// generation runtime, or App Engine flexible, use the metadata server.
err := c.computeUniverseDomain()
if err != nil {
return "", err
}
}
// If not on Google Compute Engine, or in case of any non-error path in
// computeUniverseDomain that did not set universeDomain, set the default
// universe domain.
if c.universeDomain == "" {
c.universeDomain = universeDomainDefault
}
return c.universeDomain, nil
}
// computeUniverseDomain fetches the default service domain for a given Cloud
// universe from Google Compute Engine (GCE)'s metadata server. It's only valid
// to use this method if your program is running on a GCE instance.
func (c *Credentials) computeUniverseDomain() error {
var err error
c.universeDomain, err = metadata.Get("universe/universe_domain")
if err != nil {
if _, ok := err.(metadata.NotDefinedError); ok {
// http.StatusNotFound (404)
c.universeDomain = universeDomainDefault
return nil
} else {
return err
}
}
return nil
}
// DefaultCredentials is the old name of Credentials.
@@ -76,6 +149,12 @@ type CredentialsParams struct {
// Note: This option is currently only respected when using credentials
// fetched from the GCE metadata server.
EarlyTokenRefresh time.Duration
// UniverseDomain is the default service domain for a given Cloud universe.
// Only supported in authentication flows that support universe domains.
// This value takes precedence over a universe domain explicitly specified
// in a credentials config file or by the GCE metadata server. Optional.
UniverseDomain string
}
func (params CredentialsParams) deepCopy() CredentialsParams {
@@ -145,16 +224,6 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar
return CredentialsFromJSONWithParams(ctx, b, params)
}
// Third, if we're on a Google App Engine standard first generation runtime (<= Go 1.9)
// use those credentials. App Engine standard second generation runtimes (>= Go 1.11)
// and App Engine flexible use ComputeTokenSource and the metadata server.
if appengineTokenFunc != nil {
return &Credentials{
ProjectID: appengineAppIDFunc(ctx),
TokenSource: AppEngineTokenSource(ctx, params.Scopes...),
}, nil
}
// Fourth, if we're on Google Compute Engine, an App Engine standard second generation runtime,
// or App Engine flexible, use the metadata server.
if metadata.OnGCE() {
@@ -162,6 +231,7 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar
return &Credentials{
ProjectID: id,
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
universeDomain: params.UniverseDomain,
}, nil
}
@@ -200,6 +270,16 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
if err := json.Unmarshal(jsonData, &f); err != nil {
return nil, err
}
universeDomain := f.UniverseDomain
if params.UniverseDomain != "" {
universeDomain = params.UniverseDomain
}
// Authorized user credentials are only supported in the googleapis.com universe.
if f.Type == userCredentialsKey {
universeDomain = universeDomainDefault
}
ts, err := f.tokenSource(ctx, params)
if err != nil {
return nil, err
@@ -209,6 +289,7 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
ProjectID: f.ProjectID,
TokenSource: ts,
JSON: jsonData,
universeDomain: universeDomain,
}, nil
}

297
google/default_test.go Normal file
View File

@@ -0,0 +1,297 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
var saJSONJWT = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var saJSONJWTUniverseDomain = []byte(`{
"type": "service_account",
"project_id": "fake_project",
"universe_domain": "example.com",
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gopher%40fake_project.iam.gserviceaccount.com"
}`)
var userJSON = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2"
}`)
var userJSONUniverseDomain = []byte(`{
"client_id": "abc123.apps.googleusercontent.com",
"client_secret": "shh",
"refresh_token": "refreshing",
"type": "authorized_user",
"quota_project_id": "fake_project2",
"universe_domain": "example.com"
}`)
var universeDomain = "example.com"
var universeDomain2 = "apis-tpclp.goog"
func TestCredentialsFromJSONWithParams_SA(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWT, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}
func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWT, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
}
func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWTUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if creds.UniverseDomain() != universeDomain {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if got != universeDomain {
t.Fatalf("got %q, want %q", got, universeDomain)
}
}
func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, saJSONJWTUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "fake_project"; creds.ProjectID != want {
t.Fatalf("got %q, want %q", creds.ProjectID, want)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if got != universeDomain2 {
t.Fatalf("got %q, want %q", got, universeDomain2)
}
}
func TestCredentialsFromJSONWithParams_User(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSON, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSON, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSONUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain(t *testing.T) {
ctx := context.Background()
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
UniverseDomain: universeDomain2,
}
creds, err := CredentialsFromJSONWithParams(ctx, userJSONUniverseDomain, params)
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func TestComputeUniverseDomain(t *testing.T) {
universeDomainPath := "/computeMetadata/v1/universe/universe_domain"
universeDomainResponseBody := "example.com"
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != universeDomainPath {
t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath)
}
w.Write([]byte(universeDomainResponseBody))
}))
defer s.Close()
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
// Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block
creds := &Credentials{
ProjectID: "fake_project",
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
universeDomain: params.UniverseDomain, // empty
}
c := make(chan bool)
go func() {
got, err := creds.GetUniverseDomain() // First conflicting access.
if err != nil {
t.Error(err)
}
if want := universeDomainResponseBody; got != want {
t.Errorf("got %q, want %q", got, want)
}
c <- true
}()
got, err := creds.GetUniverseDomain() // Second conflicting access.
<-c
if err != nil {
t.Error(err)
}
if want := universeDomainResponseBody; got != want {
t.Errorf("got %q, want %q", got, want)
}
}

View File

@@ -101,6 +101,8 @@
// executable-sourced credentials), please check out:
// https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in
//
// # Security considerations
//
// Note that this library does not perform any validation on the token_url, token_info_url,
// or service_account_impersonation_url fields of the credential configuration.
// It is not recommended to use a credential configuration that you did not generate with

View File

@@ -16,6 +16,7 @@ import (
"cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/externalaccount"
"golang.org/x/oauth2/google/internal/externalaccountauthorizeduser"
"golang.org/x/oauth2/jwt"
)
@@ -23,6 +24,7 @@ import (
var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
AuthStyle: oauth2.AuthStyleInParams,
}
@@ -98,6 +100,7 @@ const (
serviceAccountKey = "service_account"
userCredentialsKey = "authorized_user"
externalAccountKey = "external_account"
externalAccountAuthorizedUserKey = "external_account_authorized_user"
impersonatedServiceAccount = "impersonated_service_account"
)
@@ -112,6 +115,7 @@ type credentialsFile struct {
AuthURL string `json:"auth_uri"`
TokenURL string `json:"token_uri"`
ProjectID string `json:"project_id"`
UniverseDomain string `json:"universe_domain"`
// User Credential fields
// (These typically come from gcloud auth.)
@@ -131,6 +135,9 @@ type credentialsFile struct {
QuotaProjectID string `json:"quota_project_id"`
WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
// External Account Authorized User fields
RevokeURL string `json:"revoke_url"`
// Service account impersonation
SourceCredentials *credentialsFile `json:"source_credentials"`
}
@@ -199,6 +206,19 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
WorkforcePoolUserProject: f.WorkforcePoolUserProject,
}
return cfg.TokenSource(ctx)
case externalAccountAuthorizedUserKey:
cfg := &externalaccountauthorizeduser.Config{
Audience: f.Audience,
RefreshToken: f.RefreshToken,
TokenURL: f.TokenURLExternal,
TokenInfoURL: f.TokenInfoURL,
ClientID: f.ClientID,
ClientSecret: f.ClientSecret,
RevokeURL: f.RevokeURL,
QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes,
}
return cfg.TokenSource(ctx)
case impersonatedServiceAccount:
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")

View File

@@ -5,6 +5,8 @@
package google
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
@@ -137,3 +139,21 @@ func TestJWTConfigFromJSONNoAudience(t *testing.T) {
t.Errorf("Audience = %q; want %q", got, want)
}
}
func TestComputeTokenSource(t *testing.T) {
tokenPath := "/computeMetadata/v1/instance/service-accounts/default/token"
tokenResponseBody := `{"access_token":"Sample.Access.Token","token_type":"Bearer","expires_in":3600}`
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != tokenPath {
t.Errorf("got %s, want %s", r.URL.Path, tokenPath)
}
w.Write([]byte(tokenResponseBody))
}))
defer s.Close()
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
ts := ComputeTokenSource("")
_, err := ts.Token()
if err != nil {
t.Errorf("ts.Token() = %v", err)
}
}

View File

@@ -274,49 +274,6 @@ type awsRequest struct {
Headers []awsRequestHeader `json:"headers"`
}
func (cs awsCredentialSource) validateMetadataServers() error {
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
return err
}
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
return err
}
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
}
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
if metadataUrl == "" {
// Zero value means use default, which is valid.
return true
}
u, err := url.Parse(metadataUrl)
if err != nil {
// Unparseable URL means invalid
return false
}
for _, validHostname := range validHostnames {
if u.Hostname() == validHostname {
// If it's one of the valid hostnames, everything is good
return true
}
}
// hostname not found in our allowlist, so not valid
return false
}
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
if !cs.isValidMetadataServer(metadataUrl) {
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
}
return nil
}
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil)
@@ -339,6 +296,10 @@ func shouldUseMetadataServer() bool {
return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
}
func (cs awsCredentialSource) credentialSourceType() string {
return "aws"
}
func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil {
headers := make(map[string]string)

View File

@@ -585,25 +585,18 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -631,25 +624,18 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
func TestAWSCredential_IMDSv2(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -677,10 +663,6 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
delete(server.Credentials, "Token")
tfc := testFileConfig
@@ -688,15 +670,12 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -724,21 +703,15 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -746,7 +719,6 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -774,21 +746,15 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -796,7 +762,6 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -823,21 +788,15 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -846,7 +805,6 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
"AWS_DEFAULT_REGION": "us-east-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -873,25 +831,18 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.EnvironmentID = "aws3"
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
_, err = tfc.parse(context.Background())
_, err := tfc.parse(context.Background())
if err == nil {
t.Fatalf("parse() should have failed")
}
@@ -903,23 +854,16 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.RegionURL = ""
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -939,23 +883,17 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRegion = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -975,10 +913,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}"))
}
@@ -987,13 +922,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1013,10 +945,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
}
@@ -1025,13 +954,10 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1051,23 +977,16 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.URL = ""
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1087,23 +1006,16 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRolename = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1123,23 +1035,16 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1159,10 +1064,6 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Metadata server should not have been called.")
@@ -1174,11 +1075,9 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
@@ -1186,7 +1085,6 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1214,28 +1112,21 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": accessKeyID,
"AWS_SECRET_ACCESS_KEY": secretAccessKey,
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1263,28 +1154,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1312,28 +1196,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background())
if err != nil {
@@ -1358,87 +1235,19 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testin
}
}
func TestAWSCredential_Validations(t *testing.T) {
var metadataServerValidityTests = []struct {
name string
credSource CredentialSource
errText string
}{
{
name: "No Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "",
URL: "",
IMDSv2SessionTokenURL: "",
},
}, {
name: "IPv4 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
}, {
name: "IPv6 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
},
}, {
name: "Faulty RegionURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url",
}, {
name: "Faulty CredVerificationURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://abc.com/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url",
}, {
name: "Faulty IMDSv2SessionTokenURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://abc.com/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url",
},
}
func TestAwsCredential_CredentialSourceType(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
for _, tt := range metadataServerValidityTests {
t.Run(tt.name, func(t *testing.T) {
tfc := testFileConfig
tfc.CredentialSource = tt.credSource
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(context.Background())
base, err := tfc.parse(context.Background())
if err != nil {
if tt.errText == "" {
t.Errorf("Didn't expect an error, but got %v", err)
} else if tt.errText != err.Error() {
t.Errorf("Expected %v, but got %v", tt.errText, err)
t.Fatalf("parse() failed %v", err)
}
} else {
if tt.errText != "" {
t.Errorf("Expected error %v, but got none", tt.errText)
}
}
})
if got, want := base.credentialSourceType(), "aws"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}

View File

@@ -8,13 +8,12 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing
@@ -63,31 +62,10 @@ type Config struct {
WorkforcePoolUserProject string
}
// Each element consists of a list of patterns. validateURLs checks for matches
// that include all elements in a given list, in that order.
var (
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
)
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
parsed, err := url.Parse(input)
if err != nil {
return false
}
if !strings.EqualFold(parsed.Scheme, scheme) {
return false
}
toTest := parsed.Host
for _, pattern := range patterns {
if pattern.MatchString(toTest) {
return true
}
}
return false
}
func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input)
}
@@ -185,10 +163,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
}
if err := awsCredSource.validateMetadataServers(); err != nil {
return nil, err
}
return awsCredSource, nil
}
} else if c.CredentialSource.File != "" {
@@ -202,6 +176,7 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
}
type baseCredentialSource interface {
credentialSourceType() string
subjectToken() (string, error)
}
@@ -211,6 +186,15 @@ type tokenSource struct {
conf *Config
}
func getMetricsHeaderValue(conf *Config, credSource baseCredentialSource) string {
return fmt.Sprintf("gl-go/%s auth/%s google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t",
goVersion(),
"unknown",
credSource.credentialSourceType(),
conf.ServiceAccountImpersonationURL != "",
conf.ServiceAccountImpersonationLifetimeSeconds != 0)
}
// Token allows tokenSource to conform to the oauth2.TokenSource interface.
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
@@ -224,7 +208,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, err
}
stsRequest := stsTokenExchangeRequest{
stsRequest := stsexchange.TokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience,
Scope: conf.Scopes,
@@ -234,7 +218,8 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
}
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
clientAuth := clientAuthentication{
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
@@ -247,7 +232,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
"userProject": conf.WorkforcePoolUserProject,
}
}
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ package externalaccount
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@@ -51,6 +52,7 @@ type testExchangeTokenServer struct {
url string
authorization string
contentType string
metricsHeader string
body string
response string
}
@@ -68,6 +70,10 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T
if got, want := headerContentType, tets.contentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, tets.metricsHeader; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
@@ -106,6 +112,10 @@ func validateToken(t *testing.T, tok *oauth2.Token) {
}
}
func getExpectedMetricsHeader(source string, saImpersonation bool, configLifetime bool) string {
return fmt.Sprintf("gl-go/%s auth/unknown google-byoid-sdk source/%s sa-impersonation/%t config-lifetime/%t", goVersion(), source, saImpersonation, configLifetime)
}
func TestToken(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
@@ -120,6 +130,7 @@ func TestToken(t *testing.T) {
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: baseCredsRequestBody,
response: baseCredsResponseBody,
}
@@ -147,6 +158,7 @@ func TestWorkforcePoolTokenWithClientID(t *testing.T) {
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithClientId,
response: baseCredsResponseBody,
}
@@ -173,6 +185,7 @@ func TestWorkforcePoolTokenWithoutClientID(t *testing.T) {
url: "/",
authorization: "",
contentType: "application/x-www-form-urlencoded",
metricsHeader: getExpectedMetricsHeader("file", false, false),
body: workforcePoolRequestBodyWithoutClientId,
response: baseCredsResponseBody,
}

View File

@@ -233,6 +233,10 @@ func (cs executableCredentialSource) parseSubjectTokenFromSource(response []byte
return "", tokenTypeError(source)
}
func (cs executableCredentialSource) credentialSourceType() string {
return "executable"
}
func (cs executableCredentialSource) subjectToken() (string, error) {
if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil {
return token, err

View File

@@ -150,6 +150,9 @@ func TestCreateExecutableCredential(t *testing.T) {
if ecs.Timeout != tt.expectedTimeout {
t.Errorf("ecs.Timeout got %v but want %v", ecs.Timeout, tt.expectedTimeout)
}
if ecs.credentialSourceType() != "executable" {
t.Errorf("ecs.CredentialSourceType() got %s but want executable", ecs.credentialSourceType())
}
}
})
}

View File

@@ -19,6 +19,10 @@ type fileCredentialSource struct {
Format format
}
func (cs fileCredentialSource) credentialSourceType() string {
return "file"
}
func (cs fileCredentialSource) subjectToken() (string, error) {
tokenFile, err := os.Open(cs.File)
if err != nil {

View File

@@ -68,6 +68,9 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
t.Errorf("got %v but want %v", out, test.want)
}
if got, want := base.credentialSourceType(), "file"; got != want {
t.Errorf("got %v but want %v", got, want)
}
})
}
}

View File

@@ -0,0 +1,64 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"strings"
"unicode"
)
var (
// version is a package internal global variable for testing purposes.
version = runtime.Version
)
// versionUnknown is only used when the runtime version cannot be determined.
const versionUnknown = "UNKNOWN"
// goVersion returns a Go runtime version derived from the runtime environment
// that is modified to be suitable for reporting in a header, meaning it has no
// whitespace. If it is unable to determine the Go runtime version, it returns
// versionUnknown.
func goVersion() string {
const develPrefix = "devel +"
s := version()
if strings.HasPrefix(s, develPrefix) {
s = s[len(develPrefix):]
if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
return s
} else if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 {
s = s[:p]
}
notSemverRune := func(r rune) bool {
return !strings.ContainsRune("0123456789.", r)
}
if strings.HasPrefix(s, "go1") {
s = s[2:]
var prerelease string
if p := strings.IndexFunc(s, notSemverRune); p >= 0 {
s, prerelease = s[:p], s[p:]
}
if strings.HasSuffix(s, ".") {
s += "0"
} else if strings.Count(s, ".") < 2 {
s += ".0"
}
if prerelease != "" {
// Some release candidates already have a dash in them.
if !strings.HasPrefix(prerelease, "-") {
prerelease = "-" + prerelease
}
s += prerelease
}
return s
}
return "UNKNOWN"
}

View File

@@ -0,0 +1,48 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"runtime"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGoVersion(t *testing.T) {
testVersion := func(v string) func() string {
return func() string {
return v
}
}
for _, tst := range []struct {
v func() string
want string
}{
{
testVersion("go1.19"),
"1.19.0",
},
{
testVersion("go1.21-20230317-RC01"),
"1.21.0-20230317-RC01",
},
{
testVersion("devel +abc1234"),
"abc1234",
},
{
testVersion("this should be unknown"),
versionUnknown,
},
} {
version = tst.v
got := goVersion()
if diff := cmp.Diff(got, tst.want); diff != "" {
t.Errorf("got(-),want(+):\n%s", diff)
}
}
version = runtime.Version
}

View File

@@ -42,7 +42,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin
}))
}
func createTargetServer(t *testing.T) *httptest.Server {
func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), "/"; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
@@ -55,6 +55,10 @@ func createTargetServer(t *testing.T) *httptest.Server {
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerMetrics := r.Header.Get("x-goog-api-client")
if got, want := headerMetrics, metricsHeaderWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
@@ -71,6 +75,7 @@ var impersonationTests = []struct {
name string
config Config
expectedImpersonationBody string
expectedMetricsHeader string
}{
{
name: "Base Impersonation",
@@ -84,6 +89,7 @@ var impersonationTests = []struct {
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
},
expectedImpersonationBody: "{\"lifetime\":\"3600s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, false),
},
{
name: "With TokenLifetime Set",
@@ -98,6 +104,7 @@ var impersonationTests = []struct {
ServiceAccountImpersonationLifetimeSeconds: 10000,
},
expectedImpersonationBody: "{\"lifetime\":\"10000s\",\"scope\":[\"https://www.googleapis.com/auth/devstorage.full_control\"]}",
expectedMetricsHeader: getExpectedMetricsHeader("file", true, true),
},
}
@@ -109,7 +116,7 @@ func TestImpersonation(t *testing.T) {
defer impersonateServer.Close()
testImpersonateConfig.ServiceAccountImpersonationURL = impersonateServer.URL
targetServer := createTargetServer(t)
targetServer := createTargetServer(tt.expectedMetricsHeader, t)
defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL

View File

@@ -23,6 +23,10 @@ type urlCredentialSource struct {
ctx context.Context
}
func (cs urlCredentialSource) credentialSourceType() string {
return "url"
}
func (cs urlCredentialSource) subjectToken() (string, error) {
client := oauth2.NewClient(cs.ctx, nil)
req, err := http.NewRequest("GET", cs.URL, nil)

View File

@@ -111,3 +111,21 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
t.Errorf("got %v but want %v", out, myURLToken)
}
}
func TestURLCredential_CredentialSourceType(t *testing.T) {
cs := CredentialSource{
URL: "http://example.com",
Format: format{Type: fileTypeText},
}
tfc := testFileConfig
tfc.CredentialSource = cs
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
if got, want := base.credentialSourceType(), "url"; got != want {
t.Errorf("got %v but want %v", got, want)
}
}

View File

@@ -0,0 +1,114 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"errors"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
// now aliases time.Now for testing.
var now = func() time.Time {
return time.Now().UTC()
}
var tokenValid = func(token oauth2.Token) bool {
return token.Valid()
}
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
// the provider identifier in that pool.
Audience string
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
RefreshToken string
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
// None if the token can not be refreshed.
TokenURL string
// TokenInfoURL is the optional STS endpoint URL for token introspection.
TokenInfoURL string
// ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
// access token. When provided, STS will be called with additional basic authentication using client_id as username
// and client_secret as password.
ClientSecret string
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
Token string
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
Expiry time.Time
// RevokeURL is the optional STS endpoint URL for revoking tokens.
RevokeURL string
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
// project used to create the credentials.
QuotaProjectID string
Scopes []string
}
func (c *Config) canRefresh() bool {
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
}
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
var token oauth2.Token
if c.Token != "" && !c.Expiry.IsZero() {
token = oauth2.Token{
AccessToken: c.Token,
Expiry: c.Expiry,
TokenType: "Bearer",
}
}
if !tokenValid(token) && !c.canRefresh() {
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
return oauth2.ReuseTokenSource(&token, ts), nil
}
type tokenSource struct {
ctx context.Context
conf *Config
}
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
if !conf.canRefresh() {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}
if stsResponse.ExpiresIn < 0 {
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
}
if stsResponse.RefreshToken != "" {
conf.RefreshToken = stsResponse.RefreshToken
}
token := &oauth2.Token{
AccessToken: stsResponse.AccessToken,
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
TokenType: "Bearer",
}
return token, nil
}

View File

@@ -0,0 +1,259 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/stsexchange"
)
const expiryDelta = 10 * time.Second
var (
expiry = time.Unix(234852, 0)
testNow = func() time.Time { return expiry }
testValid = func(t oauth2.Token) bool {
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
}
)
type testRefreshTokenServer struct {
URL string
Authorization string
ContentType string
Body string
ResponsePayload *stsexchange.Response
Response string
server *httptest.Server
}
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expiry: now().Add(time.Hour),
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
if config.RefreshToken != "CCCCCCC" {
t.Fatalf("Refresh token not updated")
}
}
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.close(t)
testCases := []struct {
name string
config Config
}{
{
name: "empty config",
config: Config{},
},
{
name: "missing refresh token",
config: Config{
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing token url",
config: Config{
RefreshToken: "BBBBBBBBB",
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client id",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client secrect",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
_, err := tc.config.TokenSource((context.Background()))
if err == nil {
t.Fatalf("Expected error, but received none")
}
if got := err.Error(); got != expectErrMsg {
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
}
})
}
}
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
t.Helper()
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) close(t *testing.T) error {
t.Helper()
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"encoding/base64"
@@ -12,8 +12,8 @@ import (
"golang.org/x/oauth2"
)
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type clientAuthentication struct {
// ClientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type ClientAuthentication struct {
// AuthStyle can be either basic or request-body
AuthStyle oauth2.AuthStyle
ClientID string
@@ -23,7 +23,7 @@ type clientAuthentication struct {
// InjectAuthentication is used to add authentication to a Secure Token Service exchange
// request. It modifies either the passed url.Values or http.Header depending on the desired
// authentication format.
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
return
}

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"net/http"
@@ -38,7 +38,7 @@ func TestClientAuthentication_InjectHeaderAuthentication(t *testing.T) {
"Content-Type": ContentType,
}
headerAuthentication := clientAuthentication{
headerAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID,
ClientSecret: clientSecret,
@@ -80,7 +80,7 @@ func TestClientAuthentication_ParamsAuthentication(t *testing.T) {
headerP := http.Header{
"Content-Type": ContentType,
}
paramsAuthentication := clientAuthentication{
paramsAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInParams,
ClientID: clientID,
ClientSecret: clientSecret,

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"context"
@@ -18,14 +18,17 @@ import (
"golang.org/x/oauth2"
)
// exchangeToken performs an oauth2 token exchange with the provided endpoint.
func defaultHeader() http.Header {
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
return header
}
// ExchangeToken performs an oauth2 token exchange with the provided endpoint.
// The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server.
func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchangeRequest, authentication clientAuthentication, headers http.Header, options map[string]interface{}) (*stsTokenExchangeResponse, error) {
client := oauth2.NewClient(ctx, nil)
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
data := url.Values{}
data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@@ -41,13 +44,28 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
data.Set("options", string(opts))
}
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
if headers == nil {
headers = defaultHeader()
}
client := oauth2.NewClient(ctx, nil)
authentication.InjectAuthentication(data, headers)
encodedData := data.Encode()
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
}
req = req.WithContext(ctx)
for key, list := range headers {
@@ -71,7 +89,7 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
}
var stsResp stsTokenExchangeResponse
var stsResp Response
err = json.Unmarshal(body, &stsResp)
if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
@@ -81,8 +99,8 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
return &stsResp, nil
}
// stsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type stsTokenExchangeRequest struct {
// TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type TokenExchangeRequest struct {
ActingParty struct {
ActorToken string
ActorTokenType string
@@ -96,8 +114,8 @@ type stsTokenExchangeRequest struct {
SubjectTokenType string
}
// stsTokenExchangeResponse is used to decode the remote server response during an oauth2 token exchange.
type stsTokenExchangeResponse struct {
// Response is used to decode the remote server response during an oauth2 token exchange.
type Response struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
package stsexchange
import (
"context"
@@ -16,13 +16,13 @@ import (
"golang.org/x/oauth2"
)
var auth = clientAuthentication{
var auth = ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID,
ClientSecret: clientSecret,
}
var tokenRequest = stsTokenExchangeRequest{
var exchangeTokenRequest = TokenExchangeRequest{
ActingParty: struct {
ActorToken string
ActorTokenType string
@@ -36,9 +36,9 @@ var tokenRequest = stsTokenExchangeRequest{
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
}
var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedToken = stsTokenExchangeResponse{
var exchangeRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var exchangeResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedExchangeToken = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
@@ -47,6 +47,18 @@ var expectedToken = stsTokenExchangeResponse{
RefreshToken: "",
}
var refreshToken = "ReFrEsHtOkEn"
var refreshRequestBody = "grant_type=refresh_token&refresh_token=" + refreshToken
var refreshResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform","refresh_token":"REFRESHED_REFRESH"}`
var expectedRefreshResponse = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: "https://www.googleapis.com/auth/cloud-platform",
RefreshToken: "REFRESHED_REFRESH",
}
func TestExchangeToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
@@ -65,26 +77,34 @@ func TestExchangeToken(t *testing.T) {
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), requestbody; got != want {
if got, want := string(body), exchangeRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody))
w.Write([]byte(exchangeResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
resp, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedToken, *resp)
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
resp, err = ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, nil, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
}
func TestExchangeToken_Err(t *testing.T) {
@@ -96,7 +116,7 @@ func TestExchangeToken_Err(t *testing.T) {
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
_, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
@@ -171,7 +191,7 @@ func TestExchangeToken_Opts(t *testing.T) {
// Send a proper reply so that no other errors crop up.
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody))
w.Write([]byte(exchangeResponseBody))
}))
defer ts.Close()
@@ -183,5 +203,69 @@ func TestExchangeToken_Opts(t *testing.T) {
inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption
inputOpts["two"] = secondOption
exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, inputOpts)
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
}
func TestRefreshToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Unexpected request method, %v is found", r.Method)
}
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
if got, want := r.Header.Get("Authorization"), "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("Unexpected authorization header, got %v, want %v", got, want)
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), refreshRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(refreshResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
resp, err = RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
}
func TestRefreshToken_Err(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("what's wrong with this response?"))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
}

View File

@@ -1,13 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build appengine
package internal
import "google.golang.org/appengine/urlfetch"
func init() {
appengineClientHook = urlfetch.Client
}

View File

@@ -18,16 +18,11 @@ var HTTPClient ContextKey
// because nobody else can create a ContextKey, being unexported.
type ContextKey struct{}
var appengineClientHook func(context.Context) *http.Client
func ContextClient(ctx context.Context) *http.Client {
if ctx != nil {
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc
}
}
if appengineClientHook != nil {
return appengineClientHook(ctx)
}
return http.DefaultClient
}

View File

@@ -76,6 +76,7 @@ type TokenSource interface {
// endpoint URLs.
type Endpoint struct {
AuthURL string
DeviceAuthURL string
TokenURL string
// AuthStyle optionally specifies how the endpoint wants the
@@ -143,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// State is an opaque value used by the client to maintain state between the
// request and callback. The authorization server includes this value when
// redirecting the user agent back to the client.
//
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce.
// It can also be used to pass the PKCE challenge.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
//
// To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
// generate a random state parameter and verify it after exchange.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL)
@@ -166,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
v.Set("scope", strings.Join(c.Scopes, " "))
}
if state != "" {
// TODO(light): Docs say never to omit state; don't allow empty.
v.Set("state", state)
}
for _, opt := range opts {
@@ -211,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
//
// The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state").
// calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks.
//
// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
// If using PKCE to protect against CSRF attacks, opts should include a
// VerifierOption.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{
"grant_type": {"authorization_code"},

68
pkce.go Normal file
View File

@@ -0,0 +1,68 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/url"
)
const (
codeChallengeKey = "code_challenge"
codeChallengeMethodKey = "code_challenge_method"
codeVerifierKey = "code_verifier"
)
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
// This follows recommendations in RFC 7636.
//
// A fresh verifier should be generated for each authorization.
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
// (or Config.DeviceAccessToken).
func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then
// base64url-encoded to produce a 43-octet URL-safe string to use as the
// code verifier."
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
data := make([]byte, 32)
if _, err := rand.Read(data); err != nil {
panic(err)
}
return base64.RawURLEncoding.EncodeToString(data)
}
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier}
}
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
//
// Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:])
}
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
// only.
func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{
challenge_method: "S256",
challenge: S256ChallengeFromVerifier(verifier),
}
}
type challengeOption struct{ challenge_method, challenge string }
func (p challengeOption) setValue(m url.Values) {
m.Set(codeChallengeMethodKey, p.challenge_method)
m.Set(codeChallengeKey, p.challenge)
}