diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go index b9645588b77..74b02c37ff8 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go @@ -19,6 +19,8 @@ package proxy import ( "bytes" "crypto/rand" + "errors" + "fmt" "io" "net" "net/http" @@ -48,7 +50,6 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { defer close(streamChan) stopServerChan := make(chan struct{}) defer close(stopServerChan) - // Create fake upstream SPDY server. spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name}) require.NoError(t, err) @@ -107,6 +108,120 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { assert.Equal(t, randomData, actual, "error validating tunneled random data") } +func TestTunnelingResponseWriter_Hijack(t *testing.T) { + // Regular hijack returns connection, nil bufio, and no error. + trw := &tunnelingResponseWriter{conn: &mockConn{}} + assert.False(t, trw.hijacked, "hijacked field starts false before Hijack()") + assert.False(t, trw.written, "written field startes false before Hijack()") + actual, bufio, err := trw.Hijack() + assert.NoError(t, err, "Hijack() does not return error") + assert.NotNil(t, actual, "conn returned from Hijack() is not nil") + assert.Nil(t, bufio, "bufio returned from Hijack() is always nil") + assert.True(t, trw.hijacked, "hijacked field becomes true after Hijack()") + assert.False(t, trw.written, "written field stays false after Hijack()") + // Hijacking after writing to response writer is an error. + trw = &tunnelingResponseWriter{written: true} + _, _, err = trw.Hijack() + assert.Error(t, err, "Hijack after writing to response writer is error") + assert.True(t, strings.Contains(err.Error(), "connection has already been written to")) + // Hijacking after already hijacked is an error. + trw = &tunnelingResponseWriter{hijacked: true} + _, _, err = trw.Hijack() + assert.Error(t, err, "Hijack after writing to response writer is error") + assert.True(t, strings.Contains(err.Error(), "connection has already been hijacked")) +} + +func TestTunnelingResponseWriter_DelegateResponseWriter(t *testing.T) { + // Validate Header() for delegate response writer. + expectedHeader := http.Header{} + expectedHeader.Set("foo", "bar") + trw := &tunnelingResponseWriter{w: &mockResponseWriter{header: expectedHeader}} + assert.Equal(t, expectedHeader, trw.Header(), "") + // Validate Write() for delegate response writer. + expectedWrite := []byte("this is a test write string") + assert.False(t, trw.written, "written field is before Write()") + _, err := trw.Write(expectedWrite) + assert.NoError(t, err, "No error expected after Write() on tunneling response writer") + assert.True(t, trw.written, "written field is set after writing to tunneling response writer") + // Writing to response writer after hijacked is an error. + trw.hijacked = true + _, err = trw.Write(expectedWrite) + assert.Error(t, err, "Writing to ResponseWriter after Hijack() is an error") + assert.True(t, errors.Is(err, http.ErrHijacked), "Hijacked error returned if writing after hijacked") + // Validate WriteHeader(). + trw = &tunnelingResponseWriter{w: &mockResponseWriter{}} + expectedStatusCode := 201 + assert.False(t, trw.written, "Written field originally false in delegate response writer") + trw.WriteHeader(expectedStatusCode) + assert.Equal(t, expectedStatusCode, trw.w.(*mockResponseWriter).statusCode, "Expected written status code is correct") + assert.True(t, trw.written, "Written field set to true after writing delegate response writer") + // Response writer already written to does not write status. + trw = &tunnelingResponseWriter{w: &mockResponseWriter{}} + trw.written = true + trw.WriteHeader(expectedStatusCode) + assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code for previously written response writer") + // Hijacked response writer does not write status. + trw = &tunnelingResponseWriter{w: &mockResponseWriter{}} + trw.hijacked = true + trw.WriteHeader(expectedStatusCode) + assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code written to hijacked response writer") + assert.False(t, trw.written, "Hijacked response writer does not write status") + // Writing "101 Switching Protocols" status is an error, since it should happen via hijacked connection. + trw = &tunnelingResponseWriter{w: &mockResponseWriter{header: http.Header{}}} + trw.WriteHeader(http.StatusSwitchingProtocols) + assert.Equal(t, http.StatusInternalServerError, trw.w.(*mockResponseWriter).statusCode, "Internal server error written") +} + +func TestTunnelingWebsocketUpgraderConn_LocalRemoteAddress(t *testing.T) { + expectedLocalAddr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 80, + } + expectedRemoteAddr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 2), + Port: 443, + } + tc := &tunnelingWebsocketUpgraderConn{ + conn: &mockConn{ + localAddr: expectedLocalAddr, + remoteAddr: expectedRemoteAddr, + }, + } + assert.Equal(t, expectedLocalAddr, tc.LocalAddr(), "LocalAddr() returns expected TCPAddr") + assert.Equal(t, expectedRemoteAddr, tc.RemoteAddr(), "RemoteAddr() returns expected TCPAddr") + // Connection nil, returns empty address + tc.conn = nil + assert.Equal(t, noopAddr{}, tc.LocalAddr(), "nil connection, LocalAddr() returns noopAddr") + assert.Equal(t, noopAddr{}, tc.RemoteAddr(), "nil connection, RemoteAddr() returns noopAddr") + // Validate the empty strings from noopAddr + assert.Equal(t, "", noopAddr{}.Network(), "noopAddr Network() returns empty string") + assert.Equal(t, "", noopAddr{}.String(), "noopAddr String() returns empty string") +} + +func TestTunnelingWebsocketUpgraderConn_SetDeadline(t *testing.T) { + tc := &tunnelingWebsocketUpgraderConn{conn: &mockConn{}} + expected := time.Now() + assert.Nil(t, tc.SetDeadline(expected), "SetDeadline does not return error") + assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "SetDeadline() sets read deadline") + assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "SetDeadline() sets write deadline") + expected = time.Now() + assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline does not return error") + assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "Expected write deadline set") + expected = time.Now() + assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline does not return error") + assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "Expected read deadline set") + expectedErr := fmt.Errorf("deadline error") + tc = &tunnelingWebsocketUpgraderConn{conn: &mockConn{deadlineErr: expectedErr}} + expected = time.Now() + actualErr := tc.SetDeadline(expected) + assert.Equal(t, expectedErr, actualErr, "SetDeadline() expected error returned") + // Connection nil, returns nil error. + tc.conn = nil + assert.Nil(t, tc.SetDeadline(expected), "SetDeadline() with nil connection always returns nil error") + assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline() with nil connection always returns nil error") + assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline() with nil connection always returns nil error") +} + var expectedContentLengthHeaders = http.Header{ "Content-Length": []string{"25"}, "Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"}, @@ -330,7 +445,12 @@ func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, ba var _ net.Conn = &mockConn{} type mockConn struct { - written []byte + written []byte + localAddr *net.TCPAddr + remoteAddr *net.TCPAddr + readDeadline time.Time + writeDeadline time.Time + deadlineErr error } func (mc *mockConn) Write(p []byte) (int, error) { @@ -338,13 +458,31 @@ func (mc *mockConn) Write(p []byte) (int, error) { return len(p), nil } -func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil } -func (mc *mockConn) Close() error { return nil } -func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} } -func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } -func (mc *mockConn) SetDeadline(t time.Time) error { return nil } -func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil } -func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil } +func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil } +func (mc *mockConn) Close() error { return nil } +func (mc *mockConn) LocalAddr() net.Addr { return mc.localAddr } +func (mc *mockConn) RemoteAddr() net.Addr { return mc.remoteAddr } +func (mc *mockConn) SetDeadline(t time.Time) error { + mc.SetReadDeadline(t) //nolint:errcheck + mc.SetWriteDeadline(t) // nolint:errcheck + return mc.deadlineErr +} +func (mc *mockConn) SetReadDeadline(t time.Time) error { mc.readDeadline = t; return mc.deadlineErr } +func (mc *mockConn) SetWriteDeadline(t time.Time) error { mc.writeDeadline = t; return mc.deadlineErr } + +// mockResponseWriter implements "http.ResponseWriter" interface +type mockResponseWriter struct { + header http.Header + written []byte + statusCode int +} + +func (mrw *mockResponseWriter) Header() http.Header { return mrw.header } +func (mrw *mockResponseWriter) Write(p []byte) (int, error) { + mrw.written = append(mrw.written, p...) + return len(p), nil +} +func (mrw *mockResponseWriter) WriteHeader(statusCode int) { mrw.statusCode = statusCode } // fakeResponder implements "rest.Responder" interface. var _ rest.Responder = &fakeResponder{}