diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go index c38a2ad604b..8bb6f7b04a4 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go @@ -21,7 +21,6 @@ import ( "bytes" "errors" "fmt" - "io" "net" "net/http" "strings" @@ -203,19 +202,29 @@ type headerInterceptingConn struct { // and initializableConn#InitializeWrite() has been called with the result. initializableConn - lock sync.Mutex - headerBuffer []byte - initialized bool + lock sync.Mutex + headerBuffer []byte + initialized bool + initializeErr error } // initializableConn is a connection that will be initialized before any calls to Write are made type initializableConn interface { net.Conn - InitializeWrite(backendResponse *http.Response) error + // InitializeWrite is called when the backend response headers have been read. + // backendResponse contains the parsed headers. + // backendResponseBytes are the raw bytes the headers were parsed from. + InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error } const maxHeaderBytes = 1 << 20 +// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent) +var lfCRLF = []byte("\n\r\n") + +// token for header / body separation without \r (which go tolerates) +var lfLF = []byte("\n\n") + // Write intercepts to initially swallow the HTTP response, then // delegate to the tunneling "net.Conn" once the response has been // seen and processed. @@ -223,39 +232,51 @@ func (h *headerInterceptingConn) Write(b []byte) (int, error) { h.lock.Lock() defer h.lock.Unlock() + if h.initializeErr != nil { + return 0, h.initializeErr + } if h.initialized { return h.initializableConn.Write(b) } - // Write into the headerBuffer, then attempt to parse the bytes - // as an http response. + // Guard against excessive buffering if len(h.headerBuffer)+len(b) > maxHeaderBytes { return 0, fmt.Errorf("header size limit exceeded") } + + // Accumulate into headerBuffer h.headerBuffer = append(h.headerBuffer, b...) - bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer)) - resp, err := http.ReadResponse(bufferedReader, nil) - if errors.Is(err, io.ErrUnexpectedEOF) { - // don't yet have a complete set of headers + + // Attempt to parse http response headers + var headerBytes, bodyBytes []byte + if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 { + // headers terminated with \n\r\n + headerBytes = h.headerBuffer[0 : i+len(lfCRLF)] + bodyBytes = h.headerBuffer[i+len(lfCRLF):] + } else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 { + // headers terminated with \n\n (which go tolerates) + headerBytes = h.headerBuffer[0 : i+len(lfLF)] + bodyBytes = h.headerBuffer[i+len(lfLF):] + } else { + // don't yet have a complete set of headers yet return len(b), nil } + resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil) if err != nil { klog.Errorf("invalid headers: %v", err) + h.initializeErr = err return len(b), err } resp.Body.Close() //nolint:errcheck h.headerBuffer = nil - err = h.initializableConn.InitializeWrite(resp) h.initialized = true - if err != nil { - return len(b), err + h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes) + if h.initializeErr != nil { + return len(b), h.initializeErr } - - // Copy any remaining buffered data to the underlying conn - remainingBuffer, _ := io.ReadAll(bufferedReader) - if len(remainingBuffer) > 0 { - _, err = h.initializableConn.Write(remainingBuffer) + if len(bodyBytes) > 0 { + _, err = h.initializableConn.Write(bodyBytes) } return len(b), err } @@ -274,7 +295,7 @@ type tunnelingWebsocketUpgraderConn struct { err error } -func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) { +func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) { // make sure we close a connection we open in error cases var conn net.Conn defer func() { @@ -337,9 +358,9 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R u.err = err return u.err } - // replay the backend response to the hijacked conn + // replay the backend response bytes to the hijacked conn conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck - err = backendResponse.Write(conn) + _, err = conn.Write(backendResponseBytes) if err != nil { u.err = err return u.err diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go index 858ad6c8f87..b9645588b77 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go @@ -24,6 +24,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" @@ -106,56 +107,223 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { assert.Equal(t, randomData, actual, "error validating tunneled random data") } -const responseStr = `HTTP/1.1 101 Switching Protocols -Date: Sun, 25 Feb 2024 08:09:25 GMT -X-App-Protocol: portforward.k8s.io +var expectedContentLengthHeaders = http.Header{ + "Content-Length": []string{"25"}, + "Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"}, + "Split-Point": []string{"split"}, + "X-App-Protocol": []string{"portforward.k8s.io"}, +} -` +const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" + + "Content-Length: 25\r\n" + + "Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" + + "Split-Point: split\r\n" + + "X-App-Protocol: portforward.k8s.io\r\n" + + "\r\n" -const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols -Date: Sun, 25 Feb 2024 08:09:25 GMT -X-App-Protocol: portforward.k8s.io +const contentLengthBody = "0123456789split0123456789" -This is extra data. -` +var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody -const invalidResponseStr = `INVALID/1.1 101 Switching Protocols -Date: Sun, 25 Feb 2024 08:09:25 GMT -X-App-Protocol: portforward.k8s.io +var expectedResponseHeaders = http.Header{ + "Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"}, + "Split-Point": []string{"split"}, + "X-App-Protocol": []string{"portforward.k8s.io"}, +} -` +const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" + + "Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" + + "Split-Point: split\r\n" + + "X-App-Protocol: portforward.k8s.io\r\n" + + "\r\n" + +const responseBody = "This is extra split data.\n" + +var responseHeadersAndBody = responseHeaders + responseBody + +const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" + + "Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" + + "Split-Point: split\r\n" + + "X-App-Protocol: portforward.k8s.io\r\n" + + "\r\n" func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) { // Basic http response is intercepted correctly; no extra data sent to net.Conn. - testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}} - hic := &headerInterceptingConn{initializableConn: testConnConstructor} - _, err := hic.Write([]byte(responseStr)) - require.NoError(t, err) - assert.True(t, hic.initialized, "successfully parsed http response headers") - assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) - assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol")) - assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn") + t.Run("simple-no-body", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(responseHeaders)) + require.NoError(t, err) + assert.True(t, hic.initialized, "successfully parsed http response headers") + assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol")) + assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite") + assert.Equal(t, "", string(testConnConstructor.mockConn.written)) + }) + // Extra data after response headers should be sent to net.Conn. - hic = &headerInterceptingConn{initializableConn: testConnConstructor} - _, err = hic.Write([]byte(responseWithExtraStr)) - require.NoError(t, err) - assert.True(t, hic.initialized) - assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) - assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn") + t.Run("simple-single-write", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(responseHeadersAndBody)) + require.NoError(t, err) + assert.True(t, hic.initialized) + assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite") + assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn") + }) + + // Partially written headers are buffered and decoded + t.Run("simple-byte-by-byte", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + // write one byte at a time + for _, b := range []byte(responseHeadersAndBody) { + _, err := hic.Write([]byte{b}) + require.NoError(t, err) + } + assert.True(t, hic.initialized) + assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite") + assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn") + }) + + // Writes spanning the header/body breakpoint are buffered and decoded + t.Run("simple-span-headerbody", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + // write one chunk at a time + for i, chunk := range strings.Split(responseHeadersAndBody, "split") { + if i > 0 { + n, err := hic.Write([]byte("split")) + require.Equal(t, n, len("split")) + require.NoError(t, err) + } + n, err := hic.Write([]byte(chunk)) + require.Equal(t, n, len(chunk)) + require.NoError(t, err) + } + assert.True(t, hic.initialized) + assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite") + assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn") + }) + + // Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn. + t.Run("simple-tolerate-lf", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", ""))) + require.NoError(t, err) + assert.True(t, hic.initialized) + assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite") + assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn") + }) + + // Content-Length handling + t.Run("content-length-body", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(contentLengthHeadersAndBody)) + require.NoError(t, err) + assert.True(t, hic.initialized, "successfully parsed http response headers") + assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "400 Error", testConnConstructor.resp.Status) + assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite") + assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written)) + }) + + // Content-Length separately written headers and body + t.Run("content-length-headers-body", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(contentLengthHeaders)) + require.NoError(t, err) + _, err = hic.Write([]byte(contentLengthBody)) + require.NoError(t, err) + assert.True(t, hic.initialized, "successfully parsed http response headers") + assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "400 Error", testConnConstructor.resp.Status) + assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite") + assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written)) + }) + + // Content-Length accumulating byte-by-byte + t.Run("content-length-byte-by-byte", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + for _, b := range []byte(contentLengthHeadersAndBody) { + _, err := hic.Write([]byte{b}) + require.NoError(t, err) + } + assert.True(t, hic.initialized, "successfully parsed http response headers") + assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "400 Error", testConnConstructor.resp.Status) + assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite") + assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written)) + }) + + // Content-Length writes spanning headers / body + t.Run("content-length-span-headerbody", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + // write one chunk at a time + for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") { + if i > 0 { + n, err := hic.Write([]byte("split")) + require.Equal(t, n, len("split")) + require.NoError(t, err) + } + n, err := hic.Write([]byte(chunk)) + require.Equal(t, n, len(chunk)) + require.NoError(t, err) + } + assert.True(t, hic.initialized, "successfully parsed http response headers") + assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header) + assert.Equal(t, "400 Error", testConnConstructor.resp.Status) + assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite") + assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written)) + }) + // Invalid response returns error. - hic = &headerInterceptingConn{initializableConn: testConnConstructor} - _, err = hic.Write([]byte(invalidResponseStr)) - assert.Error(t, err, "expected error from invalid http response") + t.Run("invalid-single-write", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(invalidResponseData)) + assert.Error(t, err, "expected error from invalid http response") + }) + + // Invalid response written byte by byte returns error. + t.Run("invalid-byte-by-byte", func(t *testing.T) { + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + var err error + for _, b := range []byte(invalidResponseData) { + _, err = hic.Write([]byte{b}) + if err != nil { + break + } + } + assert.Error(t, err, "expected error from invalid http response") + }) } type mockConnInitializer struct { - resp *http.Response + resp *http.Response + initializeWriteConn *mockConn *mockConn } -func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error { +func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error { m.resp = backendResponse - return nil + _, err := m.initializeWriteConn.Write(backendResponseBytes) + return err } // mockConn implements "net.Conn" interface. @@ -166,9 +334,8 @@ type mockConn struct { } func (mc *mockConn) Write(p []byte) (int, error) { - mc.written = make([]byte, len(p)) - copy(mc.written, p) - return len(mc.written), nil + mc.written = append(mc.written, p...) + return len(p), nil } func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }