oauth2: auto-detect auth style by default, add Endpoint.AuthStyle
Instead of maintaining a global map of which OAuth2 servers do which auth style and/or requiring the user to tell us, just try both ways and remember which way worked. But if users want to tell us in the Endpoint, this CL also add Endpoint.AuthStyle. Fixes golang/oauth2#111 Fixes golang/oauth2#365 Fixes golang/oauth2#362 Fixes golang/oauth2#357 Fixes golang/oauth2#353 Fixes golang/oauth2#345 Fixes golang/oauth2#326 Fixes golang/oauth2#352 Fixes golang/oauth2#268 Fixes https://go-review.googlesource.com/c/oauth2/+/58510 (... and surely many more ...) Change-Id: I7b4d98ba1900ee2d3e11e629316b0bf867f7d237 Reviewed-on: https://go-review.googlesource.com/c/157820 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Ross Light <light@google.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
@@ -90,102 +91,71 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var brokenAuthHeaderProviders = []string{
|
||||
"https://accounts.google.com/",
|
||||
"https://api.codeswholesale.com/oauth/token",
|
||||
"https://api.dropbox.com/",
|
||||
"https://api.dropboxapi.com/",
|
||||
"https://api.instagram.com/",
|
||||
"https://api.netatmo.net/",
|
||||
"https://api.odnoklassniki.ru/",
|
||||
"https://api.pushbullet.com/",
|
||||
"https://api.soundcloud.com/",
|
||||
"https://api.twitch.tv/",
|
||||
"https://id.twitch.tv/",
|
||||
"https://app.box.com/",
|
||||
"https://api.box.com/",
|
||||
"https://connect.stripe.com/",
|
||||
"https://login.mailchimp.com/",
|
||||
"https://login.microsoftonline.com/",
|
||||
"https://login.salesforce.com/",
|
||||
"https://login.windows.net",
|
||||
"https://login.live.com/",
|
||||
"https://login.live-int.com/",
|
||||
"https://oauth.sandbox.trainingpeaks.com/",
|
||||
"https://oauth.trainingpeaks.com/",
|
||||
"https://oauth.vk.com/",
|
||||
"https://openapi.baidu.com/",
|
||||
"https://slack.com/",
|
||||
"https://test-sandbox.auth.corp.google.com",
|
||||
"https://test.salesforce.com/",
|
||||
"https://user.gini.net/",
|
||||
"https://www.douban.com/",
|
||||
"https://www.googleapis.com/",
|
||||
"https://www.linkedin.com/",
|
||||
"https://www.strava.com/oauth/",
|
||||
"https://www.wunderlist.com/oauth/",
|
||||
"https://api.patreon.com/",
|
||||
"https://sandbox.codeswholesale.com/oauth/token",
|
||||
"https://api.sipgate.com/v1/authorization/oauth",
|
||||
"https://api.medium.com/v1/tokens",
|
||||
"https://log.finalsurge.com/oauth/token",
|
||||
"https://multisport.todaysplan.com.au/rest/oauth/access_token",
|
||||
"https://whats.todaysplan.com.au/rest/oauth/access_token",
|
||||
"https://stackoverflow.com/oauth/access_token",
|
||||
"https://account.health.nokia.com",
|
||||
"https://accounts.zoho.com",
|
||||
"https://gitter.im/login/oauth/token",
|
||||
"https://openid-connect.onelogin.com/oidc",
|
||||
"https://api.dailymotion.com/oauth/token",
|
||||
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
||||
//
|
||||
// Deprecated: this function no longer does anything. Caller code that
|
||||
// wants to avoid potential extra HTTP requests made during
|
||||
// auto-probing of the provider's auth style should set
|
||||
// Endpoint.AuthStyle.
|
||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
||||
|
||||
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
||||
type AuthStyle int
|
||||
|
||||
const (
|
||||
AuthStyleUnknown AuthStyle = 0
|
||||
AuthStyleInParams AuthStyle = 1
|
||||
AuthStyleInHeader AuthStyle = 2
|
||||
)
|
||||
|
||||
// authStyleCache is the set of tokenURLs we've successfully used via
|
||||
// RetrieveToken and which style auth we ended up using.
|
||||
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
||||
// the set of OAuth2 servers a program contacts over time is fixed and
|
||||
// small.
|
||||
var authStyleCache struct {
|
||||
sync.Mutex
|
||||
m map[string]AuthStyle // keyed by tokenURL
|
||||
}
|
||||
|
||||
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
|
||||
var brokenAuthHeaderDomains = []string{
|
||||
".auth0.com",
|
||||
".force.com",
|
||||
".myshopify.com",
|
||||
".okta.com",
|
||||
".oktapreview.com",
|
||||
// ResetAuthCache resets the global authentication style cache used
|
||||
// for AuthStyleUnknown token requests.
|
||||
func ResetAuthCache() {
|
||||
authStyleCache.Lock()
|
||||
defer authStyleCache.Unlock()
|
||||
authStyleCache.m = nil
|
||||
}
|
||||
|
||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
||||
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
|
||||
// lookupAuthStyle reports which auth style we last used with tokenURL
|
||||
// when calling RetrieveToken and whether we have ever done so.
|
||||
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
||||
authStyleCache.Lock()
|
||||
defer authStyleCache.Unlock()
|
||||
style, ok = authStyleCache.m[tokenURL]
|
||||
return
|
||||
}
|
||||
|
||||
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
|
||||
// implements the OAuth2 spec correctly
|
||||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||
// In summary:
|
||||
// - Reddit only accepts client secret in the Authorization header
|
||||
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
||||
// - Google only accepts URL param (not spec compliant?), not Auth header
|
||||
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
||||
func providerAuthHeaderWorks(tokenURL string) bool {
|
||||
for _, s := range brokenAuthHeaderProviders {
|
||||
if strings.HasPrefix(tokenURL, s) {
|
||||
// Some sites fail to implement the OAuth2 spec fully.
|
||||
return false
|
||||
}
|
||||
// setAuthStyle adds an entry to authStyleCache, documented above.
|
||||
func setAuthStyle(tokenURL string, v AuthStyle) {
|
||||
authStyleCache.Lock()
|
||||
defer authStyleCache.Unlock()
|
||||
if authStyleCache.m == nil {
|
||||
authStyleCache.m = make(map[string]AuthStyle)
|
||||
}
|
||||
|
||||
if u, err := url.Parse(tokenURL); err == nil {
|
||||
for _, s := range brokenAuthHeaderDomains {
|
||||
if strings.HasSuffix(u.Host, s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assume the provider implements the spec properly
|
||||
// otherwise. We can add more exceptions as they're
|
||||
// discovered. We will _not_ be adding configurable hooks
|
||||
// to this package to let users select server bugs.
|
||||
return true
|
||||
authStyleCache.m[tokenURL] = v
|
||||
}
|
||||
|
||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
|
||||
bustedAuth := !providerAuthHeaderWorks(tokenURL)
|
||||
if bustedAuth {
|
||||
// newTokenRequest returns a new *http.Request to retrieve a new token
|
||||
// from tokenURL using the provided clientID, clientSecret, and POST
|
||||
// body parameters.
|
||||
//
|
||||
// inParams is whether the clientID & clientSecret should be encoded
|
||||
// as the POST body. An 'inParams' value of true means to send it in
|
||||
// the POST body (along with any values in v); false means to send it
|
||||
// in the Authorization header.
|
||||
func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
|
||||
if authStyle == AuthStyleInParams {
|
||||
v = cloneURLValues(v)
|
||||
if clientID != "" {
|
||||
v.Set("client_id", clientID)
|
||||
}
|
||||
@@ -198,15 +168,70 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if !bustedAuth {
|
||||
if authStyle == AuthStyleInHeader {
|
||||
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func cloneURLValues(v url.Values) url.Values {
|
||||
v2 := make(url.Values, len(v))
|
||||
for k, vv := range v {
|
||||
v2[k] = append([]string(nil), vv...)
|
||||
}
|
||||
return v2
|
||||
}
|
||||
|
||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
|
||||
needsAuthStyleProbe := authStyle == 0
|
||||
if needsAuthStyleProbe {
|
||||
if style, ok := lookupAuthStyle(tokenURL); ok {
|
||||
authStyle = style
|
||||
needsAuthStyleProbe = false
|
||||
} else {
|
||||
authStyle = AuthStyleInHeader // the first way we'll try
|
||||
}
|
||||
}
|
||||
req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := doTokenRoundTrip(ctx, req)
|
||||
if err != nil && needsAuthStyleProbe {
|
||||
// If we get an error, assume the server wants the
|
||||
// clientID & clientSecret in a different form.
|
||||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||
// In summary:
|
||||
// - Reddit only accepts client secret in the Authorization header
|
||||
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
||||
// - Google only accepts URL param (not spec compliant?), not Auth header
|
||||
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
||||
//
|
||||
// We used to maintain a big table in this code of all the sites and which way
|
||||
// they went, but maintaining it didn't scale & got annoying.
|
||||
// So just try both ways.
|
||||
authStyle = AuthStyleInParams // the second way we'll try
|
||||
req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
|
||||
token, err = doTokenRoundTrip(ctx, req)
|
||||
}
|
||||
if needsAuthStyleProbe && err == nil {
|
||||
setAuthStyle(tokenURL, authStyle)
|
||||
}
|
||||
// Don't overwrite `RefreshToken` with an empty value
|
||||
// if this was a token refreshing request.
|
||||
if token != nil && token.RefreshToken == "" {
|
||||
token.RefreshToken = v.Get("refresh_token")
|
||||
}
|
||||
return token, err
|
||||
}
|
||||
|
||||
func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
|
||||
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Body.Close()
|
||||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
@@ -256,13 +281,8 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||
}
|
||||
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
||||
}
|
||||
// Don't overwrite `RefreshToken` with an empty value
|
||||
// if this was a token refreshing request.
|
||||
if token.RefreshToken == "" {
|
||||
token.RefreshToken = v.Get("refresh_token")
|
||||
}
|
||||
if token.AccessToken == "" {
|
||||
return token, errors.New("oauth2: server response missing access_token")
|
||||
return nil, errors.New("oauth2: server response missing access_token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -14,17 +13,9 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegisterBrokenAuthHeaderProvider(t *testing.T) {
|
||||
RegisterBrokenAuthHeaderProvider("https://aaa.com/")
|
||||
tokenURL := "https://aaa.com/token"
|
||||
if providerAuthHeaderWorks(tokenURL) {
|
||||
t.Errorf("got %q as unbroken; want broken", tokenURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveTokenBustedNoSecret(t *testing.T) {
|
||||
func TestRetrieveToken_InParams(t *testing.T) {
|
||||
ResetAuthCache()
|
||||
const clientID = "client-id"
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got, want := r.FormValue("client_id"), clientID; got != want {
|
||||
t.Errorf("client_id = %q; want %q", got, want)
|
||||
@@ -36,52 +27,14 @@ func TestRetrieveTokenBustedNoSecret(t *testing.T) {
|
||||
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
RegisterBrokenAuthHeaderProvider(ts.URL)
|
||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{})
|
||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
|
||||
if err != nil {
|
||||
t.Errorf("RetrieveToken = %v; want no error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_providerAuthHeaderWorks(t *testing.T) {
|
||||
for _, p := range brokenAuthHeaderProviders {
|
||||
if providerAuthHeaderWorks(p) {
|
||||
t.Errorf("got %q as unbroken; want broken", p)
|
||||
}
|
||||
p := fmt.Sprintf("%ssomesuffix", p)
|
||||
if providerAuthHeaderWorks(p) {
|
||||
t.Errorf("got %q as unbroken; want broken", p)
|
||||
}
|
||||
}
|
||||
p := "https://api.not-in-the-list-example.com/"
|
||||
if !providerAuthHeaderWorks(p) {
|
||||
t.Errorf("got %q as unbroken; want broken", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderAuthHeaderWorksDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
tokenURL string
|
||||
wantWorks bool
|
||||
}{
|
||||
{"https://dev-12345.okta.com/token-url", false},
|
||||
{"https://dev-12345.oktapreview.com/token-url", false},
|
||||
{"https://dev-12345.okta.org/token-url", true},
|
||||
{"https://foo.bar.force.com/token-url", false},
|
||||
{"https://foo.force.com/token-url", false},
|
||||
{"https://force.com/token-url", true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
got := providerAuthHeaderWorks(test.tokenURL)
|
||||
if got != test.wantWorks {
|
||||
t.Errorf("providerAuthHeaderWorks(%q) = %v; want %v", test.tokenURL, got, test.wantWorks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||
ResetAuthCache()
|
||||
const clientID = "client-id"
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -90,7 +43,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{})
|
||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
|
||||
if err != nil {
|
||||
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
|
||||
}
|
||||
@@ -103,7 +56,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{})
|
||||
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
|
||||
close(retrieved)
|
||||
if err == nil {
|
||||
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")
|
||||
|
||||
Reference in New Issue
Block a user