oauth2: redesign the API
Tests and examples aren't updated yet. The tree will be broken after this, but nobody should be using this yet anyway. Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f Reviewed-on: https://go-review.googlesource.com/1246 Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
202
transport.go
202
transport.go
@@ -5,136 +5,90 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTokenType = "Bearer"
|
||||
)
|
||||
|
||||
// Token represents the crendentials used to authorize
|
||||
// the requests to access protected resources on the OAuth 2.0
|
||||
// provider's backend.
|
||||
type Token struct {
|
||||
// AccessToken is the token that authorizes and authenticates the requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// TokenType identifies the type of token returned.
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
|
||||
// RefreshToken is a token that may be used to obtain a new access token.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
|
||||
// Expiry is the expiration datetime of the access token.
|
||||
Expiry time.Time `json:"expiry,omitempty"`
|
||||
|
||||
// raw optionally contains extra metadata from the server
|
||||
// when updating a token.
|
||||
raw interface{}
|
||||
}
|
||||
|
||||
// Extra returns an extra field returned from the server during token retrieval.
|
||||
// E.g.
|
||||
// idToken := token.Extra("id_token")
|
||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
||||
// wrapping a base RoundTripper and adding an Authorization header
|
||||
// with a token from the supplied Sources.
|
||||
//
|
||||
func (t *Token) Extra(key string) string {
|
||||
if vals, ok := t.raw.(url.Values); ok {
|
||||
return vals.Get(key)
|
||||
}
|
||||
if raw, ok := t.raw.(map[string]interface{}); ok {
|
||||
if val, ok := raw[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Expired returns true if there is no access token or the
|
||||
// access token is expired.
|
||||
func (t *Token) Expired() bool {
|
||||
if t.AccessToken == "" {
|
||||
return true
|
||||
}
|
||||
if t.Expiry.IsZero() {
|
||||
return false
|
||||
}
|
||||
return t.Expiry.Before(time.Now())
|
||||
}
|
||||
|
||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
||||
// Transport is a low-level mechanism. Most code will use the
|
||||
// higher-level Config.Client method instead.
|
||||
type Transport struct {
|
||||
opts *Options
|
||||
base http.RoundTripper
|
||||
// Source supplies the token to add to outgoing requests'
|
||||
// Authorization headers.
|
||||
Source TokenSource
|
||||
|
||||
mu sync.RWMutex
|
||||
token *Token
|
||||
}
|
||||
// Base is the base RoundTripper used to make HTTP requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Base http.RoundTripper
|
||||
|
||||
// NewTransport creates a new Transport that uses the provided
|
||||
// token fetcher as token retrieving strategy. It authenticates
|
||||
// the requests and delegates origTransport to make the actual requests.
|
||||
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
|
||||
return &Transport{
|
||||
base: base,
|
||||
opts: opts,
|
||||
token: token,
|
||||
}
|
||||
mu sync.Mutex // guards modReq
|
||||
modReq map[*http.Request]*http.Request // original -> modified
|
||||
}
|
||||
|
||||
// RoundTrip authorizes and authenticates the request with an
|
||||
// access token. If no token exists or token is expired,
|
||||
// tries to refresh/fetch a new token.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token := t.token
|
||||
|
||||
if token == nil || token.Expired() {
|
||||
// Check if the token is refreshable.
|
||||
// If token is refreshable, don't return an error,
|
||||
// rather refresh.
|
||||
if err := t.refreshToken(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token = t.token
|
||||
if t.opts.TokenStore != nil {
|
||||
t.opts.TokenStore.WriteToken(token)
|
||||
}
|
||||
if t.Source == nil {
|
||||
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||
}
|
||||
token, err := t.Source.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// To set the Authorization header, we must make a copy of the Request
|
||||
// so that we don't modify the Request we were given.
|
||||
// This is required by the specification of http.RoundTripper.
|
||||
req = cloneRequest(req)
|
||||
typ := token.TokenType
|
||||
if typ == "" {
|
||||
typ = defaultTokenType
|
||||
req2 := cloneRequest(req) // per RoundTripper contract
|
||||
token.SetAuthHeader(req2)
|
||||
t.setModReq(req, req2)
|
||||
res, err := t.base().RoundTrip(req2)
|
||||
if err != nil {
|
||||
t.setModReq(req, nil)
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", typ+" "+token.AccessToken)
|
||||
return t.base.RoundTrip(req)
|
||||
res.Body = &onEOFReader{
|
||||
rc: res.Body,
|
||||
fn: func() { t.setModReq(req, nil) },
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Token returns the token that authorizes and
|
||||
// authenticates the transport.
|
||||
func (t *Transport) Token() *Token {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.token
|
||||
// CancelRequest cancels an in-flight request by closing its connection.
|
||||
func (t *Transport) CancelRequest(req *http.Request) {
|
||||
type canceler interface {
|
||||
CancelRequest(*http.Request)
|
||||
}
|
||||
if cr, ok := t.Base.(canceler); ok {
|
||||
t.mu.Lock()
|
||||
modReq := t.modReq[req]
|
||||
delete(t.modReq, req)
|
||||
t.mu.Unlock()
|
||||
cr.CancelRequest(modReq)
|
||||
}
|
||||
}
|
||||
|
||||
// refreshToken retrieves a new token, if a refreshing/fetching
|
||||
// method is known and required credentials are presented
|
||||
// (such as a refresh token).
|
||||
func (t *Transport) refreshToken() error {
|
||||
func (t *Transport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func (t *Transport) setModReq(orig, mod *http.Request) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
token, err := t.opts.TokenFetcherFunc(t.token)
|
||||
if err != nil {
|
||||
return err
|
||||
if t.modReq == nil {
|
||||
t.modReq = make(map[*http.Request]*http.Request)
|
||||
}
|
||||
if mod == nil {
|
||||
delete(t.modReq, orig)
|
||||
} else {
|
||||
t.modReq[orig] = mod
|
||||
}
|
||||
t.token = token
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRequest returns a clone of the provided *http.Request.
|
||||
@@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request {
|
||||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header)
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = s
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
return r2
|
||||
}
|
||||
|
||||
type onEOFReader struct {
|
||||
rc io.ReadCloser
|
||||
fn func()
|
||||
}
|
||||
|
||||
func (r *onEOFReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.rc.Read(p)
|
||||
if err == io.EOF {
|
||||
r.runFunc()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *onEOFReader) Close() error {
|
||||
err := r.rc.Close()
|
||||
r.runFunc()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *onEOFReader) runFunc() {
|
||||
if fn := r.fn; fn != nil {
|
||||
fn()
|
||||
r.fn = nil
|
||||
}
|
||||
}
|
||||
|
||||
type errorTransport struct{ err error }
|
||||
|
||||
func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, t.err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user