diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index b070133..dff0881 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -66,11 +66,11 @@ type CredentialSource struct { } // parse determines the type of CredentialSource needed -func (c *Config) parse() baseCredentialSource { +func (c *Config) parse(ctx context.Context) baseCredentialSource { if c.CredentialSource.File != "" { return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format} } else if c.CredentialSource.URL != "" { - return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format} + return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx} } return nil } @@ -89,7 +89,7 @@ type tokenSource struct { func (ts tokenSource) Token() (*oauth2.Token, error) { conf := ts.conf - credSource := conf.parse() + credSource := conf.parse(ts.ctx) if credSource == nil { return nil, fmt.Errorf("oauth2/google: unable to parse credential source") } diff --git a/google/internal/externalaccount/filecredsource_test.go b/google/internal/externalaccount/filecredsource_test.go index 0bc8048..56dd71e 100644 --- a/google/internal/externalaccount/filecredsource_test.go +++ b/google/internal/externalaccount/filecredsource_test.go @@ -5,6 +5,7 @@ package externalaccount import ( + "context" "testing" ) @@ -55,7 +56,7 @@ func TestRetrieveFileSubjectToken(t *testing.T) { tfc.CredentialSource = test.cs t.Run(test.name, func(t *testing.T) { - out, err := tfc.parse().subjectToken() + out, err := tfc.parse(context.Background()).subjectToken() if err != nil { t.Errorf("Method subjectToken() errored.") } else if test.want != out { diff --git a/google/internal/externalaccount/urlcredsource.go b/google/internal/externalaccount/urlcredsource.go index 0318d48..d3818be 100644 --- a/google/internal/externalaccount/urlcredsource.go +++ b/google/internal/externalaccount/urlcredsource.go @@ -5,39 +5,39 @@ package externalaccount import ( + "context" "encoding/json" "errors" "fmt" + "golang.org/x/oauth2" "io" "io/ioutil" "net/http" - "strings" ) type urlCredentialSource struct { URL string Headers map[string]string Format format + ctx context.Context } func (cs urlCredentialSource) subjectToken() (string, error) { - client := http.Client{} - req, err := http.NewRequest("GET", cs.URL, strings.NewReader("")) + client := oauth2.NewClient(cs.ctx, nil) + req, err := http.NewRequest("GET", cs.URL, nil) for key, val := range cs.Headers { req.Header.Add(key, val) } resp, err := client.Do(req) if err != nil { - fmt.Errorf("oauth2/google: invalid response when retrieving subject token: %v", err) - return "", err + return "", fmt.Errorf("oauth2/google: invalid response when retrieving subject token: %v", err) } defer resp.Body.Close() tokenBytes, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - fmt.Errorf("oauth2/google: invalid body in subject token URL query: %v", err) - return "", err + return "", fmt.Errorf("oauth2/google: invalid body in subject token URL query: %v", err) } switch cs.Format.Type { diff --git a/google/internal/externalaccount/urlcredsource_test.go b/google/internal/externalaccount/urlcredsource_test.go index 5d66c2f..1deea05 100644 --- a/google/internal/externalaccount/urlcredsource_test.go +++ b/google/internal/externalaccount/urlcredsource_test.go @@ -14,7 +14,6 @@ import ( var myURLToken = "testTokenValue" func TestRetrieveURLSubjectToken_Text(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { t.Errorf("Unexpected request method, %v is found", r.Method) @@ -35,12 +34,10 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) { if out != myURLToken { t.Errorf("got %v but want %v", out, myURLToken) } - } // Checking that retrieveSubjectToken properly defaults to type text func TestRetrieveURLSubjectToken_Untyped(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { t.Errorf("Unexpected request method, %v is found", r.Method) @@ -55,12 +52,11 @@ func TestRetrieveURLSubjectToken_Untyped(t *testing.T) { out, err := tfc.parse().subjectToken() if err != nil { - t.Fatalf("Failed to retrieve USRL subject token: %v", err) + t.Fatalf("Failed to retrieve URL subject token: %v", err) } if out != myURLToken { t.Errorf("got %v but want %v", out, myURLToken) } - } func TestRetrieveURLSubjectToken_JSON(t *testing.T) { @@ -86,12 +82,10 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) { tfc.CredentialSource = cs out, err := tfc.parse().subjectToken() - if err != nil { t.Fatalf("%v", err) } if out != myURLToken { t.Errorf("got %v but want %v", out, myURLToken) } - }