From 20fc22f46131a18c8599369b0612e28660ab6911 Mon Sep 17 00:00:00 2001 From: Tim Allclair Date: Tue, 10 Sep 2019 10:27:24 -0700 Subject: [PATCH] Add LimitWriter util --- pkg/kubelet/util/ioutils/ioutils.go | 33 +++++++++ pkg/kubelet/util/ioutils/ioutils_test.go | 93 ++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 pkg/kubelet/util/ioutils/ioutils_test.go diff --git a/pkg/kubelet/util/ioutils/ioutils.go b/pkg/kubelet/util/ioutils/ioutils.go index 42f1998c794..1b2b5a6d5dd 100644 --- a/pkg/kubelet/util/ioutils/ioutils.go +++ b/pkg/kubelet/util/ioutils/ioutils.go @@ -35,3 +35,36 @@ func (w *writeCloserWrapper) Close() error { func WriteCloserWrapper(w io.Writer) io.WriteCloser { return &writeCloserWrapper{w} } + +// LimitWriter is a copy of the standard library ioutils.LimitReader, +// applied to the writer interface. +// LimitWriter returns a Writer that writes to w +// but stops with EOF after n bytes. +// The underlying implementation is a *LimitedWriter. +func LimitWriter(w io.Writer, n int64) io.Writer { return &LimitedWriter{w, n} } + +// A LimitedWriter writes to W but limits the amount of +// data returned to just N bytes. Each call to Write +// updates N to reflect the new amount remaining. +// Write returns EOF when N <= 0 or when the underlying W returns EOF. +type LimitedWriter struct { + W io.Writer // underlying writer + N int64 // max bytes remaining +} + +func (l *LimitedWriter) Write(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, io.ErrShortWrite + } + truncated := false + if int64(len(p)) > l.N { + p = p[0:l.N] + truncated = true + } + n, err = l.W.Write(p) + l.N -= int64(n) + if err == nil && truncated { + err = io.ErrShortWrite + } + return +} diff --git a/pkg/kubelet/util/ioutils/ioutils_test.go b/pkg/kubelet/util/ioutils/ioutils_test.go new file mode 100644 index 00000000000..524a4aed67d --- /dev/null +++ b/pkg/kubelet/util/ioutils/ioutils_test.go @@ -0,0 +1,93 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ioutils + +import ( + "bytes" + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLimitWriter(t *testing.T) { + r := rand.New(rand.NewSource(1234)) // Fixed source to prevent flakes. + + tests := []struct { + inputSize, limit, writeSize int64 + }{ + // Single write tests + {100, 101, 100}, + {100, 100, 100}, + {100, 99, 100}, + {1, 1, 1}, + {100, 10, 100}, + {100, 0, 100}, + {100, -1, 100}, + // Multi write tests + {100, 101, 10}, + {100, 100, 10}, + {100, 99, 10}, + {100, 10, 10}, + {100, 0, 10}, + {100, -1, 10}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("inputSize=%d limit=%d writes=%d", test.inputSize, test.limit, test.writeSize), func(t *testing.T) { + input := make([]byte, test.inputSize) + r.Read(input) + output := &bytes.Buffer{} + w := LimitWriter(output, test.limit) + + var ( + err error + written int64 + n int + ) + for written < test.inputSize && err == nil { + n, err = w.Write(input[written : written+test.writeSize]) + written += int64(n) + } + + expectWritten := bounded(0, test.inputSize, test.limit) + assert.EqualValues(t, expectWritten, written) + if expectWritten <= 0 { + assert.Empty(t, output) + } else { + assert.Equal(t, input[:expectWritten], output.Bytes()) + } + + if test.limit < test.inputSize { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func bounded(min, val, max int64) int64 { + if max < val { + val = max + } + if val < min { + val = min + } + return val +}