changes requested by codyoss@

This commit is contained in:
Ryan Kohler
2021-01-29 15:34:54 -08:00
parent d6857d1e58
commit ba9ae8ce1b
5 changed files with 213 additions and 99 deletions

View File

@@ -5,30 +5,31 @@
package externalaccount
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"golang.org/x/oauth2"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"sort"
"strconv"
"strings"
"time"
)
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsSecurityCredentials struct {
AccessKeyId string
SecretAccessKey string
SecurityToken string
AccessKeyID string `json:"AccessKeyID"`
SecretAccessKey string `json:"SecretAccessKey"`
SecurityToken string `json:"Token"`
}
// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct {
RegionName string
AwsSecurityCredentials awsSecurityCredentials
@@ -229,7 +230,7 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
}
}
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
}
type awsCredentialSource struct {
@@ -240,24 +241,29 @@ type awsCredentialSource struct {
TargetResource string
requestSigner *awsRequestSigner
region string
ctx context.Context
client *http.Client
}
type AwsRequestHeader struct {
type awsRequestHeader struct {
Key string `json:"key"`
Value string `json:"value"`
}
type AwsRequest struct {
type awsRequest struct {
URL string `json:"url"`
Method string `json:"method"`
Headers []AwsRequestHeader `json:"headers"`
Headers []awsRequestHeader `json:"headers"`
}
func (cs awsCredentialSource) request(req *http.Request) (*http.Response, error) {
if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil)
}
return cs.client.Do(req.WithContext(cs.ctx))
}
func (cs awsCredentialSource) subjectToken() (string, error) {
if version, _ := strconv.Atoi(cs.EnvironmentID[3:]); version != 1 {
return "", errors.New(fmt.Sprintf("oauth2/google: aws version '%d' is not supported in the current build.", version))
}
if cs.requestSigner == nil {
awsSecurityCredentials, err := cs.getSecurityCredentials()
if err != nil {
@@ -304,13 +310,13 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
# }))
*/
awsSignedReq := AwsRequest{
awsSignedReq := awsRequest{
URL: req.URL.String(),
Method: req.Method,
Method: "POST",
}
for headerKey, headerList := range req.Header {
for _, headerValue := range headerList {
awsSignedReq.Headers = append(awsSignedReq.Headers, AwsRequestHeader{
awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
Key: headerKey,
Value: headerValue,
})
@@ -337,10 +343,15 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
}
if cs.RegionURL == "" {
return "", errors.New("oauth2/google: Unable to determine AWS region.")
return "", errors.New("oauth2/google: unable to determine AWS region")
}
resp, err := http.Get(cs.RegionURL)
req, err := http.NewRequest("GET", cs.RegionURL, nil)
if err != nil {
return "", err
}
resp, err := cs.request(req)
if err != nil {
return "", err
}
@@ -352,17 +363,23 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
}
if resp.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS region - %s.", string(respBody)))
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
}
return string(respBody)[:len(respBody)-1], nil
// This endpoint will return the region in format: us-east-2b.
// Only the us-east-2 part should be used.
respBodyEnd := 0
if len(respBody) > 1 {
respBodyEnd = len(respBody) - 1
}
return string(respBody[:respBodyEnd]), nil
}
func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials awsSecurityCredentials, err error) {
if accessKeyId := getenv("AWS_ACCESS_KEY_ID"); accessKeyId != "" {
func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{
AccessKeyId: accessKeyId,
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SecurityToken: getenv("AWS_SESSION_TOKEN"),
}, nil
@@ -371,70 +388,64 @@ func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials aws
roleName, err := cs.getMetadataRoleName()
if err != nil {
return awsSecurityCredentials{}, err
return
}
credentials, err := cs.getMetadataSecurityCredentials(roleName)
if err != nil {
return awsSecurityCredentials{}, err
return
}
accessKeyId, ok := credentials["AccessKeyId"]
if !ok {
return awsSecurityCredentials{}, errors.New("oauth2/google: missing AccessKeyId credential.")
if credentials.AccessKeyID == "" {
return result, errors.New("oauth2/google: missing AccessKeyId credential")
}
secretAccessKey, ok := credentials["SecretAccessKey"]
if !ok {
return awsSecurityCredentials{}, errors.New("oauth2/google: missing SecretAccessKey credential.")
if credentials.SecretAccessKey == "" {
return result, errors.New("oauth2/google: missing SecretAccessKey credential")
}
securityToken, _ := credentials["Token"]
return awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
}, nil
return credentials, nil
}
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (map[string]string, error) {
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
if err != nil {
return nil, err
return result, err
}
req.Header.Add("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := cs.request(req)
if err != nil {
return nil, err
return result, err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
return result, err
}
if resp.StatusCode != 200 {
return nil, errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS security credentials - %s.", string(respBody)))
return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
}
var result map[string]string
err = json.Unmarshal(respBody, &result)
if err != nil {
return nil, err
}
return result, nil
return result, err
}
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
if cs.CredVerificationURL == "" {
return "", errors.New("oauth2/google: Unable to determine the AWS metadata server security credentials endpoint.")
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
}
resp, err := http.Get(cs.CredVerificationURL)
req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
if err != nil {
return "", err
}
resp, err := cs.request(req)
if err != nil {
return "", err
}
@@ -446,7 +457,7 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
}
if resp.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS role name - %s.", string(respBody)))
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
}
return string(respBody), nil