oauth2: redesign the API

Tests and examples aren't updated yet. The tree will be broken after this,
but nobody should be using this yet anyway.

Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f
Reviewed-on: https://go-review.googlesource.com/1246
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Brad Fitzpatrick
2014-12-10 10:17:33 +11:00
parent b3f9a68f05
commit a568078818
4 changed files with 556 additions and 432 deletions

View File

@@ -5,136 +5,90 @@
package oauth2
import (
"errors"
"io"
"net/http"
"net/url"
"sync"
"time"
)
const (
defaultTokenType = "Bearer"
)
// Token represents the crendentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
type Token struct {
// AccessToken is the token that authorizes and authenticates the requests.
AccessToken string `json:"access_token"`
// TokenType identifies the type of token returned.
TokenType string `json:"token_type,omitempty"`
// RefreshToken is a token that may be used to obtain a new access token.
RefreshToken string `json:"refresh_token,omitempty"`
// Expiry is the expiration datetime of the access token.
Expiry time.Time `json:"expiry,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw interface{}
}
// Extra returns an extra field returned from the server during token retrieval.
// E.g.
// idToken := token.Extra("id_token")
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
// wrapping a base RoundTripper and adding an Authorization header
// with a token from the supplied Sources.
//
func (t *Token) Extra(key string) string {
if vals, ok := t.raw.(url.Values); ok {
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok {
if val, ok := raw[key].(string); ok {
return val
}
}
return ""
}
// Expired returns true if there is no access token or the
// access token is expired.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
// Transport is a low-level mechanism. Most code will use the
// higher-level Config.Client method instead.
type Transport struct {
opts *Options
base http.RoundTripper
// Source supplies the token to add to outgoing requests'
// Authorization headers.
Source TokenSource
mu sync.RWMutex
token *Token
}
// Base is the base RoundTripper used to make HTTP requests.
// If nil, http.DefaultTransport is used.
Base http.RoundTripper
// NewTransport creates a new Transport that uses the provided
// token fetcher as token retrieving strategy. It authenticates
// the requests and delegates origTransport to make the actual requests.
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
return &Transport{
base: base,
opts: opts,
token: token,
}
mu sync.Mutex // guards modReq
modReq map[*http.Request]*http.Request // original -> modified
}
// RoundTrip authorizes and authenticates the request with an
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token := t.token
if token == nil || token.Expired() {
// Check if the token is refreshable.
// If token is refreshable, don't return an error,
// rather refresh.
if err := t.refreshToken(); err != nil {
return nil, err
}
token = t.token
if t.opts.TokenStore != nil {
t.opts.TokenStore.WriteToken(token)
}
if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
token, err := t.Source.Token()
if err != nil {
return nil, err
}
// To set the Authorization header, we must make a copy of the Request
// so that we don't modify the Request we were given.
// This is required by the specification of http.RoundTripper.
req = cloneRequest(req)
typ := token.TokenType
if typ == "" {
typ = defaultTokenType
req2 := cloneRequest(req) // per RoundTripper contract
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
if err != nil {
t.setModReq(req, nil)
return nil, err
}
req.Header.Set("Authorization", typ+" "+token.AccessToken)
return t.base.RoundTrip(req)
res.Body = &onEOFReader{
rc: res.Body,
fn: func() { t.setModReq(req, nil) },
}
return res, nil
}
// Token returns the token that authorizes and
// authenticates the transport.
func (t *Transport) Token() *Token {
t.mu.RLock()
defer t.mu.RUnlock()
return t.token
// CancelRequest cancels an in-flight request by closing its connection.
func (t *Transport) CancelRequest(req *http.Request) {
type canceler interface {
CancelRequest(*http.Request)
}
if cr, ok := t.Base.(canceler); ok {
t.mu.Lock()
modReq := t.modReq[req]
delete(t.modReq, req)
t.mu.Unlock()
cr.CancelRequest(modReq)
}
}
// refreshToken retrieves a new token, if a refreshing/fetching
// method is known and required credentials are presented
// (such as a refresh token).
func (t *Transport) refreshToken() error {
func (t *Transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func (t *Transport) setModReq(orig, mod *http.Request) {
t.mu.Lock()
defer t.mu.Unlock()
token, err := t.opts.TokenFetcherFunc(t.token)
if err != nil {
return err
if t.modReq == nil {
t.modReq = make(map[*http.Request]*http.Request)
}
if mod == nil {
delete(t.modReq, orig)
} else {
t.modReq[orig] = mod
}
t.token = token
return nil
}
// cloneRequest returns a clone of the provided *http.Request.
@@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = s
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
type onEOFReader struct {
rc io.ReadCloser
fn func()
}
func (r *onEOFReader) Read(p []byte) (n int, err error) {
n, err = r.rc.Read(p)
if err == io.EOF {
r.runFunc()
}
return
}
func (r *onEOFReader) Close() error {
err := r.rc.Close()
r.runFunc()
return err
}
func (r *onEOFReader) runFunc() {
if fn := r.fn; fn != nil {
fn()
r.fn = nil
}
}
type errorTransport struct{ err error }
func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) {
return nil, t.err
}