diff --git a/staging/src/k8s.io/apimachinery/pkg/util/framer/framer.go b/staging/src/k8s.io/apimachinery/pkg/util/framer/framer.go index 1ab8fd396ed..f18845a417c 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/framer/framer.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/framer/framer.go @@ -91,12 +91,12 @@ func (r *lengthDelimitedFrameReader) Read(data []byte) (int, error) { } n, err := io.ReadAtLeast(r.r, data[:max], int(max)) r.remaining -= n - if err == io.ErrShortBuffer || r.remaining > 0 { - return n, io.ErrShortBuffer - } if err != nil { return n, err } + if r.remaining > 0 { + return n, io.ErrShortBuffer + } if n != expect { return n, io.ErrUnexpectedEOF } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/framer/framer_test.go b/staging/src/k8s.io/apimachinery/pkg/util/framer/framer_test.go index 7275796ab44..7a04fc1b7a7 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/framer/framer_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/framer/framer_test.go @@ -20,7 +20,12 @@ import ( "bytes" "errors" "io" + "net/http" + "net/http/httptest" "testing" + "time" + + netutil "k8s.io/apimachinery/pkg/util/net" ) func TestRead(t *testing.T) { @@ -98,6 +103,7 @@ func TestReadLarge(t *testing.T) { t.Fatalf("unexpected: %v %d", err, n) } } + func TestReadInvalidFrame(t *testing.T) { data := []byte{ 0x00, 0x00, 0x00, 0x04, @@ -120,6 +126,46 @@ func TestReadInvalidFrame(t *testing.T) { } } +func TestReadClientTimeout(t *testing.T) { + header := []byte{ + 0x00, 0x00, 0x00, 0x04, + } + data := []byte{ + 0x01, 0x02, 0x03, 0x04, + 0x00, 0x00, 0x00, 0x03, + 0x05, 0x06, 0x07, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + 0x08, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(header) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + time.Sleep(1 * time.Second) + _, _ = w.Write(data) + })) + defer server.Close() + + client := &http.Client{ + Timeout: 500 * time.Millisecond, + } + + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + r := NewLengthDelimitedFrameReader(resp.Body) + buf := make([]byte, 1) + if n, err := r.Read(buf); err == nil || !netutil.IsTimeout(err) { + t.Fatalf("unexpected: %v %d", err, n) + } +} + func TestJSONFrameReader(t *testing.T) { b := bytes.NewBufferString("{\"test\":true}\n1\n[\"a\"]") r := NewJSONFramedReader(io.NopCloser(b))