Skip to content
This repository was archived by the owner on Mar 18, 2025. It is now read-only.

Commit 4ac35b2

Browse files
committed
- Ported code from aws-sdk-go for buildingCanonicalHeaders
- Ported code from aws-sdk-go for proper query string handling - Refactored code to reduce methods - Added validations for tripper
1 parent 5f2f038 commit 4ac35b2

File tree

6 files changed

+299
-109
lines changed

6 files changed

+299
-109
lines changed

pkg/remote/client.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ func NewWriteClient(endpoint string, cfg *HTTPConfig) (*WriteClient, error) {
6363
}
6464
}
6565
if cfg.SigV4 != nil {
66-
wc.hc.Transport = sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
66+
tripper, err := sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
67+
if err != nil {
68+
return nil, err
69+
}
70+
wc.hc.Transport = tripper
6771
}
6872
return wc, nil
6973
}

pkg/sigv4/const.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package sigv4
2+
3+
const (
4+
awsServiceName = "aps"
5+
signingAlgorithm = "AWS4-HMAC-SHA256"
6+
7+
authorizationHeaderKey = "Authorization"
8+
amzDateKey = "X-Amz-Date"
9+
10+
// emptyStringSHA256 is the hex encoded sha256 value of an empty string
11+
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
12+
13+
// timeFormat is the time format to be used in the X-Amz-Date header or query parameter
14+
timeFormat = "20060102T150405Z"
15+
16+
// shortTimeFormat is the shorten time format used in the credential scope
17+
shortTimeFormat = "20060102"
18+
19+
// contentSHAKey is the SHA256 of request body
20+
contentSHAKey = "X-Amz-Content-Sha256"
21+
)

pkg/sigv4/sigv4.go

Lines changed: 151 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -10,73 +10,84 @@ import (
1010
"net/http"
1111
"net/url"
1212
"sort"
13+
"strconv"
1314
"strings"
1415
"time"
1516
)
1617

17-
const signingAlgo = "AWS4-HMAC-SHA256"
18-
const awsServiceName = "aps"
19-
2018
type Signer interface {
2119
Sign(req *http.Request) error
2220
}
2321

2422
type DefaultSigner struct {
25-
iSO8601Date string
26-
canonicalHeaders string
27-
signedHeaders string
28-
credentialScope string
29-
config *Config
30-
payloadHash string
23+
config *Config
3124
}
3225

3326
func NewDefaultSigner(config *Config) Signer {
34-
return &DefaultSigner{
35-
config: config,
27+
// initialize noEscape array. This way we can avoid using init() functions
28+
for i := 0; i < len(noEscape); i++ {
29+
// AWS expects every character except these to be escaped
30+
noEscape[i] = (i >= 'A' && i <= 'Z') ||
31+
(i >= 'a' && i <= 'z') ||
32+
(i >= '0' && i <= '9') ||
33+
i == '-' ||
34+
i == '.' ||
35+
i == '_' ||
36+
i == '~'
3637
}
38+
39+
return &DefaultSigner{config: config}
3740
}
3841

3942
func (d *DefaultSigner) Sign(req *http.Request) error {
4043
now := time.Now().UTC()
41-
d.iSO8601Date = now.Format("20060102T150405Z")
42-
d.credentialScope = fmt.Sprintf(
43-
"%s/%s/%s/aws4_request",
44-
now.UTC().Format("20060102"),
45-
d.config.Region,
46-
awsServiceName,
47-
)
44+
iSO8601Date := now.Format(timeFormat)
45+
46+
credentialScope := buildCredentialScope(now, d.config.Region)
4847

4948
payloadHash, err := d.getPayloadHash(req)
5049
if err != nil {
5150
return err
5251
}
5352

54-
d.payloadHash = payloadHash
55-
d.addRequiredHeaders(req)
56-
d.canonicalHeaders, d.signedHeaders = d.getCanonicalAndSignedHeaders(req)
53+
req.Header.Set("Host", req.Host)
54+
req.Header.Set(amzDateKey, iSO8601Date)
55+
req.Header.Set(contentSHAKey, payloadHash)
5756

58-
canonicalReq := d.createCanonicalRequest(req)
59-
stringToSign, err := d.createStringToSign(canonicalReq)
60-
if err != nil {
61-
return err
62-
}
57+
_, signedHeadersStr, canonicalHeaderStr := buildCanonicalHeaders(req)
58+
59+
canonicalQueryString := getCanonicalQueryString(req.URL)
60+
canonicalReq := buildCanonicalString(
61+
req.Method,
62+
getCanonicalURI(req.URL),
63+
canonicalQueryString,
64+
canonicalHeaderStr,
65+
signedHeadersStr,
66+
payloadHash,
67+
)
68+
69+
signature := sign(
70+
deriveKey(d.config.AwsSecretAccessKey, d.config.Region),
71+
buildStringToSign(iSO8601Date, credentialScope, canonicalReq),
72+
)
6373

64-
signature := d.sign(d.createSigningKey(), stringToSign)
6574
authorizationHeader := fmt.Sprintf(
6675
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
67-
signingAlgo,
76+
signingAlgorithm,
6877
d.config.AwsAccessKeyID,
69-
d.credentialScope,
70-
d.signedHeaders,
78+
credentialScope,
79+
signedHeadersStr,
7180
signature,
7281
)
73-
req.Header.Set("Authorization", authorizationHeader)
82+
83+
req.URL.RawQuery = canonicalQueryString
84+
req.Header.Set(authorizationHeaderKey, authorizationHeader)
7485
return nil
7586
}
7687

7788
func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
7889
if req.Body == nil {
79-
return hex.EncodeToString(sha256.New().Sum(nil)), nil
90+
return emptyStringSHA256, nil
8091
}
8192

8293
reqBody, err := io.ReadAll(req.Body)
@@ -98,101 +109,144 @@ func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
98109
return payloadHash, nil
99110
}
100111

101-
func (d *DefaultSigner) addRequiredHeaders(req *http.Request) {
102-
req.Header.Set("Host", req.Host)
103-
req.Header.Set("x-amz-date", d.iSO8601Date)
104-
req.Header.Set("x-amz-content-sha256", d.payloadHash)
112+
func buildCredentialScope(signingTime time.Time, region string) string {
113+
return fmt.Sprintf(
114+
"%s/%s/%s/aws4_request",
115+
signingTime.UTC().Format(shortTimeFormat),
116+
region,
117+
awsServiceName,
118+
)
105119
}
106120

107-
func (d *DefaultSigner) getCanonicalAndSignedHeaders(req *http.Request) (string, string) {
108-
var headers []string
109-
var signedHeaders []string
121+
func buildCanonicalString(method, uri, query, canonicalHeaders, signedHeaders, payloadHash string) string {
122+
return strings.Join([]string{
123+
method,
124+
uri,
125+
query,
126+
canonicalHeaders,
127+
signedHeaders,
128+
payloadHash,
129+
}, "\n")
130+
}
110131

111-
for key, value := range req.Header {
112-
lowercaseKey := strings.ToLower(key)
113-
encodedValue := strings.TrimSpace(strings.Join(value, ","))
114-
headers = append(headers, lowercaseKey+":"+encodedValue)
115-
signedHeaders = append(signedHeaders, lowercaseKey)
116-
}
132+
var ignoredHeaders = map[string]struct{}{
133+
"Authorization": struct{}{},
134+
"User-Agent": struct{}{},
135+
"X-Amzn-Trace-Id": struct{}{},
136+
"Expect": struct{}{},
137+
}
117138

118-
sort.Strings(headers)
119-
sort.Strings(signedHeaders)
139+
func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
140+
host, header, length := req.Host, req.Header, req.ContentLength
120141

121-
canonicalHeaders := strings.Join(headers, "\n") + "\n"
122-
canonicalSignedHeaders := strings.Join(signedHeaders, ";")
123-
return canonicalHeaders, canonicalSignedHeaders
124-
}
142+
signed = make(http.Header)
125143

126-
func (d *DefaultSigner) createCanonicalRequest(req *http.Request) string {
127-
return strings.Join([]string{
128-
req.Method,
129-
d.getCanonicalURI(req.URL),
130-
d.getCanonicalQueryString(req.URL),
131-
d.canonicalHeaders,
132-
d.signedHeaders,
133-
d.payloadHash,
134-
}, "\n")
135-
}
144+
var headers []string
145+
const hostHeader = "host"
146+
headers = append(headers, hostHeader)
147+
signed[hostHeader] = append(signed[hostHeader], host)
148+
149+
const contentLengthHeader = "content-length"
150+
if length > 0 {
151+
headers = append(headers, contentLengthHeader)
152+
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
153+
}
154+
155+
for k, v := range header {
156+
if _, ok := ignoredHeaders[k]; ok {
157+
continue // ignored header
158+
}
159+
if strings.EqualFold(k, contentLengthHeader) {
160+
// prevent signing already handled content-length header.
161+
continue
162+
}
136163

137-
func (d *DefaultSigner) getCanonicalURI(u *url.URL) string {
138-
if u.Path == "" {
139-
return "/"
164+
lowerCaseKey := strings.ToLower(k)
165+
if _, ok := signed[lowerCaseKey]; ok {
166+
// include additional values
167+
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
168+
continue
169+
}
170+
171+
headers = append(headers, lowerCaseKey)
172+
signed[lowerCaseKey] = v
140173
}
174+
sort.Strings(headers)
141175

142-
// The spec requires not to encode `/`
143-
segments := strings.Split(u.Path, "/")
144-
for i, segment := range segments {
145-
segments[i] = url.PathEscape(segment)
176+
signedHeaders = strings.Join(headers, ";")
177+
178+
var canonicalHeaders strings.Builder
179+
n := len(headers)
180+
const colon = ':'
181+
for i := 0; i < n; i++ {
182+
if headers[i] == hostHeader {
183+
canonicalHeaders.WriteString(hostHeader)
184+
canonicalHeaders.WriteRune(colon)
185+
canonicalHeaders.WriteString(stripExcessSpaces(host))
186+
} else {
187+
canonicalHeaders.WriteString(headers[i])
188+
canonicalHeaders.WriteRune(colon)
189+
// Trim out leading, trailing, and dedup inner spaces from signed header values.
190+
values := signed[headers[i]]
191+
for j, v := range values {
192+
cleanedValue := strings.TrimSpace(stripExcessSpaces(v))
193+
canonicalHeaders.WriteString(cleanedValue)
194+
if j < len(values)-1 {
195+
canonicalHeaders.WriteRune(',')
196+
}
197+
}
198+
}
199+
canonicalHeaders.WriteRune('\n')
146200
}
201+
canonicalHeadersStr = canonicalHeaders.String()
147202

148-
return strings.Join(segments, "/")
203+
return signed, signedHeaders, canonicalHeadersStr
149204
}
150205

151-
func (d *DefaultSigner) getCanonicalQueryString(u *url.URL) string {
152-
queryParams := u.Query()
153-
var queryPairs []string
206+
func getCanonicalURI(u *url.URL) string {
207+
return escapePath(getURIPath(u), false)
208+
}
154209

155-
for key, values := range queryParams {
156-
for _, value := range values {
157-
queryPairs = append(queryPairs, url.QueryEscape(key)+"="+url.QueryEscape(value))
158-
}
159-
}
210+
func getCanonicalQueryString(u *url.URL) string {
211+
query := u.Query()
160212

161-
sort.Strings(queryPairs)
213+
// Sort Each Query Key's Values
214+
for key := range query {
215+
sort.Strings(query[key])
216+
}
162217

163-
return strings.Join(queryPairs, "&")
218+
var rawQuery strings.Builder
219+
rawQuery.WriteString(strings.Replace(query.Encode(), "+", "%20", -1))
220+
return rawQuery.String()
164221
}
165222

166-
func (d *DefaultSigner) createStringToSign(canonicalRequest string) (string, error) {
223+
func buildStringToSign(amzDate, credentialScope, canonicalRequestString string) string {
167224
hash := sha256.New()
168-
if _, err := hash.Write([]byte(canonicalRequest)); err != nil {
169-
return "", err
170-
}
171-
return fmt.Sprintf(
172-
"%s\n%s\n%s\n%s",
173-
signingAlgo,
174-
d.iSO8601Date,
175-
d.credentialScope,
225+
hash.Write([]byte(canonicalRequestString))
226+
return strings.Join([]string{
227+
signingAlgorithm,
228+
amzDate,
229+
credentialScope,
176230
hex.EncodeToString(hash.Sum(nil)),
177-
), nil
231+
}, "\n")
178232
}
179233

180-
func (d *DefaultSigner) createSigningKey() string {
181-
signingDate := time.Now().UTC().Format("20060102")
182-
dateKey := d.hmacSHA256([]byte("AWS4"+d.config.AwsSecretAccessKey), signingDate)
183-
dateRegionKey := d.hmacSHA256(dateKey, d.config.Region)
184-
dateRegionServiceKey := d.hmacSHA256(dateRegionKey, awsServiceName)
185-
signingKey := d.hmacSHA256(dateRegionServiceKey, "aws4_request")
234+
func deriveKey(secretKey, region string) string {
235+
signingDate := time.Now().UTC().Format(shortTimeFormat)
236+
hmacDate := hmacSHA256([]byte("AWS4"+secretKey), signingDate)
237+
hmacRegion := hmacSHA256(hmacDate, region)
238+
hmacService := hmacSHA256(hmacRegion, awsServiceName)
239+
signingKey := hmacSHA256(hmacService, "aws4_request")
186240
return string(signingKey)
187241
}
188242

189-
func (d *DefaultSigner) hmacSHA256(key []byte, data string) []byte {
243+
func hmacSHA256(key []byte, data string) []byte {
190244
h := hmac.New(sha256.New, key)
191245
h.Write([]byte(data))
192246
return h.Sum(nil)
193247
}
194248

195-
func (d *DefaultSigner) sign(signingKey string, strToSign string) string {
249+
func sign(signingKey string, strToSign string) string {
196250
h := hmac.New(sha256.New, []byte(signingKey))
197251
h.Write([]byte(strToSign))
198252
sig := hex.EncodeToString(h.Sum(nil))

pkg/sigv4/tripper.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sigv4
22

33
import (
4+
"errors"
45
"net/http"
56
)
67

@@ -16,15 +17,21 @@ type Config struct {
1617
AwsAccessKeyID string
1718
}
1819

19-
func NewRoundTripper(config *Config, next http.RoundTripper) *Tripper {
20+
func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) {
21+
if config == nil {
22+
return nil, errors.New("can't initialize a sigv4 round tripper with nil config")
23+
}
24+
2025
if next == nil {
2126
next = http.DefaultTransport
2227
}
23-
return &Tripper{
28+
29+
tripper := &Tripper{
2430
config: config,
2531
next: next,
2632
signer: NewDefaultSigner(config),
2733
}
34+
return tripper, nil
2835
}
2936

3037
func (c *Tripper) RoundTrip(req *http.Request) (*http.Response, error) {

0 commit comments

Comments
 (0)