forked from remote/oauth2
google: adding support for external account authorized user
To support a new type of credential: `ExternalAccountAuthorizedUser`
* Refactor the common dependency STS to a separate package.
* Adding the `externalaccountauthorizeduser` package.
Change-Id: I9b9624f912d216b67a0d31945a50f057f747710b
GitHub-Last-Rev: 6e2aaff345
GitHub-Pull-Request: golang/oauth2#671
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531095
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
@@ -0,0 +1,114 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package externalaccountauthorizeduser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||
)
|
||||
|
||||
// now aliases time.Now for testing.
|
||||
var now = func() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
var tokenValid = func(token oauth2.Token) bool {
|
||||
return token.Valid()
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
|
||||
// the provider identifier in that pool.
|
||||
Audience string
|
||||
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
|
||||
RefreshToken string
|
||||
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
|
||||
// None if the token can not be refreshed.
|
||||
TokenURL string
|
||||
// TokenInfoURL is the optional STS endpoint URL for token introspection.
|
||||
TokenInfoURL string
|
||||
// ClientID is only required in conjunction with ClientSecret, as described above.
|
||||
ClientID string
|
||||
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
|
||||
// access token. When provided, STS will be called with additional basic authentication using client_id as username
|
||||
// and client_secret as password.
|
||||
ClientSecret string
|
||||
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
|
||||
Token string
|
||||
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
|
||||
Expiry time.Time
|
||||
// RevokeURL is the optional STS endpoint URL for revoking tokens.
|
||||
RevokeURL string
|
||||
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
|
||||
// project used to create the credentials.
|
||||
QuotaProjectID string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
func (c *Config) canRefresh() bool {
|
||||
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
|
||||
}
|
||||
|
||||
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
|
||||
var token oauth2.Token
|
||||
if c.Token != "" && !c.Expiry.IsZero() {
|
||||
token = oauth2.Token{
|
||||
AccessToken: c.Token,
|
||||
Expiry: c.Expiry,
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
}
|
||||
if !tokenValid(token) && !c.canRefresh() {
|
||||
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
|
||||
}
|
||||
|
||||
ts := tokenSource{
|
||||
ctx: ctx,
|
||||
conf: c,
|
||||
}
|
||||
|
||||
return oauth2.ReuseTokenSource(&token, ts), nil
|
||||
}
|
||||
|
||||
type tokenSource struct {
|
||||
ctx context.Context
|
||||
conf *Config
|
||||
}
|
||||
|
||||
func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||
conf := ts.conf
|
||||
if !conf.canRefresh() {
|
||||
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
|
||||
}
|
||||
|
||||
clientAuth := stsexchange.ClientAuthentication{
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
ClientID: conf.ClientID,
|
||||
ClientSecret: conf.ClientSecret,
|
||||
}
|
||||
|
||||
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stsResponse.ExpiresIn < 0 {
|
||||
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
|
||||
}
|
||||
|
||||
if stsResponse.RefreshToken != "" {
|
||||
conf.RefreshToken = stsResponse.RefreshToken
|
||||
}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: stsResponse.AccessToken,
|
||||
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package externalaccountauthorizeduser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||
)
|
||||
|
||||
const expiryDelta = 10 * time.Second
|
||||
|
||||
var (
|
||||
expiry = time.Unix(234852, 0)
|
||||
testNow = func() time.Time { return expiry }
|
||||
testValid = func(t oauth2.Token) bool {
|
||||
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
|
||||
}
|
||||
)
|
||||
|
||||
type testRefreshTokenServer struct {
|
||||
URL string
|
||||
Authorization string
|
||||
ContentType string
|
||||
Body string
|
||||
ResponsePayload *stsexchange.Response
|
||||
Response string
|
||||
server *httptest.Server
|
||||
}
|
||||
|
||||
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
|
||||
config := &Config{
|
||||
Token: "AAAAAAA",
|
||||
Expiry: now().Add(time.Hour),
|
||||
}
|
||||
ts, err := config.TokenSource(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting token source: %v", err)
|
||||
}
|
||||
|
||||
token, err := ts.Token()
|
||||
if err != nil {
|
||||
t.Fatalf("Error retrieving Token: %v", err)
|
||||
}
|
||||
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
||||
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
|
||||
server := &testRefreshTokenServer{
|
||||
URL: "/",
|
||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||
ContentType: "application/x-www-form-urlencoded",
|
||||
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||
ResponsePayload: &stsexchange.Response{
|
||||
ExpiresIn: 3600,
|
||||
AccessToken: "AAAAAAA",
|
||||
RefreshToken: "CCCCCCC",
|
||||
},
|
||||
}
|
||||
|
||||
url, err := server.run(t)
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting server")
|
||||
}
|
||||
defer server.close(t)
|
||||
|
||||
config := &Config{
|
||||
RefreshToken: "BBBBBBBBB",
|
||||
TokenURL: url,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
}
|
||||
ts, err := config.TokenSource(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting token source: %v", err)
|
||||
}
|
||||
|
||||
token, err := ts.Token()
|
||||
if err != nil {
|
||||
t.Fatalf("Error retrieving Token: %v", err)
|
||||
}
|
||||
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
||||
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
||||
}
|
||||
if config.RefreshToken != "CCCCCCC" {
|
||||
t.Fatalf("Refresh token not updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
|
||||
server := &testRefreshTokenServer{
|
||||
URL: "/",
|
||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||
ContentType: "application/x-www-form-urlencoded",
|
||||
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||
ResponsePayload: &stsexchange.Response{
|
||||
ExpiresIn: 3600,
|
||||
AccessToken: "AAAAAAA",
|
||||
},
|
||||
}
|
||||
|
||||
url, err := server.run(t)
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting server")
|
||||
}
|
||||
defer server.close(t)
|
||||
|
||||
config := &Config{
|
||||
RefreshToken: "BBBBBBBBB",
|
||||
TokenURL: url,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
}
|
||||
ts, err := config.TokenSource(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting token source: %v", err)
|
||||
}
|
||||
|
||||
token, err := ts.Token()
|
||||
if err != nil {
|
||||
t.Fatalf("Error retrieving Token: %v", err)
|
||||
}
|
||||
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
||||
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
|
||||
server := &testRefreshTokenServer{
|
||||
URL: "/",
|
||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||
ContentType: "application/x-www-form-urlencoded",
|
||||
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||
ResponsePayload: &stsexchange.Response{
|
||||
ExpiresIn: 3600,
|
||||
AccessToken: "AAAAAAA",
|
||||
},
|
||||
}
|
||||
|
||||
url, err := server.run(t)
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting server")
|
||||
}
|
||||
defer server.close(t)
|
||||
testCases := []struct {
|
||||
name string
|
||||
config Config
|
||||
}{
|
||||
{
|
||||
name: "empty config",
|
||||
config: Config{},
|
||||
},
|
||||
{
|
||||
name: "missing refresh token",
|
||||
config: Config{
|
||||
TokenURL: url,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token url",
|
||||
config: Config{
|
||||
RefreshToken: "BBBBBBBBB",
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing client id",
|
||||
config: Config{
|
||||
RefreshToken: "BBBBBBBBB",
|
||||
TokenURL: url,
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing client secrect",
|
||||
config: Config{
|
||||
RefreshToken: "BBBBBBBBB",
|
||||
TokenURL: url,
|
||||
ClientID: "CLIENT_ID",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
|
||||
_, err := tc.config.TokenSource((context.Background()))
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error, but received none")
|
||||
}
|
||||
if got := err.Error(); got != expectErrMsg {
|
||||
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
|
||||
t.Helper()
|
||||
if trts.server != nil {
|
||||
return "", errors.New("Server is already running")
|
||||
}
|
||||
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got, want := r.URL.String(), trts.URL; got != want {
|
||||
t.Errorf("URL.String(): got %v but want %v", got, want)
|
||||
}
|
||||
headerAuth := r.Header.Get("Authorization")
|
||||
if got, want := headerAuth, trts.Authorization; got != want {
|
||||
t.Errorf("got %v but want %v", got, want)
|
||||
}
|
||||
headerContentType := r.Header.Get("Content-Type")
|
||||
if got, want := headerContentType, trts.ContentType; got != want {
|
||||
t.Errorf("got %v but want %v", got, want)
|
||||
}
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed reading request body: %s.", err)
|
||||
}
|
||||
if got, want := string(body), trts.Body; got != want {
|
||||
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if trts.ResponsePayload != nil {
|
||||
content, err := json.Marshal(trts.ResponsePayload)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to marshall response JSON")
|
||||
}
|
||||
w.Write(content)
|
||||
} else {
|
||||
w.Write([]byte(trts.Response))
|
||||
}
|
||||
}))
|
||||
return trts.server.URL, nil
|
||||
}
|
||||
|
||||
func (trts *testRefreshTokenServer) close(t *testing.T) error {
|
||||
t.Helper()
|
||||
if trts.server == nil {
|
||||
return errors.New("No server is running")
|
||||
}
|
||||
trts.server.Close()
|
||||
trts.server = nil
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user