Add a test to validate deferredResponseWriteron multiple write calls

Signed-off-by: nkeert <197718357+nkeert@users.noreply.github.com>
This commit is contained in:
nkeert
2025-02-15 10:23:21 +05:30
parent 8dbc6739e0
commit 45e2f3e438

View File

@@ -19,6 +19,7 @@ package responsewriters
import (
"bytes"
"compress/gzip"
"context"
"encoding/hex"
"encoding/json"
"errors"
@@ -32,6 +33,7 @@ import (
"os"
"reflect"
"strconv"
"strings"
"testing"
"time"
@@ -371,6 +373,124 @@ func TestSerializeObject(t *testing.T) {
}
}
func TestDeferredResponseWriter_Write(t *testing.T) {
smallChunk := bytes.Repeat([]byte("b"), defaultGzipThresholdBytes-1)
largeChunk := bytes.Repeat([]byte("b"), defaultGzipThresholdBytes+1)
tests := []struct {
name string
chunks [][]byte
expectGzip bool
}{
{
name: "one small chunk write",
chunks: [][]byte{smallChunk},
expectGzip: false,
},
{
name: "two small chunk writes",
chunks: [][]byte{smallChunk, smallChunk},
expectGzip: false,
},
{
name: "one large chunk writes",
chunks: [][]byte{largeChunk},
expectGzip: true,
},
{
name: "two large chunk writes",
chunks: [][]byte{largeChunk, largeChunk},
expectGzip: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockResponseWriter := httptest.NewRecorder()
drw := &deferredResponseWriter{
mediaType: "text/plain",
statusCode: 200,
contentEncoding: "gzip",
hw: mockResponseWriter,
ctx: context.Background(),
}
fullPayload := []byte{}
for _, chunk := range tt.chunks {
n, err := drw.Write(chunk)
if err != nil {
t.Fatalf("unexpected error while writing chunk: %v", err)
}
if n != len(chunk) {
t.Errorf("write is not complete, expected: %d bytes, written: %d bytes", len(chunk), n)
}
fullPayload = append(fullPayload, chunk...)
}
err := drw.Close()
if err != nil {
t.Fatalf("unexpected error when closing deferredResponseWriter: %v", err)
}
res := mockResponseWriter.Result()
if res.StatusCode != http.StatusOK {
t.Fatalf("status code is not writtend properly, expected: 200, got: %d", res.StatusCode)
}
contentEncoding := res.Header.Get("Content-Encoding")
varyHeader := res.Header.Get("Vary")
resBytes, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("unexpected error occurred while reading response body: %v", err)
}
if tt.expectGzip {
if contentEncoding != "gzip" {
t.Fatalf("content-encoding is not set properly, expected: gzip, got: %s", contentEncoding)
}
if !strings.Contains(varyHeader, "Accept-Encoding") {
t.Errorf("vary header doesn't have Accept-Encoding")
}
gr, err := gzip.NewReader(bytes.NewReader(resBytes))
if err != nil {
t.Fatalf("failed to create gzip reader: %v", err)
}
decompressed, err := io.ReadAll(gr)
if err != nil {
t.Fatalf("failed to decompress: %v", err)
}
if !bytes.Equal(fullPayload, decompressed) {
t.Errorf("payload mismatch, expected: %s, got: %s", fullPayload, decompressed)
}
} else {
if contentEncoding != "" {
t.Errorf("content-encoding is set unexpectedly")
}
if strings.Contains(varyHeader, "Accept-Encoding") {
t.Errorf("accept encoding is set unexpectedly")
}
if !bytes.Equal(fullPayload, resBytes) {
t.Errorf("payload mismatch, expected: %s, got: %s", fullPayload, resBytes)
}
}
})
}
}
func randTime(t *time.Time, r *rand.Rand) {
*t = time.Unix(r.Int63n(1000*365*24*60*60), r.Int63())
}