mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-09 12:07:47 +00:00
Fix headerInterceptingConn handling
This commit is contained in:
parent
9e15462843
commit
2443b3fa69
@ -21,7 +21,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -206,16 +205,26 @@ type headerInterceptingConn struct {
|
|||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
headerBuffer []byte
|
headerBuffer []byte
|
||||||
initialized bool
|
initialized bool
|
||||||
|
initializeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
// initializableConn is a connection that will be initialized before any calls to Write are made
|
// initializableConn is a connection that will be initialized before any calls to Write are made
|
||||||
type initializableConn interface {
|
type initializableConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
InitializeWrite(backendResponse *http.Response) error
|
// InitializeWrite is called when the backend response headers have been read.
|
||||||
|
// backendResponse contains the parsed headers.
|
||||||
|
// backendResponseBytes are the raw bytes the headers were parsed from.
|
||||||
|
InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxHeaderBytes = 1 << 20
|
const maxHeaderBytes = 1 << 20
|
||||||
|
|
||||||
|
// token for normal header / body separation (\r\n\r\n, but go tolerates the leading \r being absent)
|
||||||
|
var lfCRLF = []byte("\n\r\n")
|
||||||
|
|
||||||
|
// token for header / body separation without \r (which go tolerates)
|
||||||
|
var lfLF = []byte("\n\n")
|
||||||
|
|
||||||
// Write intercepts to initially swallow the HTTP response, then
|
// Write intercepts to initially swallow the HTTP response, then
|
||||||
// delegate to the tunneling "net.Conn" once the response has been
|
// delegate to the tunneling "net.Conn" once the response has been
|
||||||
// seen and processed.
|
// seen and processed.
|
||||||
@ -223,39 +232,51 @@ func (h *headerInterceptingConn) Write(b []byte) (int, error) {
|
|||||||
h.lock.Lock()
|
h.lock.Lock()
|
||||||
defer h.lock.Unlock()
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
if h.initializeErr != nil {
|
||||||
|
return 0, h.initializeErr
|
||||||
|
}
|
||||||
if h.initialized {
|
if h.initialized {
|
||||||
return h.initializableConn.Write(b)
|
return h.initializableConn.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write into the headerBuffer, then attempt to parse the bytes
|
// Guard against excessive buffering
|
||||||
// as an http response.
|
|
||||||
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
|
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
|
||||||
return 0, fmt.Errorf("header size limit exceeded")
|
return 0, fmt.Errorf("header size limit exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accumulate into headerBuffer
|
||||||
h.headerBuffer = append(h.headerBuffer, b...)
|
h.headerBuffer = append(h.headerBuffer, b...)
|
||||||
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
|
|
||||||
resp, err := http.ReadResponse(bufferedReader, nil)
|
// Attempt to parse http response headers
|
||||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
var headerBytes, bodyBytes []byte
|
||||||
// don't yet have a complete set of headers
|
if i := bytes.Index(h.headerBuffer, lfCRLF); i != -1 {
|
||||||
|
// headers terminated with \n\r\n
|
||||||
|
headerBytes = h.headerBuffer[0 : i+len(lfCRLF)]
|
||||||
|
bodyBytes = h.headerBuffer[i+len(lfCRLF):]
|
||||||
|
} else if i := bytes.Index(h.headerBuffer, lfLF); i != -1 {
|
||||||
|
// headers terminated with \n\n (which go tolerates)
|
||||||
|
headerBytes = h.headerBuffer[0 : i+len(lfLF)]
|
||||||
|
bodyBytes = h.headerBuffer[i+len(lfLF):]
|
||||||
|
} else {
|
||||||
|
// don't yet have a complete set of headers yet
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
klog.Errorf("invalid headers: %v", err)
|
klog.Errorf("invalid headers: %v", err)
|
||||||
|
h.initializeErr = err
|
||||||
return len(b), err
|
return len(b), err
|
||||||
}
|
}
|
||||||
resp.Body.Close() //nolint:errcheck
|
resp.Body.Close() //nolint:errcheck
|
||||||
|
|
||||||
h.headerBuffer = nil
|
h.headerBuffer = nil
|
||||||
err = h.initializableConn.InitializeWrite(resp)
|
|
||||||
h.initialized = true
|
h.initialized = true
|
||||||
if err != nil {
|
h.initializeErr = h.initializableConn.InitializeWrite(resp, headerBytes)
|
||||||
return len(b), err
|
if h.initializeErr != nil {
|
||||||
|
return len(b), h.initializeErr
|
||||||
}
|
}
|
||||||
|
if len(bodyBytes) > 0 {
|
||||||
// Copy any remaining buffered data to the underlying conn
|
_, err = h.initializableConn.Write(bodyBytes)
|
||||||
remainingBuffer, _ := io.ReadAll(bufferedReader)
|
|
||||||
if len(remainingBuffer) > 0 {
|
|
||||||
_, err = h.initializableConn.Write(remainingBuffer)
|
|
||||||
}
|
}
|
||||||
return len(b), err
|
return len(b), err
|
||||||
}
|
}
|
||||||
@ -274,7 +295,7 @@ type tunnelingWebsocketUpgraderConn struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) {
|
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) (err error) {
|
||||||
// make sure we close a connection we open in error cases
|
// make sure we close a connection we open in error cases
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -337,9 +358,9 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
|
|||||||
u.err = err
|
u.err = err
|
||||||
return u.err
|
return u.err
|
||||||
}
|
}
|
||||||
// replay the backend response to the hijacked conn
|
// replay the backend response bytes to the hijacked conn
|
||||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
|
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
|
||||||
err = backendResponse.Write(conn)
|
_, err = conn.Write(backendResponseBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.err = err
|
u.err = err
|
||||||
return u.err
|
return u.err
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -106,56 +107,223 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
|
|||||||
assert.Equal(t, randomData, actual, "error validating tunneled random data")
|
assert.Equal(t, randomData, actual, "error validating tunneled random data")
|
||||||
}
|
}
|
||||||
|
|
||||||
const responseStr = `HTTP/1.1 101 Switching Protocols
|
var expectedContentLengthHeaders = http.Header{
|
||||||
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
"Content-Length": []string{"25"},
|
||||||
X-App-Protocol: portforward.k8s.io
|
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
|
||||||
|
"Split-Point": []string{"split"},
|
||||||
|
"X-App-Protocol": []string{"portforward.k8s.io"},
|
||||||
|
}
|
||||||
|
|
||||||
`
|
const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" +
|
||||||
|
"Content-Length: 25\r\n" +
|
||||||
|
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
||||||
|
"Split-Point: split\r\n" +
|
||||||
|
"X-App-Protocol: portforward.k8s.io\r\n" +
|
||||||
|
"\r\n"
|
||||||
|
|
||||||
const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols
|
const contentLengthBody = "0123456789split0123456789"
|
||||||
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
|
||||||
X-App-Protocol: portforward.k8s.io
|
|
||||||
|
|
||||||
This is extra data.
|
var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody
|
||||||
`
|
|
||||||
|
|
||||||
const invalidResponseStr = `INVALID/1.1 101 Switching Protocols
|
var expectedResponseHeaders = http.Header{
|
||||||
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
|
||||||
X-App-Protocol: portforward.k8s.io
|
"Split-Point": []string{"split"},
|
||||||
|
"X-App-Protocol": []string{"portforward.k8s.io"},
|
||||||
|
}
|
||||||
|
|
||||||
`
|
const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" +
|
||||||
|
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
||||||
|
"Split-Point: split\r\n" +
|
||||||
|
"X-App-Protocol: portforward.k8s.io\r\n" +
|
||||||
|
"\r\n"
|
||||||
|
|
||||||
|
const responseBody = "This is extra split data.\n"
|
||||||
|
|
||||||
|
var responseHeadersAndBody = responseHeaders + responseBody
|
||||||
|
|
||||||
|
const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" +
|
||||||
|
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
||||||
|
"Split-Point: split\r\n" +
|
||||||
|
"X-App-Protocol: portforward.k8s.io\r\n" +
|
||||||
|
"\r\n"
|
||||||
|
|
||||||
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
|
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
|
||||||
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
|
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
|
||||||
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}}
|
t.Run("simple-no-body", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
_, err := hic.Write([]byte(responseStr))
|
_, err := hic.Write([]byte(responseHeaders))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||||
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
|
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
|
||||||
assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn")
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||||
|
assert.Equal(t, "", string(testConnConstructor.mockConn.written))
|
||||||
|
})
|
||||||
|
|
||||||
// Extra data after response headers should be sent to net.Conn.
|
// Extra data after response headers should be sent to net.Conn.
|
||||||
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
|
t.Run("simple-single-write", func(t *testing.T) {
|
||||||
_, err = hic.Write([]byte(responseWithExtraStr))
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err := hic.Write([]byte(responseHeadersAndBody))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, hic.initialized)
|
assert.True(t, hic.initialized)
|
||||||
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||||
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||||
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Partially written headers are buffered and decoded
|
||||||
|
t.Run("simple-byte-by-byte", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
// write one byte at a time
|
||||||
|
for _, b := range []byte(responseHeadersAndBody) {
|
||||||
|
_, err := hic.Write([]byte{b})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.True(t, hic.initialized)
|
||||||
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||||
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Writes spanning the header/body breakpoint are buffered and decoded
|
||||||
|
t.Run("simple-span-headerbody", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
// write one chunk at a time
|
||||||
|
for i, chunk := range strings.Split(responseHeadersAndBody, "split") {
|
||||||
|
if i > 0 {
|
||||||
|
n, err := hic.Write([]byte("split"))
|
||||||
|
require.Equal(t, n, len("split"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
n, err := hic.Write([]byte(chunk))
|
||||||
|
require.Equal(t, n, len(chunk))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.True(t, hic.initialized)
|
||||||
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
||||||
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn.
|
||||||
|
t.Run("simple-tolerate-lf", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", "")))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, hic.initialized)
|
||||||
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite")
|
||||||
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Content-Length handling
|
||||||
|
t.Run("content-length-body", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err := hic.Write([]byte(contentLengthHeadersAndBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||||
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||||
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Content-Length separately written headers and body
|
||||||
|
t.Run("content-length-headers-body", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err := hic.Write([]byte(contentLengthHeaders))
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = hic.Write([]byte(contentLengthBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||||
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||||
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Content-Length accumulating byte-by-byte
|
||||||
|
t.Run("content-length-byte-by-byte", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
for _, b := range []byte(contentLengthHeadersAndBody) {
|
||||||
|
_, err := hic.Write([]byte{b})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||||
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||||
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Content-Length writes spanning headers / body
|
||||||
|
t.Run("content-length-span-headerbody", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
// write one chunk at a time
|
||||||
|
for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") {
|
||||||
|
if i > 0 {
|
||||||
|
n, err := hic.Write([]byte("split"))
|
||||||
|
require.Equal(t, n, len("split"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
n, err := hic.Write([]byte(chunk))
|
||||||
|
require.Equal(t, n, len(chunk))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||||
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
||||||
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
||||||
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
||||||
|
})
|
||||||
|
|
||||||
// Invalid response returns error.
|
// Invalid response returns error.
|
||||||
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
|
t.Run("invalid-single-write", func(t *testing.T) {
|
||||||
_, err = hic.Write([]byte(invalidResponseStr))
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err := hic.Write([]byte(invalidResponseData))
|
||||||
assert.Error(t, err, "expected error from invalid http response")
|
assert.Error(t, err, "expected error from invalid http response")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Invalid response written byte by byte returns error.
|
||||||
|
t.Run("invalid-byte-by-byte", func(t *testing.T) {
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
var err error
|
||||||
|
for _, b := range []byte(invalidResponseData) {
|
||||||
|
_, err = hic.Write([]byte{b})
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Error(t, err, "expected error from invalid http response")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockConnInitializer struct {
|
type mockConnInitializer struct {
|
||||||
resp *http.Response
|
resp *http.Response
|
||||||
|
initializeWriteConn *mockConn
|
||||||
*mockConn
|
*mockConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error {
|
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error {
|
||||||
m.resp = backendResponse
|
m.resp = backendResponse
|
||||||
return nil
|
_, err := m.initializeWriteConn.Write(backendResponseBytes)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockConn implements "net.Conn" interface.
|
// mockConn implements "net.Conn" interface.
|
||||||
@ -166,9 +334,8 @@ type mockConn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mockConn) Write(p []byte) (int, error) {
|
func (mc *mockConn) Write(p []byte) (int, error) {
|
||||||
mc.written = make([]byte, len(p))
|
mc.written = append(mc.written, p...)
|
||||||
copy(mc.written, p)
|
return len(p), nil
|
||||||
return len(mc.written), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
|
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
|
||||||
|
Loading…
Reference in New Issue
Block a user