google/internal: Add AWS Session Token to Metadata Requests

This commit is contained in:
Sai Sunder
2022-03-08 17:47:17 +00:00
parent ee48083810
commit 29e1f4aad1
3 changed files with 205 additions and 53 deletions

View File

@@ -52,6 +52,13 @@ const (
// The AWS authorization header name for the security session token if available.
awsSecurityTokenHeader = "x-amz-security-token"
// The name of the header containing the session token for metadata endpoint calls
awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
awsIMDSv2SessionTtl = "300"
// The AWS authorization header name for the auto-generated date.
awsDateHeader = "x-amz-date"
@@ -241,6 +248,7 @@ type awsCredentialSource struct {
RegionURL string
RegionalCredVerificationURL string
CredVerificationURL string
IMDSv2SessionTokenURL string
TargetResource string
requestSigner *awsRequestSigner
region string
@@ -268,12 +276,22 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro
func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil {
awsSecurityCredentials, err := cs.getSecurityCredentials()
awsSessionToken, err := cs.getAWSSessionToken()
if err != nil {
return "", err
}
if cs.region, err = cs.getRegion(); err != nil {
headers := make(map[string]string)
if awsSessionToken != "" {
headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
}
awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
if err != nil {
return "", err
}
if cs.region, err = cs.getRegion(headers); err != nil {
return "", err
}
@@ -340,7 +358,37 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
return url.QueryEscape(string(result)), nil
}
func (cs *awsCredentialSource) getRegion() (string, error) {
func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
if cs.IMDSv2SessionTokenURL == "" {
return "", nil
}
req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil)
if err != nil {
return "", err
}
req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
resp, err := cs.doRequest(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody))
}
return string(respBody), nil
}
func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
}
@@ -357,6 +405,10 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
return "", err
}
for name, value := range headers {
req.Header.Add(name, value)
}
resp, err := cs.doRequest(req)
if err != nil {
return "", err
@@ -381,7 +433,7 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
return string(respBody[:respBodyEnd]), nil
}
func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{
@@ -392,12 +444,12 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
}
}
roleName, err := cs.getMetadataRoleName()
roleName, err := cs.getMetadataRoleName(headers)
if err != nil {
return
}
credentials, err := cs.getMetadataSecurityCredentials(roleName)
credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
if err != nil {
return
}
@@ -413,7 +465,7 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
return credentials, nil
}
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
@@ -422,6 +474,10 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
}
req.Header.Add("Content-Type", "application/json")
for name, value := range headers {
req.Header.Add(name, value)
}
resp, err := cs.doRequest(req)
if err != nil {
return result, err
@@ -441,7 +497,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
return result, err
}
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
if cs.CredVerificationURL == "" {
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
}
@@ -451,6 +507,10 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
return "", err
}
for name, value := range headers {
req.Header.Add(name, value)
}
resp, err := cs.doRequest(req)
if err != nil {
return "", err

View File

@@ -20,6 +20,8 @@ import (
var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC)
var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC)
type validateHeaders func(r *http.Request)
func setTime(testTime time.Time) func() time.Time {
return func() time.Time {
return testTime
@@ -82,7 +84,7 @@ func testRequestSigner(t *testing.T, rs *awsRequestSigner, input, expectedOutput
}
}
func TestAwsV4Signature_GetRequest(t *testing.T) {
func TestAWSv4Signature_GetRequest(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com", nil)
setDefaultTime(input)
@@ -100,7 +102,7 @@ func TestAwsV4Signature_GetRequest(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) {
func TestAWSv4Signature_GetRequestWithRelativePath(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil)
setDefaultTime(input)
@@ -118,7 +120,7 @@ func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) {
func TestAWSv4Signature_GetRequestWithDotPath(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil)
setDefaultTime(input)
@@ -136,7 +138,7 @@ func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
func TestAWSv4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil)
setDefaultTime(input)
@@ -154,7 +156,7 @@ func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) {
func TestAWSv4Signature_GetRequestWithUtf8Path(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil)
setDefaultTime(input)
@@ -172,7 +174,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
func TestAWSv4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil)
setDefaultTime(input)
@@ -190,7 +192,7 @@ func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
func TestAWSv4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil)
setDefaultTime(input)
@@ -208,7 +210,7 @@ func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) {
func TestAWSv4Signature_GetRequestWithUtf8Query(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil)
setDefaultTime(input)
@@ -226,7 +228,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_PostRequest(t *testing.T) {
func TestAWSv4Signature_PostRequest(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
setDefaultTime(input)
input.Header.Add("ZOO", "zoobar")
@@ -246,7 +248,7 @@ func TestAwsV4Signature_PostRequest(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
func TestAWSv4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
setDefaultTime(input)
input.Header.Add("zoo", "ZOOBAR")
@@ -266,7 +268,7 @@ func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_PostRequestPhfft(t *testing.T) {
func TestAWSv4Signature_PostRequestPhfft(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
setDefaultTime(input)
input.Header.Add("p", "phfft")
@@ -286,7 +288,7 @@ func TestAwsV4Signature_PostRequestPhfft(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_PostRequestWithBody(t *testing.T) {
func TestAWSv4Signature_PostRequestWithBody(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar"))
setDefaultTime(input)
input.Header.Add("Content-Type", "application/x-www-form-urlencoded")
@@ -306,7 +308,7 @@ func TestAwsV4Signature_PostRequestWithBody(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) {
func TestAWSv4Signature_PostRequestWithQueryString(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil)
setDefaultTime(input)
@@ -324,7 +326,7 @@ func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output)
}
func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) {
func TestAWSv4Signature_GetRequestWithSecurityToken(t *testing.T) {
input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
@@ -342,7 +344,7 @@ func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) {
testRequestSigner(t, requestSignerWithToken, input, output)
}
func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) {
func TestAWSv4Signature_PostRequestWithSecurityToken(t *testing.T) {
input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
@@ -360,7 +362,7 @@ func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) {
testRequestSigner(t, requestSignerWithToken, input, output)
}
func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) {
func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) {
requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}"
input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
input.Header.Add("Content-Type", "application/x-amz-json-1.0")
@@ -383,7 +385,7 @@ func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test
testRequestSigner(t, requestSignerWithToken, input, output)
}
func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
var requestSigner = &awsRequestSigner{
RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{
@@ -413,30 +415,40 @@ type testAwsServer struct {
securityCredentialURL string
regionURL string
regionalCredVerificationURL string
imdsv2SessionTokenUrl string
Credentials map[string]string
WriteRolename func(http.ResponseWriter)
WriteSecurityCredentials func(http.ResponseWriter)
WriteRegion func(http.ResponseWriter)
WriteRolename func(http.ResponseWriter, *http.Request)
WriteSecurityCredentials func(http.ResponseWriter, *http.Request)
WriteRegion func(http.ResponseWriter, *http.Request)
WriteIMDSv2SessionToken func(http.ResponseWriter, *http.Request)
}
func createAwsTestServer(url, regionURL, regionalCredVerificationURL, rolename, region string, credentials map[string]string) *testAwsServer {
func createAwsTestServer(url, regionURL, regionalCredVerificationURL, imdsv2SessionTokenUrl string, rolename, region string, credentials map[string]string, imdsv2SessionToken string, validateHeaders validateHeaders) *testAwsServer {
server := &testAwsServer{
url: url,
securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename),
regionURL: regionURL,
regionalCredVerificationURL: regionalCredVerificationURL,
imdsv2SessionTokenUrl: imdsv2SessionTokenUrl,
Credentials: credentials,
WriteRolename: func(w http.ResponseWriter) {
WriteRolename: func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
w.Write([]byte(rolename))
},
WriteRegion: func(w http.ResponseWriter) {
WriteRegion: func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
w.Write([]byte(region))
},
WriteIMDSv2SessionToken: func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
w.Write([]byte(imdsv2SessionToken))
},
}
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
jsonCredentials, _ := json.Marshal(server.Credentials)
w.Write(jsonCredentials)
}
@@ -449,6 +461,7 @@ func createDefaultAwsTestServer() *testAwsServer {
"/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"",
"gcp-aws-role",
"us-east-2b",
map[string]string{
@@ -456,31 +469,38 @@ func createDefaultAwsTestServer() *testAwsServer {
"AccessKeyId": accessKeyID,
"Token": securityToken,
},
"",
noHeaderValidation,
)
}
func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch p := r.URL.Path; p {
case server.url:
server.WriteRolename(w)
server.WriteRolename(w, r)
case server.securityCredentialURL:
server.WriteSecurityCredentials(w)
server.WriteSecurityCredentials(w, r)
case server.regionURL:
server.WriteRegion(w)
server.WriteRegion(w, r)
case server.imdsv2SessionTokenUrl:
server.WriteIMDSv2SessionToken(w, r)
}
}
func notFound(w http.ResponseWriter) {
func notFound(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("Not Found"))
}
func noHeaderValidation(r *http.Request) {}
func (server *testAwsServer) getCredentialSource(url string) CredentialSource {
return CredentialSource{
EnvironmentID: "aws1",
URL: url + server.url,
RegionURL: url + server.regionURL,
RegionalCredVerificationURL: server.regionalCredVerificationURL,
IMDSv2SessionTokenURL: url + server.imdsv2SessionTokenUrl,
}
}
@@ -530,7 +550,7 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
return neturl.QueryEscape(string(str))
}
func TestAwsCredential_BasicRequest(t *testing.T) {
func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -567,7 +587,72 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
}
}
func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
func TestAWSCredential_IMDSv2(t *testing.T) {
validateSessionTokenHeaders := func(r *http.Request) {
if r.URL.Path == "/latest/api/token" {
headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
if headerValue != awsIMDSv2SessionTtl {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
}
} else {
headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
if headerValue != "sessiontoken" {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
}
}
}
server := createAwsTestServer(
"/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"/latest/api/token",
"gcp-aws-role",
"us-east-2b",
map[string]string{
"SecretAccessKey": secretAccessKey,
"AccessKeyId": accessKeyID,
"Token": securityToken,
},
"sessiontoken",
validateSessionTokenHeaders,
)
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
securityToken,
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
delete(server.Credentials, "Token")
@@ -605,7 +690,7 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
}
}
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -646,7 +731,7 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
}
}
func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -686,7 +771,7 @@ func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) {
}
}
func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) {
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -727,7 +812,7 @@ func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) {
}
}
func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -748,7 +833,7 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
}
}
func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -775,7 +860,7 @@ func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) {
}
}
func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteRegion = notFound
@@ -802,10 +887,10 @@ func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) {
}
}
func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}"))
}
@@ -831,10 +916,10 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
}
}
func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
}
@@ -860,7 +945,7 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
}
}
func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
@@ -887,7 +972,7 @@ func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) {
}
}
func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteRolename = notFound
@@ -914,7 +999,7 @@ func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) {
}
}
func TestAwsCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteSecurityCredentials = notFound

View File

@@ -175,6 +175,7 @@ type CredentialSource struct {
RegionURL string `json:"region_url"`
RegionalCredVerificationURL string `json:"regional_cred_verification_url"`
CredVerificationURL string `json:"cred_verification_url"`
IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"`
Format format `json:"format"`
}
@@ -185,14 +186,20 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
if awsVersion != 1 {
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
}
return awsCredentialSource{
awsCredSource := awsCredentialSource{
EnvironmentID: c.CredentialSource.EnvironmentID,
RegionURL: c.CredentialSource.RegionURL,
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
CredVerificationURL: c.CredentialSource.URL,
TargetResource: c.Audience,
ctx: ctx,
}, nil
}
if c.CredentialSource.IMDSv2SessionTokenURL != "" {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
}
return awsCredSource, nil
}
} else if c.CredentialSource.File != "" {
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil