Skip to content
48 changes: 45 additions & 3 deletions http/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/gob"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"mime"
Expand Down Expand Up @@ -175,9 +176,50 @@ func RequestEncoder(r *http.Request) Encoder {
if h := r.Header.Get(k); h == "" {
r.Header.Set(k, "application/json")
}
var buf bytes.Buffer
r.Body = io.NopCloser(&buf)
return json.NewEncoder(&buf)
enc := new(jsonEncoder)
r.Body = enc
// GetBody enables request retry on HTTP/2 connections when the server
// sends GOAWAY during graceful shutdown. Without GetBody, the HTTP transport
// cannot retry because the request body has already been consumed.
r.GetBody = enc.GetBody
return enc
}

// jsonEncoder implements io.ReadCloser and provides GetBody functionality
// to support HTTP/2 request retries during server graceful shutdown (GOAWAY).
type jsonEncoder struct {
b []byte
r bytes.Reader
}

var errEncodeNotCalled = errors.New("RequestEncoder: Encode must be called prior to reading")

func (je *jsonEncoder) Read(b []byte) (n int, err error) {
if len(je.b) == 0 {
return 0, errEncodeNotCalled
}
return je.r.Read(b)
}

func (*jsonEncoder) Close() (err error) { return nil }

func (je *jsonEncoder) Encode(v any) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
je.b = b
je.r = *bytes.NewReader(b)
return nil
}

// GetBody returns a new reader of the encoded bytes, enabling request retries.
// This is required for HTTP/2 connections to handle server GOAWAY during graceful shutdown.
func (je *jsonEncoder) GetBody() (io.ReadCloser, error) {
if len(je.b) == 0 {
return nil, errEncodeNotCalled
}
return io.NopCloser(bytes.NewReader(je.b)), nil
}

// ResponseDecoder returns a HTTP response decoder.
Expand Down
29 changes: 27 additions & 2 deletions http/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -44,14 +45,38 @@ func TestRequestEncoder(t *testing.T) {
r.Header.Set(ct, c.requestCT)
}

encoder := RequestEncoder(r)
_ = RequestEncoder(r)

assert.Equal(t, wantT, fmt.Sprintf("%T", encoder))
assert.Equal(t, c.wantCT, r.Header.Get(ct))
})
}
}

func TestRequestEncoderGetBody(t *testing.T) {
r := &http.Request{Header: http.Header{}}
encoder := RequestEncoder(r)

_, err := r.Body.Read(nil)
assert.Error(t, err, "request Body should error (but not panic) if read before data is encoded")

_, err = r.GetBody()
assert.Error(t, err, "request GetBody should error (but not panic) if read before data is encoded")

err = encoder.Encode("body")
require.NoError(t, err)

bodyContents, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, `"body"`, string(bodyContents))

newBody, err := r.GetBody()
require.NoError(t, err)

newBodyContents, err := io.ReadAll(newBody)
require.NoError(t, err)
assert.Equal(t, bodyContents, newBodyContents)
}

func TestRequestDecoder(t *testing.T) {
const (
ct = "Content-Type"
Expand Down
Loading