From 770ce2d874bcb053af5c76ee07b3caeafb212bd2 Mon Sep 17 00:00:00 2001 From: Tim Hockin Date: Sat, 8 Mar 2025 15:38:10 -0800 Subject: [PATCH] Better handling of YAML that tastes like JSON For the most part, JSON is a subset of YAML. This might lead one to think that we should ALWAYS use YAML processing. Unfortunately a JSON "stream" (as defined by Go's encoding/json and many other places, though not the JSON spec) is a series of JSON objects. E.g. This: ``` {}{}{} ``` ...is a valid JSON stream. YAML does NOT accept that, insisting on `---` on a new line between YAML documents. Before this commit, YAMLOrJSONDecoder tries to detect if the input is JSON by looking at the first few characters for "{". Unfortunately, some perfectly valid YAML also tastes like that. After this commit, YAMLOrJSONDecoder will detect a failure to parse as JSON and instead flip to YAML parsing. This should handle the ambiguous YAML. Once we flip to YAML we never flip back, and once we detect a JSON stream (as defined above) we lose the ability to flip to YAML. A multi-document is either all JSON or all YAML, even if we use the JSON parser to decode the first object (because JSON is YAML for a single object). --- .../apimachinery/pkg/util/yaml/decoder.go | 163 ++++++-- .../pkg/util/yaml/decoder_test.go | 52 ++- .../pkg/util/yaml/stream_reader.go | 130 ++++++ .../pkg/util/yaml/stream_reader_test.go | 388 ++++++++++++++++++ 4 files changed, 684 insertions(+), 49 deletions(-) create mode 100644 staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader.go create mode 100644 staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader_test.go diff --git a/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder.go b/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder.go index 9837b3df281..7342f8d1e62 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder.go @@ -20,10 +20,12 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "io" "strings" "unicode" + "unicode/utf8" jsonutil "k8s.io/apimachinery/pkg/util/json" @@ -92,7 +94,7 @@ func UnmarshalStrict(data []byte, v interface{}) error { // YAML decoding path is not used (so that error messages are // JSON specific). func ToJSON(data []byte) ([]byte, error) { - if hasJSONPrefix(data) { + if IsJSONBuffer(data) { return data, nil } return yaml.YAMLToJSON(data) @@ -102,7 +104,8 @@ func ToJSON(data []byte) ([]byte, error) { // separating individual documents. It first converts the YAML // body to JSON, then unmarshals the JSON. type YAMLToJSONDecoder struct { - reader Reader + reader Reader + inputOffset int } // NewYAMLToJSONDecoder decodes YAML documents from the provided @@ -121,7 +124,7 @@ func NewYAMLToJSONDecoder(r io.Reader) *YAMLToJSONDecoder { // yaml.Unmarshal. func (d *YAMLToJSONDecoder) Decode(into interface{}) error { bytes, err := d.reader.Read() - if err != nil && err != io.EOF { + if err != nil && err != io.EOF { //nolint:errorlint return err } @@ -131,9 +134,14 @@ func (d *YAMLToJSONDecoder) Decode(into interface{}) error { return YAMLSyntaxError{err} } } + d.inputOffset += len(bytes) return err } +func (d *YAMLToJSONDecoder) InputOffset() int { + return d.inputOffset +} + // YAMLDecoder reads chunks of objects and returns ErrShortBuffer if // the data is not sufficient. type YAMLDecoder struct { @@ -229,18 +237,20 @@ func splitYAMLDocument(data []byte, atEOF bool) (advance int, token []byte, err return 0, nil, nil } -// decoder is a convenience interface for Decode. -type decoder interface { - Decode(into interface{}) error -} - -// YAMLOrJSONDecoder attempts to decode a stream of JSON documents or -// YAML documents by sniffing for a leading { character. +// YAMLOrJSONDecoder attempts to decode a stream of JSON or YAML documents. +// While JSON is YAML, the way Go's JSON decode defines a multi-document stream +// is a series of JSON objects (e.g. {}{}), but YAML defines a multi-document +// stream as a series of documents separated by "---". +// +// This decoder will attempt to decode the stream as JSON first, and if that +// fails, it will switch to YAML. Once it determines the stream is JSON (by +// finding a non-YAML-delimited series of objects), it will not switch to YAML. +// Once it switches to YAML it will not switch back to JSON. type YAMLOrJSONDecoder struct { - r io.Reader - bufferSize int - - decoder decoder + json *json.Decoder + yaml *YAMLToJSONDecoder + stream *StreamReader + count int // how many objects have been decoded } type JSONSyntaxError struct { @@ -265,31 +275,108 @@ func (e YAMLSyntaxError) Error() string { // how far into the stream the decoder will look to figure out whether this // is a JSON stream (has whitespace followed by an open brace). func NewYAMLOrJSONDecoder(r io.Reader, bufferSize int) *YAMLOrJSONDecoder { - return &YAMLOrJSONDecoder{ - r: r, - bufferSize: bufferSize, + d := &YAMLOrJSONDecoder{} + + reader, _, mightBeJSON := GuessJSONStream(r, bufferSize) + d.stream = reader + if mightBeJSON { + d.json = json.NewDecoder(reader) + } else { + d.yaml = NewYAMLToJSONDecoder(reader) } + return d } // Decode unmarshals the next object from the underlying stream into the // provide object, or returns an error. func (d *YAMLOrJSONDecoder) Decode(into interface{}) error { - if d.decoder == nil { - buffer, _, isJSON := GuessJSONStream(d.r, d.bufferSize) - if isJSON { - d.decoder = json.NewDecoder(buffer) + // Because we don't know if this is a JSON or YAML stream, a failure from + // both decoders is ambiguous. When in doubt, it will return the error from + // the JSON decoder. Unfortunately, this means that if the first document + // is invalid YAML, the error won't be awesome. + // TODO: the errors from YAML are not great, we could improve them a lot. + var firstErr error + if d.json != nil { + err := d.json.Decode(into) + if err == nil { + d.stream.Consume(int(d.json.InputOffset()) - d.stream.Consumed()) + d.count++ + return nil + } + if err == io.EOF { //nolint:errorlint + return err + } + var syntax *json.SyntaxError + if ok := errors.As(err, &syntax); ok { + firstErr = JSONSyntaxError{ + Offset: syntax.Offset, + Err: syntax, + } } else { - d.decoder = NewYAMLToJSONDecoder(buffer) + firstErr = err + } + if d.count > 1 { + // If we found 0 or 1 JSON object(s), this stream is still + // ambiguous. But if we found more than 1 JSON object, then this + // is an unambiguous JSON stream, and we should not switch to YAML. + return err + } + // If JSON decoding hits the end of one object and then fails on the + // next, it leaves any leading whitespace in the buffer, which can + // confuse the YAML decoder. We just eat any whitespace we find, up to + // and including the first newline. + d.stream.Rewind() + if err := d.consumeWhitespace(); err == nil { + d.yaml = NewYAMLToJSONDecoder(d.stream) + } + d.json = nil + } + if d.yaml != nil { + err := d.yaml.Decode(into) + if err == nil { + d.stream.Consume(d.yaml.InputOffset() - d.stream.Consumed()) + d.count++ + return nil + } + if err == io.EOF { //nolint:errorlint + return err + } + if firstErr == nil { + firstErr = err } } - err := d.decoder.Decode(into) - if syntax, ok := err.(*json.SyntaxError); ok { - return JSONSyntaxError{ - Offset: syntax.Offset, - Err: syntax, + if firstErr != nil { + return firstErr + } + return fmt.Errorf("decoding failed as both JSON and YAML") +} + +func (d *YAMLOrJSONDecoder) consumeWhitespace() error { + consumed := 0 + for { + buf, err := d.stream.ReadN(4) + if err != nil && err == io.EOF { //nolint:errorlint + return err + } + r, sz := utf8.DecodeRune(buf) + if r == utf8.RuneError || sz == 0 { + return fmt.Errorf("invalid utf8 rune") + } + d.stream.RewindN(len(buf) - sz) + if !unicode.IsSpace(r) { + d.stream.RewindN(sz) + d.stream.Consume(consumed) + return nil + } + if r == '\n' { + d.stream.Consume(consumed) + return nil + } + if err == io.EOF { //nolint:errorlint + break } } - return err + return io.EOF } type Reader interface { @@ -311,7 +398,7 @@ func (r *YAMLReader) Read() ([]byte, error) { var buffer bytes.Buffer for { line, err := r.reader.Read() - if err != nil && err != io.EOF { + if err != nil && err != io.EOF { //nolint:errorlint return nil, err } @@ -329,11 +416,11 @@ func (r *YAMLReader) Read() ([]byte, error) { if buffer.Len() != 0 { return buffer.Bytes(), nil } - if err == io.EOF { + if err == io.EOF { //nolint:errorlint return nil, err } } - if err == io.EOF { + if err == io.EOF { //nolint:errorlint if buffer.Len() != 0 { // If we're at EOF, we have a final, non-terminated line. Return it. return buffer.Bytes(), nil @@ -369,26 +456,20 @@ func (r *LineReader) Read() ([]byte, error) { // GuessJSONStream scans the provided reader up to size, looking // for an open brace indicating this is JSON. It will return the // bufio.Reader it creates for the consumer. -func GuessJSONStream(r io.Reader, size int) (io.Reader, []byte, bool) { - buffer := bufio.NewReaderSize(r, size) +func GuessJSONStream(r io.Reader, size int) (*StreamReader, []byte, bool) { + buffer := NewStreamReader(r, size) b, _ := buffer.Peek(size) - return buffer, b, hasJSONPrefix(b) + return buffer, b, IsJSONBuffer(b) } // IsJSONBuffer scans the provided buffer, looking // for an open brace indicating this is JSON. func IsJSONBuffer(buf []byte) bool { - return hasJSONPrefix(buf) + return hasPrefix(buf, jsonPrefix) } var jsonPrefix = []byte("{") -// hasJSONPrefix returns true if the provided buffer appears to start with -// a JSON open brace. -func hasJSONPrefix(buf []byte) bool { - return hasPrefix(buf, jsonPrefix) -} - // Return true if the first non-whitespace bytes in buf is // prefix. func hasPrefix(buf []byte, prefix []byte) bool { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder_test.go b/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder_test.go index 844d7ff5b5e..a9d328b182d 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/yaml/decoder_test.go @@ -19,7 +19,6 @@ package yaml import ( "bufio" "bytes" - "encoding/json" "fmt" "io" "math/rand" @@ -105,7 +104,7 @@ stuff: 1 } b = make([]byte, 15) n, err = r.Read(b) - if err != io.EOF || n != 0 { + if err != io.EOF || n != 0 { //nolint:errorlint t.Fatalf("expected EOF: %d / %v", n, err) } } @@ -205,7 +204,7 @@ stuff: 1 t.Fatalf("unexpected object: %#v", obj) } obj = generic{} - if err := s.Decode(&obj); err != io.EOF { + if err := s.Decode(&obj); err != io.EOF { //nolint:errorlint t.Fatalf("unexpected error: %v", err) } } @@ -319,6 +318,11 @@ func TestYAMLOrJSONDecoder(t *testing.T) { {"foo": "bar"}, {"baz": "biz"}, }}, + // Spaces for indent, tabs are not allowed in YAML. + {"foo:\n field: bar\n---\nbaz:\n field: biz", 100, false, false, []generic{ + {"foo": map[string]any{"field": "bar"}}, + {"baz": map[string]any{"field": "biz"}}, + }}, {"foo: bar\n---\n", 100, false, false, []generic{ {"foo": "bar"}, }}, @@ -334,6 +338,38 @@ func TestYAMLOrJSONDecoder(t *testing.T) { {"foo: bar\n", 100, false, false, []generic{ {"foo": "bar"}, }}, + // First document is JSON, second is YAML + {"{\"foo\": \"bar\"}\n---\n{baz: biz}", 100, false, false, []generic{ + {"foo": "bar"}, + {"baz": "biz"}, + }}, + // First document is JSON, second is YAML, longer than the buffer + {"{\"foo\": \"bar\"}\n---\n{baz: biz0123456780123456780123456780123456780123456789}", 20, false, false, []generic{ + {"foo": "bar"}, + {"baz": "biz0123456780123456780123456780123456780123456789"}, + }}, + // First document is JSON, then whitespace, then YAML + {"{\"foo\": \"bar\"} \n---\n{baz: biz}", 100, false, false, []generic{ + {"foo": "bar"}, + {"baz": "biz"}, + }}, + // First document is YAML, second is JSON + {"{foo: bar}\n---\n{\"baz\": \"biz\"}", 100, false, false, []generic{ + {"foo": "bar"}, + {"baz": "biz"}, + }}, + // First document is JSON, second is YAML, using spaces + {"{\n \"foo\": \"bar\"\n}\n---\n{\n baz: biz\n}", 100, false, false, []generic{ + {"foo": "bar"}, + {"baz": "biz"}, + }}, + // First document is JSON, second is YAML, using tabs + {"{\n\t\"foo\": \"bar\"\n}\n---\n{\n\tbaz: biz\n}", 100, false, false, []generic{ + {"foo": "bar"}, + {"baz": "biz"}, + }}, + // First 2 documents are JSON, third is YAML (stream is JSON) + {"{\"foo\": \"bar\"}\n{\"baz\": \"biz\"}\n---\n{qux: zrb}", 100, true, true, nil}, } for i, testCase := range testCases { decoder := NewYAMLOrJSONDecoder(bytes.NewReader([]byte(testCase.input)), testCase.buffer) @@ -348,7 +384,7 @@ func TestYAMLOrJSONDecoder(t *testing.T) { } objs = append(objs, out) } - if err != io.EOF { + if err != io.EOF { //nolint:errorlint switch { case testCase.err && err == nil: t.Errorf("%d: unexpected non-error", i) @@ -360,12 +396,12 @@ func TestYAMLOrJSONDecoder(t *testing.T) { continue } } - switch decoder.decoder.(type) { - case *YAMLToJSONDecoder: + switch { + case decoder.yaml != nil: if testCase.isJSON { t.Errorf("%d: expected JSON decoder, got YAML", i) } - case *json.Decoder: + case decoder.json != nil: if !testCase.isJSON { t.Errorf("%d: expected YAML decoder, got JSON", i) } @@ -419,7 +455,7 @@ func testReadLines(t *testing.T, lineLengths []int) { var readLines [][]byte for range lines { bytes, err := lineReader.Read() - if err != nil && err != io.EOF { + if err != nil && err != io.EOF { //nolint:errorlint t.Fatalf("failed to read lines: %v", err) } readLines = append(readLines, bytes) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader.go b/staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader.go new file mode 100644 index 00000000000..d06991057f6 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader.go @@ -0,0 +1,130 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package yaml + +import "io" + +// StreamReader is a reader designed for consuming streams of variable-length +// messages. It buffers data until it is explicitly consumed, and can be +// rewound to re-read previous data. +type StreamReader struct { + r io.Reader + buf []byte + head int // current read offset into buf + ttlConsumed int // number of bytes which have been consumed +} + +// NewStreamReader creates a new StreamReader wrapping the provided +// io.Reader. +func NewStreamReader(r io.Reader, size int) *StreamReader { + if size == 0 { + size = 4096 + } + return &StreamReader{ + r: r, + buf: make([]byte, 0, size), // Start with a reasonable capacity + } +} + +// Read implements io.Reader. It first returns any buffered data after the +// current offset, and if that's exhausted, reads from the underlying reader +// and buffers the data. The returned data is not considered consumed until the +// Consume method is called. +func (r *StreamReader) Read(p []byte) (n int, err error) { + // If we have buffered data, return it + if r.head < len(r.buf) { + n = copy(p, r.buf[r.head:]) + r.head += n + return n, nil + } + + // If we've already hit EOF, return it + if r.r == nil { + return 0, io.EOF + } + + // Read from the underlying reader + n, err = r.r.Read(p) + if n > 0 { + r.buf = append(r.buf, p[:n]...) + r.head += n + } + if err == nil { + return n, nil + } + if err == io.EOF { + // Store that we've hit EOF by setting r to nil + r.r = nil + } + return n, err +} + +// ReadN reads exactly n bytes from the reader, blocking until all bytes are +// read or an error occurs. If an error occurs, the number of bytes read is +// returned along with the error. If EOF is hit before n bytes are read, this +// will return the bytes read so far, along with io.EOF. The returned data is +// not considered consumed until the Consume method is called. +func (r *StreamReader) ReadN(want int) ([]byte, error) { + ret := make([]byte, want) + off := 0 + for off < want { + n, err := r.Read(ret[off:]) + if err != nil { + return ret[:off+n], err + } + off += n + } + return ret, nil +} + +// Peek returns the next n bytes without advancing the reader. The returned +// bytes are valid until the next call to Consume. +func (r *StreamReader) Peek(n int) ([]byte, error) { + buf, err := r.ReadN(n) + r.RewindN(len(buf)) + if err != nil { + return buf, err + } + return buf, nil +} + +// Rewind resets the reader to the beginning of the buffered data. +func (r *StreamReader) Rewind() { + r.head = 0 +} + +// RewindN rewinds the reader by n bytes. If n is greater than the current +// buffer, the reader is rewound to the beginning of the buffer. +func (r *StreamReader) RewindN(n int) { + r.head -= min(n, r.head) +} + +// Consume discards up to n bytes of previously read data from the beginning of +// the buffer. Once consumed, that data is no longer available for rewinding. +// If n is greater than the current buffer, the buffer is cleared. Consume +// never consume data from the underlying reader. +func (r *StreamReader) Consume(n int) { + n = min(n, len(r.buf)) + r.buf = r.buf[n:] + r.head -= n + r.ttlConsumed += n +} + +// Consumed returns the number of bytes consumed from the input reader. +func (r *StreamReader) Consumed() int { + return r.ttlConsumed +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader_test.go b/staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader_test.go new file mode 100644 index 00000000000..b010610c47a --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/yaml/stream_reader_test.go @@ -0,0 +1,388 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package yaml + +import ( + "io" + "strings" + "testing" +) + +// srt = StreamReaderTest +type srtStep struct { + op string + expected string + err error + size int +} + +func srtRead(size int, expected string, err error) srtStep { + return srtStep{op: "Read", size: size, expected: expected, err: err} +} +func srtReadN(size int, expected string, err error) srtStep { + return srtStep{op: "ReadN", size: size, expected: expected, err: err} +} +func srtPeek(size int, expected string, err error) srtStep { + return srtStep{op: "Peek", size: size, expected: expected, err: err} +} +func srtRewind() srtStep { + return srtStep{op: "Rewind"} +} +func srtRewindN(size int) srtStep { + return srtStep{op: "RewindN", size: size} +} +func srtConsume(size int) srtStep { + return srtStep{op: "Consume", size: size} +} +func srtConsumed(exp int) srtStep { + return srtStep{op: "Consumed", size: exp} +} + +func srtRun(t *testing.T, reader *StreamReader, steps []srtStep) { + t.Helper() + + checkRead := func(i int, step srtStep, buf []byte, err error) { + t.Helper() + if err != nil && step.err == nil { + t.Errorf("step %d: unexpected error: %v", i, err) + } else if err == nil && step.err != nil { + t.Errorf("step %d: expected error %v", i, step.err) + } else if err != nil && err != step.err { //nolint:errorlint + t.Errorf("step %d: expected error %v, got %v", i, step.err, err) + } + if got := string(buf); got != step.expected { + t.Errorf("step %d: expected %q, got %q", i, step.expected, got) + } + } + + for i, step := range steps { + switch step.op { + case "Read": + buf := make([]byte, step.size) + n, err := reader.Read(buf) + buf = buf[:n] + checkRead(i, step, buf, err) + case "ReadN": + buf, err := reader.ReadN(step.size) + checkRead(i, step, buf, err) + case "Peek": + buf, err := reader.Peek(step.size) + checkRead(i, step, buf, err) + case "Rewind": + reader.Rewind() + case "RewindN": + reader.RewindN(step.size) + case "Consume": + reader.Consume(step.size) + case "Consumed": + if n := reader.Consumed(); n != step.size { + t.Errorf("step %d: expected %d consumed, got %d", i, step.size, n) + } + default: + t.Fatalf("step %d: unknown operation %q", i, step.op) + } + } +} + +func TestStreamReader_Read(t *testing.T) { + tests := []struct { + name string + input string + steps []srtStep + }{{ + name: "empty input", + input: "", + steps: []srtStep{ + srtRead(1, "", io.EOF), + srtRead(1, "", io.EOF), // still EOF + }, + }, { + name: "simple reads", + input: "0123456789", + steps: []srtStep{ + srtRead(5, "01234", nil), + srtRead(5, "56789", nil), + srtRead(1, "", io.EOF), + }, + }, { + name: "short read at EOF", + input: "0123456789", + steps: []srtStep{ + srtRead(8, "01234567", nil), + srtRead(8, "89", nil), // short read, no error + srtRead(1, "", io.EOF), + }, + }, { + name: "short reads from buffer", + input: "0123456789", + steps: []srtStep{ + srtRead(3, "012", nil), // fill buffer + srtRewind(), + srtRead(4, "012", nil), // short read from buffer + srtRewind(), + srtRead(4, "012", nil), // still short + srtRead(4, "3456", nil), // from reader + srtRewind(), + srtRead(10, "0123456", nil), // short read from buffer + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := NewStreamReader(strings.NewReader(tt.input), 4) // small initial buffer + srtRun(t, reader, tt.steps) + }) + } +} + +func TestStreamReader_Rewind(t *testing.T) { + tests := []struct { + name string + input string + steps []srtStep + }{{ + name: "simple read and rewind", + input: "0123456789", + steps: []srtStep{ + srtRead(4, "0123", nil), + srtRead(4, "4567", nil), + srtRead(4, "89", nil), + srtRead(1, "", io.EOF), + srtRewind(), + srtRead(4, "0123", nil), + srtRead(4, "4567", nil), + srtRead(4, "89", nil), + srtRead(1, "", io.EOF), + }, + }, { + name: "multiple rewinds", + input: "01234", + steps: []srtStep{ + srtRead(2, "01", nil), + srtRewind(), + srtRead(2, "01", nil), + srtRead(2, "23", nil), + srtRewind(), + srtRead(2, "01", nil), + srtRead(2, "23", nil), + srtRead(2, "4", nil), + srtRead(1, "", io.EOF), + srtRewind(), + srtRead(100, "01234", nil), + srtRead(1, "", io.EOF), + }, + }, { + name: "empty input", + input: "", + steps: []srtStep{ + srtRead(1, "", io.EOF), + srtRewind(), + srtRead(1, "", io.EOF), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := NewStreamReader(strings.NewReader(tt.input), 4) // small initial buffer + srtRun(t, reader, tt.steps) + }) + } +} + +func TestStreamReader_RewindN(t *testing.T) { + tests := []struct { + name string + input string + steps []srtStep + }{{ + name: "simple rewindn", + input: "0123456789", + steps: []srtStep{ + srtRead(4, "0123", nil), + srtRead(4, "4567", nil), + srtRead(4, "89", nil), + srtRead(1, "", io.EOF), + srtRewindN(4), + srtRead(2, "67", nil), + srtRewindN(4), + srtRead(10, "456789", nil), + srtRead(1, "", io.EOF), + }, + }, { + name: "empty input", + input: "", + steps: []srtStep{ + srtRead(1, "", io.EOF), + srtRewindN(100), + srtRead(1, "", io.EOF), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := NewStreamReader(strings.NewReader(tt.input), 4) // small initial buffer + srtRun(t, reader, tt.steps) + }) + } +} + +func TestStreamReader_Consume(t *testing.T) { + tests := []struct { + name string + input string + steps []srtStep + }{{ + name: "simple consume", + input: "0123456789", + steps: []srtStep{ + srtConsumed(0), + srtRead(4, "0123", nil), + srtRead(4, "4567", nil), + srtConsume(2), // drops 01 + srtConsumed(2), + srtRead(4, "89", nil), + srtRead(1, "", io.EOF), + srtRewind(), + srtRead(5, "23456", nil), + srtRead(5, "789", nil), + srtRead(1, "", io.EOF), + srtConsumed(2), + }, + }, { + name: "consume too much", + input: "01234", + steps: []srtStep{ + srtConsumed(0), + srtRead(5, "01234", nil), + srtConsume(5), + srtConsumed(5), + srtConsume(5), + srtConsumed(5), + srtRead(1, "", io.EOF), + srtConsume(5), + srtConsumed(5), + srtRead(1, "", io.EOF), + srtConsumed(5), + }, + }, { + name: "empty input", + input: "", + steps: []srtStep{ + srtConsumed(0), + srtConsume(5), + srtRead(1, "", io.EOF), + srtConsumed(0), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := NewStreamReader(strings.NewReader(tt.input), 4) // small initial buffer + srtRun(t, reader, tt.steps) + }) + } +} + +func TestStreamReader_ReadN(t *testing.T) { + tests := []struct { + name string + input string + steps []srtStep + }{{ + name: "short read full readN", + input: "0123456789", + steps: []srtStep{ + srtRead(3, "012", nil), // fill buffer + srtRewind(), + srtRead(5, "012", nil), // short read from buffer + srtRewind(), + srtReadN(5, "01234", nil), // full readN + srtRewind(), + srtRead(10, "01234", nil), // short read from buffer + srtRewind(), + srtReadN(10, "0123456789", nil), // full readN + srtRewind(), + srtRead(10, "0123456789", nil), // full read from buffer + srtRead(1, "", io.EOF), + }, + }, { + name: "short read consume readN", + input: "0123456789", + steps: []srtStep{ + srtRead(3, "012", nil), // fill buffer + srtRewind(), + srtRead(4, "012", nil), // short read from buffer + srtConsume(1), + srtRewind(), + srtRead(4, "12", nil), // short read from buffer + srtRewind(), + srtReadN(4, "1234", nil), // full read + srtConsume(1), + srtRewind(), + srtRead(4, "234", nil), // short read from buffer + srtRewind(), + srtReadN(10, "23456789", io.EOF), // short readN, EOF + srtRewind(), + srtRead(10, "23456789", nil), // full read from buffer + srtRead(1, "", io.EOF), + }, + }, { + name: "empty input", + input: "", + steps: []srtStep{ + srtReadN(1, "", io.EOF), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := NewStreamReader(strings.NewReader(tt.input), 4) // small initial buffer + srtRun(t, reader, tt.steps) + }) + } +} + +func TestStreamReader_Peek(t *testing.T) { + tests := []struct { + name string + input string + steps []srtStep + }{{ + name: "simple peek", + input: "0123456789", + steps: []srtStep{ + srtPeek(3, "012", nil), // fill buffer + srtRead(5, "012", nil), // short read from buffer + srtRewind(), + srtPeek(6, "012345", nil), // fill buffer + srtRead(10, "012345", nil), // short read from buffer + }, + }, { + name: "empty input", + input: "", + steps: []srtStep{ + srtPeek(1, "", io.EOF), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := NewStreamReader(strings.NewReader(tt.input), 0) + srtRun(t, reader, tt.steps) + }) + } +}