portforward: tunnel spdy through websockets

This commit is contained in:
Sean Sullivan 2024-02-21 08:56:07 +00:00 committed by Sean Sullivan
parent 89cbd94e68
commit 8b447d8c97
20 changed files with 1560 additions and 19 deletions

View File

@ -619,6 +619,13 @@ const (
// Enable users to specify when a Pod is ready for scheduling.
PodSchedulingReadiness featuregate.Feature = "PodSchedulingReadiness"
// owner: @seans3
// kep: http://kep.k8s.io/4006
// alpha: v1.30
//
// Enables PortForward to be proxied with a websocket client
PortForwardWebsockets featuregate.Feature = "PortForwardWebsockets"
// owner: @jessfraz
// alpha: v1.12
//
@ -1101,6 +1108,8 @@ var defaultKubernetesFeatureGates = map[featuregate.Feature]featuregate.FeatureS
PodSchedulingReadiness: {Default: true, PreRelease: featuregate.GA, LockToDefault: true}, // GA in 1.30; remove in 1.32
PortForwardWebsockets: {Default: false, PreRelease: featuregate.Alpha},
ProcMountType: {Default: false, PreRelease: featuregate.Alpha},
QOSReserved: {Default: false, PreRelease: featuregate.Alpha},

View File

@ -242,7 +242,12 @@ func (r *PortForwardREST) Connect(ctx context.Context, name string, opts runtime
if err != nil {
return nil, err
}
return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil
handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder)
if utilfeature.DefaultFeatureGate.Enabled(features.PortForwardWebsockets) {
tunnelingHandler := translator.NewTunnelingHandler(handler)
handler = translator.NewTranslatingHandler(handler, tunnelingHandler, wsstream.IsWebSocketRequestWithTunnelingProtocol)
}
return handler, nil
}
func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler {

View File

@ -27,6 +27,7 @@ import (
"golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/portforward"
"k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/klog/v2"
@ -106,6 +107,23 @@ func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
return false
}
// IsWebSocketRequestWithTunnelingProtocol returns true if the request contains headers
// identifying that it is requesting a websocket upgrade with a tunneling protocol;
// false otherwise.
func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool {
if !IsWebSocketRequest(req) {
return false
}
requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) {
return true
}
}
return false
}
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
// read and write deadlines are pushed every time a new message is received.
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
@ -301,6 +319,12 @@ func protocolSupportsStreamClose(protocol string) bool {
return protocol == remotecommand.StreamProtocolV5Name
}
// protocolSupportsWebsocketTunneling returns true if the passed protocol
// is a tunneled Kubernetes spdy protocol; false otherwise.
func protocolSupportsWebsocketTunneling(protocol string) bool {
return strings.HasPrefix(protocol, portforward.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, portforward.KubernetesSuffix)
}
// handle implements a websocket handler.
func (conn *Conn) handle(ws *websocket.Conn) {
conn.initialize(ws)

View File

@ -0,0 +1,24 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
const (
PortForwardV1Name = "portforward.k8s.io"
WebsocketsSPDYTunnelingPrefix = "SPDY/3.1+"
KubernetesSuffix = ".k8s.io"
WebsocketsSPDYTunnelingPortForwardV1 = WebsocketsSPDYTunnelingPrefix + PortForwardV1Name
)

View File

@ -36,6 +36,7 @@ import (
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"github.com/mxk/go-flowrate/flowrate"
"k8s.io/klog/v2"
)
@ -336,6 +337,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
clone.Host = h.Location.Host
}
clone.URL = &location
klog.V(6).Infof("UpgradeAwareProxy: dialing for SPDY upgrade with headers: %v", clone.Header)
backendConn, err = h.DialForUpgrade(clone)
if err != nil {
klog.V(6).Infof("Proxy connection error: %v", err)
@ -370,13 +372,13 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
// hijacking should be the last step in the upgrade.
requestHijacker, ok := w.(http.Hijacker)
if !ok {
klog.V(6).Infof("Unable to hijack response writer: %T", w)
klog.Errorf("Unable to hijack response writer: %T", w)
h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
return true
}
requestHijackedConn, _, err := requestHijacker.Hijack()
if err != nil {
klog.V(6).Infof("Unable to hijack response: %v", err)
klog.Errorf("Unable to hijack response: %v", err)
h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
return true
}
@ -420,7 +422,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
} else {
writer = backendConn
}
_, err := io.Copy(writer, requestHijackedConn)
_, err := io.Copy(writer, &loggingReader{name: "client->backend", delegate: requestHijackedConn})
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
klog.Errorf("Error proxying data from client to backend: %v", err)
}
@ -434,7 +436,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
} else {
reader = backendConn
}
_, err := io.Copy(requestHijackedConn, reader)
_, err := io.Copy(requestHijackedConn, &loggingReader{name: "backend->client", delegate: reader})
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
klog.Errorf("Error proxying data from backend to client: %v", err)
}
@ -452,6 +454,18 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
return true
}
// loggingReader logs the bytes read from the "delegate" with a "name" prefix.
type loggingReader struct {
name string
delegate io.Reader
}
func (l *loggingReader) Read(p []byte) (int, error) {
n, err := l.delegate.Read(p)
klog.V(8).Infof("%s: %d bytes, err=%v, bytes=% X", l.name, n, err, p[:n])
return n, err
}
// FIXME: Taken from net/http/httputil/reverseproxy.go as singleJoiningSlash is not exported to be re-used.
// See-also: https://github.com/golang/go/issues/44290
func singleJoiningSlash(a, b string) string {

View File

@ -16,6 +16,7 @@ require (
github.com/google/go-cmp v0.6.0
github.com/google/gofuzz v1.2.0
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
@ -77,7 +78,6 @@ require (
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/btree v1.0.1 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect

View File

@ -0,0 +1,433 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
gwebsocket "github.com/gorilla/websocket"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
constants "k8s.io/apimachinery/pkg/util/portforward"
"k8s.io/client-go/tools/portforward"
"k8s.io/klog/v2"
)
// TunnelingHandler is a handler which tunnels SPDY through WebSockets.
type TunnelingHandler struct {
// Used to communicate between upstream SPDY and downstream tunnel.
upgradeHandler http.Handler
}
// NewTunnelingHandler is used to create the tunnel between an upstream
// SPDY connection and a downstream tunneling connection through the stored
// UpgradeAwareProxy.
func NewTunnelingHandler(upgradeHandler http.Handler) *TunnelingHandler {
return &TunnelingHandler{upgradeHandler: upgradeHandler}
}
// ServeHTTP uses the upgradeHandler to tunnel between a downstream tunneling
// connection and an upstream SPDY connection. The tunneling connection is
// a wrapped WebSockets connection which communicates SPDY framed data.
func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
klog.V(4).Infoln("TunnelingHandler ServeHTTP")
spdyProtocols := spdyProtocolsFromWebsocketProtocols(req)
if len(spdyProtocols) == 0 {
http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest)
return
}
spdyRequest := createSPDYRequest(req, spdyProtocols...)
writer := &tunnelingResponseWriter{
w: w,
conn: &headerInterceptingConn{
initializableConn: &tunnelingWebsocketUpgraderConn{
w: w,
req: req,
},
},
}
klog.V(4).Infoln("Tunnel spdy through websockets using the UpgradeAwareProxy")
h.upgradeHandler.ServeHTTP(writer, spdyRequest)
}
// createSPDYRequest modifies the passed request to remove
// WebSockets headers and add SPDY upgrade information, including
// spdy protocols acceptable to the client.
func createSPDYRequest(req *http.Request, spdyProtocols ...string) *http.Request {
clone := utilnet.CloneRequest(req)
// Clean up the websocket headers from the http request.
clone.Header.Del(wsstream.WebSocketProtocolHeader)
clone.Header.Del("Sec-Websocket-Key")
clone.Header.Del("Sec-Websocket-Version")
clone.Header.Del(httpstream.HeaderUpgrade)
// Update the http request for an upstream SPDY upgrade.
clone.Method = "POST"
clone.Body = nil // Remove the request body which is unused.
clone.Header.Set(httpstream.HeaderUpgrade, spdy.HeaderSpdy31)
clone.Header.Del(httpstream.HeaderProtocolVersion)
for i := range spdyProtocols {
clone.Header.Add(httpstream.HeaderProtocolVersion, spdyProtocols[i])
}
return clone
}
// spdyProtocolsFromWebsocketProtocols returns a list of spdy protocols by filtering
// to Kubernetes websocket subprotocols prefixed with "SPDY/3.1+", then removing the prefix
func spdyProtocolsFromWebsocketProtocols(req *http.Request) []string {
var spdyProtocols []string
for _, protocol := range gwebsocket.Subprotocols(req) {
if strings.HasPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, constants.KubernetesSuffix) {
spdyProtocols = append(spdyProtocols, strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix))
}
}
return spdyProtocols
}
var _ http.ResponseWriter = &tunnelingResponseWriter{}
var _ http.Hijacker = &tunnelingResponseWriter{}
// tunnelingResponseWriter implements the http.ResponseWriter and http.Hijacker interfaces.
// Only non-upgrade responses can be written using WriteHeader() and Write().
// Once Write or WriteHeader is called, Hijack returns an error.
// Once Hijack is called, Write, WriteHeader, and Hijack return errors.
type tunnelingResponseWriter struct {
// w is used to delegate Header(), WriteHeader(), and Write() calls
w http.ResponseWriter
// conn is returned from Hijack()
conn net.Conn
// mu guards writes
mu sync.Mutex
// wrote tracks whether WriteHeader or Write has been called
written bool
// hijacked tracks whether Hijack has been called
hijacked bool
}
// Hijack returns a delegate "net.Conn".
// An error is returned if Write(), WriteHeader(), or Hijack() was previously called.
// The returned bufio.ReadWriter is always nil.
func (w *tunnelingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.written {
klog.Errorf("Hijack called after write")
return nil, nil, errors.New("connection has already been written to")
}
if w.hijacked {
klog.Errorf("Hijack called after hijack")
return nil, nil, errors.New("connection has already been hijacked")
}
w.hijacked = true
klog.V(6).Infof("Hijack returning websocket tunneling net.Conn")
return w.conn, nil, nil
}
// Header is delegated to the stored "http.ResponseWriter".
func (w *tunnelingResponseWriter) Header() http.Header {
return w.w.Header()
}
// Write is delegated to the stored "http.ResponseWriter".
func (w *tunnelingResponseWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.hijacked {
klog.Errorf("Write called after hijack")
return 0, http.ErrHijacked
}
w.written = true
return w.w.Write(p)
}
// WriteHeader is delegated to the stored "http.ResponseWriter".
func (w *tunnelingResponseWriter) WriteHeader(statusCode int) {
w.mu.Lock()
defer w.mu.Unlock()
if w.written {
klog.Errorf("WriteHeader called after write")
return
}
if w.hijacked {
klog.Errorf("WriteHeader called after hijack")
return
}
w.written = true
if statusCode == http.StatusSwitchingProtocols {
// 101 upgrade responses must come via the hijacked connection, not WriteHeader
klog.Errorf("WriteHeader called with 101 upgrade")
http.Error(w.w, "unexpected upgrade", http.StatusInternalServerError)
return
}
// pass through non-upgrade responses we don't need to translate
w.w.WriteHeader(statusCode)
}
// headerInterceptingConn wraps the tunneling "net.Conn" to drain the
// HTTP response status/headers from the upstream SPDY connection, then use
// that to decide how to initialize the delegate connection for writes.
type headerInterceptingConn struct {
// initializableConn is delegated to for all net.Conn methods.
// initializableConn.Write() is not called until response headers have been read
// and initializableConn#InitializeWrite() has been called with the result.
initializableConn
lock sync.Mutex
headerBuffer []byte
initialized bool
}
// initializableConn is a connection that will be initialized before any calls to Write are made
type initializableConn interface {
net.Conn
InitializeWrite(backendResponse *http.Response) error
}
const maxHeaderBytes = 1 << 20
// Write intercepts to initially swallow the HTTP response, then
// delegate to the tunneling "net.Conn" once the response has been
// seen and processed.
func (h *headerInterceptingConn) Write(b []byte) (int, error) {
h.lock.Lock()
defer h.lock.Unlock()
if h.initialized {
return h.initializableConn.Write(b)
}
// Write into the headerBuffer, then attempt to parse the bytes
// as an http response.
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
return 0, fmt.Errorf("header size limit exceeded")
}
h.headerBuffer = append(h.headerBuffer, b...)
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
resp, err := http.ReadResponse(bufferedReader, nil)
if errors.Is(err, io.ErrUnexpectedEOF) {
// don't yet have a complete set of headers
return len(b), nil
}
if err != nil {
klog.Errorf("invalid headers: %v", err)
return len(b), err
}
resp.Body.Close() //nolint:errcheck
h.headerBuffer = nil
err = h.initializableConn.InitializeWrite(resp)
h.initialized = true
if err != nil {
return len(b), err
}
// Copy any remaining buffered data to the underlying conn
remainingBuffer, _ := io.ReadAll(bufferedReader)
if len(remainingBuffer) > 0 {
_, err = h.initializableConn.Write(remainingBuffer)
}
return len(b), err
}
type tunnelingWebsocketUpgraderConn struct {
// req is the websocket request, used for upgrading
req *http.Request
// w is the websocket writer, used for upgrading and writing error responses
w http.ResponseWriter
// lock guards conn and err
lock sync.RWMutex
// if conn is non-nil, InitializeWrite succeeded
conn net.Conn
// if err is non-nil, InitializeWrite failed or Close was called before InitializeWrite
err error
}
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) {
// make sure we close a connection we open in error cases
var conn net.Conn
defer func() {
if err != nil && conn != nil {
conn.Close() //nolint:errcheck
}
}()
u.lock.Lock()
defer u.lock.Unlock()
if u.conn != nil {
return fmt.Errorf("InitializeWrite already called")
}
if u.err != nil {
return u.err
}
if backendResponse.StatusCode == http.StatusSwitchingProtocols {
connectionHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) {
klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header)
u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response")
http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
return u.err
}
// Translate the server's chosen SPDY protocol into the tunneled websocket protocol for the handshake
var serverWebsocketProtocols []string
if backendSPDYProtocol := strings.TrimSpace(backendResponse.Header.Get(httpstream.HeaderProtocolVersion)); backendSPDYProtocol != "" {
serverWebsocketProtocols = []string{constants.WebsocketsSPDYTunnelingPrefix + backendSPDYProtocol}
} else {
serverWebsocketProtocols = []string{}
}
// Try to upgrade the websocket connection.
// Beyond this point, we don't need to write errors to the response.
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: serverWebsocketProtocols,
}
conn, err := upgrader.Upgrade(u.w, u.req, nil)
if err != nil {
klog.Errorf("error upgrading websocket connection: %v", err)
u.err = err
return u.err
}
klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol())
u.conn = portforward.NewTunnelingConnection("server", conn)
return nil
}
// anything other than an upgrade should pass through the backend response
// try to hijack
conn, _, err = u.w.(http.Hijacker).Hijack()
if err != nil {
klog.Errorf("Unable to hijack response: %v", err)
u.err = err
return u.err
}
// replay the backend response to the hijacked conn
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
err = backendResponse.Write(conn)
if err != nil {
u.err = err
return u.err
}
u.conn = conn
return nil
}
func (u *tunnelingWebsocketUpgraderConn) Read(b []byte) (n int, err error) {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.Read(b)
}
if u.err != nil {
return 0, u.err
}
// return empty read without blocking until we are initialized
return 0, nil
}
func (u *tunnelingWebsocketUpgraderConn) Write(b []byte) (n int, err error) {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.Write(b)
}
if u.err != nil {
return 0, u.err
}
return 0, fmt.Errorf("Write called before Initialize")
}
func (u *tunnelingWebsocketUpgraderConn) Close() error {
u.lock.Lock()
defer u.lock.Unlock()
if u.conn != nil {
return u.conn.Close()
}
if u.err != nil {
return u.err
}
// record that we closed so we don't write again or try to initialize
u.err = fmt.Errorf("connection closed")
// write a response
http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
return nil
}
func (u *tunnelingWebsocketUpgraderConn) LocalAddr() net.Addr {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.LocalAddr()
}
return noopAddr{}
}
func (u *tunnelingWebsocketUpgraderConn) RemoteAddr() net.Addr {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.RemoteAddr()
}
return noopAddr{}
}
func (u *tunnelingWebsocketUpgraderConn) SetDeadline(t time.Time) error {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.SetDeadline(t)
}
return nil
}
func (u *tunnelingWebsocketUpgraderConn) SetReadDeadline(t time.Time) error {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.SetReadDeadline(t)
}
return nil
}
func (u *tunnelingWebsocketUpgraderConn) SetWriteDeadline(t time.Time) error {
u.lock.RLock()
defer u.lock.RUnlock()
if u.conn != nil {
return u.conn.SetWriteDeadline(t)
}
return nil
}
type noopAddr struct{}
func (n noopAddr) Network() string { return "" }
func (n noopAddr) String() string { return "" }

View File

@ -0,0 +1,197 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bytes"
"crypto/rand"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/portforward"
"k8s.io/apimachinery/pkg/util/proxy"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apiserver/pkg/registry/rest"
restconfig "k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
)
func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
// Create fake upstream SPDY server, with channel receiving SPDY streams.
streamChan := make(chan httpstream.Stream)
defer close(streamChan)
stopServerChan := make(chan struct{})
defer close(stopServerChan)
// Create fake upstream SPDY server.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
require.NoError(t, err)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, justQueueStream(streamChan))
require.NotNil(t, conn)
defer conn.Close() //nolint:errcheck
<-stopServerChan
}))
defer spdyServer.Close()
// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then
// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer.
url, err := url.Parse(spdyServer.URL)
require.NoError(t, err)
transport, err := fakeTransport()
require.NoError(t, err)
upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
tunnelingHandler := NewTunnelingHandler(upgradeHandler)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
tunnelingHandler.ServeHTTP(w, req)
}))
defer tunnelingServer.Close()
// Create SPDY client connection containing a TunnelingConnection by upgrading
// a request to TunnelingHandler using new portforward version 2.
tunnelingURL, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
require.NoError(t, err)
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
require.NoError(t, err)
assert.Equal(t, constants.PortForwardV1Name, protocol)
defer spdyClient.Close() //nolint:errcheck
// Create a SPDY client stream, which will queue a SPDY server stream
// on the stream creation channel. Send random data on the client stream
// reading off the SPDY server stream, and validating it was tunneled.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
_, err = rand.Read(randomData)
require.NoError(t, err)
var actual []byte
go func() {
clientStream, err := spdyClient.CreateStream(http.Header{})
require.NoError(t, err)
_, err = io.Copy(clientStream, bytes.NewReader(randomData))
require.NoError(t, err)
clientStream.Close() //nolint:errcheck
}()
select {
case serverStream := <-streamChan:
actual, err = io.ReadAll(serverStream)
require.NoError(t, err)
defer serverStream.Close() //nolint:errcheck
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
}
assert.Equal(t, randomData, actual, "error validating tunneled random data")
}
const responseStr = `HTTP/1.1 101 Switching Protocols
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
`
const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
This is extra data.
`
const invalidResponseStr = `INVALID/1.1 101 Switching Protocols
Date: Sun, 25 Feb 2024 08:09:25 GMT
X-App-Protocol: portforward.k8s.io
`
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}}
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
_, err := hic.Write([]byte(responseStr))
require.NoError(t, err)
assert.True(t, hic.initialized, "successfully parsed http response headers")
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn")
// Extra data after response headers should be sent to net.Conn.
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
_, err = hic.Write([]byte(responseWithExtraStr))
require.NoError(t, err)
assert.True(t, hic.initialized)
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
// Invalid response returns error.
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
_, err = hic.Write([]byte(invalidResponseStr))
assert.Error(t, err, "expected error from invalid http response")
}
type mockConnInitializer struct {
resp *http.Response
*mockConn
}
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error {
m.resp = backendResponse
return nil
}
// mockConn implements "net.Conn" interface.
var _ net.Conn = &mockConn{}
type mockConn struct {
written []byte
}
func (mc *mockConn) Write(p []byte) (int, error) {
mc.written = make([]byte, len(p))
copy(mc.written, p)
return len(mc.written), nil
}
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
func (mc *mockConn) Close() error { return nil }
func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
// fakeResponder implements "rest.Responder" interface.
var _ rest.Responder = &fakeResponder{}
type fakeResponder struct{}
func (fr *fakeResponder) Object(statusCode int, obj runtime.Object) {}
func (fr *fakeResponder) Error(err error) {}
// justQueueStream skips the usual stream validation before
// queueing the stream on the stream channel.
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
streams <- stream
return nil
}
}

View File

@ -0,0 +1,57 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/klog/v2"
)
var _ httpstream.Dialer = &fallbackDialer{}
// fallbackDialer encapsulates a primary and secondary dialer, including
// the boolean function to determine if the primary dialer failed. Implements
// the httpstream.Dialer interface.
type fallbackDialer struct {
primary httpstream.Dialer
secondary httpstream.Dialer
shouldFallback func(error) bool
}
// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers,
// as well as the boolean function to determine if the primary dialer failed.
func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer {
return &fallbackDialer{
primary: primary,
secondary: secondary,
shouldFallback: shouldFallback,
}
}
// Dial is the single function necessary to implement the "httpstream.Dialer" interface.
// It takes the protocol version strings to request, returning an the upgraded
// httstream.Connection and the negotiated protocol version accepted. If the initial
// primary dialer fails, this function attempts the secondary dialer. Returns an error
// if one occurs.
func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
conn, version, err := f.primary.Dial(protocols...)
if err != nil && f.shouldFallback(err) {
klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
return f.secondary.Dial(protocols...)
}
return conn, version, err
}

View File

@ -0,0 +1,53 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFallbackDialer(t *testing.T) {
protocol := "v6.fake.k8s.io"
// If "shouldFallback" is false, then only primary should be dialed.
primary := &fakeDialer{dialed: false}
primary.negotiatedProtocol = protocol
secondary := &fakeDialer{dialed: false}
fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse)
_, negotiated, err := fallbackDialer.Dial(protocol)
assert.True(t, primary.dialed, "no fallback; primary should have dialed")
assert.Equal(t, protocol, negotiated, "")
assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
assert.Nil(t, err, "error should be nil")
// If "shouldFallback" is true, then primary AND secondary should be dialed.
primary.dialed = false // reset dialed field
primary.err = fmt.Errorf("bad handshake")
secondary.dialed = false // reset dialed field
secondary.negotiatedProtocol = protocol
fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue)
_, negotiated, err = fallbackDialer.Dial(protocol)
assert.True(t, primary.dialed, "fallback; primary should have dialed (first)")
assert.True(t, secondary.dialed, "fallback; secondary should have dialed")
assert.Equal(t, protocol, negotiated)
assert.Nil(t, err)
}
func alwaysTrue(err error) bool { return true }
func alwaysFalse(err error) bool { return false }

View File

@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
var protocol string
pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
if protocol != PortForwardProtocolV1Name {
return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
}
return pf.forward()
}

View File

@ -431,6 +431,7 @@ func TestGetListener(t *testing.T) {
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
dialer := &fakeDialer{
conn: newFakeConnection(),
negotiatedProtocol: PortForwardProtocolV1Name,
}
stopChan := make(chan struct{})
@ -571,6 +572,7 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
dialer := &fakeDialer{
conn: newFakeConnection(),
negotiatedProtocol: PortForwardProtocolV1Name,
}
stopChan := make(chan struct{})
@ -602,6 +604,7 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
dialer := &fakeDialer{
conn: newFakeConnection(),
negotiatedProtocol: PortForwardProtocolV1Name,
}
stopChan := make(chan struct{})

View File

@ -0,0 +1,153 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"io"
"net"
"sync"
"time"
gwebsocket "github.com/gorilla/websocket"
"k8s.io/klog/v2"
)
var _ net.Conn = &TunnelingConnection{}
// TunnelingConnection implements the "httpstream.Connection" interface, wrapping
// a websocket connection that tunnels SPDY.
type TunnelingConnection struct {
name string
conn *gwebsocket.Conn
inProgressMessage io.Reader
closeOnce sync.Once
}
// NewTunnelingConnection wraps the passed gorilla/websockets connection
// with the TunnelingConnection struct (implementing net.Conn).
func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
return &TunnelingConnection{
name: name,
conn: conn,
}
}
// Read implements "io.Reader" interface, reading from the stored connection
// into the passed buffer "p". Returns the number of bytes read and an error.
// Can keep track of the "inProgress" messsage from the tunneled connection.
func (c *TunnelingConnection) Read(p []byte) (int, error) {
klog.V(7).Infof("%s: tunneling connection read...", c.name)
defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name)
for {
if c.inProgressMessage == nil {
klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name)
messageType, nextReader, err := c.conn.NextReader()
if err != nil {
closeError := &gwebsocket.CloseError{}
if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
return 0, io.EOF
}
klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err)
return 0, err
}
if messageType != gwebsocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type received")
}
c.inProgressMessage = nextReader
}
klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name)
i, err := c.inProgressMessage.Read(p)
if i == 0 && err == io.EOF {
c.inProgressMessage = nil
} else {
klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
return i, err
}
}
}
// Write implements "io.Writer" interface, copying the data in the passed
// byte array "p" into the stored tunneled connection. Returns the number
// of bytes written and an error.
func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p)
defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name)
w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
if err != nil {
return 0, err
}
defer func() {
// close, which flushes the message
closeErr := w.Close()
if closeErr != nil && err == nil {
// if closing/flushing errored and we weren't already returning an error, return the close error
err = closeErr
}
}()
n, err = w.Write(p)
return
}
// Close implements "io.Closer" interface, signaling the other tunneled connection
// endpoint, and closing the tunneled connection only once.
func (c *TunnelingConnection) Close() error {
var err error
c.closeOnce.Do(func() {
klog.V(7).Infof("%s: tunneling connection Close()...", c.name)
// Signal other endpoint that websocket connection is closing; ignore error.
normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck
err = c.conn.Close()
})
return err
}
// LocalAddr implements part of the "net.Conn" interface, returning the local
// endpoint network address of the tunneled connection.
func (c *TunnelingConnection) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// LocalAddr implements part of the "net.Conn" interface, returning the remote
// endpoint network address of the tunneled connection.
func (c *TunnelingConnection) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the *absolute* time in the future for both
// read and write deadlines. Returns an error if one occurs.
func (c *TunnelingConnection) SetDeadline(t time.Time) error {
rerr := c.SetReadDeadline(t)
werr := c.SetWriteDeadline(t)
return errors.Join(rerr, werr)
}
// SetDeadline sets the *absolute* time in the future for the
// read deadlines. Returns an error if one occurs.
func (c *TunnelingConnection) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetDeadline sets the *absolute* time in the future for the
// write deadlines. Returns an error if one occurs.
func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@ -0,0 +1,190 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
gwebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/portforward"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
"k8s.io/client-go/transport/websocket"
)
func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
// Stream channel that will receive streams created on upstream SPDY server.
streamChan := make(chan httpstream.Stream)
defer close(streamChan)
stopServerChan := make(chan struct{})
defer close(stopServerChan)
// Create tunneling connection server endpoint with fake upstream SPDY server.
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
}
conn, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
defer conn.Close() //nolint:errcheck
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
tunnelingConn := NewTunnelingConnection("server", conn)
spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
require.NoError(t, err)
defer spdyConn.Close() //nolint:errcheck
<-stopServerChan
}))
defer tunnelingServer.Close()
// Dial the client tunneling connection to the tunneling server.
url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
require.NoError(t, err)
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
require.NoError(t, err)
assert.Equal(t, constants.PortForwardV1Name, protocol)
defer spdyClient.Close() //nolint:errcheck
// Create a SPDY client stream, which will queue a SPDY server stream
// on the stream creation channel. Send data on the client stream
// reading off the SPDY server stream, and validating it was tunneled.
expected := "This is a test tunneling SPDY data through websockets."
var actual []byte
go func() {
clientStream, err := spdyClient.CreateStream(http.Header{})
require.NoError(t, err)
_, err = io.Copy(clientStream, strings.NewReader(expected))
require.NoError(t, err)
clientStream.Close() //nolint:errcheck
}()
select {
case serverStream := <-streamChan:
actual, err = io.ReadAll(serverStream)
require.NoError(t, err)
defer serverStream.Close() //nolint:errcheck
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
}
assert.Equal(t, expected, string(actual), "error validating tunneled string")
}
func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
stopServerChan := make(chan struct{})
defer close(stopServerChan)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
}
conn, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
defer conn.Close() //nolint:errcheck
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
<-stopServerChan
}))
defer tunnelingServer.Close()
// Create the client side tunneling connection.
url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
tConn, err := dialForTunnelingConnection(url)
require.NoError(t, err, "error creating client tunneling connection")
defer tConn.Close() //nolint:errcheck
// Validate "LocalAddr()" and "RemoteAddr()"
localAddr := tConn.LocalAddr()
remoteAddr := tConn.RemoteAddr()
assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
_, err = net.ResolveTCPAddr("tcp", localAddr.String())
assert.NoError(t, err, "tunneling connection local addr should parse")
_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
assert.NoError(t, err, "tunneling connection remote addr should parse")
}
func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
stopServerChan := make(chan struct{})
defer close(stopServerChan)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
}
conn, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
defer conn.Close() //nolint:errcheck
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
<-stopServerChan
}))
defer tunnelingServer.Close()
// Create the client side tunneling connection.
url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
tConn, err := dialForTunnelingConnection(url)
require.NoError(t, err, "error creating client tunneling connection")
defer tConn.Close() //nolint:errcheck
// Validate the read and write deadlines.
err = tConn.SetReadDeadline(time.Time{})
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
err = tConn.SetWriteDeadline(time.Time{})
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
err = tConn.SetDeadline(time.Time{})
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
}
// dialForTunnelingConnection upgrades a request at the passed "url", creating
// a websocket connection. Returns the TunnelingConnection injected with the
// websocket connection or an error if one occurs.
func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return nil, err
}
// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
if err != nil {
return nil, err
}
conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
if err != nil {
return nil, err
}
return NewTunnelingConnection("client", conn), nil
}
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
streams <- stream
return nil
}
}

View File

@ -0,0 +1,93 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/portforward"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/transport/websocket"
"k8s.io/klog/v2"
)
const PingPeriod = 10 * time.Second
// tunnelingDialer implements "httpstream.Dial" interface
type tunnelingDialer struct {
url *url.URL
transport http.RoundTripper
holder websocket.ConnectionHolder
}
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
// returns an error if one occurs.
func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
transport, holder, err := websocket.RoundTripperFor(config)
if err != nil {
return nil, err
}
return &tunnelingDialer{
url: url,
transport: transport,
holder: holder,
}, nil
}
// Dial upgrades to a tunneling streaming connection, returning a SPDY connection
// containing a WebSockets connection (which implements "net.Conn"). Also
// returns the protocol negotiated, or an error.
func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
// There is no passed context, so skip the context when creating request for now.
// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
req, err := http.NewRequest("GET", d.url.String(), nil)
if err != nil {
return nil, "", err
}
// Add the spdy tunneling prefix to the requested protocols. The tunneling
// handler will know how to negotiate these protocols.
tunnelingProtocols := []string{}
for _, protocol := range protocols {
tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
}
klog.V(4).Infoln("Before WebSocket Upgrade Connection...")
conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
if err != nil {
return nil, "", err
}
if conn == nil {
return nil, "", fmt.Errorf("negotiated websocket connection is nil")
}
protocol := conn.Subprotocol()
protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
klog.V(4).Infof("negotiated protocol: %s", protocol)
// Wrap the websocket connection which implements "net.Conn".
tConn := NewTunnelingConnection("client", conn)
// Create SPDY connection injecting the previously created tunneling connection.
spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)
return spdyConn, protocol, err
}

View File

@ -31,6 +31,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/cli-runtime/pkg/genericiooptions"
"k8s.io/client-go/kubernetes/scheme"
@ -50,7 +51,7 @@ import (
type PortForwardOptions struct {
Namespace string
PodName string
RESTClient *restclient.RESTClient
RESTClient restclient.Interface
Config *restclient.Config
PodClient corev1client.PodsGetter
Address []string
@ -99,11 +100,7 @@ const (
)
func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command {
opts := &PortForwardOptions{
PortForwarder: &defaultPortForwarder{
IOStreams: streams,
},
}
opts := NewDefaultPortForwardOptions(streams)
cmd := &cobra.Command{
Use: "port-forward TYPE/NAME [options] [LOCAL_PORT:]REMOTE_PORT [...[LOCAL_PORT_N:]REMOTE_PORT_N]",
DisableFlagsInUseLine: true,
@ -123,6 +120,14 @@ func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *c
return cmd
}
func NewDefaultPortForwardOptions(streams genericiooptions.IOStreams) *PortForwardOptions {
return &PortForwardOptions{
PortForwarder: &defaultPortForwarder{
IOStreams: streams,
},
}
}
type portForwarder interface {
ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error
}
@ -137,6 +142,14 @@ func (f *defaultPortForwarder) ForwardPorts(method string, url *url.URL, opts Po
return err
}
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url)
if cmdutil.PortForwardWebsockets.IsEnabled() {
tunnelingDialer, err := portforward.NewSPDYOverWebsocketDialer(url, opts.Config)
if err != nil {
return err
}
// First attempt tunneling (websocket) dialer, then fallback to spdy dialer.
dialer = portforward.NewFallbackDialer(tunnelingDialer, dialer, httpstream.IsUpgradeFailure)
}
fw, err := portforward.NewOnAddresses(dialer, opts.Address, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.Out, f.ErrOut)
if err != nil {
return err
@ -387,7 +400,14 @@ func (o PortForwardOptions) Validate() error {
// RunPortForward implements all the necessary functionality for port-forward cmd.
func (o PortForwardOptions) RunPortForward() error {
pod, err := o.PodClient.Pods(o.Namespace).Get(context.TODO(), o.PodName, metav1.GetOptions{})
return o.RunPortForwardContext(context.Background())
}
// RunPortForwardContext implements all the necessary functionality for port-forward cmd.
// It ends portforwarding when an error is received from the backend, or an os.Interrupt
// signal is received, or the provided context is done.
func (o PortForwardOptions) RunPortForwardContext(ctx context.Context) error {
pod, err := o.PodClient.Pods(o.Namespace).Get(ctx, o.PodName, metav1.GetOptions{})
if err != nil {
return err
}
@ -401,7 +421,10 @@ func (o PortForwardOptions) RunPortForward() error {
defer signal.Stop(signals)
go func() {
<-signals
select {
case <-signals:
case <-ctx.Done():
}
if o.StopChannel != nil {
close(o.StopChannel)
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package portforward
import (
"context"
"fmt"
"net/http"
"net/url"
@ -101,6 +102,8 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) {
}
opts := &PortForwardOptions{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cmd := NewCmdPortForward(tf, genericiooptions.NewTestIOStreamsDiscard())
cmd.Run = func(cmd *cobra.Command, args []string) {
if err = opts.Complete(tf, cmd, args); err != nil {
@ -110,7 +113,7 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) {
if err = opts.Validate(); err != nil {
return
}
err = opts.RunPortForward()
err = opts.RunPortForwardContext(ctx)
}
for name, value := range flags {

View File

@ -430,6 +430,7 @@ const (
InteractiveDelete FeatureGate = "KUBECTL_INTERACTIVE_DELETE"
OpenAPIV3Patch FeatureGate = "KUBECTL_OPENAPIV3_PATCH"
RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS"
PortForwardWebsockets FeatureGate = "KUBECTL_PORT_FORWARD_WEBSOCKETS"
)
// IsEnabled returns true iff environment variable is set to true.

View File

@ -0,0 +1,27 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"testing"
"k8s.io/kubernetes/test/integration/framework"
)
func TestMain(m *testing.M) {
framework.EtcdMain(m.Run)
}

View File

@ -0,0 +1,228 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/wait"
utilfeature "k8s.io/apiserver/pkg/util/feature"
"k8s.io/cli-runtime/pkg/genericiooptions"
"k8s.io/client-go/kubernetes"
featuregatetesting "k8s.io/component-base/featuregate/testing"
"k8s.io/kubectl/pkg/cmd/portforward"
kubeletportforward "k8s.io/kubelet/pkg/cri/streaming/portforward"
kastesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing"
kubefeatures "k8s.io/kubernetes/pkg/features"
"k8s.io/kubernetes/test/integration/framework"
)
const remotePort = "8765"
func TestPortforward(t *testing.T) {
defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, kubefeatures.PortForwardWebsockets, true)()
t.Setenv("KUBECTL_PORT_FORWARD_WEBSOCKETS", "true")
var podName string
var podUID types.UID
backendServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
t.Logf("backend saw request: %v", req.URL.String())
kubeletportforward.ServePortForward(
w,
req,
&dummyPortForwarder{t: t},
podName,
podUID,
&kubeletportforward.V4Options{},
wait.ForeverTestTimeout, // idle timeout
remotecommand.DefaultStreamCreationTimeout, // stream creation timeout
[]string{kubeletportforward.ProtocolV1Name},
)
}))
defer backendServer.Close()
backendURL, _ := url.Parse(backendServer.URL)
backendHost := backendURL.Hostname()
backendPort, _ := strconv.Atoi(backendURL.Port())
etcd := framework.SharedEtcd()
server := kastesting.StartTestServerOrDie(t, nil, []string{"--disable-admission-plugins=ServiceAccount"}, etcd)
defer server.TearDownFn()
adminClient, err := kubernetes.NewForConfig(server.ClientConfig)
require.NoError(t, err)
node := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{Name: "mynode"},
Status: corev1.NodeStatus{
DaemonEndpoints: corev1.NodeDaemonEndpoints{KubeletEndpoint: corev1.DaemonEndpoint{Port: int32(backendPort)}},
Addresses: []corev1.NodeAddress{{Type: corev1.NodeInternalIP, Address: backendHost}},
},
}
if _, err := adminClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}); err != nil {
t.Fatal(err)
}
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Namespace: "default", Name: "mypod"},
Spec: corev1.PodSpec{
NodeName: "mynode",
Containers: []corev1.Container{{Name: "test", Image: "test"}},
},
}
if _, err := adminClient.CoreV1().Pods("default").Create(context.Background(), pod, metav1.CreateOptions{}); err != nil {
t.Fatal(err)
}
if _, err := adminClient.CoreV1().Pods("default").Patch(context.Background(), "mypod", types.MergePatchType, []byte(`{"status":{"phase":"Running"}}`), metav1.PatchOptions{}, "status"); err != nil {
t.Fatal(err)
}
// local port missing asks os to find random open port.
// Example: ":8000" (local = random, remote = 8000)
localRemotePort := fmt.Sprintf(":%s", remotePort)
streams, _, out, errOut := genericiooptions.NewTestIOStreams()
portForwardOptions := portforward.NewDefaultPortForwardOptions(streams)
portForwardOptions.Namespace = "default"
portForwardOptions.PodName = "mypod"
portForwardOptions.RESTClient = adminClient.CoreV1().RESTClient()
portForwardOptions.Config = server.ClientConfig
portForwardOptions.PodClient = adminClient.CoreV1()
portForwardOptions.Address = []string{"127.0.0.1"}
portForwardOptions.Ports = []string{localRemotePort}
portForwardOptions.StopChannel = make(chan struct{}, 1)
portForwardOptions.ReadyChannel = make(chan struct{})
if err := portForwardOptions.Validate(); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
if err := portForwardOptions.RunPortForwardContext(ctx); err != nil {
t.Error(err)
}
}()
t.Log("waiting for port forward to be ready")
select {
case <-portForwardOptions.ReadyChannel:
t.Log("port forward was ready")
case <-time.After(wait.ForeverTestTimeout):
t.Error("port forward was never ready")
}
// Parse out the randomly selected local port from "out" stream.
localPort, err := parsePort(out.String())
require.NoError(t, err)
t.Logf("Local Port: %s", localPort)
timeoutContext, cleanupTimeoutContext := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cleanupTimeoutContext()
testReq, _ := http.NewRequest("GET", fmt.Sprintf("http://127.0.0.1:%s/test", localPort), nil)
testReq = testReq.WithContext(timeoutContext)
testResp, err := http.DefaultClient.Do(testReq)
if err != nil {
t.Error(err)
} else {
t.Log(testResp.StatusCode)
data, err := io.ReadAll(testResp.Body)
if err != nil {
t.Error(err)
} else {
t.Log("client saw response:", string(data))
}
if string(data) != fmt.Sprintf("request to %s was ok", remotePort) {
t.Errorf("unexpected data")
}
if testResp.StatusCode != 200 {
t.Error("expected success")
}
}
cancel()
wg.Wait()
t.Logf("stdout: %s", out.String())
t.Logf("stderr: %s", errOut.String())
}
// parsePort parses out the local port from the port-forward output string.
// This should work for both IP4 and IP6 addresses.
//
// Example: "Forwarding from 127.0.0.1:8000 -> 4000", returns "8000".
func parsePort(forwardAddr string) (string, error) {
parts := strings.Split(forwardAddr, " ")
if len(parts) != 5 {
return "", fmt.Errorf("unable to parse local port from stdout: %s", forwardAddr)
}
// parts[2] = "127.0.0.1:<LOCAL_PORT>"
_, localPort, err := net.SplitHostPort(parts[2])
if err != nil {
return "", fmt.Errorf("unable to parse local port: %w", err)
}
return localPort, nil
}
type dummyPortForwarder struct {
t *testing.T
}
func (d *dummyPortForwarder) PortForward(ctx context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
d.t.Logf("handling port forward request for %d", port)
req, err := http.ReadRequest(bufio.NewReader(stream))
if err != nil {
d.t.Logf("error reading request: %v", err)
return err
}
d.t.Log(req.URL.String())
defer req.Body.Close() //nolint:errcheck
resp := &http.Response{
StatusCode: 200,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(bytes.NewBufferString(fmt.Sprintf("request to %d was ok", port))),
}
resp.Write(stream) //nolint:errcheck
return stream.Close()
}