Fix headerInterceptingConn handling

This commit is contained in:
Jordan Liggitt 2024-03-02 17:57:39 -05:00 committed by Sean Sullivan
parent 9e15462843
commit 2443b3fa69
2 changed files with 246 additions and 58 deletions

View File

@ -21,7 +21,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -203,19 +202,29 @@ type headerInterceptingConn struct {
// and initializableConn#InitializeWrite() has been called with the result. // and initializableConn#InitializeWrite() has been called with the result.
initializableConn initializableConn
lock sync.Mutex lock sync.Mutex
headerBuffer []byte headerBuffer []byte
initialized bool initialized bool
initializeErr error
} }
// initializableConn is a connection that will be initialized before any calls to Write are made // initializableConn is a connection that will be initialized before any calls to Write are made
type initializableConn interface { type initializableConn interface {
net.Conn net.Conn
InitializeWrite(backendResponse *http.Response) error // InitializeWrite is called when the backend response headers have been read.
// backendResponse contains the parsed headers.
// backendResponseBytes are the raw bytes the headers were parsed from.
InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error
} }
const maxHeaderBytes = 1 << 20 const maxHeaderBytes = 1 << 20
// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent)
var lfCRLF = []byte("\n\r\n")
// token for header / body separation without \r (which go tolerates)
var lfLF = []byte("\n\n")
// Write intercepts to initially swallow the HTTP response, then // Write intercepts to initially swallow the HTTP response, then
// delegate to the tunneling "net.Conn" once the response has been // delegate to the tunneling "net.Conn" once the response has been
// seen and processed. // seen and processed.
@ -223,39 +232,51 @@ func (h *headerInterceptingConn) Write(b []byte) (int, error) {
h.lock.Lock() h.lock.Lock()
defer h.lock.Unlock() defer h.lock.Unlock()
if h.initializeErr != nil {
return 0, h.initializeErr
}
if h.initialized { if h.initialized {
return h.initializableConn.Write(b) return h.initializableConn.Write(b)
} }
// Write into the headerBuffer, then attempt to parse the bytes // Guard against excessive buffering
// as an http response.
if len(h.headerBuffer)+len(b) > maxHeaderBytes { if len(h.headerBuffer)+len(b) > maxHeaderBytes {
return 0, fmt.Errorf("header size limit exceeded") return 0, fmt.Errorf("header size limit exceeded")
} }
// Accumulate into headerBuffer
h.headerBuffer = append(h.headerBuffer, b...) h.headerBuffer = append(h.headerBuffer, b...)
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
resp, err := http.ReadResponse(bufferedReader, nil) // Attempt to parse http response headers
if errors.Is(err, io.ErrUnexpectedEOF) { var headerBytes, bodyBytes []byte
// don't yet have a complete set of headers if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 {
// headers terminated with \n\r\n
headerBytes = h.headerBuffer[0 : i+len(lfCRLF)]
bodyBytes = h.headerBuffer[i+len(lfCRLF):]
} else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 {
// headers terminated with \n\n (which go tolerates)
headerBytes = h.headerBuffer[0 : i+len(lfLF)]
bodyBytes = h.headerBuffer[i+len(lfLF):]
} else {
// don't yet have a complete set of headers yet
return len(b), nil return len(b), nil
} }
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
if err != nil { if err != nil {
klog.Errorf("invalid headers: %v", err) klog.Errorf("invalid headers: %v", err)
h.initializeErr = err
return len(b), err return len(b), err
} }
resp.Body.Close() //nolint:errcheck resp.Body.Close() //nolint:errcheck
h.headerBuffer = nil h.headerBuffer = nil
err = h.initializableConn.InitializeWrite(resp)
h.initialized = true h.initialized = true
if err != nil { h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes)
return len(b), err if h.initializeErr != nil {
return len(b), h.initializeErr
} }
if len(bodyBytes) > 0 {
// Copy any remaining buffered data to the underlying conn _, err = h.initializableConn.Write(bodyBytes)
remainingBuffer, _ := io.ReadAll(bufferedReader)
if len(remainingBuffer) > 0 {
_, err = h.initializableConn.Write(remainingBuffer)
} }
return len(b), err return len(b), err
} }
@ -274,7 +295,7 @@ type tunnelingWebsocketUpgraderConn struct {
err error err error
} }
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) { func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) {
// make sure we close a connection we open in error cases // make sure we close a connection we open in error cases
var conn net.Conn var conn net.Conn
defer func() { defer func() {
@ -337,9 +358,9 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
u.err = err u.err = err
return u.err return u.err
} }
// replay the backend response to the hijacked conn // replay the backend response bytes to the hijacked conn
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
err = backendResponse.Write(conn) _, err = conn.Write(backendResponseBytes)
if err != nil { if err != nil {
u.err = err u.err = err
return u.err return u.err

View File

@ -24,6 +24,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"time" "time"
@ -106,56 +107,223 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
assert.Equal(t, randomData, actual, "error validating tunneled random data") assert.Equal(t, randomData, actual, "error validating tunneled random data")
} }
const responseStr = `HTTP/1.1 101 Switching Protocols var expectedContentLengthHeaders = http.Header{
Date: Sun, 25 Feb 2024 08:09:25 GMT "Content-Length": []string{"25"},
X-App-Protocol: portforward.k8s.io "Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
"Split-Point": []string{"split"},
"X-App-Protocol": []string{"portforward.k8s.io"},
}
` const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" +
"Content-Length: 25\r\n" +
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
"Split-Point: split\r\n" +
"X-App-Protocol: portforward.k8s.io\r\n" +
"\r\n"
const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols const contentLengthBody = "0123456789split0123456789"
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
This is extra data. var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody
`
const invalidResponseStr = `INVALID/1.1 101 Switching Protocols var expectedResponseHeaders = http.Header{
Date: Sun, 25 Feb 2024 08:09:25 GMT "Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
X-App-Protocol: portforward.k8s.io "Split-Point": []string{"split"},
"X-App-Protocol": []string{"portforward.k8s.io"},
}
` const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" +
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
"Split-Point: split\r\n" +
"X-App-Protocol: portforward.k8s.io\r\n" +
"\r\n"
const responseBody = "This is extra split data.\n"
var responseHeadersAndBody = responseHeaders + responseBody
const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" +
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
"Split-Point: split\r\n" +
"X-App-Protocol: portforward.k8s.io\r\n" +
"\r\n"
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) { func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
// Basic http response is intercepted correctly; no extra data sent to net.Conn. // Basic http response is intercepted correctly; no extra data sent to net.Conn.
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}} t.Run("simple-no-body", func(t *testing.T) {
hic := &headerInterceptingConn{initializableConn: testConnConstructor} testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
_, err := hic.Write([]byte(responseStr)) hic := &headerInterceptingConn{initializableConn: testConnConstructor}
require.NoError(t, err) _, err := hic.Write([]byte(responseHeaders))
assert.True(t, hic.initialized, "successfully parsed http response headers") require.NoError(t, err)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol")) assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn") assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, "", string(testConnConstructor.mockConn.written))
})
// Extra data after response headers should be sent to net.Conn. // Extra data after response headers should be sent to net.Conn.
hic = &headerInterceptingConn{initializableConn: testConnConstructor} t.Run("simple-single-write", func(t *testing.T) {
_, err = hic.Write([]byte(responseWithExtraStr)) testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
require.NoError(t, err) hic := &headerInterceptingConn{initializableConn: testConnConstructor}
assert.True(t, hic.initialized) _, err := hic.Write([]byte(responseHeadersAndBody))
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) require.NoError(t, err)
assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn") assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Partially written headers are buffered and decoded
t.Run("simple-byte-by-byte", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
// write one byte at a time
for _, b := range []byte(responseHeadersAndBody) {
_, err := hic.Write([]byte{b})
require.NoError(t, err)
}
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Writes spanning the header/body breakpoint are buffered and decoded
t.Run("simple-span-headerbody", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
// write one chunk at a time
for i, chunk := range strings.Split(responseHeadersAndBody, "split") {
if i > 0 {
n, err := hic.Write([]byte("split"))
require.Equal(t, n, len("split"))
require.NoError(t, err)
}
n, err := hic.Write([]byte(chunk))
require.Equal(t, n, len(chunk))
require.NoError(t, err)
}
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn.
t.Run("simple-tolerate-lf", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", "")))
require.NoError(t, err)
assert.True(t, hic.initialized)
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite")
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
})
// Content-Length handling
t.Run("content-length-body", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(contentLengthHeadersAndBody))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Content-Length separately written headers and body
t.Run("content-length-headers-body", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(contentLengthHeaders))
require.NoError(t, err)
_, err = hic.Write([]byte(contentLengthBody))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Content-Length accumulating byte-by-byte
t.Run("content-length-byte-by-byte", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
for _, b := range []byte(contentLengthHeadersAndBody) {
_, err := hic.Write([]byte{b})
require.NoError(t, err)
}
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Content-Length writes spanning headers / body
t.Run("content-length-span-headerbody", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
// write one chunk at a time
for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") {
if i > 0 {
n, err := hic.Write([]byte("split"))
require.Equal(t, n, len("split"))
require.NoError(t, err)
}
n, err := hic.Write([]byte(chunk))
require.Equal(t, n, len(chunk))
require.NoError(t, err)
}
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
})
// Invalid response returns error. // Invalid response returns error.
hic = &headerInterceptingConn{initializableConn: testConnConstructor} t.Run("invalid-single-write", func(t *testing.T) {
_, err = hic.Write([]byte(invalidResponseStr)) testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
assert.Error(t, err, "expected error from invalid http response") hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(invalidResponseData))
assert.Error(t, err, "expected error from invalid http response")
})
// Invalid response written byte by byte returns error.
t.Run("invalid-byte-by-byte", func(t *testing.T) {
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
var err error
for _, b := range []byte(invalidResponseData) {
_, err = hic.Write([]byte{b})
if err != nil {
break
}
}
assert.Error(t, err, "expected error from invalid http response")
})
} }
type mockConnInitializer struct { type mockConnInitializer struct {
resp *http.Response resp *http.Response
initializeWriteConn *mockConn
*mockConn *mockConn
} }
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error { func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error {
m.resp = backendResponse m.resp = backendResponse
return nil _, err := m.initializeWriteConn.Write(backendResponseBytes)
return err
} }
// mockConn implements "net.Conn" interface. // mockConn implements "net.Conn" interface.
@ -166,9 +334,8 @@ type mockConn struct {
} }
func (mc *mockConn) Write(p []byte) (int, error) { func (mc *mockConn) Write(p []byte) (int, error) {
mc.written = make([]byte, len(p)) mc.written = append(mc.written, p...)
copy(mc.written, p) return len(p), nil
return len(mc.written), nil
} }
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil } func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }