adds portforward streamtunnel unit tests

This commit is contained in:
Sean Sullivan 2024-03-06 16:37:17 -08:00
parent 05cb0a55c8
commit ffafb2b9ca

View File

@ -19,6 +19,8 @@ package proxy
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -48,7 +50,6 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
defer close(streamChan) defer close(streamChan)
stopServerChan := make(chan struct{}) stopServerChan := make(chan struct{})
defer close(stopServerChan) defer close(stopServerChan)
// Create fake upstream SPDY server.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name}) _, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
require.NoError(t, err) require.NoError(t, err)
@ -107,6 +108,120 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
assert.Equal(t, randomData, actual, "error validating tunneled random data") assert.Equal(t, randomData, actual, "error validating tunneled random data")
} }
func TestTunnelingResponseWriter_Hijack(t *testing.T) {
// Regular hijack returns connection, nil bufio, and no error.
trw := &tunnelingResponseWriter{conn: &mockConn{}}
assert.False(t, trw.hijacked, "hijacked field starts false before Hijack()")
assert.False(t, trw.written, "written field startes false before Hijack()")
actual, bufio, err := trw.Hijack()
assert.NoError(t, err, "Hijack() does not return error")
assert.NotNil(t, actual, "conn returned from Hijack() is not nil")
assert.Nil(t, bufio, "bufio returned from Hijack() is always nil")
assert.True(t, trw.hijacked, "hijacked field becomes true after Hijack()")
assert.False(t, trw.written, "written field stays false after Hijack()")
// Hijacking after writing to response writer is an error.
trw = &tunnelingResponseWriter{written: true}
_, _, err = trw.Hijack()
assert.Error(t, err, "Hijack after writing to response writer is error")
assert.True(t, strings.Contains(err.Error(), "connection has already been written to"))
// Hijacking after already hijacked is an error.
trw = &tunnelingResponseWriter{hijacked: true}
_, _, err = trw.Hijack()
assert.Error(t, err, "Hijack after writing to response writer is error")
assert.True(t, strings.Contains(err.Error(), "connection has already been hijacked"))
}
func TestTunnelingResponseWriter_DelegateResponseWriter(t *testing.T) {
// Validate Header() for delegate response writer.
expectedHeader := http.Header{}
expectedHeader.Set("foo", "bar")
trw := &tunnelingResponseWriter{w: &mockResponseWriter{header: expectedHeader}}
assert.Equal(t, expectedHeader, trw.Header(), "")
// Validate Write() for delegate response writer.
expectedWrite := []byte("this is a test write string")
assert.False(t, trw.written, "written field is before Write()")
_, err := trw.Write(expectedWrite)
assert.NoError(t, err, "No error expected after Write() on tunneling response writer")
assert.True(t, trw.written, "written field is set after writing to tunneling response writer")
// Writing to response writer after hijacked is an error.
trw.hijacked = true
_, err = trw.Write(expectedWrite)
assert.Error(t, err, "Writing to ResponseWriter after Hijack() is an error")
assert.True(t, errors.Is(err, http.ErrHijacked), "Hijacked error returned if writing after hijacked")
// Validate WriteHeader().
trw = &tunnelingResponseWriter{w: &mockResponseWriter{}}
expectedStatusCode := 201
assert.False(t, trw.written, "Written field originally false in delegate response writer")
trw.WriteHeader(expectedStatusCode)
assert.Equal(t, expectedStatusCode, trw.w.(*mockResponseWriter).statusCode, "Expected written status code is correct")
assert.True(t, trw.written, "Written field set to true after writing delegate response writer")
// Response writer already written to does not write status.
trw = &tunnelingResponseWriter{w: &mockResponseWriter{}}
trw.written = true
trw.WriteHeader(expectedStatusCode)
assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code for previously written response writer")
// Hijacked response writer does not write status.
trw = &tunnelingResponseWriter{w: &mockResponseWriter{}}
trw.hijacked = true
trw.WriteHeader(expectedStatusCode)
assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code written to hijacked response writer")
assert.False(t, trw.written, "Hijacked response writer does not write status")
// Writing "101 Switching Protocols" status is an error, since it should happen via hijacked connection.
trw = &tunnelingResponseWriter{w: &mockResponseWriter{header: http.Header{}}}
trw.WriteHeader(http.StatusSwitchingProtocols)
assert.Equal(t, http.StatusInternalServerError, trw.w.(*mockResponseWriter).statusCode, "Internal server error written")
}
func TestTunnelingWebsocketUpgraderConn_LocalRemoteAddress(t *testing.T) {
expectedLocalAddr := &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 80,
}
expectedRemoteAddr := &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 2),
Port: 443,
}
tc := &tunnelingWebsocketUpgraderConn{
conn: &mockConn{
localAddr: expectedLocalAddr,
remoteAddr: expectedRemoteAddr,
},
}
assert.Equal(t, expectedLocalAddr, tc.LocalAddr(), "LocalAddr() returns expected TCPAddr")
assert.Equal(t, expectedRemoteAddr, tc.RemoteAddr(), "RemoteAddr() returns expected TCPAddr")
// Connection nil, returns empty address
tc.conn = nil
assert.Equal(t, noopAddr{}, tc.LocalAddr(), "nil connection, LocalAddr() returns noopAddr")
assert.Equal(t, noopAddr{}, tc.RemoteAddr(), "nil connection, RemoteAddr() returns noopAddr")
// Validate the empty strings from noopAddr
assert.Equal(t, "", noopAddr{}.Network(), "noopAddr Network() returns empty string")
assert.Equal(t, "", noopAddr{}.String(), "noopAddr String() returns empty string")
}
func TestTunnelingWebsocketUpgraderConn_SetDeadline(t *testing.T) {
tc := &tunnelingWebsocketUpgraderConn{conn: &mockConn{}}
expected := time.Now()
assert.Nil(t, tc.SetDeadline(expected), "SetDeadline does not return error")
assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "SetDeadline() sets read deadline")
assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "SetDeadline() sets write deadline")
expected = time.Now()
assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline does not return error")
assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "Expected write deadline set")
expected = time.Now()
assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline does not return error")
assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "Expected read deadline set")
expectedErr := fmt.Errorf("deadline error")
tc = &tunnelingWebsocketUpgraderConn{conn: &mockConn{deadlineErr: expectedErr}}
expected = time.Now()
actualErr := tc.SetDeadline(expected)
assert.Equal(t, expectedErr, actualErr, "SetDeadline() expected error returned")
// Connection nil, returns nil error.
tc.conn = nil
assert.Nil(t, tc.SetDeadline(expected), "SetDeadline() with nil connection always returns nil error")
assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline() with nil connection always returns nil error")
assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline() with nil connection always returns nil error")
}
var expectedContentLengthHeaders = http.Header{ var expectedContentLengthHeaders = http.Header{
"Content-Length": []string{"25"}, "Content-Length": []string{"25"},
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"}, "Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
@ -330,7 +445,12 @@ func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, ba
var _ net.Conn = &mockConn{} var _ net.Conn = &mockConn{}
type mockConn struct { type mockConn struct {
written []byte written []byte
localAddr *net.TCPAddr
remoteAddr *net.TCPAddr
readDeadline time.Time
writeDeadline time.Time
deadlineErr error
} }
func (mc *mockConn) Write(p []byte) (int, error) { func (mc *mockConn) Write(p []byte) (int, error) {
@ -338,13 +458,31 @@ func (mc *mockConn) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil } func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
func (mc *mockConn) Close() error { return nil } func (mc *mockConn) Close() error { return nil }
func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} } func (mc *mockConn) LocalAddr() net.Addr { return mc.localAddr }
func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } func (mc *mockConn) RemoteAddr() net.Addr { return mc.remoteAddr }
func (mc *mockConn) SetDeadline(t time.Time) error { return nil } func (mc *mockConn) SetDeadline(t time.Time) error {
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil } mc.SetReadDeadline(t) //nolint:errcheck
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil } mc.SetWriteDeadline(t) // nolint:errcheck
return mc.deadlineErr
}
func (mc *mockConn) SetReadDeadline(t time.Time) error { mc.readDeadline = t; return mc.deadlineErr }
func (mc *mockConn) SetWriteDeadline(t time.Time) error { mc.writeDeadline = t; return mc.deadlineErr }
// mockResponseWriter implements "http.ResponseWriter" interface
type mockResponseWriter struct {
header http.Header
written []byte
statusCode int
}
func (mrw *mockResponseWriter) Header() http.Header { return mrw.header }
func (mrw *mockResponseWriter) Write(p []byte) (int, error) {
mrw.written = append(mrw.written, p...)
return len(p), nil
}
func (mrw *mockResponseWriter) WriteHeader(statusCode int) { mrw.statusCode = statusCode }
// fakeResponder implements "rest.Responder" interface. // fakeResponder implements "rest.Responder" interface.
var _ rest.Responder = &fakeResponder{} var _ rest.Responder = &fakeResponder{}