forked from remote/oauth2
google/internal: Add AWS Session Token to Metadata Requests
This commit is contained in:
@@ -52,6 +52,13 @@ const (
|
|||||||
// The AWS authorization header name for the security session token if available.
|
// The AWS authorization header name for the security session token if available.
|
||||||
awsSecurityTokenHeader = "x-amz-security-token"
|
awsSecurityTokenHeader = "x-amz-security-token"
|
||||||
|
|
||||||
|
// The name of the header containing the session token for metadata endpoint calls
|
||||||
|
awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
|
||||||
|
|
||||||
|
awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
|
||||||
|
|
||||||
|
awsIMDSv2SessionTtl = "300"
|
||||||
|
|
||||||
// The AWS authorization header name for the auto-generated date.
|
// The AWS authorization header name for the auto-generated date.
|
||||||
awsDateHeader = "x-amz-date"
|
awsDateHeader = "x-amz-date"
|
||||||
|
|
||||||
@@ -241,6 +248,7 @@ type awsCredentialSource struct {
|
|||||||
RegionURL string
|
RegionURL string
|
||||||
RegionalCredVerificationURL string
|
RegionalCredVerificationURL string
|
||||||
CredVerificationURL string
|
CredVerificationURL string
|
||||||
|
IMDSv2SessionTokenURL string
|
||||||
TargetResource string
|
TargetResource string
|
||||||
requestSigner *awsRequestSigner
|
requestSigner *awsRequestSigner
|
||||||
region string
|
region string
|
||||||
@@ -268,12 +276,22 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro
|
|||||||
|
|
||||||
func (cs awsCredentialSource) subjectToken() (string, error) {
|
func (cs awsCredentialSource) subjectToken() (string, error) {
|
||||||
if cs.requestSigner == nil {
|
if cs.requestSigner == nil {
|
||||||
awsSecurityCredentials, err := cs.getSecurityCredentials()
|
awsSessionToken, err := cs.getAWSSessionToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cs.region, err = cs.getRegion(); err != nil {
|
headers := make(map[string]string)
|
||||||
|
if awsSessionToken != "" {
|
||||||
|
headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
|
||||||
|
}
|
||||||
|
|
||||||
|
awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cs.region, err = cs.getRegion(headers); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,7 +358,37 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
|
|||||||
return url.QueryEscape(string(result)), nil
|
return url.QueryEscape(string(result)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getRegion() (string, error) {
|
func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
|
||||||
|
if cs.IMDSv2SessionTokenURL == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
|
||||||
|
|
||||||
|
resp, err := cs.doRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(respBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
|
||||||
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
|
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
|
||||||
return envAwsRegion, nil
|
return envAwsRegion, nil
|
||||||
}
|
}
|
||||||
@@ -357,6 +405,10 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for name, value := range headers {
|
||||||
|
req.Header.Add(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := cs.doRequest(req)
|
resp, err := cs.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -381,7 +433,7 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
|
|||||||
return string(respBody[:respBodyEnd]), nil
|
return string(respBody[:respBodyEnd]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
|
func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
|
||||||
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
|
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
|
||||||
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
|
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
|
||||||
return awsSecurityCredentials{
|
return awsSecurityCredentials{
|
||||||
@@ -392,12 +444,12 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
roleName, err := cs.getMetadataRoleName()
|
roleName, err := cs.getMetadataRoleName(headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials, err := cs.getMetadataSecurityCredentials(roleName)
|
credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -413,7 +465,7 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
|
|||||||
return credentials, nil
|
return credentials, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
|
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) {
|
||||||
var result awsSecurityCredentials
|
var result awsSecurityCredentials
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
|
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
|
||||||
@@ -422,6 +474,10 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
|
|||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
|
for name, value := range headers {
|
||||||
|
req.Header.Add(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := cs.doRequest(req)
|
resp, err := cs.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
@@ -441,7 +497,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
|
func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
|
||||||
if cs.CredVerificationURL == "" {
|
if cs.CredVerificationURL == "" {
|
||||||
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
|
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
|
||||||
}
|
}
|
||||||
@@ -451,6 +507,10 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for name, value := range headers {
|
||||||
|
req.Header.Add(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := cs.doRequest(req)
|
resp, err := cs.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import (
|
|||||||
var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC)
|
var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC)
|
||||||
var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC)
|
var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC)
|
||||||
|
|
||||||
|
type validateHeaders func(r *http.Request)
|
||||||
|
|
||||||
func setTime(testTime time.Time) func() time.Time {
|
func setTime(testTime time.Time) func() time.Time {
|
||||||
return func() time.Time {
|
return func() time.Time {
|
||||||
return testTime
|
return testTime
|
||||||
@@ -82,7 +84,7 @@ func testRequestSigner(t *testing.T, rs *awsRequestSigner, input, expectedOutput
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequest(t *testing.T) {
|
func TestAWSv4Signature_GetRequest(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -100,7 +102,7 @@ func TestAwsV4Signature_GetRequest(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithRelativePath(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -118,7 +120,7 @@ func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithDotPath(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -136,7 +138,7 @@ func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -154,7 +156,7 @@ func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithUtf8Path(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -172,7 +174,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -190,7 +192,7 @@ func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -208,7 +210,7 @@ func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithUtf8Query(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil)
|
input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -226,7 +228,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequest(t *testing.T) {
|
func TestAWSv4Signature_PostRequest(t *testing.T) {
|
||||||
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
|
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
input.Header.Add("ZOO", "zoobar")
|
input.Header.Add("ZOO", "zoobar")
|
||||||
@@ -246,7 +248,7 @@ func TestAwsV4Signature_PostRequest(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
|
func TestAWSv4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
|
||||||
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
|
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
input.Header.Add("zoo", "ZOOBAR")
|
input.Header.Add("zoo", "ZOOBAR")
|
||||||
@@ -266,7 +268,7 @@ func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestPhfft(t *testing.T) {
|
func TestAWSv4Signature_PostRequestPhfft(t *testing.T) {
|
||||||
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
|
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
input.Header.Add("p", "phfft")
|
input.Header.Add("p", "phfft")
|
||||||
@@ -286,7 +288,7 @@ func TestAwsV4Signature_PostRequestPhfft(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestWithBody(t *testing.T) {
|
func TestAWSv4Signature_PostRequestWithBody(t *testing.T) {
|
||||||
input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar"))
|
input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar"))
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
input.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
input.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -306,7 +308,7 @@ func TestAwsV4Signature_PostRequestWithBody(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) {
|
func TestAWSv4Signature_PostRequestWithQueryString(t *testing.T) {
|
||||||
input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil)
|
input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil)
|
||||||
setDefaultTime(input)
|
setDefaultTime(input)
|
||||||
|
|
||||||
@@ -324,7 +326,7 @@ func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) {
|
|||||||
testRequestSigner(t, defaultRequestSigner, input, output)
|
testRequestSigner(t, defaultRequestSigner, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) {
|
func TestAWSv4Signature_GetRequestWithSecurityToken(t *testing.T) {
|
||||||
input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
|
input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
|
||||||
|
|
||||||
output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
|
output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
|
||||||
@@ -342,7 +344,7 @@ func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) {
|
|||||||
testRequestSigner(t, requestSignerWithToken, input, output)
|
testRequestSigner(t, requestSignerWithToken, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) {
|
func TestAWSv4Signature_PostRequestWithSecurityToken(t *testing.T) {
|
||||||
input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
|
input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
|
||||||
|
|
||||||
output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
|
output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
|
||||||
@@ -360,7 +362,7 @@ func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) {
|
|||||||
testRequestSigner(t, requestSignerWithToken, input, output)
|
testRequestSigner(t, requestSignerWithToken, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) {
|
func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) {
|
||||||
requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}"
|
requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}"
|
||||||
input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
|
input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
|
||||||
input.Header.Add("Content-Type", "application/x-amz-json-1.0")
|
input.Header.Add("Content-Type", "application/x-amz-json-1.0")
|
||||||
@@ -383,7 +385,7 @@ func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test
|
|||||||
testRequestSigner(t, requestSignerWithToken, input, output)
|
testRequestSigner(t, requestSignerWithToken, input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
|
func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
|
||||||
var requestSigner = &awsRequestSigner{
|
var requestSigner = &awsRequestSigner{
|
||||||
RegionName: "us-east-2",
|
RegionName: "us-east-2",
|
||||||
AwsSecurityCredentials: awsSecurityCredentials{
|
AwsSecurityCredentials: awsSecurityCredentials{
|
||||||
@@ -413,30 +415,40 @@ type testAwsServer struct {
|
|||||||
securityCredentialURL string
|
securityCredentialURL string
|
||||||
regionURL string
|
regionURL string
|
||||||
regionalCredVerificationURL string
|
regionalCredVerificationURL string
|
||||||
|
imdsv2SessionTokenUrl string
|
||||||
|
|
||||||
Credentials map[string]string
|
Credentials map[string]string
|
||||||
|
|
||||||
WriteRolename func(http.ResponseWriter)
|
WriteRolename func(http.ResponseWriter, *http.Request)
|
||||||
WriteSecurityCredentials func(http.ResponseWriter)
|
WriteSecurityCredentials func(http.ResponseWriter, *http.Request)
|
||||||
WriteRegion func(http.ResponseWriter)
|
WriteRegion func(http.ResponseWriter, *http.Request)
|
||||||
|
WriteIMDSv2SessionToken func(http.ResponseWriter, *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAwsTestServer(url, regionURL, regionalCredVerificationURL, rolename, region string, credentials map[string]string) *testAwsServer {
|
func createAwsTestServer(url, regionURL, regionalCredVerificationURL, imdsv2SessionTokenUrl string, rolename, region string, credentials map[string]string, imdsv2SessionToken string, validateHeaders validateHeaders) *testAwsServer {
|
||||||
server := &testAwsServer{
|
server := &testAwsServer{
|
||||||
url: url,
|
url: url,
|
||||||
securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename),
|
securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename),
|
||||||
regionURL: regionURL,
|
regionURL: regionURL,
|
||||||
regionalCredVerificationURL: regionalCredVerificationURL,
|
regionalCredVerificationURL: regionalCredVerificationURL,
|
||||||
|
imdsv2SessionTokenUrl: imdsv2SessionTokenUrl,
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
WriteRolename: func(w http.ResponseWriter) {
|
WriteRolename: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
validateHeaders(r)
|
||||||
w.Write([]byte(rolename))
|
w.Write([]byte(rolename))
|
||||||
},
|
},
|
||||||
WriteRegion: func(w http.ResponseWriter) {
|
WriteRegion: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
validateHeaders(r)
|
||||||
w.Write([]byte(region))
|
w.Write([]byte(region))
|
||||||
},
|
},
|
||||||
|
WriteIMDSv2SessionToken: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
validateHeaders(r)
|
||||||
|
w.Write([]byte(imdsv2SessionToken))
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
|
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
validateHeaders(r)
|
||||||
jsonCredentials, _ := json.Marshal(server.Credentials)
|
jsonCredentials, _ := json.Marshal(server.Credentials)
|
||||||
w.Write(jsonCredentials)
|
w.Write(jsonCredentials)
|
||||||
}
|
}
|
||||||
@@ -449,6 +461,7 @@ func createDefaultAwsTestServer() *testAwsServer {
|
|||||||
"/latest/meta-data/iam/security-credentials",
|
"/latest/meta-data/iam/security-credentials",
|
||||||
"/latest/meta-data/placement/availability-zone",
|
"/latest/meta-data/placement/availability-zone",
|
||||||
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
|
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
|
||||||
|
"",
|
||||||
"gcp-aws-role",
|
"gcp-aws-role",
|
||||||
"us-east-2b",
|
"us-east-2b",
|
||||||
map[string]string{
|
map[string]string{
|
||||||
@@ -456,31 +469,38 @@ func createDefaultAwsTestServer() *testAwsServer {
|
|||||||
"AccessKeyId": accessKeyID,
|
"AccessKeyId": accessKeyID,
|
||||||
"Token": securityToken,
|
"Token": securityToken,
|
||||||
},
|
},
|
||||||
|
"",
|
||||||
|
noHeaderValidation,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
switch p := r.URL.Path; p {
|
switch p := r.URL.Path; p {
|
||||||
case server.url:
|
case server.url:
|
||||||
server.WriteRolename(w)
|
server.WriteRolename(w, r)
|
||||||
case server.securityCredentialURL:
|
case server.securityCredentialURL:
|
||||||
server.WriteSecurityCredentials(w)
|
server.WriteSecurityCredentials(w, r)
|
||||||
case server.regionURL:
|
case server.regionURL:
|
||||||
server.WriteRegion(w)
|
server.WriteRegion(w, r)
|
||||||
|
case server.imdsv2SessionTokenUrl:
|
||||||
|
server.WriteIMDSv2SessionToken(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func notFound(w http.ResponseWriter) {
|
func notFound(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
w.Write([]byte("Not Found"))
|
w.Write([]byte("Not Found"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func noHeaderValidation(r *http.Request) {}
|
||||||
|
|
||||||
func (server *testAwsServer) getCredentialSource(url string) CredentialSource {
|
func (server *testAwsServer) getCredentialSource(url string) CredentialSource {
|
||||||
return CredentialSource{
|
return CredentialSource{
|
||||||
EnvironmentID: "aws1",
|
EnvironmentID: "aws1",
|
||||||
URL: url + server.url,
|
URL: url + server.url,
|
||||||
RegionURL: url + server.regionURL,
|
RegionURL: url + server.regionURL,
|
||||||
RegionalCredVerificationURL: server.regionalCredVerificationURL,
|
RegionalCredVerificationURL: server.regionalCredVerificationURL,
|
||||||
|
IMDSv2SessionTokenURL: url + server.imdsv2SessionTokenUrl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,7 +550,7 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
|
|||||||
return neturl.QueryEscape(string(str))
|
return neturl.QueryEscape(string(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_BasicRequest(t *testing.T) {
|
func TestAWSCredential_BasicRequest(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -567,7 +587,72 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
func TestAWSCredential_IMDSv2(t *testing.T) {
|
||||||
|
validateSessionTokenHeaders := func(r *http.Request) {
|
||||||
|
if r.URL.Path == "/latest/api/token" {
|
||||||
|
headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
|
||||||
|
if headerValue != awsIMDSv2SessionTtl {
|
||||||
|
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
|
||||||
|
if headerValue != "sessiontoken" {
|
||||||
|
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
server := createAwsTestServer(
|
||||||
|
"/latest/meta-data/iam/security-credentials",
|
||||||
|
"/latest/meta-data/placement/availability-zone",
|
||||||
|
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
|
||||||
|
"/latest/api/token",
|
||||||
|
"gcp-aws-role",
|
||||||
|
"us-east-2b",
|
||||||
|
map[string]string{
|
||||||
|
"SecretAccessKey": secretAccessKey,
|
||||||
|
"AccessKeyId": accessKeyID,
|
||||||
|
"Token": securityToken,
|
||||||
|
},
|
||||||
|
"sessiontoken",
|
||||||
|
validateSessionTokenHeaders,
|
||||||
|
)
|
||||||
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
|
tfc := testFileConfig
|
||||||
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
|
oldGetenv := getenv
|
||||||
|
defer func() { getenv = oldGetenv }()
|
||||||
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
oldNow := now
|
||||||
|
defer func() { now = oldNow }()
|
||||||
|
now = setTime(defaultTime)
|
||||||
|
|
||||||
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := getExpectedSubjectToken(
|
||||||
|
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
|
||||||
|
"us-east-2",
|
||||||
|
accessKeyID,
|
||||||
|
secretAccessKey,
|
||||||
|
securityToken,
|
||||||
|
)
|
||||||
|
|
||||||
|
if got, want := out, expected; !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
delete(server.Credentials, "Token")
|
delete(server.Credentials, "Token")
|
||||||
@@ -605,7 +690,7 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -646,7 +731,7 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -686,7 +771,7 @@ func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -727,7 +812,7 @@ func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
|
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -748,7 +833,7 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) {
|
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -775,7 +860,7 @@ func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) {
|
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
server.WriteRegion = notFound
|
server.WriteRegion = notFound
|
||||||
@@ -802,10 +887,10 @@ func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
|
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
|
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("{}"))
|
w.Write([]byte("{}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -831,10 +916,10 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
|
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
|
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
|
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -860,7 +945,7 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) {
|
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
|
||||||
@@ -887,7 +972,7 @@ func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) {
|
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
server.WriteRolename = notFound
|
server.WriteRolename = notFound
|
||||||
@@ -914,7 +999,7 @@ func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAwsCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
server.WriteSecurityCredentials = notFound
|
server.WriteSecurityCredentials = notFound
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ type CredentialSource struct {
|
|||||||
RegionURL string `json:"region_url"`
|
RegionURL string `json:"region_url"`
|
||||||
RegionalCredVerificationURL string `json:"regional_cred_verification_url"`
|
RegionalCredVerificationURL string `json:"regional_cred_verification_url"`
|
||||||
CredVerificationURL string `json:"cred_verification_url"`
|
CredVerificationURL string `json:"cred_verification_url"`
|
||||||
|
IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"`
|
||||||
Format format `json:"format"`
|
Format format `json:"format"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,14 +186,20 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
|
|||||||
if awsVersion != 1 {
|
if awsVersion != 1 {
|
||||||
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
|
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
|
||||||
}
|
}
|
||||||
return awsCredentialSource{
|
|
||||||
|
awsCredSource := awsCredentialSource{
|
||||||
EnvironmentID: c.CredentialSource.EnvironmentID,
|
EnvironmentID: c.CredentialSource.EnvironmentID,
|
||||||
RegionURL: c.CredentialSource.RegionURL,
|
RegionURL: c.CredentialSource.RegionURL,
|
||||||
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
|
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
|
||||||
CredVerificationURL: c.CredentialSource.URL,
|
CredVerificationURL: c.CredentialSource.URL,
|
||||||
TargetResource: c.Audience,
|
TargetResource: c.Audience,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
}, nil
|
}
|
||||||
|
if c.CredentialSource.IMDSv2SessionTokenURL != "" {
|
||||||
|
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return awsCredSource, nil
|
||||||
}
|
}
|
||||||
} else if c.CredentialSource.File != "" {
|
} else if c.CredentialSource.File != "" {
|
||||||
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
|
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user