Introduce an option function type

- Reduce the duplicate code by merging the flows and
determining the flow type by looking at the provided options.
- Options as a function type allows us to validate an individual
an option in its scope and makes it easier to compose the
built-in options with the third-party ones.
This commit is contained in:
Burcu Dogan
2014-11-07 11:36:41 +11:00
parent 49f4824137
commit 0cf6f9b144
15 changed files with 847 additions and 774 deletions

316
oauth2.go
View File

@@ -8,85 +8,118 @@
package oauth2
import (
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"mime"
"time"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// TokenFetcher refreshes or fetches a new access token from the
// provider. It should return an error if it's not capable of
// retrieving a token.
type TokenFetcher interface {
// FetchToken retrieves a new access token for the provider.
// If the implementation doesn't know how to retrieve a new token,
// it returns an error. The existing token may be nil.
FetchToken(existing *Token) (*Token, error)
// Option represents a function that applies some state to
// an Options object.
type Option func(*Options) error
// Client requires the OAuth 2.0 client credentials. You need to provide
// the client identifier and optionally the client secret that are
// assigned to your application by the OAuth 2.0 provider.
func Client(id, secret string) Option {
return func(opts *Options) error {
opts.ClientID = id
opts.ClientSecret = secret
return nil
}
}
// Options represents options to provide OAuth 2.0 client credentials
// and access level. A sample configuration:
type Options struct {
// ClientID is the OAuth client identifier used when communicating with
// the configured OAuth provider.
ClientID string `json:"client_id"`
// ClientSecret is the OAuth client secret used when communicating with
// the configured OAuth provider.
ClientSecret string `json:"client_secret"`
// RedirectURL is the URL to which the user will be returned after
// granting (or denying) access.
RedirectURL string `json:"redirect_url"`
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string `json:"scopes,omitempty"`
// RedirectURL requires the URL to which the user will be returned after
// granting (or denying) access.
func RedirectURL(url string) Option {
return func(opts *Options) error {
opts.RedirectURL = url
return nil
}
}
// NewConfig creates a generic OAuth 2.0 configuration that talks
// to an OAuth 2.0 provider specified with authURL and tokenURL.
func NewConfig(opts *Options, authURL, tokenURL string) (*Config, error) {
aURL, err := url.Parse(authURL)
if err != nil {
return nil, err
// Scope requires a list of requested permission scopes.
// It is optinal to specify scopes.
func Scope(scopes ...string) Option {
return func(o *Options) error {
o.Scopes = scopes
return nil
}
tURL, err := url.Parse(tokenURL)
if err != nil {
return nil, err
}
if opts.ClientID == "" {
return nil, errors.New("oauth2: missing client ID")
}
return &Config{
opts: opts,
authURL: aURL,
tokenURL: tURL,
}, nil
}
// Config represents the configuration of an OAuth 2.0 consumer client.
type Config struct {
// Client is the HTTP client to be used to retrieve
// tokens from the OAuth 2.0 provider.
Client *http.Client
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
func Endpoint(authURL, tokenURL string) Option {
return func(o *Options) error {
au, err := url.Parse(authURL)
if err != nil {
return err
}
tu, err := url.Parse(tokenURL)
if err != nil {
return err
}
o.TokenFetcherFunc = makeThreeLeggedFetcher(o)
o.AuthURL = au
o.TokenURL = tu
return nil
}
}
// Transport is the http.RoundTripper to be used
// to construct new oauth2.Transport instances from
// this configuration.
Transport http.RoundTripper
// HTTPClient allows you to provide a custom http.Client to be
// used to retrieve tokens from the OAuth 2.0 provider.
func HTTPClient(c *http.Client) Option {
return func(o *Options) error {
o.Client = c
return nil
}
}
opts *Options
// AuthURL is the URL the user will be directed to
// in order to grant access.
authURL *url.URL
// TokenURL is the URL used to retrieve OAuth tokens.
tokenURL *url.URL
// RoundTripper allows you to provide a custom http.RoundTripper
// to be used to construct new oauth2.Transport instances.
// If none is provided a default RoundTripper will be used.
func RoundTripper(tr http.RoundTripper) Option {
return func(o *Options) error {
o.Transport = tr
return nil
}
}
type Flow struct {
opts Options
}
// New initiates a new flow. It determines the type of the OAuth 2.0
// (2-legged, 3-legged or custom) by looking at the provided options.
// If the flow type cannot determined automatically, an error is returned.
func New(options ...Option) (*Flow, error) {
f := &Flow{}
for _, opt := range options {
if err := opt(&f.opts); err != nil {
return nil, err
}
}
switch {
case f.opts.TokenFetcherFunc != nil:
return f, nil
case f.opts.AUD != nil:
// TODO(jbd): Assert required JWT params.
f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts)
return f, nil
case f.opts.AuthURL != nil && f.opts.TokenURL != nil:
// TODO(jbd): Assert required OAuth2 params.
f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts)
return f, nil
default:
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
}
}
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
@@ -112,13 +145,13 @@ type Config struct {
// granted consent and the code can only be exchanged for an
// access token. If set to "force" the user will always be prompted,
// and the code can be exchanged for a refresh token.
func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string) {
u := *c.authURL
func (f *Flow) AuthCodeURL(state, accessType, prompt string) string {
u := f.opts.AuthURL
v := url.Values{
"response_type": {"code"},
"client_id": {c.opts.ClientID},
"redirect_uri": condVal(c.opts.RedirectURL),
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
"client_id": {f.opts.ClientID},
"redirect_uri": condVal(f.opts.RedirectURL),
"scope": condVal(strings.Join(f.opts.Scopes, " ")),
"state": condVal(state),
"access_type": condVal(accessType),
"approval_prompt": condVal(prompt),
@@ -132,65 +165,122 @@ func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string)
return u.String()
}
// NewTransport creates a new authorizable transport. It doesn't
// initialize the new transport with a token, so after creation,
// you need to set a valid token (or an expired token with a valid
// refresh token) in order to be able to do authorized requests.
func (c *Config) NewTransport() *Transport {
return NewTransport(c.transport(), c, nil)
// exchange exchanges the authorization code with the OAuth 2.0 provider
// to retrieve a new access token.
func (f *Flow) exchange(code string) (*Token, error) {
return retrieveToken(&f.opts, url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": condVal(f.opts.RedirectURL),
"scope": condVal(strings.Join(f.opts.Scopes, " ")),
})
}
// NewTransportWithCode exchanges the OAuth 2.0 authorization code with
// the provider to fetch a new access token (and refresh token). Once
// it successfully retrieves a new token, creates a new transport
// authorized with it.
func (c *Config) NewTransportWithCode(code string) (*Transport, error) {
token, err := c.Exchange(code)
// NewTransportFromCode exchanges the code to retrieve a new access token
// and returns an authorized and authenticated Transport.
func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
token, err := f.exchange(code)
if err != nil {
return nil, err
}
return NewTransport(c.transport(), c, token), nil
return f.NewTransportFromToken(token), nil
}
// FetchToken retrieves a new access token and updates the existing token
// with the newly fetched credentials. If existing token doesn't
// contain a refresh token, it returns an error.
func (c *Config) FetchToken(existing *Token) (*Token, error) {
if existing == nil || existing.RefreshToken == "" {
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
// NewTransportFromToken returns a new Transport that is authorized
// and authenticated with the provided token.
func (f *Flow) NewTransportFromToken(t *Token) *Transport {
tr := f.opts.Transport
if tr == nil {
tr = http.DefaultTransport
}
return c.retrieveToken(url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {existing.RefreshToken},
})
return newTransport(tr, f.opts.TokenFetcherFunc, t)
}
// Exchange exchanges the authorization code with the OAuth 2.0 provider
// to retrieve a new access token.
func (c *Config) Exchange(code string) (*Token, error) {
return c.retrieveToken(url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": condVal(c.opts.RedirectURL),
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
})
// NewTransport returns a Transport.
func (f *Flow) NewTransport() *Transport {
return f.NewTransportFromToken(nil)
}
func (c *Config) retrieveToken(v url.Values) (*Token, error) {
v.Set("client_id", c.opts.ClientID)
bustedAuth := !providerAuthHeaderWorks(c.tokenURL.String())
if bustedAuth && c.opts.ClientSecret != "" {
v.Set("client_secret", c.opts.ClientSecret)
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return func(t *Token) (*Token, error) {
if t == nil || t.RefreshToken == "" {
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
}
return retrieveToken(o, url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {t.RefreshToken},
})
}
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(v.Encode()))
}
// Options represents an object to keep the state of the OAuth 2.0 flow.
type Options struct {
// ClientID is the OAuth client identifier used when communicating with
// the configured OAuth provider.
ClientID string
// ClientSecret is the OAuth client secret used when communicating with
// the configured OAuth provider.
ClientSecret string
// RedirectURL is the URL to which the user will be returned after
// granting (or denying) access.
RedirectURL string
// Email is the OAuth client identifier used when communicating with
// the configured OAuth provider.
Email string
// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey *rsa.PrivateKey
// Scopes identify the level of access being requested.
Subject string
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
AuthURL *url.URL
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
TokenURL *url.URL
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
AUD *url.URL
TokenFetcherFunc func(t *Token) (*Token, error)
Transport http.RoundTripper
Client *http.Client
}
func retrieveToken(o *Options, v url.Values) (*Token, error) {
v.Set("client_id", o.ClientID)
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
if bustedAuth && o.ClientSecret != "" {
v.Set("client_secret", o.ClientSecret)
}
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth && c.opts.ClientSecret != "" {
req.SetBasicAuth(c.opts.ClientID, c.opts.ClientSecret)
if !bustedAuth && o.ClientSecret != "" {
req.SetBasicAuth(o.ClientID, o.ClientSecret)
}
r, err := c.client().Do(req)
c := o.Client
if c == nil {
c = &http.Client{}
}
r, err := c.Do(req)
if err != nil {
return nil, err
}
@@ -199,7 +289,7 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) {
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if c := r.StatusCode; c < 200 || c > 299 {
if code := r.StatusCode; code < 200 || code > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
}
@@ -255,20 +345,6 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) {
return token, nil
}
func (c *Config) transport() http.RoundTripper {
if c.Transport != nil {
return c.Transport
}
return http.DefaultTransport
}
func (c *Config) client() *http.Client {
if c.Client != nil {
return c.Client
}
return http.DefaultClient
}
func condVal(v string) []string {
if v == "" {
return nil