forked from remote/oauth2
First wave of review updates.
Change-Id: Ibfe8cb23f12c516d9264fcbbee8d8af64b458c89
This commit is contained in:
@@ -66,11 +66,11 @@ type CredentialSource struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parse determines the type of CredentialSource needed
|
// parse determines the type of CredentialSource needed
|
||||||
func (c *Config) parse() baseCredentialSource {
|
func (c *Config) parse(ctx context.Context) baseCredentialSource {
|
||||||
if c.CredentialSource.File != "" {
|
if c.CredentialSource.File != "" {
|
||||||
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}
|
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}
|
||||||
} else if c.CredentialSource.URL != "" {
|
} else if c.CredentialSource.URL != "" {
|
||||||
return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format}
|
return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -89,7 +89,7 @@ type tokenSource struct {
|
|||||||
func (ts tokenSource) Token() (*oauth2.Token, error) {
|
func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||||
conf := ts.conf
|
conf := ts.conf
|
||||||
|
|
||||||
credSource := conf.parse()
|
credSource := conf.parse(ts.ctx)
|
||||||
if credSource == nil {
|
if credSource == nil {
|
||||||
return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
|
return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
package externalaccount
|
package externalaccount
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
|
|||||||
tfc.CredentialSource = test.cs
|
tfc.CredentialSource = test.cs
|
||||||
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
out, err := tfc.parse().subjectToken()
|
out, err := tfc.parse(context.Background()).subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Method subjectToken() errored.")
|
t.Errorf("Method subjectToken() errored.")
|
||||||
} else if test.want != out {
|
} else if test.want != out {
|
||||||
|
|||||||
@@ -5,39 +5,39 @@
|
|||||||
package externalaccount
|
package externalaccount
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type urlCredentialSource struct {
|
type urlCredentialSource struct {
|
||||||
URL string
|
URL string
|
||||||
Headers map[string]string
|
Headers map[string]string
|
||||||
Format format
|
Format format
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs urlCredentialSource) subjectToken() (string, error) {
|
func (cs urlCredentialSource) subjectToken() (string, error) {
|
||||||
client := http.Client{}
|
client := oauth2.NewClient(cs.ctx, nil)
|
||||||
req, err := http.NewRequest("GET", cs.URL, strings.NewReader(""))
|
req, err := http.NewRequest("GET", cs.URL, nil)
|
||||||
|
|
||||||
for key, val := range cs.Headers {
|
for key, val := range cs.Headers {
|
||||||
req.Header.Add(key, val)
|
req.Header.Add(key, val)
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Errorf("oauth2/google: invalid response when retrieving subject token: %v", err)
|
return "", fmt.Errorf("oauth2/google: invalid response when retrieving subject token: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
tokenBytes, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
tokenBytes, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Errorf("oauth2/google: invalid body in subject token URL query: %v", err)
|
return "", fmt.Errorf("oauth2/google: invalid body in subject token URL query: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch cs.Format.Type {
|
switch cs.Format.Type {
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
var myURLToken = "testTokenValue"
|
var myURLToken = "testTokenValue"
|
||||||
|
|
||||||
func TestRetrieveURLSubjectToken_Text(t *testing.T) {
|
func TestRetrieveURLSubjectToken_Text(t *testing.T) {
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
t.Errorf("Unexpected request method, %v is found", r.Method)
|
t.Errorf("Unexpected request method, %v is found", r.Method)
|
||||||
@@ -35,12 +34,10 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) {
|
|||||||
if out != myURLToken {
|
if out != myURLToken {
|
||||||
t.Errorf("got %v but want %v", out, myURLToken)
|
t.Errorf("got %v but want %v", out, myURLToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checking that retrieveSubjectToken properly defaults to type text
|
// Checking that retrieveSubjectToken properly defaults to type text
|
||||||
func TestRetrieveURLSubjectToken_Untyped(t *testing.T) {
|
func TestRetrieveURLSubjectToken_Untyped(t *testing.T) {
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
t.Errorf("Unexpected request method, %v is found", r.Method)
|
t.Errorf("Unexpected request method, %v is found", r.Method)
|
||||||
@@ -55,12 +52,11 @@ func TestRetrieveURLSubjectToken_Untyped(t *testing.T) {
|
|||||||
|
|
||||||
out, err := tfc.parse().subjectToken()
|
out, err := tfc.parse().subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to retrieve USRL subject token: %v", err)
|
t.Fatalf("Failed to retrieve URL subject token: %v", err)
|
||||||
}
|
}
|
||||||
if out != myURLToken {
|
if out != myURLToken {
|
||||||
t.Errorf("got %v but want %v", out, myURLToken)
|
t.Errorf("got %v but want %v", out, myURLToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
|
func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
|
||||||
@@ -86,12 +82,10 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
|
|||||||
tfc.CredentialSource = cs
|
tfc.CredentialSource = cs
|
||||||
|
|
||||||
out, err := tfc.parse().subjectToken()
|
out, err := tfc.parse().subjectToken()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
if out != myURLToken {
|
if out != myURLToken {
|
||||||
t.Errorf("got %v but want %v", out, myURLToken)
|
t.Errorf("got %v but want %v", out, myURLToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user