Fixes requested by codyoss@

This commit is contained in:
Ryan Kohler
2021-01-20 10:44:08 -08:00
parent 3f1a1ba4db
commit d1a7728e50
2 changed files with 198 additions and 55 deletions

View File

@@ -1,13 +1,16 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"path"
@@ -16,13 +19,13 @@ import (
"time"
)
// A utility class to sign http requests using a AWS V4 signature
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
type RequestSigner struct {
RegionName string
AwsSecurityCredentials map[string]string
debugTimestamp time.Time
}
// NewRequestSigner is a method to create a RequestSigner with the appropriately filled out fields.
func NewRequestSigner(regionName string, awsSecurityCredentials map[string]string) *RequestSigner {
return &RequestSigner{
RegionName: regionName,
@@ -30,32 +33,60 @@ func NewRequestSigner(regionName string, awsSecurityCredentials map[string]strin
}
}
const (
// AWS Signature Version 4 signing algorithm identifier.
const awsAlgorithm = "AWS4-HMAC-SHA256"
awsAlgorithm = "AWS4-HMAC-SHA256"
// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
const awsRequestType = "aws4_request"
awsRequestType = "aws4_request"
// The AWS authorization header name for the security session token if available.
const awsSecurityTokenHeader = "x-amz-security-token"
awsSecurityTokenHeader = "x-amz-security-token"
// The AWS authorization header name for the auto-generated date.
const awsDateHeader = "x-amz-date"
awsDateHeader = "x-amz-date"
const awsTimeFormatLong = "20060102T150405Z"
const awsTimeFormatShort = "20060102"
awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatShort = "20060102"
)
func getSha256(input []byte) string {
func getSha256(input []byte) (string, error) {
hash := sha256.New()
hash.Write(input)
return hex.EncodeToString(hash.Sum(nil))
if _, err := hash.Write(input); err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
func getHmacSha256(key, input []byte) []byte {
func getHmacSha256(key, input []byte) ([]byte, error) {
hash := hmac.New(sha256.New, key)
hash.Write(input)
return hash.Sum(nil)
if _, err := hash.Write(input); err != nil {
return nil, err
}
return hash.Sum(nil), nil
}
func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
if r.Header != nil {
r2.Header = make(http.Header, len(r.Header))
// Find total number of values.
headerCount := 0
for _, headerValues := range r.Header {
headerCount += len(headerValues)
}
copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
for headerKey, headerValues := range r.Header {
headerCount = copy(copiedHeaders, headerValues)
r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
copiedHeaders = copiedHeaders[headerCount:]
}
}
return r2
}
func canonicalPath(req *http.Request) string {
@@ -90,20 +121,31 @@ func canonicalHeaders(req *http.Request) (string, string) {
}
sort.Strings(headers)
var fullHeaders []string
var fullHeaders strings.Builder
for _, header := range headers {
headerValue := strings.Join(lowerCaseHeaders[header], ",")
fullHeaders = append(fullHeaders, header+":"+headerValue+"\n")
fullHeaders.WriteString(header)
fullHeaders.WriteRune(':')
fullHeaders.WriteString(headerValue)
fullHeaders.WriteRune('\n')
}
return strings.Join(headers, ";"), strings.Join(fullHeaders, "")
return strings.Join(headers, ";"), fullHeaders.String()
}
func requestDataHash(req *http.Request) string {
requestData := []byte{}
func requestDataHash(req *http.Request) (string, error) {
var requestData []byte
if req.Body != nil {
requestBody, _ := req.GetBody()
requestData, _ = ioutil.ReadAll(requestBody)
requestBody, err := req.GetBody()
if err != nil {
return "", err
}
defer requestBody.Close()
requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
if err != nil {
return "", err
}
}
return getSha256(requestData)
@@ -116,63 +158,93 @@ func requestHost(req *http.Request) string {
return req.URL.Host
}
func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) string {
func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
dataHash, err := requestDataHash(req)
if err != nil {
return "", err
}
return strings.Join([]string{
req.Method,
canonicalPath(req),
canonicalQuery(req),
canonicalHeaderData,
canonicalHeaderColumns,
requestDataHash(req),
}, "\n")
dataHash,
}, "\n"), nil
}
func (rs *RequestSigner) SignedRequest(req *http.Request) *http.Request {
timestamp := rs.debugTimestamp
if timestamp.IsZero() {
timestamp = time.Now()
}
signedRequest := req.Clone(req.Context())
// SignRequest adds the appropriate headers to an http.Request
// or returns an error if something prevented this.
func (rs *RequestSigner) SignRequest(req *http.Request) error {
signedRequest := cloneRequest(req)
timestamp := now()
signedRequest.Header.Add("host", requestHost(req))
securityToken, ok := rs.AwsSecurityCredentials["security_token"]
if ok {
signedRequest.Header.Add("x-amz-security-token", securityToken)
if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
}
if signedRequest.Header.Get("date") == "" {
signedRequest.Header.Add("x-amz-date", timestamp.Format(awsTimeFormatLong))
signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
}
signedRequest.Header.Set("Authorization", rs.generateAuthentication(signedRequest, timestamp))
authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
if err != nil {
return err
}
signedRequest.Header.Set("Authorization", authorizationCode)
return signedRequest
req.Header = signedRequest.Header
return nil
}
func (rs *RequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) string {
func (rs *RequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
if !ok {
return "", errors.New("Missing Secret Access Key")
}
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
if !ok {
return "", errors.New("Missing Access Key Id")
}
canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
dateStamp := timestamp.Format(awsTimeFormatShort)
serviceName := strings.Split(requestHost(req), ".")[0]
serviceName := ""
if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
serviceName = splitHost[0]
}
credentialScope := strings.Join([]string{
dateStamp, rs.RegionName, serviceName, awsRequestType,
}, "/")
credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
if err != nil {
return "", err
}
requestHash, err := getSha256([]byte(requestString))
if err != nil{
return "", err
}
stringToSign := strings.Join([]string{
awsAlgorithm,
timestamp.Format(awsTimeFormatLong),
credentialScope,
getSha256([]byte(canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData))),
requestHash,
}, "\n")
signingKey := []byte("AWS4" + rs.AwsSecurityCredentials["secret_access_key"])
signingKey := []byte("AWS4" + secretAccessKey)
for _, signingInput := range []string{
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
} {
signingKey = getHmacSha256(signingKey, []byte(signingInput))
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
if err != nil{
return "", err
}
}
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials["access_key_id"], credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey))
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
}