Fix headerInterceptingConn handling

This commit is contained in:
Jordan Liggitt 2024-03-02 17:57:39 -05:00 committed by Sean Sullivan
parent 9e15462843
commit 2443b3fa69
2 changed files with 246 additions and 58 deletions

View File

@ -21,7 +21,6 @@ import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
@ -203,19 +202,29 @@ type headerInterceptingConn struct {
// and initializableConn#InitializeWrite() has been called with the result.
initializableConn
lock sync.Mutex
headerBuffer []byte
initialized bool
lock sync.Mutex
headerBuffer []byte
initialized bool
initializeErr error
}
// initializableConn is a connection that will be initialized before any calls to Write are made
type initializableConn interface {
net.Conn
InitializeWrite(backendResponse *http.Response) error
// InitializeWrite is called when the backend response headers have been read.
// backendResponse contains the parsed headers.
// backendResponseBytes are the raw bytes the headers were parsed from.
InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error
}
const maxHeaderBytes = 1 << 20
// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent)
var lfCRLF = []byte("\n\r\n")
// token for header / body separation without \r (which go tolerates)
var lfLF = []byte("\n\n")
// Write intercepts to initially swallow the HTTP response, then
// delegate to the tunneling "net.Conn" once the response has been
// seen and processed.
@ -223,39 +232,51 @@ func (h *headerInterceptingConn) Write(b []byte) (int, error) {
h.lock.Lock()
defer h.lock.Unlock()
if h.initializeErr != nil {
return 0, h.initializeErr
}
if h.initialized {
return h.initializableConn.Write(b)
}
// Write into the headerBuffer, then attempt to parse the bytes
// as an http response.
// Guard against excessive buffering
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
return 0, fmt.Errorf("header size limit exceeded")
}
// Accumulate into headerBuffer
h.headerBuffer = append(h.headerBuffer, b...)
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
resp, err := http.ReadResponse(bufferedReader, nil)
if errors.Is(err, io.ErrUnexpectedEOF) {
// don't yet have a complete set of headers
// Attempt to parse http response headers
var headerBytes, bodyBytes []byte
if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 {
// headers terminated with \n\r\n
headerBytes = h.headerBuffer[0 : i+len(lfCRLF)]
bodyBytes = h.headerBuffer[i+len(lfCRLF):]
} else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 {
// headers terminated with \n\n (which go tolerates)
headerBytes = h.headerBuffer[0 : i+len(lfLF)]
bodyBytes = h.headerBuffer[i+len(lfLF):]
} else {
// don't yet have a complete set of headers yet
return len(b), nil
}
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
if err != nil {
klog.Errorf("invalid headers: %v", err)
h.initializeErr = err
return len(b), err
}
resp.Body.Close() //nolint:errcheck
h.headerBuffer = nil
err = h.initializableConn.InitializeWrite(resp)
h.initialized = true
if err != nil {
return len(b), err
h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes)
if h.initializeErr != nil {
return len(b), h.initializeErr
}
// Copy any remaining buffered data to the underlying conn
remainingBuffer, _ := io.ReadAll(bufferedReader)
if len(remainingBuffer) > 0 {
_, err = h.initializableConn.Write(remainingBuffer)
if len(bodyBytes) > 0 {
_, err = h.initializableConn.Write(bodyBytes)
}
return len(b), err
}
@ -274,7 +295,7 @@ type tunnelingWebsocketUpgraderConn struct {
err error
}
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) {
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) {
// make sure we close a connection we open in error cases
var conn net.Conn
defer func() {
@ -337,9 +358,9 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
u.err = err
return u.err
}
// replay the backend response to the hijacked conn
// replay the backend response bytes to the hijacked conn
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
err = backendResponse.Write(conn)
_, err = conn.Write(backendResponseBytes)
if err != nil {
u.err = err
return u.err

View File

@ -24,6 +24,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@ -106,56 +107,223 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
assert.Equal(t, randomData, actual, "error validating tunneled random data")
}
const responseStr = `HTTP/1.1 101 Switching Protocols
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
var expectedContentLengthHeaders = http.Header{
"Content-Length": []string{"25"},
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
"Split-Point": []string{"split"},
"X-App-Protocol": []string{"portforward.k8s.io"},
}
`
const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" +
"Content-Length: 25\r\n" +
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
"Split-Point: split\r\n" +
"X-App-Protocol: portforward.k8s.io\r\n" +
"\r\n"
const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
const contentLengthBody = "0123456789split0123456789"
This is extra data.
`
var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody
const invalidResponseStr = `INVALID/1.1 101 Switching Protocols
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
var expectedResponseHeaders = http.Header{
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
"Split-Point": []string{"split"},
"X-App-Protocol": []string{"portforward.k8s.io"},
}
`
const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" +
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
"Split-Point: split\r\n" +
"X-App-Protocol: portforward.k8s.io\r\n" +
"\r\n"
const responseBody = "This is extra split data.\n"
var responseHeadersAndBody = responseHeaders + responseBody
const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" +
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
"Split-Point: split\r\n" +
"X-App-Protocol: portforward.k8s.io\r\n" +
"\r\n"
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(responseStr))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn")
t.Run("simple-no-body", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(responseHeaders))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, "", string(testConnConstructor.mockConn.written))
})
// Extra data after response headers should be sent to net.Conn.
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
_, err = hic.Write([]byte(responseWithExtraStr))
require.NoError(t, err)
assert.True(t, hic.initialized)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
t.Run("simple-single-write", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(responseHeadersAndBody))
require.NoError(t, err)
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Partially written headers are buffered and decoded
t.Run("simple-byte-by-byte", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
// write one byte at a time
for _, b := range []byte(responseHeadersAndBody) {
_, err := hic.Write([]byte{b})
require.NoError(t, err)
}
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Writes spanning the header/body breakpoint are buffered and decoded
t.Run("simple-span-headerbody", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
// write one chunk at a time
for i, chunk := range strings.Split(responseHeadersAndBody, "split") {
if i > 0 {
n, err := hic.Write([]byte("split"))
require.Equal(t, n, len("split"))
require.NoError(t, err)
}
n, err := hic.Write([]byte(chunk))
require.Equal(t, n, len(chunk))
require.NoError(t, err)
}
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn.
t.Run("simple-tolerate-lf", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", "")))
require.NoError(t, err)
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Content-Length handling
t.Run("content-length-body", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(contentLengthHeadersAndBody))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Content-Length separately written headers and body
t.Run("content-length-headers-body", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(contentLengthHeaders))
require.NoError(t, err)
_, err = hic.Write([]byte(contentLengthBody))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Content-Length accumulating byte-by-byte
t.Run("content-length-byte-by-byte", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
for _, b := range []byte(contentLengthHeadersAndBody) {
_, err := hic.Write([]byte{b})
require.NoError(t, err)
}
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Content-Length writes spanning headers / body
t.Run("content-length-span-headerbody", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
// write one chunk at a time
for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") {
if i > 0 {
n, err := hic.Write([]byte("split"))
require.Equal(t, n, len("split"))
require.NoError(t, err)
}
n, err := hic.Write([]byte(chunk))
require.Equal(t, n, len(chunk))
require.NoError(t, err)
}
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Invalid response returns error.
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
_, err = hic.Write([]byte(invalidResponseStr))
assert.Error(t, err, "expected error from invalid http response")
t.Run("invalid-single-write", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(invalidResponseData))
assert.Error(t, err, "expected error from invalid http response")
})
// Invalid response written byte by byte returns error.
t.Run("invalid-byte-by-byte", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
var err error
for _, b := range []byte(invalidResponseData) {
_, err = hic.Write([]byte{b})
if err != nil {
break
}
}
assert.Error(t, err, "expected error from invalid http response")
})
}
type mockConnInitializer struct {
resp *http.Response
resp *http.Response
initializeWriteConn *mockConn
*mockConn
}
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error {
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error {
m.resp = backendResponse
return nil
_, err := m.initializeWriteConn.Write(backendResponseBytes)
return err
}
// mockConn implements "net.Conn" interface.
@ -166,9 +334,8 @@ type mockConn struct {
}
func (mc *mockConn) Write(p []byte) (int, error) {
mc.written = make([]byte, len(p))
copy(mc.written, p)
return len(mc.written), nil
mc.written = append(mc.written, p...)
return len(p), nil
}
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }