Introduce an option function type
- Reduce the duplicate code by merging the flows and determining the flow type by looking at the provided options. - Options as a function type allows us to validate an individual an option in its scope and makes it easier to compose the built-in options with the third-party ones.
This commit is contained in:
316
oauth2.go
316
oauth2.go
@@ -8,85 +8,118 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenFetcher refreshes or fetches a new access token from the
|
||||
// provider. It should return an error if it's not capable of
|
||||
// retrieving a token.
|
||||
type TokenFetcher interface {
|
||||
// FetchToken retrieves a new access token for the provider.
|
||||
// If the implementation doesn't know how to retrieve a new token,
|
||||
// it returns an error. The existing token may be nil.
|
||||
FetchToken(existing *Token) (*Token, error)
|
||||
// Option represents a function that applies some state to
|
||||
// an Options object.
|
||||
type Option func(*Options) error
|
||||
|
||||
// Client requires the OAuth 2.0 client credentials. You need to provide
|
||||
// the client identifier and optionally the client secret that are
|
||||
// assigned to your application by the OAuth 2.0 provider.
|
||||
func Client(id, secret string) Option {
|
||||
return func(opts *Options) error {
|
||||
opts.ClientID = id
|
||||
opts.ClientSecret = secret
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Options represents options to provide OAuth 2.0 client credentials
|
||||
// and access level. A sample configuration:
|
||||
type Options struct {
|
||||
// ClientID is the OAuth client identifier used when communicating with
|
||||
// the configured OAuth provider.
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
// ClientSecret is the OAuth client secret used when communicating with
|
||||
// the configured OAuth provider.
|
||||
ClientSecret string `json:"client_secret"`
|
||||
|
||||
// RedirectURL is the URL to which the user will be returned after
|
||||
// granting (or denying) access.
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
|
||||
// Scopes optionally specifies a list of requested permission scopes.
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
// RedirectURL requires the URL to which the user will be returned after
|
||||
// granting (or denying) access.
|
||||
func RedirectURL(url string) Option {
|
||||
return func(opts *Options) error {
|
||||
opts.RedirectURL = url
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewConfig creates a generic OAuth 2.0 configuration that talks
|
||||
// to an OAuth 2.0 provider specified with authURL and tokenURL.
|
||||
func NewConfig(opts *Options, authURL, tokenURL string) (*Config, error) {
|
||||
aURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Scope requires a list of requested permission scopes.
|
||||
// It is optinal to specify scopes.
|
||||
func Scope(scopes ...string) Option {
|
||||
return func(o *Options) error {
|
||||
o.Scopes = scopes
|
||||
return nil
|
||||
}
|
||||
tURL, err := url.Parse(tokenURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.ClientID == "" {
|
||||
return nil, errors.New("oauth2: missing client ID")
|
||||
}
|
||||
return &Config{
|
||||
opts: opts,
|
||||
authURL: aURL,
|
||||
tokenURL: tURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Config represents the configuration of an OAuth 2.0 consumer client.
|
||||
type Config struct {
|
||||
// Client is the HTTP client to be used to retrieve
|
||||
// tokens from the OAuth 2.0 provider.
|
||||
Client *http.Client
|
||||
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
|
||||
func Endpoint(authURL, tokenURL string) Option {
|
||||
return func(o *Options) error {
|
||||
au, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tu, err := url.Parse(tokenURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.TokenFetcherFunc = makeThreeLeggedFetcher(o)
|
||||
o.AuthURL = au
|
||||
o.TokenURL = tu
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Transport is the http.RoundTripper to be used
|
||||
// to construct new oauth2.Transport instances from
|
||||
// this configuration.
|
||||
Transport http.RoundTripper
|
||||
// HTTPClient allows you to provide a custom http.Client to be
|
||||
// used to retrieve tokens from the OAuth 2.0 provider.
|
||||
func HTTPClient(c *http.Client) Option {
|
||||
return func(o *Options) error {
|
||||
o.Client = c
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
opts *Options
|
||||
// AuthURL is the URL the user will be directed to
|
||||
// in order to grant access.
|
||||
authURL *url.URL
|
||||
// TokenURL is the URL used to retrieve OAuth tokens.
|
||||
tokenURL *url.URL
|
||||
// RoundTripper allows you to provide a custom http.RoundTripper
|
||||
// to be used to construct new oauth2.Transport instances.
|
||||
// If none is provided a default RoundTripper will be used.
|
||||
func RoundTripper(tr http.RoundTripper) Option {
|
||||
return func(o *Options) error {
|
||||
o.Transport = tr
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type Flow struct {
|
||||
opts Options
|
||||
}
|
||||
|
||||
// New initiates a new flow. It determines the type of the OAuth 2.0
|
||||
// (2-legged, 3-legged or custom) by looking at the provided options.
|
||||
// If the flow type cannot determined automatically, an error is returned.
|
||||
func New(options ...Option) (*Flow, error) {
|
||||
f := &Flow{}
|
||||
for _, opt := range options {
|
||||
if err := opt(&f.opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case f.opts.TokenFetcherFunc != nil:
|
||||
return f, nil
|
||||
case f.opts.AUD != nil:
|
||||
// TODO(jbd): Assert required JWT params.
|
||||
f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts)
|
||||
return f, nil
|
||||
case f.opts.AuthURL != nil && f.opts.TokenURL != nil:
|
||||
// TODO(jbd): Assert required OAuth2 params.
|
||||
f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts)
|
||||
return f, nil
|
||||
default:
|
||||
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
|
||||
}
|
||||
}
|
||||
|
||||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||
@@ -112,13 +145,13 @@ type Config struct {
|
||||
// granted consent and the code can only be exchanged for an
|
||||
// access token. If set to "force" the user will always be prompted,
|
||||
// and the code can be exchanged for a refresh token.
|
||||
func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string) {
|
||||
u := *c.authURL
|
||||
func (f *Flow) AuthCodeURL(state, accessType, prompt string) string {
|
||||
u := f.opts.AuthURL
|
||||
v := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {c.opts.ClientID},
|
||||
"redirect_uri": condVal(c.opts.RedirectURL),
|
||||
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
|
||||
"client_id": {f.opts.ClientID},
|
||||
"redirect_uri": condVal(f.opts.RedirectURL),
|
||||
"scope": condVal(strings.Join(f.opts.Scopes, " ")),
|
||||
"state": condVal(state),
|
||||
"access_type": condVal(accessType),
|
||||
"approval_prompt": condVal(prompt),
|
||||
@@ -132,65 +165,122 @@ func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// NewTransport creates a new authorizable transport. It doesn't
|
||||
// initialize the new transport with a token, so after creation,
|
||||
// you need to set a valid token (or an expired token with a valid
|
||||
// refresh token) in order to be able to do authorized requests.
|
||||
func (c *Config) NewTransport() *Transport {
|
||||
return NewTransport(c.transport(), c, nil)
|
||||
// exchange exchanges the authorization code with the OAuth 2.0 provider
|
||||
// to retrieve a new access token.
|
||||
func (f *Flow) exchange(code string) (*Token, error) {
|
||||
return retrieveToken(&f.opts, url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"redirect_uri": condVal(f.opts.RedirectURL),
|
||||
"scope": condVal(strings.Join(f.opts.Scopes, " ")),
|
||||
})
|
||||
}
|
||||
|
||||
// NewTransportWithCode exchanges the OAuth 2.0 authorization code with
|
||||
// the provider to fetch a new access token (and refresh token). Once
|
||||
// it successfully retrieves a new token, creates a new transport
|
||||
// authorized with it.
|
||||
func (c *Config) NewTransportWithCode(code string) (*Transport, error) {
|
||||
token, err := c.Exchange(code)
|
||||
// NewTransportFromCode exchanges the code to retrieve a new access token
|
||||
// and returns an authorized and authenticated Transport.
|
||||
func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
|
||||
token, err := f.exchange(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewTransport(c.transport(), c, token), nil
|
||||
return f.NewTransportFromToken(token), nil
|
||||
}
|
||||
|
||||
// FetchToken retrieves a new access token and updates the existing token
|
||||
// with the newly fetched credentials. If existing token doesn't
|
||||
// contain a refresh token, it returns an error.
|
||||
func (c *Config) FetchToken(existing *Token) (*Token, error) {
|
||||
if existing == nil || existing.RefreshToken == "" {
|
||||
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
|
||||
// NewTransportFromToken returns a new Transport that is authorized
|
||||
// and authenticated with the provided token.
|
||||
func (f *Flow) NewTransportFromToken(t *Token) *Transport {
|
||||
tr := f.opts.Transport
|
||||
if tr == nil {
|
||||
tr = http.DefaultTransport
|
||||
}
|
||||
return c.retrieveToken(url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {existing.RefreshToken},
|
||||
})
|
||||
return newTransport(tr, f.opts.TokenFetcherFunc, t)
|
||||
}
|
||||
|
||||
// Exchange exchanges the authorization code with the OAuth 2.0 provider
|
||||
// to retrieve a new access token.
|
||||
func (c *Config) Exchange(code string) (*Token, error) {
|
||||
return c.retrieveToken(url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"redirect_uri": condVal(c.opts.RedirectURL),
|
||||
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
|
||||
})
|
||||
// NewTransport returns a Transport.
|
||||
func (f *Flow) NewTransport() *Transport {
|
||||
return f.NewTransportFromToken(nil)
|
||||
}
|
||||
|
||||
func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||
v.Set("client_id", c.opts.ClientID)
|
||||
bustedAuth := !providerAuthHeaderWorks(c.tokenURL.String())
|
||||
if bustedAuth && c.opts.ClientSecret != "" {
|
||||
v.Set("client_secret", c.opts.ClientSecret)
|
||||
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||
return func(t *Token) (*Token, error) {
|
||||
if t == nil || t.RefreshToken == "" {
|
||||
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
|
||||
}
|
||||
return retrieveToken(o, url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {t.RefreshToken},
|
||||
})
|
||||
}
|
||||
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(v.Encode()))
|
||||
}
|
||||
|
||||
// Options represents an object to keep the state of the OAuth 2.0 flow.
|
||||
type Options struct {
|
||||
// ClientID is the OAuth client identifier used when communicating with
|
||||
// the configured OAuth provider.
|
||||
ClientID string
|
||||
|
||||
// ClientSecret is the OAuth client secret used when communicating with
|
||||
// the configured OAuth provider.
|
||||
ClientSecret string
|
||||
|
||||
// RedirectURL is the URL to which the user will be returned after
|
||||
// granting (or denying) access.
|
||||
RedirectURL string
|
||||
|
||||
// Email is the OAuth client identifier used when communicating with
|
||||
// the configured OAuth provider.
|
||||
Email string
|
||||
|
||||
// PrivateKey contains the contents of an RSA private key or the
|
||||
// contents of a PEM file that contains a private key. The provided
|
||||
// private key is used to sign JWT payloads.
|
||||
// PEM containers with a passphrase are not supported.
|
||||
// Use the following command to convert a PKCS 12 file into a PEM.
|
||||
//
|
||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||
//
|
||||
PrivateKey *rsa.PrivateKey
|
||||
|
||||
// Scopes identify the level of access being requested.
|
||||
Subject string
|
||||
|
||||
// Scopes optionally specifies a list of requested permission scopes.
|
||||
Scopes []string
|
||||
|
||||
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
|
||||
AuthURL *url.URL
|
||||
|
||||
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
|
||||
TokenURL *url.URL
|
||||
|
||||
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
|
||||
AUD *url.URL
|
||||
|
||||
TokenFetcherFunc func(t *Token) (*Token, error)
|
||||
|
||||
Transport http.RoundTripper
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
||||
v.Set("client_id", o.ClientID)
|
||||
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
|
||||
if bustedAuth && o.ClientSecret != "" {
|
||||
v.Set("client_secret", o.ClientSecret)
|
||||
}
|
||||
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if !bustedAuth && c.opts.ClientSecret != "" {
|
||||
req.SetBasicAuth(c.opts.ClientID, c.opts.ClientSecret)
|
||||
if !bustedAuth && o.ClientSecret != "" {
|
||||
req.SetBasicAuth(o.ClientID, o.ClientSecret)
|
||||
}
|
||||
r, err := c.client().Do(req)
|
||||
c := o.Client
|
||||
if c == nil {
|
||||
c = &http.Client{}
|
||||
}
|
||||
r, err := c.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -199,7 +289,7 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
if c := r.StatusCode; c < 200 || c > 299 {
|
||||
if code := r.StatusCode; code < 200 || code > 299 {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
||||
}
|
||||
|
||||
@@ -255,20 +345,6 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Config) transport() http.RoundTripper {
|
||||
if c.Transport != nil {
|
||||
return c.Transport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func (c *Config) client() *http.Client {
|
||||
if c.Client != nil {
|
||||
return c.Client
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func condVal(v string) []string {
|
||||
if v == "" {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user