mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-07 19:23:40 +00:00
Fix headerInterceptingConn handling
This commit is contained in:
parent
9e15462843
commit
2443b3fa69
@ -21,7 +21,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -203,19 +202,29 @@ type headerInterceptingConn struct {
|
||||
// and initializableConn#InitializeWrite() has been called with the result.
|
||||
initializableConn
|
||||
|
||||
lock sync.Mutex
|
||||
headerBuffer []byte
|
||||
initialized bool
|
||||
lock sync.Mutex
|
||||
headerBuffer []byte
|
||||
initialized bool
|
||||
initializeErr error
|
||||
}
|
||||
|
||||
// initializableConn is a connection that will be initialized before any calls to Write are made
|
||||
type initializableConn interface {
|
||||
net.Conn
|
||||
InitializeWrite(backendResponse *http.Response) error
|
||||
// InitializeWrite is called when the backend response headers have been read.
|
||||
// backendResponse contains the parsed headers.
|
||||
// backendResponseBytes are the raw bytes the headers were parsed from.
|
||||
InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error
|
||||
}
|
||||
|
||||
const maxHeaderBytes = 1 << 20
|
||||
|
||||
// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent)
|
||||
var lfCRLF = []byte("\n\r\n")
|
||||
|
||||
// token for header / body separation without \r (which go tolerates)
|
||||
var lfLF = []byte("\n\n")
|
||||
|
||||
// Write intercepts to initially swallow the HTTP response, then
|
||||
// delegate to the tunneling "net.Conn" once the response has been
|
||||
// seen and processed.
|
||||
@ -223,39 +232,51 @@ func (h *headerInterceptingConn) Write(b []byte) (int, error) {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
if h.initializeErr != nil {
|
||||
return 0, h.initializeErr
|
||||
}
|
||||
if h.initialized {
|
||||
return h.initializableConn.Write(b)
|
||||
}
|
||||
|
||||
// Write into the headerBuffer, then attempt to parse the bytes
|
||||
// as an http response.
|
||||
// Guard against excessive buffering
|
||||
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
|
||||
return 0, fmt.Errorf("header size limit exceeded")
|
||||
}
|
||||
|
||||
// Accumulate into headerBuffer
|
||||
h.headerBuffer = append(h.headerBuffer, b...)
|
||||
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
|
||||
resp, err := http.ReadResponse(bufferedReader, nil)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// don't yet have a complete set of headers
|
||||
|
||||
// Attempt to parse http response headers
|
||||
var headerBytes, bodyBytes []byte
|
||||
if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 {
|
||||
// headers terminated with \n\r\n
|
||||
headerBytes = h.headerBuffer[0 : i+len(lfCRLF)]
|
||||
bodyBytes = h.headerBuffer[i+len(lfCRLF):]
|
||||
} else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 {
|
||||
// headers terminated with \n\n (which go tolerates)
|
||||
headerBytes = h.headerBuffer[0 : i+len(lfLF)]
|
||||
bodyBytes = h.headerBuffer[i+len(lfLF):]
|
||||
} else {
|
||||
// don't yet have a complete set of headers yet
|
||||
return len(b), nil
|
||||
}
|
||||
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
|
||||
if err != nil {
|
||||
klog.Errorf("invalid headers: %v", err)
|
||||
h.initializeErr = err
|
||||
return len(b), err
|
||||
}
|
||||
resp.Body.Close() //nolint:errcheck
|
||||
|
||||
h.headerBuffer = nil
|
||||
err = h.initializableConn.InitializeWrite(resp)
|
||||
h.initialized = true
|
||||
if err != nil {
|
||||
return len(b), err
|
||||
h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes)
|
||||
if h.initializeErr != nil {
|
||||
return len(b), h.initializeErr
|
||||
}
|
||||
|
||||
// Copy any remaining buffered data to the underlying conn
|
||||
remainingBuffer, _ := io.ReadAll(bufferedReader)
|
||||
if len(remainingBuffer) > 0 {
|
||||
_, err = h.initializableConn.Write(remainingBuffer)
|
||||
if len(bodyBytes) > 0 {
|
||||
_, err = h.initializableConn.Write(bodyBytes)
|
||||
}
|
||||
return len(b), err
|
||||
}
|
||||
@ -274,7 +295,7 @@ type tunnelingWebsocketUpgraderConn struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) {
|
||||
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) {
|
||||
// make sure we close a connection we open in error cases
|
||||
var conn net.Conn
|
||||
defer func() {
|
||||
@ -337,9 +358,9 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
|
||||
u.err = err
|
||||
return u.err
|
||||
}
|
||||
// replay the backend response to the hijacked conn
|
||||
// replay the backend response bytes to the hijacked conn
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
|
||||
err = backendResponse.Write(conn)
|
||||
_, err = conn.Write(backendResponseBytes)
|
||||
if err != nil {
|
||||
u.err = err
|
||||
return u.err
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -106,56 +107,223 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
|
||||
assert.Equal(t, randomData, actual, "error validating tunneled random data")
|
||||
}
|
||||
|
||||
const responseStr = `HTTP/1.1 101 Switching Protocols
|
||||
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
||||
X-App-Protocol: portforward.k8s.io
|
||||
var expectedContentLengthHeaders = http.Header{
|
||||
"Content-Length": []string{"25"},
|
||||
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
|
||||
"Split-Point": []string{"split"},
|
||||
"X-App-Protocol": []string{"portforward.k8s.io"},
|
||||
}
|
||||
|
||||
`
|
||||
const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" +
|
||||
"Content-Length: 25\r\n" +
|
||||
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
||||
"Split-Point: split\r\n" +
|
||||
"X-App-Protocol: portforward.k8s.io\r\n" +
|
||||
"\r\n"
|
||||
|
||||
const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols
|
||||
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
||||
X-App-Protocol: portforward.k8s.io
|
||||
const contentLengthBody = "0123456789split0123456789"
|
||||
|
||||
This is extra data.
|
||||
`
|
||||
var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody
|
||||
|
||||
const invalidResponseStr = `INVALID/1.1 101 Switching Protocols
|
||||
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
||||
X-App-Protocol: portforward.k8s.io
|
||||
var expectedResponseHeaders = http.Header{
|
||||
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
|
||||
"Split-Point": []string{"split"},
|
||||
"X-App-Protocol": []string{"portforward.k8s.io"},
|
||||
}
|
||||
|
||||
`
|
||||
const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" +
|
||||
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
||||
"Split-Point: split\r\n" +
|
||||
"X-App-Protocol: portforward.k8s.io\r\n" +
|
||||
"\r\n"
|
||||
|
||||
const responseBody = "This is extra split data.\n"
|
||||
|
||||
var responseHeadersAndBody = responseHeaders + responseBody
|
||||
|
||||
const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" +
|
||||
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
||||
"Split-Point: split\r\n" +
|
||||
"X-App-Protocol: portforward.k8s.io\r\n" +
|
||||
"\r\n"
|
||||
|
||||
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
|
||||
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(responseStr))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
|
||||
assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn")
|
||||
t.Run("simple-no-body", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(responseHeaders))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
|
||||
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||
assert.Equal(t, "", string(testConnConstructor.mockConn.written))
|
||||
})
|
||||
|
||||
// Extra data after response headers should be sent to net.Conn.
|
||||
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err = hic.Write([]byte(responseWithExtraStr))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized)
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||
t.Run("simple-single-write", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(responseHeadersAndBody))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized)
|
||||
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||
})
|
||||
|
||||
// Partially written headers are buffered and decoded
|
||||
t.Run("simple-byte-by-byte", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
// write one byte at a time
|
||||
for _, b := range []byte(responseHeadersAndBody) {
|
||||
_, err := hic.Write([]byte{b})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.True(t, hic.initialized)
|
||||
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||
})
|
||||
|
||||
// Writes spanning the header/body breakpoint are buffered and decoded
|
||||
t.Run("simple-span-headerbody", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
// write one chunk at a time
|
||||
for i, chunk := range strings.Split(responseHeadersAndBody, "split") {
|
||||
if i > 0 {
|
||||
n, err := hic.Write([]byte("split"))
|
||||
require.Equal(t, n, len("split"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
n, err := hic.Write([]byte(chunk))
|
||||
require.Equal(t, n, len(chunk))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.True(t, hic.initialized)
|
||||
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||
})
|
||||
|
||||
// Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn.
|
||||
t.Run("simple-tolerate-lf", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", "")))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized)
|
||||
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite")
|
||||
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||
})
|
||||
|
||||
// Content-Length handling
|
||||
t.Run("content-length-body", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(contentLengthHeadersAndBody))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||
})
|
||||
|
||||
// Content-Length separately written headers and body
|
||||
t.Run("content-length-headers-body", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(contentLengthHeaders))
|
||||
require.NoError(t, err)
|
||||
_, err = hic.Write([]byte(contentLengthBody))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||
})
|
||||
|
||||
// Content-Length accumulating byte-by-byte
|
||||
t.Run("content-length-byte-by-byte", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
for _, b := range []byte(contentLengthHeadersAndBody) {
|
||||
_, err := hic.Write([]byte{b})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||
})
|
||||
|
||||
// Content-Length writes spanning headers / body
|
||||
t.Run("content-length-span-headerbody", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
// write one chunk at a time
|
||||
for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") {
|
||||
if i > 0 {
|
||||
n, err := hic.Write([]byte("split"))
|
||||
require.Equal(t, n, len("split"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
n, err := hic.Write([]byte(chunk))
|
||||
require.Equal(t, n, len(chunk))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||
})
|
||||
|
||||
// Invalid response returns error.
|
||||
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err = hic.Write([]byte(invalidResponseStr))
|
||||
assert.Error(t, err, "expected error from invalid http response")
|
||||
t.Run("invalid-single-write", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
_, err := hic.Write([]byte(invalidResponseData))
|
||||
assert.Error(t, err, "expected error from invalid http response")
|
||||
})
|
||||
|
||||
// Invalid response written byte by byte returns error.
|
||||
t.Run("invalid-byte-by-byte", func(t *testing.T) {
|
||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||
var err error
|
||||
for _, b := range []byte(invalidResponseData) {
|
||||
_, err = hic.Write([]byte{b})
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Error(t, err, "expected error from invalid http response")
|
||||
})
|
||||
}
|
||||
|
||||
type mockConnInitializer struct {
|
||||
resp *http.Response
|
||||
resp *http.Response
|
||||
initializeWriteConn *mockConn
|
||||
*mockConn
|
||||
}
|
||||
|
||||
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error {
|
||||
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error {
|
||||
m.resp = backendResponse
|
||||
return nil
|
||||
_, err := m.initializeWriteConn.Write(backendResponseBytes)
|
||||
return err
|
||||
}
|
||||
|
||||
// mockConn implements "net.Conn" interface.
|
||||
@ -166,9 +334,8 @@ type mockConn struct {
|
||||
}
|
||||
|
||||
func (mc *mockConn) Write(p []byte) (int, error) {
|
||||
mc.written = make([]byte, len(p))
|
||||
copy(mc.written, p)
|
||||
return len(mc.written), nil
|
||||
mc.written = append(mc.written, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
|
||||
|
Loading…
Reference in New Issue
Block a user