From 9b6b7610ad5cd4b1b578a11438aaa693e36c5db9 Mon Sep 17 00:00:00 2001 From: Burcu Dogan Date: Wed, 10 Dec 2014 23:30:13 -0800 Subject: [PATCH] oauth2: rewrite google package, fix the broken build Change-Id: I2753a88d7be483bdbc0cac09a1beccc4806ea4bc Reviewed-on: https://go-review.googlesource.com/1361 Reviewed-by: Brad Fitzpatrick Reviewed-by: Andrew Gerrand --- example_test.go | 57 +++----- google/appengine.go | 113 +++------------- google/appengine_test.go | 266 ------------------------------------- google/appenginevm.go | 104 +++------------ google/appenginevm_test.go | 265 ------------------------------------ google/example_test.go | 109 +++++++-------- google/google.go | 162 +++++++++++----------- google/source_appengine.go | 68 ++++++++++ internal/oauth2.go | 1 + jwt.go | 12 +- jwt_test.go | 55 +++----- oauth2.go | 9 +- oauth2_test.go | 110 +++++++-------- transport_test.go | 32 ++--- 14 files changed, 337 insertions(+), 1026 deletions(-) delete mode 100644 google/appengine_test.go delete mode 100644 google/appenginevm_test.go create mode 100644 google/source_appengine.go diff --git a/example_test.go b/example_test.go index fb8dd8e..cb4726f 100644 --- a/example_test.go +++ b/example_test.go @@ -7,7 +7,6 @@ package oauth2_test import ( "fmt" "log" - "net/http" "testing" "golang.org/x/oauth2" @@ -17,23 +16,20 @@ import ( // Related to https://codereview.appspot.com/107320046 func TestA(t *testing.T) {} -func Example_regular() { - opts, err := oauth2.New( - oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"), - oauth2.RedirectURL("YOUR_REDIRECT_URL"), - oauth2.Scope("SCOPE1", "SCOPE2"), - oauth2.Endpoint( - "https://provider.com/o/oauth2/auth", - "https://provider.com/o/oauth2/token", - ), - ) - if err != nil { - log.Fatal(err) +func ExampleConfig() { + conf := &oauth2.Config{ + ClientID: "YOUR_CLIENT_ID", + ClientSecret: "YOUR_CLIENT_SECRET", + Scopes: []string{"SCOPE1", "SCOPE2"}, + Endpoint: oauth2.Endpoint{ + AuthURL: "https://provider.com/o/oauth2/auth", + TokenURL: "https://provider.com/o/oauth2/token", + }, } // Redirect user to consent page to ask for permission // for the scopes specified above. - url := opts.AuthCodeURL("state", "online", "auto") + url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline) fmt.Printf("Visit the URL for the auth dialog: %v", url) // Use the authorization code that is pushed to the redirect URL. @@ -41,22 +37,22 @@ func Example_regular() { // an access token and initiate a Transport that is // authorized and authenticated by the retrieved token. var code string - if _, err = fmt.Scan(&code); err != nil { + if _, err := fmt.Scan(&code); err != nil { log.Fatal(err) } - t, err := opts.NewTransportFromCode(code) + tok, err := conf.Exchange(oauth2.NoContext, code) if err != nil { log.Fatal(err) } - // You can use t to initiate a new http.Client and - // start making authenticated requests. - client := http.Client{Transport: t} + client := conf.Client(oauth2.NoContext, tok) client.Get("...") } -func Example_jWT() { - opts, err := oauth2.New( +func ExampleJWTConfig() { + var initialToken *oauth2.Token // nil means no initial token + conf := &oauth2.JWTConfig{ + Email: "xxx@developer.com", // The contents of your RSA private key or your PEM file // that contains a private key. // If you have a p12 file instead, you @@ -65,23 +61,12 @@ func Example_jWT() { // $ openssl pkcs12 -in key.p12 -out key.pem -nodes // // It only supports PEM containers with no passphrase. - oauth2.JWTClient( - "xxx@developer.gserviceaccount.com", - []byte("-----BEGIN RSA PRIVATE KEY-----...")), - oauth2.Scope("SCOPE1", "SCOPE2"), - oauth2.JWTEndpoint("https://provider.com/o/oauth2/token"), - // If you would like to impersonate a user, you can - // create a transport with a subject. The following GET - // request will be made on the behalf of user@example.com. - // Subject is optional. - oauth2.Subject("user@example.com"), - ) - if err != nil { - log.Fatal(err) + PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), + Subject: "user@example.com", + TokenURL: "https://provider.com/o/oauth2/token", } - // Initiate an http.Client, the following GET request will be // authorized and authenticated on the behalf of user@example.com. - client := http.Client{Transport: opts.NewTransport()} + client := conf.Client(oauth2.NoContext, initialToken) client.Get("...") } diff --git a/google/appengine.go b/google/appengine.go index 0502693..c6213d9 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -7,108 +7,31 @@ package google import ( - "net/http" - "strings" - "sync" "time" - "golang.org/x/oauth2" - "appengine" - "appengine/memcache" - "appengine/urlfetch" + + "golang.org/x/oauth2" ) -var ( - // memcacheGob enables mocking of the memcache.Gob calls for unit testing. - memcacheGob memcacher = &aeMemcache{} - - // accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. - accessTokenFunc = appengine.AccessToken - - // mu protects multiple threads from attempting to fetch a token at the same time. - mu sync.Mutex - - // tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. - tokens map[string]*oauth2.Token -) - -// safetyMargin is used to avoid clock-skew problems. -// 5 minutes is conservative because tokens are valid for 60 minutes. -const safetyMargin = 5 * time.Minute - -func init() { - tokens = make(map[string]*oauth2.Token) -} - -// AppEngineContext requires an App Engine request context. -func AppEngineContext(ctx appengine.Context) oauth2.Option { - return func(opts *oauth2.Options) error { - opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts) - opts.Client = &http.Client{ - Transport: &urlfetch.Transport{Context: ctx}, - } - return nil +// AppEngineTokenSource returns a token source that fetches tokens +// issued to the current App Engine application's service account. +// If you are implementing a 3-legged OAuth 2.0 flow on App Engine +// that involves user accounts, see oauth2.Config instead. +// +// You are required to provide a valid appengine.Context as context. +func AppEngineTokenSource(ctx appengine.Context, scope ...string) oauth2.TokenSource { + return &appEngineTokenSource{ + ctx: ctx, + scopes: scope, + fetcherFunc: aeFetcherFunc, } } -// FetchToken fetches a new access token for the provided scopes. -// Tokens are cached locally and also with Memcache so that the app can scale -// without hitting quota limits by calling appengine.AccessToken too frequently. -func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) { - return func(existing *oauth2.Token) (*oauth2.Token, error) { - mu.Lock() - defer mu.Unlock() - - key := ":" + strings.Join(opts.Scopes, "_") - now := time.Now().Add(safetyMargin) - if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { - return t, nil - } - delete(tokens, key) - - // Attempt to get token from Memcache - tok := new(oauth2.Token) - _, err := memcacheGob.Get(ctx, key, tok) - if err == nil && !tok.Expiry.Before(now) { - tokens[key] = tok // Save token locally - return tok, nil - } - - token, expiry, err := accessTokenFunc(ctx, opts.Scopes...) - if err != nil { - return nil, err - } - t := &oauth2.Token{ - AccessToken: token, - Expiry: expiry, - } - tokens[key] = t - // Also back up token in Memcache - if err = memcacheGob.Set(ctx, &memcache.Item{ - Key: key, - Value: []byte{}, - Object: *t, - Expiration: expiry.Sub(now), - }); err != nil { - ctx.Errorf("unexpected memcache.Set error: %v", err) - } - return t, nil +var aeFetcherFunc = func(ctx oauth2.Context, scope ...string) (string, time.Time, error) { + c, ok := ctx.(appengine.Context) + if !ok { + return "", time.Time{}, errInvalidContext } -} - -// aeMemcache wraps the needed Memcache functionality to make it easy to mock -type aeMemcache struct{} - -func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { - return memcache.Gob.Get(c, key, tok) -} - -func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error { - return memcache.Gob.Set(c, item) -} - -type memcacher interface { - Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) - Set(c appengine.Context, item *memcache.Item) error + return appengine.AccessToken(c, scope...) } diff --git a/google/appengine_test.go b/google/appengine_test.go deleted file mode 100644 index 2c07ce4..0000000 --- a/google/appengine_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2014 The oauth2 Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build appengine,!appenginevm - -package google - -import ( - "fmt" - "log" - "net/http" - "sync" - "testing" - "time" - - "golang.org/x/oauth2" - - "appengine" - "appengine/memcache" -) - -type tokMap map[string]*oauth2.Token - -type mockMemcache struct { - mu sync.RWMutex - vals tokMap - getCount, setCount int -} - -func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.getCount++ - v, ok := m.vals[key] - if !ok { - return nil, fmt.Errorf("unexpected test error: key %q not found", key) - } - *tok = *v - return nil, nil // memcache.Item is ignored anyway - return nil -} - -func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error { - m.mu.Lock() - defer m.mu.Unlock() - m.setCount++ - tok, ok := item.Object.(oauth2.Token) - if !ok { - log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item) - } - m.vals[item.Key] = &tok - return nil -} - -var accessTokenCount = 0 - -func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) { - accessTokenCount++ - return "mytoken", time.Now(), nil -} - -const ( - testScope = "myscope" - testScopeKey = ":" + testScope -) - -func init() { - accessTokenFunc = mockAccessToken -} - -func TestFetchTokenLocalCacheMiss(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - delete(tokens, testScopeKey) // clear local cache - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 1; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 1; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache has been populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenLocalCacheHit(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - // Pre-populate the local cache - tokens[testScopeKey] = &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(1 * time.Hour), - } - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get("server") - if err != nil { - t.Errorf("unable to FetchToken: %v", err) - } - if w := 0; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 0; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 0; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache remains populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenMemcacheHit(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - delete(tokens, testScopeKey) // clear local cache - // Pre-populate the memcache - tok := &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(1 * time.Hour), - } - m.Set(nil, &memcache.Item{ - Key: testScopeKey, - Object: *tok, - Expiration: 1 * time.Hour, - }) - m.setCount = 0 - - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - c := http.Client{Transport: f.NewTransport()} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 0; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 0; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache has been populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenLocalCacheExpired(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - // Pre-populate the local cache - tokens[testScopeKey] = &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(-1 * time.Hour), - } - // Pre-populate the memcache - tok := &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(1 * time.Hour), - } - m.Set(nil, &memcache.Item{ - Key: testScopeKey, - Object: *tok, - Expiration: 1 * time.Hour, - }) - m.setCount = 0 - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - c := http.Client{Transport: f.NewTransport()} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 0; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 0; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache remains populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenMemcacheExpired(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - delete(tokens, testScopeKey) // clear local cache - // Pre-populate the memcache - tok := &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(-1 * time.Hour), - } - m.Set(nil, &memcache.Item{ - Key: testScopeKey, - Object: *tok, - Expiration: -1 * time.Hour, - }) - m.setCount = 0 - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - c := http.Client{Transport: f.NewTransport()} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 1; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 1; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache has been populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} diff --git a/google/appenginevm.go b/google/appenginevm.go index ce2b1bd..12af742 100644 --- a/google/appenginevm.go +++ b/google/appenginevm.go @@ -7,102 +7,30 @@ package google import ( - "strings" - "sync" "time" "golang.org/x/oauth2" "google.golang.org/appengine" - "google.golang.org/appengine/memcache" ) -var ( - // memcacheGob enables mocking of the memcache.Gob calls for unit testing. - memcacheGob memcacher = &aeMemcache{} - - // accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. - accessTokenFunc = appengine.AccessToken - - // mu protects multiple threads from attempting to fetch a token at the same time. - mu sync.Mutex - - // tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. - tokens map[string]*oauth2.Token -) - -// safetyMargin is used to avoid clock-skew problems. -// 5 minutes is conservative because tokens are valid for 60 minutes. -const safetyMargin = 5 * time.Minute - -func init() { - tokens = make(map[string]*oauth2.Token) -} - -// AppEngineContext requires an App Engine request context. -func AppEngineContext(ctx appengine.Context) oauth2.Option { - return func(opts *oauth2.Options) error { - opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts) - return nil +// AppEngineTokenSource returns a token source that fetches tokens +// issued to the current App Engine application's service account. +// If you are implementing a 3-legged OAuth 2.0 flow on App Engine +// that involves user accounts, see oauth2.Config instead. +// +// You are required to provide a valid appengine.Context as context. +func AppEngineTokenSource(ctx appengine.Context, scope ...string) oauth2.TokenSource { + return &appEngineTokenSource{ + ctx: ctx, + scopes: scope, + fetcherFunc: aeVMFetcherFunc, } } -// FetchToken fetches a new access token for the provided scopes. -// Tokens are cached locally and also with Memcache so that the app can scale -// without hitting quota limits by calling appengine.AccessToken too frequently. -func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) { - return func(existing *oauth2.Token) (*oauth2.Token, error) { - mu.Lock() - defer mu.Unlock() - - key := ":" + strings.Join(opts.Scopes, "_") - now := time.Now().Add(safetyMargin) - if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { - return t, nil - } - delete(tokens, key) - - // Attempt to get token from Memcache - tok := new(oauth2.Token) - _, err := memcacheGob.Get(ctx, key, tok) - if err == nil && !tok.Expiry.Before(now) { - tokens[key] = tok // Save token locally - return tok, nil - } - - token, expiry, err := accessTokenFunc(ctx, opts.Scopes...) - if err != nil { - return nil, err - } - t := &oauth2.Token{ - AccessToken: token, - Expiry: expiry, - } - tokens[key] = t - // Also back up token in Memcache - if err = memcacheGob.Set(ctx, &memcache.Item{ - Key: key, - Value: []byte{}, - Object: *t, - Expiration: expiry.Sub(now), - }); err != nil { - ctx.Errorf("unexpected memcache.Set error: %v", err) - } - return t, nil +var aeVMFetcherFunc = func(ctx oauth2.Context, scope ...string) (string, time.Time, error) { + c, ok := ctx.(appengine.Context) + if !ok { + return "", time.Time{}, errInvalidContext } -} - -// aeMemcache wraps the needed Memcache functionality to make it easy to mock -type aeMemcache struct{} - -func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { - return memcache.Gob.Get(c, key, tok) -} - -func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error { - return memcache.Gob.Set(c, item) -} - -type memcacher interface { - Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) - Set(c appengine.Context, item *memcache.Item) error + return appengine.AccessToken(c, scope...) } diff --git a/google/appenginevm_test.go b/google/appenginevm_test.go deleted file mode 100644 index 3ca4b0d..0000000 --- a/google/appenginevm_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright 2014 The oauth2 Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build appenginevm !appengine - -package google - -import ( - "fmt" - "log" - "net/http" - "sync" - "testing" - "time" - - "golang.org/x/oauth2" - "google.golang.org/appengine" - "google.golang.org/appengine/memcache" -) - -type tokMap map[string]*oauth2.Token - -type mockMemcache struct { - mu sync.RWMutex - vals tokMap - getCount, setCount int -} - -func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.getCount++ - v, ok := m.vals[key] - if !ok { - return nil, fmt.Errorf("unexpected test error: key %q not found", key) - } - *tok = *v - return nil, nil // memcache.Item is ignored anyway - return nil -} - -func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error { - m.mu.Lock() - defer m.mu.Unlock() - m.setCount++ - tok, ok := item.Object.(oauth2.Token) - if !ok { - log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item) - } - m.vals[item.Key] = &tok - return nil -} - -var accessTokenCount = 0 - -func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) { - accessTokenCount++ - return "mytoken", time.Now(), nil -} - -const ( - testScope = "myscope" - testScopeKey = ":" + testScope -) - -func init() { - accessTokenFunc = mockAccessToken -} - -func TestFetchTokenLocalCacheMiss(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - delete(tokens, testScopeKey) // clear local cache - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 1; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 1; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache has been populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenLocalCacheHit(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - // Pre-populate the local cache - tokens[testScopeKey] = &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(1 * time.Hour), - } - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get("server") - if err != nil { - t.Errorf("unable to FetchToken: %v", err) - } - if w := 0; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 0; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 0; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache remains populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenMemcacheHit(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - delete(tokens, testScopeKey) // clear local cache - // Pre-populate the memcache - tok := &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(1 * time.Hour), - } - m.Set(nil, &memcache.Item{ - Key: testScopeKey, - Object: *tok, - Expiration: 1 * time.Hour, - }) - m.setCount = 0 - - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - c := http.Client{Transport: f.NewTransport()} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 0; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 0; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache has been populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenLocalCacheExpired(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - // Pre-populate the local cache - tokens[testScopeKey] = &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(-1 * time.Hour), - } - // Pre-populate the memcache - tok := &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(1 * time.Hour), - } - m.Set(nil, &memcache.Item{ - Key: testScopeKey, - Object: *tok, - Expiration: 1 * time.Hour, - }) - m.setCount = 0 - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - c := http.Client{Transport: f.NewTransport()} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 0; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 0; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache remains populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} - -func TestFetchTokenMemcacheExpired(t *testing.T) { - m := &mockMemcache{vals: make(tokMap)} - memcacheGob = m - accessTokenCount = 0 - delete(tokens, testScopeKey) // clear local cache - // Pre-populate the memcache - tok := &oauth2.Token{ - AccessToken: "mytoken", - Expiry: time.Now().Add(-1 * time.Hour), - } - m.Set(nil, &memcache.Item{ - Key: testScopeKey, - Object: *tok, - Expiration: -1 * time.Hour, - }) - m.setCount = 0 - f, err := oauth2.New( - AppEngineContext(nil), - oauth2.Scope(testScope), - ) - if err != nil { - t.Error(err) - } - c := http.Client{Transport: f.NewTransport()} - c.Get("server") - if w := 1; m.getCount != w { - t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) - } - if w := 1; accessTokenCount != w { - t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) - } - if w := 1; m.setCount != w { - t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) - } - // Make sure local cache has been populated - _, ok := tokens[testScopeKey] - if !ok { - t.Errorf("local cache not populated!") - } -} diff --git a/google/example_test.go b/google/example_test.go index 9fec175..31ff67a 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -8,6 +8,7 @@ package google_test import ( "fmt" + "io/ioutil" "log" "net/http" "testing" @@ -15,6 +16,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/appengine" + "google.golang.org/appengine/urlfetch" ) // Remove after Go 1.4. @@ -24,33 +26,31 @@ func TestA(t *testing.T) {} func Example_webServer() { // Your credentials should be obtained from the Google // Developer Console (https://console.developers.google.com). - opts, err := oauth2.New( - oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"), - oauth2.RedirectURL("YOUR_REDIRECT_URL"), - oauth2.Scope( + conf := &oauth2.Config{ + ClientID: "YOUR_CLIENT_ID", + ClientSecret: "YOUR_CLIENT_SECRET", + RedirectURL: "YOUR_REDIRECT_URL", + Scopes: []string{ "https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/blogger", - ), - google.Endpoint(), - ) - if err != nil { - log.Fatal(err) + }, + Endpoint: google.Endpoint, } // Redirect user to Google's consent page to ask for permission // for the scopes specified above. - url := opts.AuthCodeURL("state", "online", "auto") + url := conf.AuthCodeURL("state") fmt.Printf("Visit the URL for the auth dialog: %v", url) - // Handle the exchange code to initiate a transport - t, err := opts.NewTransportFromCode("exchange-code") + // Handle the exchange code to initiate a transport. + tok, err := conf.Exchange(oauth2.NoContext, "authorization-code") if err != nil { log.Fatal(err) } - client := http.Client{Transport: t} + client := conf.Client(oauth2.NoContext, tok) client.Get("...") } -func Example_serviceAccountsJSON() { +func ExampleJWTConfigFromJSON() { // Your credentials should be obtained from the Google // Developer Console (https://console.developers.google.com). // Navigate to your project, then see the "Credentials" page @@ -58,27 +58,26 @@ func Example_serviceAccountsJSON() { // To create a service account client, click "Create new Client ID", // select "Service Account", and click "Create Client ID". A JSON // key file will then be downloaded to your computer. - opts, err := oauth2.New( - google.ServiceAccountJSONKey("/path/to/your-project-key.json"), - oauth2.Scope( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/blogger", - ), - ) + data, err := ioutil.ReadFile("/path/to/your-project-key.json") + if err != nil { + log.Fatal(err) + } + conf, err := google.JWTConfigFromJSON(oauth2.NoContext, data, "https://www.googleapis.com/auth/bigquery") if err != nil { log.Fatal(err) } // Initiate an http.Client. The following GET request will be // authorized and authenticated on the behalf of // your service account. - client := http.Client{Transport: opts.NewTransport()} + client := conf.Client(oauth2.NoContext, nil) client.Get("...") } -func Example_serviceAccounts() { +func Example_serviceAccount() { // Your credentials should be obtained from the Google // Developer Console (https://console.developers.google.com). - opts, err := oauth2.New( + conf := &oauth2.JWTConfig{ + Email: "xxx@developer.gserviceaccount.com", // The contents of your RSA private key or your PEM file // that contains a private key. // If you have a p12 file instead, you @@ -87,58 +86,46 @@ func Example_serviceAccounts() { // $ openssl pkcs12 -in key.p12 -out key.pem -nodes // // It only supports PEM containers with no passphrase. - oauth2.JWTClient( - "xxx@developer.gserviceaccount.com", - []byte("-----BEGIN RSA PRIVATE KEY-----...")), - oauth2.Scope( + PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), + Scopes: []string{ "https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/blogger", - ), - google.JWTEndpoint(), + }, + TokenURL: google.JWTTokenURL, // If you would like to impersonate a user, you can // create a transport with a subject. The following GET // request will be made on the behalf of user@example.com. - // Subject is optional. - oauth2.Subject("user@example.com"), - ) - if err != nil { - log.Fatal(err) + // Optional. + Subject: "user@example.com", } - // Initiate an http.Client, the following GET request will be // authorized and authenticated on the behalf of user@example.com. - client := http.Client{Transport: opts.NewTransport()} + client := conf.Client(oauth2.NoContext, nil) client.Get("...") } -func Example_appEngine() { - ctx := appengine.NewContext(nil) - opts, err := oauth2.New( - google.AppEngineContext(ctx), - oauth2.Scope( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/blogger", - ), - ) - if err != nil { - log.Fatal(err) +func ExampleAppEngineTokenSource() { + var req *http.Request // from the ServeHTTP handler + ctx := appengine.NewContext(req) + client := &http.Client{ + Transport: &oauth2.Transport{ + Source: google.AppEngineTokenSource(ctx, "https://www.googleapis.com/auth/bigquery"), + Base: &urlfetch.Transport{ + Context: ctx, + }, + }, } - // The following client will be authorized by the App Engine - // app's service account for the provided scopes. - client := http.Client{Transport: opts.NewTransport()} client.Get("...") } -func Example_computeEngine() { - opts, err := oauth2.New( - // Query Google Compute Engine's metadata server to retrieve - // an access token for the provided account. - // If no account is specified, "default" is used. - google.ComputeEngineAccount(""), - ) - if err != nil { - log.Fatal(err) +func ExampleComputeTokenSource() { + client := &http.Client{ + Transport: &oauth2.Transport{ + // Fetch from Google Compute Engine's metadata server to retrieve + // an access token for the provided account. + // If no account is specified, "default" is used. + Source: google.ComputeTokenSource(""), + }, } - client := http.Client{Transport: opts.NewTransport()} client.Get("...") } diff --git a/google/google.go b/google/google.go index 8256e2c..4890776 100644 --- a/google/google.go +++ b/google/google.go @@ -17,19 +17,41 @@ import ( "encoding/json" "fmt" - "io/ioutil" + "net" "net/http" - "net/url" "time" "golang.org/x/oauth2" - "golang.org/x/oauth2/internal" ) -var ( - uriGoogleAuth, _ = url.Parse("https://accounts.google.com/o/oauth2/auth") - uriGoogleToken, _ = url.Parse("https://accounts.google.com/o/oauth2/token") -) +// Endpoint is Google's OAuth 2.0 endpoint. +var Endpoint = oauth2.Endpoint{ + AuthURL: "https://accounts.google.com/o/oauth2/auth", + TokenURL: "https://accounts.google.com/o/oauth2/token", +} + +// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow. +const JWTTokenURL = "https://accounts.google.com/o/oauth2/token" + +// JWTConfigFromJSON uses a Google Developers service account JSON key file to read +// the credentials that authorize and authenticate the requests. +// Create a service account on "Credentials" page under "APIs & Auth" for your +// project at https://console.developers.google.com to download a JSON key file. +func JWTConfigFromJSON(ctx oauth2.Context, jsonKey []byte, scope ...string) (*oauth2.JWTConfig, error) { + var key struct { + Email string `json:"client_email"` + PrivateKey string `json:"private_key"` + } + if err := json.Unmarshal(jsonKey, &key); err != nil { + return nil, err + } + return &oauth2.JWTConfig{ + Email: key.Email, + PrivateKey: []byte(key.PrivateKey), + Scopes: scope, + TokenURL: JWTTokenURL, + }, nil +} type metaTokenRespBody struct { AccessToken string `json:"access_token"` @@ -37,93 +59,57 @@ type metaTokenRespBody struct { TokenType string `json:"token_type"` } -// JWTEndpoint adds the endpoints required to complete the 2-legged service account flow. -func JWTEndpoint() oauth2.Option { - return func(opts *oauth2.Options) error { - opts.AUD = uriGoogleToken - return nil - } +// ComputeTokenSource returns a token source that fetches access tokens +// from Google Compute Engine (GCE)'s metadata server. It's only valid to use +// this token source if your program is running on a GCE instance. +// If no account is specified, "default" is used. +// Further information about retrieving access tokens from the GCE metadata +// server can be found at https://cloud.google.com/compute/docs/authentication. +func ComputeTokenSource(account string) oauth2.TokenSource { + return &computeSource{account: account} } -// Endpoint adds the endpoints required to do the 3-legged Web server flow. -func Endpoint() oauth2.Option { - return func(opts *oauth2.Options) error { - opts.AuthURL = uriGoogleAuth - opts.TokenURL = uriGoogleToken - return nil - } +type computeSource struct { + account string } -// ComputeEngineAccount uses the specified account to retrieve an access -// token from the Google Compute Engine's metadata server. If no user is -// provided, "default" is being used. -func ComputeEngineAccount(account string) oauth2.Option { - return func(opts *oauth2.Options) error { - if account == "" { - account = "default" - } - opts.TokenFetcherFunc = makeComputeFetcher(opts, account) - return nil - } +var metaClient = &http.Client{ + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 750 * time.Millisecond, + KeepAlive: 30 * time.Second, + }).Dial, + ResponseHeaderTimeout: 750 * time.Millisecond, + }, } -// ServiceAccountJSONKey uses the provided Google Developers -// JSON key file to authorize the user. See the "Credentials" page under -// "APIs & Auth" for your project at https://console.developers.google.com -// to download a JSON key file. -func ServiceAccountJSONKey(filename string) oauth2.Option { - return func(opts *oauth2.Options) error { - b, err := ioutil.ReadFile(filename) - if err != nil { - return err - } - var key struct { - Email string `json:"client_email"` - PrivateKey string `json:"private_key"` - } - if err := json.Unmarshal(b, &key); err != nil { - return err - } - pk, err := internal.ParseKey([]byte(key.PrivateKey)) - if err != nil { - return err - } - opts.Email = key.Email - opts.PrivateKey = pk - opts.AUD = uriGoogleToken - return nil +func (cs *computeSource) Token() (*oauth2.Token, error) { + acct := cs.account + if acct == "" { + acct = "default" } -} - -func makeComputeFetcher(opts *oauth2.Options, account string) func(*oauth2.Token) (*oauth2.Token, error) { - return func(t *oauth2.Token) (*oauth2.Token, error) { - u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token" - req, err := http.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - req.Header.Add("X-Google-Metadata-Request", "True") - c := &http.Client{} - if opts.Client != nil { - c = opts.Client - } - resp, err := c.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode > 299 { - return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode) - } - var tokenResp metaTokenRespBody - err = json.NewDecoder(resp.Body).Decode(&tokenResp) - if err != nil { - return nil, err - } - return &oauth2.Token{ - AccessToken: tokenResp.AccessToken, - TokenType: tokenResp.TokenType, - Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second), - }, nil + u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + acct + "/token" + req, err := http.NewRequest("GET", u, nil) + if err != nil { + return nil, err } + req.Header.Add("X-Google-Metadata-Request", "True") + resp, err := metaClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode) + } + var tokenResp metaTokenRespBody + err = json.NewDecoder(resp.Body).Decode(&tokenResp) + if err != nil { + return nil, err + } + return &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second), + }, nil } diff --git a/google/source_appengine.go b/google/source_appengine.go new file mode 100644 index 0000000..9b8aa97 --- /dev/null +++ b/google/source_appengine.go @@ -0,0 +1,68 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package google + +import ( + "errors" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/oauth2" +) + +var ( + aeTokensMu sync.Mutex // guards aeTokens and appEngineTokenSource.key + + // aeTokens helps the fetched tokens to be reused until their expiration. + aeTokens = make(map[string]*tokenLock) // key is '\0'-separated scopes +) + +var errInvalidContext = errors.New("oauth2: a valid appengine.Context is required") + +type tokenLock struct { + mu sync.Mutex // guards t; held while updating t + t *oauth2.Token +} + +type appEngineTokenSource struct { + ctx oauth2.Context + scopes []string + key string // guarded by package-level mutex, aeTokensMu + + // fetcherFunc makes the actual RPC to fetch a new access token with an expiry time. + // Provider of this function is responsible to assert that the given context is valid. + fetcherFunc func(ctx oauth2.Context, scope ...string) (string, time.Time, error) +} + +func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) { + aeTokensMu.Lock() + if ts.key == "" { + sort.Sort(sort.StringSlice(ts.scopes)) + ts.key = strings.Join(ts.scopes, string(0)) + } + tok, ok := aeTokens[ts.key] + if !ok { + tok = &tokenLock{} + aeTokens[ts.key] = tok + } + aeTokensMu.Unlock() + + tok.mu.Lock() + defer tok.mu.Unlock() + if tok.t != nil && !tok.t.Expired() { + return tok.t, nil + } + access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...) + if err != nil { + return nil, err + } + tok.t = &oauth2.Token{ + AccessToken: access, + Expiry: exp, + } + return tok.t, nil +} diff --git a/internal/oauth2.go b/internal/oauth2.go index b91b662..47c8f14 100644 --- a/internal/oauth2.go +++ b/internal/oauth2.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package internal contains support packages for oauth2 package. package internal import ( diff --git a/jwt.go b/jwt.go index d861e93..eedbfc1 100644 --- a/jwt.go +++ b/jwt.go @@ -5,7 +5,6 @@ package oauth2 import ( - "crypto/rsa" "encoding/json" "fmt" "io" @@ -15,6 +14,7 @@ import ( "strings" "time" + "golang.org/x/oauth2/internal" "golang.org/x/oauth2/jws" ) @@ -38,7 +38,7 @@ type JWTConfig struct { // // $ openssl pkcs12 -in key.p12 -out key.pem -nodes // - PrivateKey *rsa.PrivateKey + PrivateKey []byte // Subject is the optional user to impersonate. Subject string @@ -76,8 +76,8 @@ func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource { func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client { return &http.Client{ Transport: &Transport{ - Source: c.TokenSource(ctx, initialToken), Base: contextTransport(ctx), + Source: c.TokenSource(ctx, initialToken), }, } } @@ -90,6 +90,10 @@ type jwtSource struct { } func (js jwtSource) Token() (*Token, error) { + pk, err := internal.ParseKey(js.conf.PrivateKey) + if err != nil { + return nil, err + } hc, err := contextClient(js.ctx) if err != nil { return nil, err @@ -105,7 +109,7 @@ func (js jwtSource) Token() (*Token, error) { // to be compatible with legacy OAuth 2.0 providers. claimSet.Prn = subject } - payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey) + payload, err := jws.Encode(defaultHeader, claimSet, pk) if err != nil { return nil, err } diff --git a/jwt_test.go b/jwt_test.go index b51c702..2fe371b 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -48,18 +48,16 @@ func TestJWTFetch_JSONResponse(t *testing.T) { }`)) })) defer ts.Close() - f, err := New( - JWTClient("aaa@xxx.com", dummyPrivateKey), - JWTEndpoint(ts.URL), - ) - if err != nil { - t.Error(err) - } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get(ts.URL) - tok := tr.Token() + conf := &JWTConfig{ + Email: "aaa@xxx.com", + PrivateKey: dummyPrivateKey, + TokenURL: ts.URL, + } + tok, err := conf.TokenSource(NoContext, nil).Token() + if err != nil { + t.Fatal(err) + } if tok.Expired() { t.Errorf("Token shouldn't be expired.") } @@ -81,19 +79,15 @@ func TestJWTFetch_BadResponse(t *testing.T) { w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - f, err := New( - JWTClient("aaa@xxx.com", dummyPrivateKey), - JWTEndpoint(ts.URL), - ) - if err != nil { - t.Error(err) + + conf := &JWTConfig{ + Email: "aaa@xxx.com", + PrivateKey: dummyPrivateKey, + TokenURL: ts.URL, } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get(ts.URL) - tok := tr.Token() + tok, err := conf.TokenSource(NoContext, nil).Token() if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Fatal(err) } if tok.AccessToken != "" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) @@ -113,19 +107,14 @@ func TestJWTFetch_BadResponseType(t *testing.T) { w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - f, err := New( - JWTClient("aaa@xxx.com", dummyPrivateKey), - JWTEndpoint(ts.URL), - ) - if err != nil { - t.Error(err) + conf := &JWTConfig{ + Email: "aaa@xxx.com", + PrivateKey: dummyPrivateKey, + TokenURL: ts.URL, } - tr := f.NewTransport() - c := http.Client{Transport: tr} - c.Get(ts.URL) - tok := tr.Token() + tok, err := conf.TokenSource(NoContext, nil).Token() if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Fatal(err) } if tok.AccessToken != "" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) diff --git a/oauth2.go b/oauth2.go index 88121f3..753aa60 100644 --- a/oauth2.go +++ b/oauth2.go @@ -27,9 +27,14 @@ import ( // Context can be an golang.org/x/net.Context, or an App Engine Context. // In the future these will be unified. -// If you don't care and aren't running on App Engine, you may use nil. +// If you don't care and aren't running on App Engine, you may use NoContext. type Context interface{} +// NoContext is the default context. If you're not running this code +// on App Engine or not using golang.org/x/net.Context to provide a custom +// HTTP client, you should use NoContext. +var NoContext Context = nil + // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's URLs. type Config struct { @@ -272,8 +277,8 @@ func contextTransport(ctx Context) http.RoundTripper { func (c *Config) Client(ctx Context, t *Token) *http.Client { return &http.Client{ Transport: &Transport{ - Source: c.TokenSource(ctx, t), Base: contextTransport(ctx), + Source: c.TokenSource(ctx, t), }, } } diff --git a/oauth2_test.go b/oauth2_test.go index 6c21043..8159b86 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -10,6 +10,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "golang.org/x/net/context" ) type mockTransport struct { @@ -33,31 +35,37 @@ func (c *mockCache) WriteToken(*Token) { // do nothing } -func newOpts(url string) *Options { - opts, _ := New( - Client("CLIENT_ID", "CLIENT_SECRET"), - RedirectURL("REDIRECT_URL"), - Scope("scope1", "scope2"), - Endpoint(url+"/auth", url+"/token"), - ) - return opts +func newConf(url string) *Config { + return &Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: []string{"scope1", "scope2"}, + Endpoint: Endpoint{ + AuthURL: url + "/auth", + TokenURL: url + "/token", + }, + } } func TestAuthCodeURL(t *testing.T) { - opts := newOpts("server") - url := opts.AuthCodeURL("foo", "offline", "force") + conf := newConf("server") + url := conf.AuthCodeURL("foo", AccessTypeOffline, ApprovalForce) if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" { t.Errorf("Auth code URL doesn't match the expected, found: %v", url) } } func TestAuthCodeURL_Optional(t *testing.T) { - opts, _ := New( - Client("CLIENT_ID", ""), - Endpoint("auth-url", "token-token"), - ) - url := opts.AuthCodeURL("", "", "") - if url != "auth-url?client_id=CLIENT_ID&response_type=code" { + conf := &Config{ + ClientID: "CLIENT_ID", + Endpoint: Endpoint{ + AuthURL: "/auth-url", + TokenURL: "/token-url", + }, + } + url := conf.AuthCodeURL("") + if url != "/auth-url?client_id=CLIENT_ID&response_type=code" { t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) } } @@ -86,12 +94,11 @@ func TestExchangeRequest(t *testing.T) { w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) })) defer ts.Close() - opts := newOpts(ts.URL) - tr, err := opts.NewTransportFromCode("exchange-code") + conf := newConf(ts.URL) + tok, err := conf.Exchange(NoContext, "exchange-code") if err != nil { t.Error(err) } - tok := tr.Token() if tok.Expired() { t.Errorf("Token shouldn't be expired.") } @@ -131,15 +138,11 @@ func TestExchangeRequest_JSONResponse(t *testing.T) { w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) })) defer ts.Close() - opts := newOpts(ts.URL) - tr, err := opts.NewTransportFromCode("exchange-code") + conf := newConf(ts.URL) + tok, err := conf.Exchange(NoContext, "exchange-code") if err != nil { t.Error(err) } - tok := tr.Token() - if tok.Expiry.IsZero() { - t.Errorf("Token expiry should not be zero.") - } if tok.Expired() { t.Errorf("Token shouldn't be expired.") } @@ -161,12 +164,11 @@ func TestExchangeRequest_BadResponse(t *testing.T) { w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - opts := newOpts(ts.URL) - tr, err := opts.NewTransportFromCode("exchange-code") + conf := newConf(ts.URL) + tok, err := conf.Exchange(NoContext, "code") if err != nil { - t.Error(err) + t.Fatal(err) } - tok := tr.Token() if tok.AccessToken != "" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } @@ -178,12 +180,11 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - opts := newOpts(ts.URL) - tr, err := opts.NewTransportFromCode("exchange-code") + conf := newConf(ts.URL) + tok, err := conf.Exchange(NoContext, "exchange-code") if err != nil { t.Error(err) } - tok := tr.Token() if tok.AccessToken != "" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } @@ -200,15 +201,16 @@ func TestExchangeRequest_NonBasicAuth(t *testing.T) { }, } c := &http.Client{Transport: tr} - opts, err := New( - Client("CLIENT_ID", ""), - Endpoint("https://accounts.google.com/auth", "https://accounts.google.com/token"), - HTTPClient(c), - ) - if err != nil { - t.Error(err) + conf := &Config{ + ClientID: "CLIENT_ID", + Endpoint: Endpoint{ + AuthURL: "https://accounts.google.com/auth", + TokenURL: "https://accounts.google.com/token", + }, } - opts.NewTransportFromCode("code") + + ctx := context.WithValue(context.Background(), HTTPClient, c) + conf.Exchange(ctx, "code") } func TestTokenRefreshRequest(t *testing.T) { @@ -229,10 +231,8 @@ func TestTokenRefreshRequest(t *testing.T) { } })) defer ts.Close() - opts := newOpts(ts.URL) - tr := opts.NewTransport() - tr.token = &Token{RefreshToken: "REFRESH_TOKEN"} - c := http.Client{Transport: tr} + conf := newConf(ts.URL) + c := conf.Client(NoContext, &Token{RefreshToken: "REFRESH_TOKEN"}) c.Get(ts.URL + "/somethingelse") } @@ -254,28 +254,10 @@ func TestFetchWithNoRefreshToken(t *testing.T) { } })) defer ts.Close() - opts := newOpts(ts.URL) - tr := opts.NewTransport() - c := http.Client{Transport: tr} + conf := newConf(ts.URL) + c := conf.Client(NoContext, nil) _, err := c.Get(ts.URL + "/somethingelse") if err == nil { t.Errorf("Fetch should return an error if no refresh token is set") } } - -func TestCacheNoToken(t *testing.T) { - opts, err := New( - Client("CLIENT_ID", "CLIENT_SECRET"), - Endpoint("/auth", "/token"), - ) - if err != nil { - t.Error(err) - } - tr, err := opts.NewTransportFromTokenStore(&mockCache{token: nil, readErr: nil}) - if err != nil { - t.Errorf("No error expected, %v is found", err) - } - if tr != nil { - t.Errorf("No transport should have been initiated, tr is found to be %v", tr) - } -} diff --git a/transport_test.go b/transport_test.go index 5fbccf6..b3414e3 100644 --- a/transport_test.go +++ b/transport_test.go @@ -7,45 +7,29 @@ import ( "time" ) -type mockTokenFetcher struct{ token *Token } +type tokenSource struct{ token *Token } -func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) { - return func(*Token) (*Token, error) { - return f.token, nil - } +func (t *tokenSource) Token() (*Token, error) { + return t.token, nil } -func TestInitialTokenRead(t *testing.T) { - tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) - server := newMockServer(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer abc" { - t.Errorf("Transport doesn't set the Authorization header from the initial token") - } - }) - defer server.Close() - client := http.Client{Transport: tr} - client.Get(server.URL) -} - -func TestTokenFetch(t *testing.T) { - fetcher := &mockTokenFetcher{ +func TestTransportTokenSource(t *testing.T) { + ts := &tokenSource{ token: &Token{ AccessToken: "abc", }, } - tr := newTransport(http.DefaultTransport, &Options{TokenFetcherFunc: fetcher.Fn()}, nil) + tr := &Transport{ + Source: ts, + } server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the fetched token") } }) defer server.Close() - client := http.Client{Transport: tr} client.Get(server.URL) - if tr.Token().AccessToken != "abc" { - t.Errorf("New token is not set, found %v", tr.Token()) - } } func TestExpiredWithNoAccessToken(t *testing.T) {