forked from remote/oauth2
address comments
This commit is contained in:
@@ -8,14 +8,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google/internal/sts_exchange"
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||||
)
|
)
|
||||||
|
|
||||||
// now aliases time.Now for testing
|
// now aliases time.Now for testing
|
||||||
@@ -64,31 +62,10 @@ type Config struct {
|
|||||||
WorkforcePoolUserProject string
|
WorkforcePoolUserProject string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each element consists of a list of patterns. validateURLs checks for matches
|
|
||||||
// that include all elements in a given list, in that order.
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
|
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
|
|
||||||
parsed, err := url.Parse(input)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !strings.EqualFold(parsed.Scheme, scheme) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
toTest := parsed.Host
|
|
||||||
|
|
||||||
for _, pattern := range patterns {
|
|
||||||
if pattern.MatchString(toTest) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateWorkforceAudience(input string) bool {
|
func validateWorkforceAudience(input string) bool {
|
||||||
return validWorkforceAudiencePattern.MatchString(input)
|
return validWorkforceAudiencePattern.MatchString(input)
|
||||||
}
|
}
|
||||||
@@ -231,7 +208,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stsRequest := sts_exchange.StsTokenExchangeRequest{
|
stsRequest := stsexchange.StsTokenExchangeRequest{
|
||||||
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
|
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
|
||||||
Audience: conf.Audience,
|
Audience: conf.Audience,
|
||||||
Scope: conf.Scopes,
|
Scope: conf.Scopes,
|
||||||
@@ -242,7 +219,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
|||||||
header := make(http.Header)
|
header := make(http.Header)
|
||||||
header.Add("Content-Type", "application/x-www-form-urlencoded")
|
header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
|
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
|
||||||
clientAuth := sts_exchange.ClientAuthentication{
|
clientAuth := stsexchange.ClientAuthentication{
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
ClientID: conf.ClientID,
|
ClientID: conf.ClientID,
|
||||||
ClientSecret: conf.ClientSecret,
|
ClientSecret: conf.ClientSecret,
|
||||||
@@ -255,7 +232,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
|||||||
"userProject": conf.WorkforcePoolUserProject,
|
"userProject": conf.WorkforcePoolUserProject,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stsResp, err := sts_exchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
|
stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google/internal/sts_exchange"
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||||
)
|
)
|
||||||
|
|
||||||
// now aliases time.Now for testing.
|
// now aliases time.Now for testing.
|
||||||
@@ -87,13 +87,13 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
|||||||
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
|
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAuth := sts_exchange.ClientAuthentication{
|
clientAuth := stsexchange.ClientAuthentication{
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
ClientID: conf.ClientID,
|
ClientID: conf.ClientID,
|
||||||
ClientSecret: conf.ClientSecret,
|
ClientSecret: conf.ClientSecret,
|
||||||
}
|
}
|
||||||
|
|
||||||
stsResponse, err := sts_exchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
|
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google/internal/sts_exchange"
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||||
)
|
)
|
||||||
|
|
||||||
const expiryDelta = 10 * time.Second
|
const expiryDelta = 10 * time.Second
|
||||||
@@ -33,59 +33,11 @@ type testRefreshTokenServer struct {
|
|||||||
Authorization string
|
Authorization string
|
||||||
ContentType string
|
ContentType string
|
||||||
Body string
|
Body string
|
||||||
ResponsePayload *sts_exchange.Response
|
ResponsePayload *stsexchange.Response
|
||||||
Response string
|
Response string
|
||||||
server *httptest.Server
|
server *httptest.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (trts *testRefreshTokenServer) Run(t *testing.T) (string, error) {
|
|
||||||
if trts.server != nil {
|
|
||||||
return "", errors.New("Server is already running")
|
|
||||||
}
|
|
||||||
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if got, want := r.URL.String(), trts.URL; got != want {
|
|
||||||
t.Errorf("URL.String(): got %v but want %v", got, want)
|
|
||||||
}
|
|
||||||
headerAuth := r.Header.Get("Authorization")
|
|
||||||
if got, want := headerAuth, trts.Authorization; got != want {
|
|
||||||
t.Errorf("got %v but want %v", got, want)
|
|
||||||
}
|
|
||||||
headerContentType := r.Header.Get("Content-Type")
|
|
||||||
if got, want := headerContentType, trts.ContentType; got != want {
|
|
||||||
t.Errorf("got %v but want %v", got, want)
|
|
||||||
}
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed reading request body: %s.", err)
|
|
||||||
}
|
|
||||||
if got, want := string(body), trts.Body; got != want {
|
|
||||||
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if trts.ResponsePayload != nil {
|
|
||||||
content, err := json.Marshal(trts.ResponsePayload)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to marshall response JSON")
|
|
||||||
}
|
|
||||||
w.Write(content)
|
|
||||||
} else {
|
|
||||||
w.Write([]byte(trts.Response))
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
return trts.server.URL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (trts *testRefreshTokenServer) Close() error {
|
|
||||||
if trts.server == nil {
|
|
||||||
return errors.New("No server is running")
|
|
||||||
}
|
|
||||||
trts.server.Close()
|
|
||||||
trts.server = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests
|
|
||||||
|
|
||||||
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
|
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Token: "AAAAAAA",
|
Token: "AAAAAAA",
|
||||||
@@ -111,18 +63,18 @@ func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t
|
|||||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
ContentType: "application/x-www-form-urlencoded",
|
ContentType: "application/x-www-form-urlencoded",
|
||||||
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||||
ResponsePayload: &sts_exchange.Response{
|
ResponsePayload: &stsexchange.Response{
|
||||||
ExpiresIn: 3600,
|
ExpiresIn: 3600,
|
||||||
AccessToken: "AAAAAAA",
|
AccessToken: "AAAAAAA",
|
||||||
RefreshToken: "CCCCCCC",
|
RefreshToken: "CCCCCCC",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := server.Run(t)
|
url, err := server.run(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error starting server")
|
t.Fatalf("Error starting server")
|
||||||
}
|
}
|
||||||
defer server.Close()
|
defer server.close(t)
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
RefreshToken: "BBBBBBBBB",
|
RefreshToken: "BBBBBBBBB",
|
||||||
@@ -153,17 +105,17 @@ func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing
|
|||||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
ContentType: "application/x-www-form-urlencoded",
|
ContentType: "application/x-www-form-urlencoded",
|
||||||
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||||
ResponsePayload: &sts_exchange.Response{
|
ResponsePayload: &stsexchange.Response{
|
||||||
ExpiresIn: 3600,
|
ExpiresIn: 3600,
|
||||||
AccessToken: "AAAAAAA",
|
AccessToken: "AAAAAAA",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := server.Run(t)
|
url, err := server.run(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error starting server")
|
t.Fatalf("Error starting server")
|
||||||
}
|
}
|
||||||
defer server.Close()
|
defer server.close(t)
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
RefreshToken: "BBBBBBBBB",
|
RefreshToken: "BBBBBBBBB",
|
||||||
@@ -191,17 +143,17 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
|
|||||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
ContentType: "application/x-www-form-urlencoded",
|
ContentType: "application/x-www-form-urlencoded",
|
||||||
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||||
ResponsePayload: &sts_exchange.Response{
|
ResponsePayload: &stsexchange.Response{
|
||||||
ExpiresIn: 3600,
|
ExpiresIn: 3600,
|
||||||
AccessToken: "AAAAAAA",
|
AccessToken: "AAAAAAA",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := server.Run(t)
|
url, err := server.run(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error starting server")
|
t.Fatalf("Error starting server")
|
||||||
}
|
}
|
||||||
defer server.Close()
|
defer server.close(t)
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
config Config
|
config Config
|
||||||
@@ -257,3 +209,51 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
if trts.server != nil {
|
||||||
|
return "", errors.New("Server is already running")
|
||||||
|
}
|
||||||
|
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got, want := r.URL.String(), trts.URL; got != want {
|
||||||
|
t.Errorf("URL.String(): got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
headerAuth := r.Header.Get("Authorization")
|
||||||
|
if got, want := headerAuth, trts.Authorization; got != want {
|
||||||
|
t.Errorf("got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
|
if got, want := headerContentType, trts.ContentType; got != want {
|
||||||
|
t.Errorf("got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed reading request body: %s.", err)
|
||||||
|
}
|
||||||
|
if got, want := string(body), trts.Body; got != want {
|
||||||
|
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if trts.ResponsePayload != nil {
|
||||||
|
content, err := json.Marshal(trts.ResponsePayload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to marshall response JSON")
|
||||||
|
}
|
||||||
|
w.Write(content)
|
||||||
|
} else {
|
||||||
|
w.Write([]byte(trts.Response))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return trts.server.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trts *testRefreshTokenServer) close(t *testing.T) error {
|
||||||
|
t.Helper()
|
||||||
|
if trts.server == nil {
|
||||||
|
return errors.New("No server is running")
|
||||||
|
}
|
||||||
|
trts.server.Close()
|
||||||
|
trts.server = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package sts_exchange
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package sts_exchange
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package sts_exchange
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package sts_exchange
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
Reference in New Issue
Block a user