forked from remote/oauth2
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07085280e4 | ||
|
|
a835fc4358 | ||
|
|
2e4a4e2bfb | ||
|
|
ac6658e9cb |
@@ -47,6 +47,10 @@ type Config struct {
|
|||||||
// client ID & client secret sent. The zero value means to
|
// client ID & client secret sent. The zero value means to
|
||||||
// auto-detect.
|
// auto-detect.
|
||||||
AuthStyle oauth2.AuthStyle
|
AuthStyle oauth2.AuthStyle
|
||||||
|
|
||||||
|
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
|
||||||
|
// the zero value (AuthStyleAutoDetect).
|
||||||
|
authStyleCache internal.LazyAuthStyleCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token uses client credentials to retrieve a token.
|
// Token uses client credentials to retrieve a token.
|
||||||
@@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
|
|||||||
v[k] = p
|
v[k] = p
|
||||||
}
|
}
|
||||||
|
|
||||||
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
|
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if rErr, ok := err.(*internal.RetrieveError); ok {
|
if rErr, ok := err.(*internal.RetrieveError); ok {
|
||||||
return nil, (*oauth2.RetrieveError)(rErr)
|
return nil, (*oauth2.RetrieveError)(rErr)
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/oauth2/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newConf(serverURL string) *Config {
|
func newConf(serverURL string) *Config {
|
||||||
@@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenRefreshRequest(t *testing.T) {
|
func TestTokenRefreshRequest(t *testing.T) {
|
||||||
internal.ResetAuthCache()
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.String() == "/somethingelse" {
|
if r.URL.String() == "/somethingelse" {
|
||||||
return
|
return
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module golang.org/x/oauth2
|
module golang.org/x/oauth2
|
||||||
|
|
||||||
go 1.17
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go/compute/metadata v0.2.3
|
cloud.google.com/go/compute/metadata v0.2.3
|
||||||
@@ -11,6 +11,6 @@ require (
|
|||||||
require (
|
require (
|
||||||
cloud.google.com/go/compute v1.20.1 // indirect
|
cloud.google.com/go/compute v1.20.1 // indirect
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
golang.org/x/net v0.12.0 // indirect
|
golang.org/x/net v0.15.0 // indirect
|
||||||
google.golang.org/protobuf v1.31.0 // indirect
|
google.golang.org/protobuf v1.31.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
//go:build appengine
|
//go:build appengine
|
||||||
// +build appengine
|
|
||||||
|
|
||||||
// This file applies to App Engine first generation runtimes (<= Go 1.9).
|
// This file applies to App Engine first generation runtimes (<= Go 1.9).
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
//go:build !appengine
|
//go:build !appengine
|
||||||
// +build !appengine
|
|
||||||
|
|
||||||
// This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible.
|
// This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible.
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
//go:build appengine
|
//go:build appengine
|
||||||
// +build appengine
|
|
||||||
|
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,41 +116,60 @@ const (
|
|||||||
AuthStyleInHeader AuthStyle = 2
|
AuthStyleInHeader AuthStyle = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
// authStyleCache is the set of tokenURLs we've successfully used via
|
// LazyAuthStyleCache is a backwards compatibility compromise to let Configs
|
||||||
|
// have a lazily-initialized AuthStyleCache.
|
||||||
|
//
|
||||||
|
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
|
||||||
|
// both would ideally just embed an unexported AuthStyleCache but because both
|
||||||
|
// were historically allowed to be copied by value we can't retroactively add an
|
||||||
|
// uncopyable Mutex to them.
|
||||||
|
//
|
||||||
|
// We could use an atomic.Pointer, but that was added recently enough (in Go
|
||||||
|
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
|
||||||
|
// still pass. By using an atomic.Value, it supports both Go 1.17 and
|
||||||
|
// copying by value, even if that's not ideal.
|
||||||
|
type LazyAuthStyleCache struct {
|
||||||
|
v atomic.Value // of *AuthStyleCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
||||||
|
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
c := new(AuthStyleCache)
|
||||||
|
if !lc.v.CompareAndSwap(nil, c) {
|
||||||
|
c = lc.v.Load().(*AuthStyleCache)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthStyleCache is the set of tokenURLs we've successfully used via
|
||||||
// RetrieveToken and which style auth we ended up using.
|
// RetrieveToken and which style auth we ended up using.
|
||||||
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
||||||
// the set of OAuth2 servers a program contacts over time is fixed and
|
// the set of OAuth2 servers a program contacts over time is fixed and
|
||||||
// small.
|
// small.
|
||||||
var authStyleCache struct {
|
type AuthStyleCache struct {
|
||||||
sync.Mutex
|
mu sync.Mutex
|
||||||
m map[string]AuthStyle // keyed by tokenURL
|
m map[string]AuthStyle // keyed by tokenURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetAuthCache resets the global authentication style cache used
|
|
||||||
// for AuthStyleUnknown token requests.
|
|
||||||
func ResetAuthCache() {
|
|
||||||
authStyleCache.Lock()
|
|
||||||
defer authStyleCache.Unlock()
|
|
||||||
authStyleCache.m = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupAuthStyle reports which auth style we last used with tokenURL
|
// lookupAuthStyle reports which auth style we last used with tokenURL
|
||||||
// when calling RetrieveToken and whether we have ever done so.
|
// when calling RetrieveToken and whether we have ever done so.
|
||||||
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
||||||
authStyleCache.Lock()
|
c.mu.Lock()
|
||||||
defer authStyleCache.Unlock()
|
defer c.mu.Unlock()
|
||||||
style, ok = authStyleCache.m[tokenURL]
|
style, ok = c.m[tokenURL]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// setAuthStyle adds an entry to authStyleCache, documented above.
|
// setAuthStyle adds an entry to authStyleCache, documented above.
|
||||||
func setAuthStyle(tokenURL string, v AuthStyle) {
|
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
|
||||||
authStyleCache.Lock()
|
c.mu.Lock()
|
||||||
defer authStyleCache.Unlock()
|
defer c.mu.Unlock()
|
||||||
if authStyleCache.m == nil {
|
if c.m == nil {
|
||||||
authStyleCache.m = make(map[string]AuthStyle)
|
c.m = make(map[string]AuthStyle)
|
||||||
}
|
}
|
||||||
authStyleCache.m[tokenURL] = v
|
c.m[tokenURL] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTokenRequest returns a new *http.Request to retrieve a new token
|
// newTokenRequest returns a new *http.Request to retrieve a new token
|
||||||
@@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
|
|||||||
return v2
|
return v2
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
|
||||||
needsAuthStyleProbe := authStyle == 0
|
needsAuthStyleProbe := authStyle == 0
|
||||||
if needsAuthStyleProbe {
|
if needsAuthStyleProbe {
|
||||||
if style, ok := lookupAuthStyle(tokenURL); ok {
|
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
|
||||||
authStyle = style
|
authStyle = style
|
||||||
needsAuthStyleProbe = false
|
needsAuthStyleProbe = false
|
||||||
} else {
|
} else {
|
||||||
@@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
|||||||
token, err = doTokenRoundTrip(ctx, req)
|
token, err = doTokenRoundTrip(ctx, req)
|
||||||
}
|
}
|
||||||
if needsAuthStyleProbe && err == nil {
|
if needsAuthStyleProbe && err == nil {
|
||||||
setAuthStyle(tokenURL, authStyle)
|
styleCache.setAuthStyle(tokenURL, authStyle)
|
||||||
}
|
}
|
||||||
// Don't overwrite `RefreshToken` with an empty value
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
// if this was a token refreshing request.
|
// if this was a token refreshing request.
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRetrieveToken_InParams(t *testing.T) {
|
func TestRetrieveToken_InParams(t *testing.T) {
|
||||||
ResetAuthCache()
|
styleCache := new(AuthStyleCache)
|
||||||
const clientID = "client-id"
|
const clientID = "client-id"
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if got, want := r.FormValue("client_id"), clientID; got != want {
|
if got, want := r.FormValue("client_id"), clientID; got != want {
|
||||||
@@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
|
|||||||
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
|
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
|
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RetrieveToken = %v; want no error", err)
|
t.Errorf("RetrieveToken = %v; want no error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRetrieveTokenWithContexts(t *testing.T) {
|
func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||||
ResetAuthCache()
|
styleCache := new(AuthStyleCache)
|
||||||
const clientID = "client-id"
|
const clientID = "client-id"
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
|
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
|
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
|
||||||
}
|
}
|
||||||
@@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel()
|
cancel()
|
||||||
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
|
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
|
||||||
close(retrieved)
|
close(retrieved)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")
|
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")
|
||||||
|
|||||||
@@ -58,6 +58,10 @@ type Config struct {
|
|||||||
|
|
||||||
// Scope specifies optional requested permissions.
|
// Scope specifies optional requested permissions.
|
||||||
Scopes []string
|
Scopes []string
|
||||||
|
|
||||||
|
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
|
||||||
|
// the zero value (AuthStyleAutoDetect).
|
||||||
|
authStyleCache internal.LazyAuthStyleCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// A TokenSource is anything that can return a token.
|
// A TokenSource is anything that can return a token.
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockTransport struct {
|
type mockTransport struct {
|
||||||
@@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
||||||
internal.ResetAuthCache()
|
|
||||||
tr := &mockTransport{
|
tr := &mockTransport{
|
||||||
rt: func(r *http.Request) (w *http.Response, err error) {
|
rt: func(r *http.Request) (w *http.Response, err error) {
|
||||||
headerAuth := r.Header.Get("Authorization")
|
headerAuth := r.Header.Get("Authorization")
|
||||||
@@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenRefreshRequest(t *testing.T) {
|
func TestTokenRefreshRequest(t *testing.T) {
|
||||||
internal.ResetAuthCache()
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.String() == "/somethingelse" {
|
if r.URL.String() == "/somethingelse" {
|
||||||
return
|
return
|
||||||
|
|||||||
2
token.go
2
token.go
@@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
|
|||||||
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
||||||
// with an error..
|
// with an error..
|
||||||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
||||||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle))
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if rErr, ok := err.(*internal.RetrieveError); ok {
|
if rErr, ok := err.(*internal.RetrieveError); ok {
|
||||||
return nil, (*RetrieveError)(rErr)
|
return nil, (*RetrieveError)(rErr)
|
||||||
|
|||||||
Reference in New Issue
Block a user