From 8b447d8c97e8823b4308eb91cf7d75693e867c61 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Wed, 21 Feb 2024 08:56:07 +0000 Subject: [PATCH] portforward: tunnel spdy through websockets --- pkg/features/kube_features.go | 9 + pkg/registry/core/pod/rest/subresources.go | 7 +- .../pkg/util/httpstream/wsstream/conn.go | 24 + .../pkg/util/portforward/constants.go | 24 + .../pkg/util/proxy/upgradeaware.go | 22 +- staging/src/k8s.io/apiserver/go.mod | 2 +- .../apiserver/pkg/util/proxy/streamtunnel.go | 433 ++++++++++++++++++ .../pkg/util/proxy/streamtunnel_test.go | 197 ++++++++ .../tools/portforward/fallback_dialer.go | 57 +++ .../tools/portforward/fallback_dialer_test.go | 53 +++ .../tools/portforward/portforward.go | 6 +- .../tools/portforward/portforward_test.go | 9 +- .../tools/portforward/tunneling_connection.go | 153 +++++++ .../portforward/tunneling_connection_test.go | 190 ++++++++ .../tools/portforward/tunneling_dialer.go | 93 ++++ .../pkg/cmd/portforward/portforward.go | 39 +- .../pkg/cmd/portforward/portforward_test.go | 5 +- .../k8s.io/kubectl/pkg/cmd/util/helpers.go | 1 + .../apiserver/portforward/main_test.go | 27 ++ .../apiserver/portforward/portforward_test.go | 228 +++++++++ 20 files changed, 1560 insertions(+), 19 deletions(-) create mode 100644 staging/src/k8s.io/apimachinery/pkg/util/portforward/constants.go create mode 100644 staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go create mode 100644 staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go create mode 100644 staging/src/k8s.io/client-go/tools/portforward/fallback_dialer.go create mode 100644 staging/src/k8s.io/client-go/tools/portforward/fallback_dialer_test.go create mode 100644 staging/src/k8s.io/client-go/tools/portforward/tunneling_connection.go create mode 100644 staging/src/k8s.io/client-go/tools/portforward/tunneling_connection_test.go create mode 100644 staging/src/k8s.io/client-go/tools/portforward/tunneling_dialer.go create mode 100644 test/integration/apiserver/portforward/main_test.go create mode 100644 test/integration/apiserver/portforward/portforward_test.go diff --git a/pkg/features/kube_features.go b/pkg/features/kube_features.go index d172264efe5..b3ec278265f 100644 --- a/pkg/features/kube_features.go +++ b/pkg/features/kube_features.go @@ -619,6 +619,13 @@ const ( // Enable users to specify when a Pod is ready for scheduling. PodSchedulingReadiness featuregate.Feature = "PodSchedulingReadiness" + // owner: @seans3 + // kep: http://kep.k8s.io/4006 + // alpha: v1.30 + // + // Enables PortForward to be proxied with a websocket client + PortForwardWebsockets featuregate.Feature = "PortForwardWebsockets" + // owner: @jessfraz // alpha: v1.12 // @@ -1101,6 +1108,8 @@ var defaultKubernetesFeatureGates = map[featuregate.Feature]featuregate.FeatureS PodSchedulingReadiness: {Default: true, PreRelease: featuregate.GA, LockToDefault: true}, // GA in 1.30; remove in 1.32 + PortForwardWebsockets: {Default: false, PreRelease: featuregate.Alpha}, + ProcMountType: {Default: false, PreRelease: featuregate.Alpha}, QOSReserved: {Default: false, PreRelease: featuregate.Alpha}, diff --git a/pkg/registry/core/pod/rest/subresources.go b/pkg/registry/core/pod/rest/subresources.go index 0e031412fbd..87787678271 100644 --- a/pkg/registry/core/pod/rest/subresources.go +++ b/pkg/registry/core/pod/rest/subresources.go @@ -242,7 +242,12 @@ func (r *PortForwardREST) Connect(ctx context.Context, name string, opts runtime if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil + handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder) + if utilfeature.DefaultFeatureGate.Enabled(features.PortForwardWebsockets) { + tunnelingHandler := translator.NewTunnelingHandler(handler) + handler = translator.NewTranslatingHandler(handler, tunnelingHandler, wsstream.IsWebSocketRequestWithTunnelingProtocol) + } + return handler, nil } func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go index 8a741936a3d..2e477fee2ae 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go @@ -27,6 +27,7 @@ import ( "golang.org/x/net/websocket" "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/portforward" "k8s.io/apimachinery/pkg/util/remotecommand" "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/klog/v2" @@ -106,6 +107,23 @@ func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool { return false } +// IsWebSocketRequestWithTunnelingProtocol returns true if the request contains headers +// identifying that it is requesting a websocket upgrade with a tunneling protocol; +// false otherwise. +func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool { + if !IsWebSocketRequest(req) { + return false + } + requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader)) + for _, requestedProtocol := range strings.Split(requestedProtocols, ",") { + if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) { + return true + } + } + + return false +} + // IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the // read and write deadlines are pushed every time a new message is received. func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) { @@ -301,6 +319,12 @@ func protocolSupportsStreamClose(protocol string) bool { return protocol == remotecommand.StreamProtocolV5Name } +// protocolSupportsWebsocketTunneling returns true if the passed protocol +// is a tunneled Kubernetes spdy protocol; false otherwise. +func protocolSupportsWebsocketTunneling(protocol string) bool { + return strings.HasPrefix(protocol, portforward.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, portforward.KubernetesSuffix) +} + // handle implements a websocket handler. func (conn *Conn) handle(ws *websocket.Conn) { conn.initialize(ws) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/portforward/constants.go b/staging/src/k8s.io/apimachinery/pkg/util/portforward/constants.go new file mode 100644 index 00000000000..68532881565 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/portforward/constants.go @@ -0,0 +1,24 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +const ( + PortForwardV1Name = "portforward.k8s.io" + WebsocketsSPDYTunnelingPrefix = "SPDY/3.1+" + KubernetesSuffix = ".k8s.io" + WebsocketsSPDYTunnelingPortForwardV1 = WebsocketsSPDYTunnelingPrefix + PortForwardV1Name +) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go b/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go index 76acdfb4aca..1fdae735af5 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go @@ -36,6 +36,7 @@ import ( utilruntime "k8s.io/apimachinery/pkg/util/runtime" "github.com/mxk/go-flowrate/flowrate" + "k8s.io/klog/v2" ) @@ -336,6 +337,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques clone.Host = h.Location.Host } clone.URL = &location + klog.V(6).Infof("UpgradeAwareProxy: dialing for SPDY upgrade with headers: %v", clone.Header) backendConn, err = h.DialForUpgrade(clone) if err != nil { klog.V(6).Infof("Proxy connection error: %v", err) @@ -370,13 +372,13 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques // hijacking should be the last step in the upgrade. requestHijacker, ok := w.(http.Hijacker) if !ok { - klog.V(6).Infof("Unable to hijack response writer: %T", w) + klog.Errorf("Unable to hijack response writer: %T", w) h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w)) return true } requestHijackedConn, _, err := requestHijacker.Hijack() if err != nil { - klog.V(6).Infof("Unable to hijack response: %v", err) + klog.Errorf("Unable to hijack response: %v", err) h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err)) return true } @@ -420,7 +422,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques } else { writer = backendConn } - _, err := io.Copy(writer, requestHijackedConn) + _, err := io.Copy(writer, &loggingReader{name: "client->backend", delegate: requestHijackedConn}) if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { klog.Errorf("Error proxying data from client to backend: %v", err) } @@ -434,7 +436,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques } else { reader = backendConn } - _, err := io.Copy(requestHijackedConn, reader) + _, err := io.Copy(requestHijackedConn, &loggingReader{name: "backend->client", delegate: reader}) if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { klog.Errorf("Error proxying data from backend to client: %v", err) } @@ -452,6 +454,18 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques return true } +// loggingReader logs the bytes read from the "delegate" with a "name" prefix. +type loggingReader struct { + name string + delegate io.Reader +} + +func (l *loggingReader) Read(p []byte) (int, error) { + n, err := l.delegate.Read(p) + klog.V(8).Infof("%s: %d bytes, err=%v, bytes=% X", l.name, n, err, p[:n]) + return n, err +} + // FIXME: Taken from net/http/httputil/reverseproxy.go as singleJoiningSlash is not exported to be re-used. // See-also: https://github.com/golang/go/issues/44290 func singleJoiningSlash(a, b string) string { diff --git a/staging/src/k8s.io/apiserver/go.mod b/staging/src/k8s.io/apiserver/go.mod index 43cfcc7d952..96d53939be7 100644 --- a/staging/src/k8s.io/apiserver/go.mod +++ b/staging/src/k8s.io/apiserver/go.mod @@ -16,6 +16,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/gofuzz v1.2.0 github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f @@ -77,7 +78,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/btree v1.0.1 // indirect - github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go new file mode 100644 index 00000000000..c38a2ad604b --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go @@ -0,0 +1,433 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + gwebsocket "github.com/gorilla/websocket" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" + utilnet "k8s.io/apimachinery/pkg/util/net" + constants "k8s.io/apimachinery/pkg/util/portforward" + "k8s.io/client-go/tools/portforward" + "k8s.io/klog/v2" +) + +// TunnelingHandler is a handler which tunnels SPDY through WebSockets. +type TunnelingHandler struct { + // Used to communicate between upstream SPDY and downstream tunnel. + upgradeHandler http.Handler +} + +// NewTunnelingHandler is used to create the tunnel between an upstream +// SPDY connection and a downstream tunneling connection through the stored +// UpgradeAwareProxy. +func NewTunnelingHandler(upgradeHandler http.Handler) *TunnelingHandler { + return &TunnelingHandler{upgradeHandler: upgradeHandler} +} + +// ServeHTTP uses the upgradeHandler to tunnel between a downstream tunneling +// connection and an upstream SPDY connection. The tunneling connection is +// a wrapped WebSockets connection which communicates SPDY framed data. +func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + klog.V(4).Infoln("TunnelingHandler ServeHTTP") + + spdyProtocols := spdyProtocolsFromWebsocketProtocols(req) + if len(spdyProtocols) == 0 { + http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest) + return + } + + spdyRequest := createSPDYRequest(req, spdyProtocols...) + + writer := &tunnelingResponseWriter{ + w: w, + conn: &headerInterceptingConn{ + initializableConn: &tunnelingWebsocketUpgraderConn{ + w: w, + req: req, + }, + }, + } + + klog.V(4).Infoln("Tunnel spdy through websockets using the UpgradeAwareProxy") + h.upgradeHandler.ServeHTTP(writer, spdyRequest) +} + +// createSPDYRequest modifies the passed request to remove +// WebSockets headers and add SPDY upgrade information, including +// spdy protocols acceptable to the client. +func createSPDYRequest(req *http.Request, spdyProtocols ...string) *http.Request { + clone := utilnet.CloneRequest(req) + // Clean up the websocket headers from the http request. + clone.Header.Del(wsstream.WebSocketProtocolHeader) + clone.Header.Del("Sec-Websocket-Key") + clone.Header.Del("Sec-Websocket-Version") + clone.Header.Del(httpstream.HeaderUpgrade) + // Update the http request for an upstream SPDY upgrade. + clone.Method = "POST" + clone.Body = nil // Remove the request body which is unused. + clone.Header.Set(httpstream.HeaderUpgrade, spdy.HeaderSpdy31) + clone.Header.Del(httpstream.HeaderProtocolVersion) + for i := range spdyProtocols { + clone.Header.Add(httpstream.HeaderProtocolVersion, spdyProtocols[i]) + } + return clone +} + +// spdyProtocolsFromWebsocketProtocols returns a list of spdy protocols by filtering +// to Kubernetes websocket subprotocols prefixed with "SPDY/3.1+", then removing the prefix +func spdyProtocolsFromWebsocketProtocols(req *http.Request) []string { + var spdyProtocols []string + for _, protocol := range gwebsocket.Subprotocols(req) { + if strings.HasPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, constants.KubernetesSuffix) { + spdyProtocols = append(spdyProtocols, strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)) + } + } + return spdyProtocols +} + +var _ http.ResponseWriter = &tunnelingResponseWriter{} +var _ http.Hijacker = &tunnelingResponseWriter{} + +// tunnelingResponseWriter implements the http.ResponseWriter and http.Hijacker interfaces. +// Only non-upgrade responses can be written using WriteHeader() and Write(). +// Once Write or WriteHeader is called, Hijack returns an error. +// Once Hijack is called, Write, WriteHeader, and Hijack return errors. +type tunnelingResponseWriter struct { + // w is used to delegate Header(), WriteHeader(), and Write() calls + w http.ResponseWriter + // conn is returned from Hijack() + conn net.Conn + // mu guards writes + mu sync.Mutex + // wrote tracks whether WriteHeader or Write has been called + written bool + // hijacked tracks whether Hijack has been called + hijacked bool +} + +// Hijack returns a delegate "net.Conn". +// An error is returned if Write(), WriteHeader(), or Hijack() was previously called. +// The returned bufio.ReadWriter is always nil. +func (w *tunnelingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + w.mu.Lock() + defer w.mu.Unlock() + if w.written { + klog.Errorf("Hijack called after write") + return nil, nil, errors.New("connection has already been written to") + } + if w.hijacked { + klog.Errorf("Hijack called after hijack") + return nil, nil, errors.New("connection has already been hijacked") + } + w.hijacked = true + klog.V(6).Infof("Hijack returning websocket tunneling net.Conn") + return w.conn, nil, nil +} + +// Header is delegated to the stored "http.ResponseWriter". +func (w *tunnelingResponseWriter) Header() http.Header { + return w.w.Header() +} + +// Write is delegated to the stored "http.ResponseWriter". +func (w *tunnelingResponseWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + if w.hijacked { + klog.Errorf("Write called after hijack") + return 0, http.ErrHijacked + } + w.written = true + return w.w.Write(p) +} + +// WriteHeader is delegated to the stored "http.ResponseWriter". +func (w *tunnelingResponseWriter) WriteHeader(statusCode int) { + w.mu.Lock() + defer w.mu.Unlock() + if w.written { + klog.Errorf("WriteHeader called after write") + return + } + if w.hijacked { + klog.Errorf("WriteHeader called after hijack") + return + } + w.written = true + + if statusCode == http.StatusSwitchingProtocols { + // 101 upgrade responses must come via the hijacked connection, not WriteHeader + klog.Errorf("WriteHeader called with 101 upgrade") + http.Error(w.w, "unexpected upgrade", http.StatusInternalServerError) + return + } + + // pass through non-upgrade responses we don't need to translate + w.w.WriteHeader(statusCode) +} + +// headerInterceptingConn wraps the tunneling "net.Conn" to drain the +// HTTP response status/headers from the upstream SPDY connection, then use +// that to decide how to initialize the delegate connection for writes. +type headerInterceptingConn struct { + // initializableConn is delegated to for all net.Conn methods. + // initializableConn.Write() is not called until response headers have been read + // and initializableConn#InitializeWrite() has been called with the result. + initializableConn + + lock sync.Mutex + headerBuffer []byte + initialized bool +} + +// initializableConn is a connection that will be initialized before any calls to Write are made +type initializableConn interface { + net.Conn + InitializeWrite(backendResponse *http.Response) error +} + +const maxHeaderBytes = 1 << 20 + +// Write intercepts to initially swallow the HTTP response, then +// delegate to the tunneling "net.Conn" once the response has been +// seen and processed. +func (h *headerInterceptingConn) Write(b []byte) (int, error) { + h.lock.Lock() + defer h.lock.Unlock() + + if h.initialized { + return h.initializableConn.Write(b) + } + + // Write into the headerBuffer, then attempt to parse the bytes + // as an http response. + if len(h.headerBuffer)+len(b) > maxHeaderBytes { + return 0, fmt.Errorf("header size limit exceeded") + } + h.headerBuffer = append(h.headerBuffer, b...) + bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer)) + resp, err := http.ReadResponse(bufferedReader, nil) + if errors.Is(err, io.ErrUnexpectedEOF) { + // don't yet have a complete set of headers + return len(b), nil + } + if err != nil { + klog.Errorf("invalid headers: %v", err) + return len(b), err + } + resp.Body.Close() //nolint:errcheck + + h.headerBuffer = nil + err = h.initializableConn.InitializeWrite(resp) + h.initialized = true + if err != nil { + return len(b), err + } + + // Copy any remaining buffered data to the underlying conn + remainingBuffer, _ := io.ReadAll(bufferedReader) + if len(remainingBuffer) > 0 { + _, err = h.initializableConn.Write(remainingBuffer) + } + return len(b), err +} + +type tunnelingWebsocketUpgraderConn struct { + // req is the websocket request, used for upgrading + req *http.Request + // w is the websocket writer, used for upgrading and writing error responses + w http.ResponseWriter + + // lock guards conn and err + lock sync.RWMutex + // if conn is non-nil, InitializeWrite succeeded + conn net.Conn + // if err is non-nil, InitializeWrite failed or Close was called before InitializeWrite + err error +} + +func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) { + // make sure we close a connection we open in error cases + var conn net.Conn + defer func() { + if err != nil && conn != nil { + conn.Close() //nolint:errcheck + } + }() + + u.lock.Lock() + defer u.lock.Unlock() + if u.conn != nil { + return fmt.Errorf("InitializeWrite already called") + } + if u.err != nil { + return u.err + } + + if backendResponse.StatusCode == http.StatusSwitchingProtocols { + connectionHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderConnection)) + upgradeHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderUpgrade)) + if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) { + klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header) + u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response") + http.Error(u.w, u.err.Error(), http.StatusInternalServerError) + return u.err + } + + // Translate the server's chosen SPDY protocol into the tunneled websocket protocol for the handshake + var serverWebsocketProtocols []string + if backendSPDYProtocol := strings.TrimSpace(backendResponse.Header.Get(httpstream.HeaderProtocolVersion)); backendSPDYProtocol != "" { + serverWebsocketProtocols = []string{constants.WebsocketsSPDYTunnelingPrefix + backendSPDYProtocol} + } else { + serverWebsocketProtocols = []string{} + } + + // Try to upgrade the websocket connection. + // Beyond this point, we don't need to write errors to the response. + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: serverWebsocketProtocols, + } + conn, err := upgrader.Upgrade(u.w, u.req, nil) + if err != nil { + klog.Errorf("error upgrading websocket connection: %v", err) + u.err = err + return u.err + } + + klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol()) + u.conn = portforward.NewTunnelingConnection("server", conn) + return nil + } + + // anything other than an upgrade should pass through the backend response + + // try to hijack + conn, _, err = u.w.(http.Hijacker).Hijack() + if err != nil { + klog.Errorf("Unable to hijack response: %v", err) + u.err = err + return u.err + } + // replay the backend response to the hijacked conn + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck + err = backendResponse.Write(conn) + if err != nil { + u.err = err + return u.err + } + u.conn = conn + return nil +} + +func (u *tunnelingWebsocketUpgraderConn) Read(b []byte) (n int, err error) { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.Read(b) + } + if u.err != nil { + return 0, u.err + } + // return empty read without blocking until we are initialized + return 0, nil +} +func (u *tunnelingWebsocketUpgraderConn) Write(b []byte) (n int, err error) { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.Write(b) + } + if u.err != nil { + return 0, u.err + } + return 0, fmt.Errorf("Write called before Initialize") +} +func (u *tunnelingWebsocketUpgraderConn) Close() error { + u.lock.Lock() + defer u.lock.Unlock() + if u.conn != nil { + return u.conn.Close() + } + if u.err != nil { + return u.err + } + // record that we closed so we don't write again or try to initialize + u.err = fmt.Errorf("connection closed") + // write a response + http.Error(u.w, u.err.Error(), http.StatusInternalServerError) + return nil +} +func (u *tunnelingWebsocketUpgraderConn) LocalAddr() net.Addr { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.LocalAddr() + } + return noopAddr{} +} +func (u *tunnelingWebsocketUpgraderConn) RemoteAddr() net.Addr { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.RemoteAddr() + } + return noopAddr{} +} +func (u *tunnelingWebsocketUpgraderConn) SetDeadline(t time.Time) error { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.SetDeadline(t) + } + return nil +} +func (u *tunnelingWebsocketUpgraderConn) SetReadDeadline(t time.Time) error { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.SetReadDeadline(t) + } + return nil +} +func (u *tunnelingWebsocketUpgraderConn) SetWriteDeadline(t time.Time) error { + u.lock.RLock() + defer u.lock.RUnlock() + if u.conn != nil { + return u.conn.SetWriteDeadline(t) + } + return nil +} + +type noopAddr struct{} + +func (n noopAddr) Network() string { return "" } +func (n noopAddr) String() string { return "" } diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go new file mode 100644 index 00000000000..858ad6c8f87 --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go @@ -0,0 +1,197 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "bytes" + "crypto/rand" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + constants "k8s.io/apimachinery/pkg/util/portforward" + "k8s.io/apimachinery/pkg/util/proxy" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/apiserver/pkg/registry/rest" + restconfig "k8s.io/client-go/rest" + "k8s.io/client-go/tools/portforward" +) + +func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) { + // Create fake upstream SPDY server, with channel receiving SPDY streams. + streamChan := make(chan httpstream.Stream) + defer close(streamChan) + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + // Create fake upstream SPDY server. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + _, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name}) + require.NoError(t, err) + upgrader := spdy.NewResponseUpgrader() + conn := upgrader.UpgradeResponse(w, req, justQueueStream(streamChan)) + require.NotNil(t, conn) + defer conn.Close() //nolint:errcheck + <-stopServerChan + })) + defer spdyServer.Close() + // Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then + // create TunnelingHandler by injecting upgrade handler. Create TunnelingServer. + url, err := url.Parse(spdyServer.URL) + require.NoError(t, err) + transport, err := fakeTransport() + require.NoError(t, err) + upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{})) + tunnelingHandler := NewTunnelingHandler(upgradeHandler) + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + tunnelingHandler.ServeHTTP(w, req) + })) + defer tunnelingServer.Close() + // Create SPDY client connection containing a TunnelingConnection by upgrading + // a request to TunnelingHandler using new portforward version 2. + tunnelingURL, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host}) + require.NoError(t, err) + spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name) + require.NoError(t, err) + assert.Equal(t, constants.PortForwardV1Name, protocol) + defer spdyClient.Close() //nolint:errcheck + // Create a SPDY client stream, which will queue a SPDY server stream + // on the stream creation channel. Send random data on the client stream + // reading off the SPDY server stream, and validating it was tunneled. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + _, err = rand.Read(randomData) + require.NoError(t, err) + var actual []byte + go func() { + clientStream, err := spdyClient.CreateStream(http.Header{}) + require.NoError(t, err) + _, err = io.Copy(clientStream, bytes.NewReader(randomData)) + require.NoError(t, err) + clientStream.Close() //nolint:errcheck + }() + select { + case serverStream := <-streamChan: + actual, err = io.ReadAll(serverStream) + require.NoError(t, err) + defer serverStream.Close() //nolint:errcheck + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("timeout waiting for spdy stream to arrive on channel.") + } + assert.Equal(t, randomData, actual, "error validating tunneled random data") +} + +const responseStr = `HTTP/1.1 101 Switching Protocols +Date: Sun, 25 Feb 2024 08:09:25 GMT +X-App-Protocol: portforward.k8s.io + +` + +const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols +Date: Sun, 25 Feb 2024 08:09:25 GMT +X-App-Protocol: portforward.k8s.io + +This is extra data. +` + +const invalidResponseStr = `INVALID/1.1 101 Switching Protocols +Date: Sun, 25 Feb 2024 08:09:25 GMT +X-App-Protocol: portforward.k8s.io + +` + +func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) { + // Basic http response is intercepted correctly; no extra data sent to net.Conn. + testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}} + hic := &headerInterceptingConn{initializableConn: testConnConstructor} + _, err := hic.Write([]byte(responseStr)) + require.NoError(t, err) + assert.True(t, hic.initialized, "successfully parsed http response headers") + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol")) + assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn") + // Extra data after response headers should be sent to net.Conn. + hic = &headerInterceptingConn{initializableConn: testConnConstructor} + _, err = hic.Write([]byte(responseWithExtraStr)) + require.NoError(t, err) + assert.True(t, hic.initialized) + assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status) + assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn") + // Invalid response returns error. + hic = &headerInterceptingConn{initializableConn: testConnConstructor} + _, err = hic.Write([]byte(invalidResponseStr)) + assert.Error(t, err, "expected error from invalid http response") +} + +type mockConnInitializer struct { + resp *http.Response + *mockConn +} + +func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error { + m.resp = backendResponse + return nil +} + +// mockConn implements "net.Conn" interface. +var _ net.Conn = &mockConn{} + +type mockConn struct { + written []byte +} + +func (mc *mockConn) Write(p []byte) (int, error) { + mc.written = make([]byte, len(p)) + copy(mc.written, p) + return len(mc.written), nil +} + +func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil } +func (mc *mockConn) Close() error { return nil } +func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (mc *mockConn) SetDeadline(t time.Time) error { return nil } +func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// fakeResponder implements "rest.Responder" interface. +var _ rest.Responder = &fakeResponder{} + +type fakeResponder struct{} + +func (fr *fakeResponder) Object(statusCode int, obj runtime.Object) {} +func (fr *fakeResponder) Error(err error) {} + +// justQueueStream skips the usual stream validation before +// queueing the stream on the stream channel. +func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { + streams <- stream + return nil + } +} diff --git a/staging/src/k8s.io/client-go/tools/portforward/fallback_dialer.go b/staging/src/k8s.io/client-go/tools/portforward/fallback_dialer.go new file mode 100644 index 00000000000..8fb74a41857 --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/portforward/fallback_dialer.go @@ -0,0 +1,57 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/klog/v2" +) + +var _ httpstream.Dialer = &fallbackDialer{} + +// fallbackDialer encapsulates a primary and secondary dialer, including +// the boolean function to determine if the primary dialer failed. Implements +// the httpstream.Dialer interface. +type fallbackDialer struct { + primary httpstream.Dialer + secondary httpstream.Dialer + shouldFallback func(error) bool +} + +// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers, +// as well as the boolean function to determine if the primary dialer failed. +func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer { + return &fallbackDialer{ + primary: primary, + secondary: secondary, + shouldFallback: shouldFallback, + } +} + +// Dial is the single function necessary to implement the "httpstream.Dialer" interface. +// It takes the protocol version strings to request, returning an the upgraded +// httstream.Connection and the negotiated protocol version accepted. If the initial +// primary dialer fails, this function attempts the secondary dialer. Returns an error +// if one occurs. +func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { + conn, version, err := f.primary.Dial(protocols...) + if err != nil && f.shouldFallback(err) { + klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err) + return f.secondary.Dial(protocols...) + } + return conn, version, err +} diff --git a/staging/src/k8s.io/client-go/tools/portforward/fallback_dialer_test.go b/staging/src/k8s.io/client-go/tools/portforward/fallback_dialer_test.go new file mode 100644 index 00000000000..4680fa298da --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/portforward/fallback_dialer_test.go @@ -0,0 +1,53 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFallbackDialer(t *testing.T) { + protocol := "v6.fake.k8s.io" + // If "shouldFallback" is false, then only primary should be dialed. + primary := &fakeDialer{dialed: false} + primary.negotiatedProtocol = protocol + secondary := &fakeDialer{dialed: false} + fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse) + _, negotiated, err := fallbackDialer.Dial(protocol) + assert.True(t, primary.dialed, "no fallback; primary should have dialed") + assert.Equal(t, protocol, negotiated, "") + assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed") + assert.Nil(t, err, "error should be nil") + // If "shouldFallback" is true, then primary AND secondary should be dialed. + primary.dialed = false // reset dialed field + primary.err = fmt.Errorf("bad handshake") + secondary.dialed = false // reset dialed field + secondary.negotiatedProtocol = protocol + fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue) + _, negotiated, err = fallbackDialer.Dial(protocol) + assert.True(t, primary.dialed, "fallback; primary should have dialed (first)") + assert.True(t, secondary.dialed, "fallback; secondary should have dialed") + assert.Equal(t, protocol, negotiated) + assert.Nil(t, err) +} + +func alwaysTrue(err error) bool { return true } + +func alwaysFalse(err error) bool { return false } diff --git a/staging/src/k8s.io/client-go/tools/portforward/portforward.go b/staging/src/k8s.io/client-go/tools/portforward/portforward.go index b581043f6ee..83ef3e929b3 100644 --- a/staging/src/k8s.io/client-go/tools/portforward/portforward.go +++ b/staging/src/k8s.io/client-go/tools/portforward/portforward.go @@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name) + var protocol string + pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } defer pf.streamConn.Close() + if protocol != PortForwardProtocolV1Name { + return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol) + } return pf.forward() } diff --git a/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go b/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go index 3c90a3fde51..075a22e62a1 100644 --- a/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go +++ b/staging/src/k8s.io/client-go/tools/portforward/portforward_test.go @@ -430,7 +430,8 @@ func TestGetListener(t *testing.T) { func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) { dialer := &fakeDialer{ - conn: newFakeConnection(), + conn: newFakeConnection(), + negotiatedProtocol: PortForwardProtocolV1Name, } stopChan := make(chan struct{}) @@ -570,7 +571,8 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) { func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) { dialer := &fakeDialer{ - conn: newFakeConnection(), + conn: newFakeConnection(), + negotiatedProtocol: PortForwardProtocolV1Name, } stopChan := make(chan struct{}) @@ -601,7 +603,8 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) { func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) { dialer := &fakeDialer{ - conn: newFakeConnection(), + conn: newFakeConnection(), + negotiatedProtocol: PortForwardProtocolV1Name, } stopChan := make(chan struct{}) diff --git a/staging/src/k8s.io/client-go/tools/portforward/tunneling_connection.go b/staging/src/k8s.io/client-go/tools/portforward/tunneling_connection.go new file mode 100644 index 00000000000..4c04531b64f --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/portforward/tunneling_connection.go @@ -0,0 +1,153 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + gwebsocket "github.com/gorilla/websocket" + + "k8s.io/klog/v2" +) + +var _ net.Conn = &TunnelingConnection{} + +// TunnelingConnection implements the "httpstream.Connection" interface, wrapping +// a websocket connection that tunnels SPDY. +type TunnelingConnection struct { + name string + conn *gwebsocket.Conn + inProgressMessage io.Reader + closeOnce sync.Once +} + +// NewTunnelingConnection wraps the passed gorilla/websockets connection +// with the TunnelingConnection struct (implementing net.Conn). +func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection { + return &TunnelingConnection{ + name: name, + conn: conn, + } +} + +// Read implements "io.Reader" interface, reading from the stored connection +// into the passed buffer "p". Returns the number of bytes read and an error. +// Can keep track of the "inProgress" messsage from the tunneled connection. +func (c *TunnelingConnection) Read(p []byte) (int, error) { + klog.V(7).Infof("%s: tunneling connection read...", c.name) + defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name) + for { + if c.inProgressMessage == nil { + klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name) + messageType, nextReader, err := c.conn.NextReader() + if err != nil { + closeError := &gwebsocket.CloseError{} + if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure { + return 0, io.EOF + } + klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err) + return 0, err + } + if messageType != gwebsocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type received") + } + c.inProgressMessage = nextReader + } + klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name) + i, err := c.inProgressMessage.Read(p) + if i == 0 && err == io.EOF { + c.inProgressMessage = nil + } else { + klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i]) + return i, err + } + } +} + +// Write implements "io.Writer" interface, copying the data in the passed +// byte array "p" into the stored tunneled connection. Returns the number +// of bytes written and an error. +func (c *TunnelingConnection) Write(p []byte) (n int, err error) { + klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p) + defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name) + w, err := c.conn.NextWriter(gwebsocket.BinaryMessage) + if err != nil { + return 0, err + } + defer func() { + // close, which flushes the message + closeErr := w.Close() + if closeErr != nil && err == nil { + // if closing/flushing errored and we weren't already returning an error, return the close error + err = closeErr + } + }() + + n, err = w.Write(p) + return +} + +// Close implements "io.Closer" interface, signaling the other tunneled connection +// endpoint, and closing the tunneled connection only once. +func (c *TunnelingConnection) Close() error { + var err error + c.closeOnce.Do(func() { + klog.V(7).Infof("%s: tunneling connection Close()...", c.name) + // Signal other endpoint that websocket connection is closing; ignore error. + normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "") + c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck + err = c.conn.Close() + }) + return err +} + +// LocalAddr implements part of the "net.Conn" interface, returning the local +// endpoint network address of the tunneled connection. +func (c *TunnelingConnection) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// LocalAddr implements part of the "net.Conn" interface, returning the remote +// endpoint network address of the tunneled connection. +func (c *TunnelingConnection) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the *absolute* time in the future for both +// read and write deadlines. Returns an error if one occurs. +func (c *TunnelingConnection) SetDeadline(t time.Time) error { + rerr := c.SetReadDeadline(t) + werr := c.SetWriteDeadline(t) + return errors.Join(rerr, werr) +} + +// SetDeadline sets the *absolute* time in the future for the +// read deadlines. Returns an error if one occurs. +func (c *TunnelingConnection) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetDeadline sets the *absolute* time in the future for the +// write deadlines. Returns an error if one occurs. +func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/staging/src/k8s.io/client-go/tools/portforward/tunneling_connection_test.go b/staging/src/k8s.io/client-go/tools/portforward/tunneling_connection_test.go new file mode 100644 index 00000000000..4127f49d438 --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/portforward/tunneling_connection_test.go @@ -0,0 +1,190 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + gwebsocket "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + constants "k8s.io/apimachinery/pkg/util/portforward" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + "k8s.io/client-go/transport/websocket" +) + +func TestTunnelingConnection_ReadWriteClose(t *testing.T) { + // Stream channel that will receive streams created on upstream SPDY server. + streamChan := make(chan httpstream.Stream) + defer close(streamChan) + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + // Create tunneling connection server endpoint with fake upstream SPDY server. + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, + } + conn, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) + tunnelingConn := NewTunnelingConnection("server", conn) + spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan)) + require.NoError(t, err) + defer spdyConn.Close() //nolint:errcheck + <-stopServerChan + })) + defer tunnelingServer.Close() + // Dial the client tunneling connection to the tunneling server. + url, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host}) + require.NoError(t, err) + spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name) + require.NoError(t, err) + assert.Equal(t, constants.PortForwardV1Name, protocol) + defer spdyClient.Close() //nolint:errcheck + // Create a SPDY client stream, which will queue a SPDY server stream + // on the stream creation channel. Send data on the client stream + // reading off the SPDY server stream, and validating it was tunneled. + expected := "This is a test tunneling SPDY data through websockets." + var actual []byte + go func() { + clientStream, err := spdyClient.CreateStream(http.Header{}) + require.NoError(t, err) + _, err = io.Copy(clientStream, strings.NewReader(expected)) + require.NoError(t, err) + clientStream.Close() //nolint:errcheck + }() + select { + case serverStream := <-streamChan: + actual, err = io.ReadAll(serverStream) + require.NoError(t, err) + defer serverStream.Close() //nolint:errcheck + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("timeout waiting for spdy stream to arrive on channel.") + } + assert.Equal(t, expected, string(actual), "error validating tunneled string") +} + +func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) { + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, + } + conn, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) + <-stopServerChan + })) + defer tunnelingServer.Close() + // Create the client side tunneling connection. + url, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + tConn, err := dialForTunnelingConnection(url) + require.NoError(t, err, "error creating client tunneling connection") + defer tConn.Close() //nolint:errcheck + // Validate "LocalAddr()" and "RemoteAddr()" + localAddr := tConn.LocalAddr() + remoteAddr := tConn.RemoteAddr() + assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP") + assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP") + _, err = net.ResolveTCPAddr("tcp", localAddr.String()) + assert.NoError(t, err, "tunneling connection local addr should parse") + _, err = net.ResolveTCPAddr("tcp", remoteAddr.String()) + assert.NoError(t, err, "tunneling connection remote addr should parse") +} + +func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) { + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, + } + conn, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) + <-stopServerChan + })) + defer tunnelingServer.Close() + // Create the client side tunneling connection. + url, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + tConn, err := dialForTunnelingConnection(url) + require.NoError(t, err, "error creating client tunneling connection") + defer tConn.Close() //nolint:errcheck + // Validate the read and write deadlines. + err = tConn.SetReadDeadline(time.Time{}) + assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") + err = tConn.SetWriteDeadline(time.Time{}) + assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") + err = tConn.SetDeadline(time.Time{}) + assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") + err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0)) + assert.NoError(t, err, "setting deadline 10 year from now succeeds") + err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0)) + assert.NoError(t, err, "setting deadline 10 year from now succeeds") + err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0)) + assert.NoError(t, err, "setting deadline 10 year from now succeeds") +} + +// dialForTunnelingConnection upgrades a request at the passed "url", creating +// a websocket connection. Returns the TunnelingConnection injected with the +// websocket connection or an error if one occurs. +func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) { + req, err := http.NewRequest("GET", url.String(), nil) + if err != nil { + return nil, err + } + // Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol. + tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1} + transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host}) + if err != nil { + return nil, err + } + conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...) + if err != nil { + return nil, err + } + return NewTunnelingConnection("client", conn), nil +} + +func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { + streams <- stream + return nil + } +} diff --git a/staging/src/k8s.io/client-go/tools/portforward/tunneling_dialer.go b/staging/src/k8s.io/client-go/tools/portforward/tunneling_dialer.go new file mode 100644 index 00000000000..2bef5ecd720 --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/portforward/tunneling_dialer.go @@ -0,0 +1,93 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + constants "k8s.io/apimachinery/pkg/util/portforward" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport/websocket" + "k8s.io/klog/v2" +) + +const PingPeriod = 10 * time.Second + +// tunnelingDialer implements "httpstream.Dial" interface +type tunnelingDialer struct { + url *url.URL + transport http.RoundTripper + holder websocket.ConnectionHolder +} + +// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer" +// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function +// returns an error if one occurs. +func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) { + transport, holder, err := websocket.RoundTripperFor(config) + if err != nil { + return nil, err + } + return &tunnelingDialer{ + url: url, + transport: transport, + holder: holder, + }, nil +} + +// Dial upgrades to a tunneling streaming connection, returning a SPDY connection +// containing a WebSockets connection (which implements "net.Conn"). Also +// returns the protocol negotiated, or an error. +func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { + // There is no passed context, so skip the context when creating request for now. + // Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17). + req, err := http.NewRequest("GET", d.url.String(), nil) + if err != nil { + return nil, "", err + } + // Add the spdy tunneling prefix to the requested protocols. The tunneling + // handler will know how to negotiate these protocols. + tunnelingProtocols := []string{} + for _, protocol := range protocols { + tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol + tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol) + } + klog.V(4).Infoln("Before WebSocket Upgrade Connection...") + conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...) + if err != nil { + return nil, "", err + } + if conn == nil { + return nil, "", fmt.Errorf("negotiated websocket connection is nil") + } + protocol := conn.Subprotocol() + protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) + klog.V(4).Infof("negotiated protocol: %s", protocol) + + // Wrap the websocket connection which implements "net.Conn". + tConn := NewTunnelingConnection("client", conn) + // Create SPDY connection injecting the previously created tunneling connection. + spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod) + + return spdyConn, protocol, err +} diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward.go b/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward.go index 31e1eef7bc9..f02003546bb 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward.go @@ -31,6 +31,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/cli-runtime/pkg/genericiooptions" "k8s.io/client-go/kubernetes/scheme" @@ -50,7 +51,7 @@ import ( type PortForwardOptions struct { Namespace string PodName string - RESTClient *restclient.RESTClient + RESTClient restclient.Interface Config *restclient.Config PodClient corev1client.PodsGetter Address []string @@ -99,11 +100,7 @@ const ( ) func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command { - opts := &PortForwardOptions{ - PortForwarder: &defaultPortForwarder{ - IOStreams: streams, - }, - } + opts := NewDefaultPortForwardOptions(streams) cmd := &cobra.Command{ Use: "port-forward TYPE/NAME [options] [LOCAL_PORT:]REMOTE_PORT [...[LOCAL_PORT_N:]REMOTE_PORT_N]", DisableFlagsInUseLine: true, @@ -123,6 +120,14 @@ func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *c return cmd } +func NewDefaultPortForwardOptions(streams genericiooptions.IOStreams) *PortForwardOptions { + return &PortForwardOptions{ + PortForwarder: &defaultPortForwarder{ + IOStreams: streams, + }, + } +} + type portForwarder interface { ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error } @@ -137,6 +142,14 @@ func (f *defaultPortForwarder) ForwardPorts(method string, url *url.URL, opts Po return err } dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url) + if cmdutil.PortForwardWebsockets.IsEnabled() { + tunnelingDialer, err := portforward.NewSPDYOverWebsocketDialer(url, opts.Config) + if err != nil { + return err + } + // First attempt tunneling (websocket) dialer, then fallback to spdy dialer. + dialer = portforward.NewFallbackDialer(tunnelingDialer, dialer, httpstream.IsUpgradeFailure) + } fw, err := portforward.NewOnAddresses(dialer, opts.Address, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.Out, f.ErrOut) if err != nil { return err @@ -387,7 +400,14 @@ func (o PortForwardOptions) Validate() error { // RunPortForward implements all the necessary functionality for port-forward cmd. func (o PortForwardOptions) RunPortForward() error { - pod, err := o.PodClient.Pods(o.Namespace).Get(context.TODO(), o.PodName, metav1.GetOptions{}) + return o.RunPortForwardContext(context.Background()) +} + +// RunPortForwardContext implements all the necessary functionality for port-forward cmd. +// It ends portforwarding when an error is received from the backend, or an os.Interrupt +// signal is received, or the provided context is done. +func (o PortForwardOptions) RunPortForwardContext(ctx context.Context) error { + pod, err := o.PodClient.Pods(o.Namespace).Get(ctx, o.PodName, metav1.GetOptions{}) if err != nil { return err } @@ -401,7 +421,10 @@ func (o PortForwardOptions) RunPortForward() error { defer signal.Stop(signals) go func() { - <-signals + select { + case <-signals: + case <-ctx.Done(): + } if o.StopChannel != nil { close(o.StopChannel) } diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward_test.go b/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward_test.go index fb2252c54db..2aee31e7d47 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward_test.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward_test.go @@ -17,6 +17,7 @@ limitations under the License. package portforward import ( + "context" "fmt" "net/http" "net/url" @@ -101,6 +102,8 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) { } opts := &PortForwardOptions{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() cmd := NewCmdPortForward(tf, genericiooptions.NewTestIOStreamsDiscard()) cmd.Run = func(cmd *cobra.Command, args []string) { if err = opts.Complete(tf, cmd, args); err != nil { @@ -110,7 +113,7 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) { if err = opts.Validate(); err != nil { return } - err = opts.RunPortForward() + err = opts.RunPortForwardContext(ctx) } for name, value := range flags { diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go b/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go index fe45d118563..d9e401fa493 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go @@ -430,6 +430,7 @@ const ( InteractiveDelete FeatureGate = "KUBECTL_INTERACTIVE_DELETE" OpenAPIV3Patch FeatureGate = "KUBECTL_OPENAPIV3_PATCH" RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS" + PortForwardWebsockets FeatureGate = "KUBECTL_PORT_FORWARD_WEBSOCKETS" ) // IsEnabled returns true iff environment variable is set to true. diff --git a/test/integration/apiserver/portforward/main_test.go b/test/integration/apiserver/portforward/main_test.go new file mode 100644 index 00000000000..b7a62225a9d --- /dev/null +++ b/test/integration/apiserver/portforward/main_test.go @@ -0,0 +1,27 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "testing" + + "k8s.io/kubernetes/test/integration/framework" +) + +func TestMain(m *testing.M) { + framework.EtcdMain(m.Run) +} diff --git a/test/integration/apiserver/portforward/portforward_test.go b/test/integration/apiserver/portforward/portforward_test.go new file mode 100644 index 00000000000..e18c8ea2694 --- /dev/null +++ b/test/integration/apiserver/portforward/portforward_test.go @@ -0,0 +1,228 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + utilfeature "k8s.io/apiserver/pkg/util/feature" + "k8s.io/cli-runtime/pkg/genericiooptions" + "k8s.io/client-go/kubernetes" + featuregatetesting "k8s.io/component-base/featuregate/testing" + "k8s.io/kubectl/pkg/cmd/portforward" + kubeletportforward "k8s.io/kubelet/pkg/cri/streaming/portforward" + kastesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing" + kubefeatures "k8s.io/kubernetes/pkg/features" + + "k8s.io/kubernetes/test/integration/framework" +) + +const remotePort = "8765" + +func TestPortforward(t *testing.T) { + defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, kubefeatures.PortForwardWebsockets, true)() + t.Setenv("KUBECTL_PORT_FORWARD_WEBSOCKETS", "true") + + var podName string + var podUID types.UID + backendServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + t.Logf("backend saw request: %v", req.URL.String()) + kubeletportforward.ServePortForward( + w, + req, + &dummyPortForwarder{t: t}, + podName, + podUID, + &kubeletportforward.V4Options{}, + wait.ForeverTestTimeout, // idle timeout + remotecommand.DefaultStreamCreationTimeout, // stream creation timeout + []string{kubeletportforward.ProtocolV1Name}, + ) + })) + defer backendServer.Close() + backendURL, _ := url.Parse(backendServer.URL) + backendHost := backendURL.Hostname() + backendPort, _ := strconv.Atoi(backendURL.Port()) + + etcd := framework.SharedEtcd() + server := kastesting.StartTestServerOrDie(t, nil, []string{"--disable-admission-plugins=ServiceAccount"}, etcd) + defer server.TearDownFn() + + adminClient, err := kubernetes.NewForConfig(server.ClientConfig) + require.NoError(t, err) + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "mynode"}, + Status: corev1.NodeStatus{ + DaemonEndpoints: corev1.NodeDaemonEndpoints{KubeletEndpoint: corev1.DaemonEndpoint{Port: int32(backendPort)}}, + Addresses: []corev1.NodeAddress{{Type: corev1.NodeInternalIP, Address: backendHost}}, + }, + } + if _, err := adminClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}); err != nil { + t.Fatal(err) + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Namespace: "default", Name: "mypod"}, + Spec: corev1.PodSpec{ + NodeName: "mynode", + Containers: []corev1.Container{{Name: "test", Image: "test"}}, + }, + } + if _, err := adminClient.CoreV1().Pods("default").Create(context.Background(), pod, metav1.CreateOptions{}); err != nil { + t.Fatal(err) + } + if _, err := adminClient.CoreV1().Pods("default").Patch(context.Background(), "mypod", types.MergePatchType, []byte(`{"status":{"phase":"Running"}}`), metav1.PatchOptions{}, "status"); err != nil { + t.Fatal(err) + } + + // local port missing asks os to find random open port. + // Example: ":8000" (local = random, remote = 8000) + localRemotePort := fmt.Sprintf(":%s", remotePort) + streams, _, out, errOut := genericiooptions.NewTestIOStreams() + portForwardOptions := portforward.NewDefaultPortForwardOptions(streams) + portForwardOptions.Namespace = "default" + portForwardOptions.PodName = "mypod" + portForwardOptions.RESTClient = adminClient.CoreV1().RESTClient() + portForwardOptions.Config = server.ClientConfig + portForwardOptions.PodClient = adminClient.CoreV1() + portForwardOptions.Address = []string{"127.0.0.1"} + portForwardOptions.Ports = []string{localRemotePort} + portForwardOptions.StopChannel = make(chan struct{}, 1) + portForwardOptions.ReadyChannel = make(chan struct{}) + + if err := portForwardOptions.Validate(); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + if err := portForwardOptions.RunPortForwardContext(ctx); err != nil { + t.Error(err) + } + }() + + t.Log("waiting for port forward to be ready") + select { + case <-portForwardOptions.ReadyChannel: + t.Log("port forward was ready") + case <-time.After(wait.ForeverTestTimeout): + t.Error("port forward was never ready") + } + + // Parse out the randomly selected local port from "out" stream. + localPort, err := parsePort(out.String()) + require.NoError(t, err) + t.Logf("Local Port: %s", localPort) + + timeoutContext, cleanupTimeoutContext := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cleanupTimeoutContext() + testReq, _ := http.NewRequest("GET", fmt.Sprintf("http://127.0.0.1:%s/test", localPort), nil) + testReq = testReq.WithContext(timeoutContext) + testResp, err := http.DefaultClient.Do(testReq) + if err != nil { + t.Error(err) + } else { + t.Log(testResp.StatusCode) + data, err := io.ReadAll(testResp.Body) + if err != nil { + t.Error(err) + } else { + t.Log("client saw response:", string(data)) + } + if string(data) != fmt.Sprintf("request to %s was ok", remotePort) { + t.Errorf("unexpected data") + } + if testResp.StatusCode != 200 { + t.Error("expected success") + } + } + + cancel() + + wg.Wait() + t.Logf("stdout: %s", out.String()) + t.Logf("stderr: %s", errOut.String()) +} + +// parsePort parses out the local port from the port-forward output string. +// This should work for both IP4 and IP6 addresses. +// +// Example: "Forwarding from 127.0.0.1:8000 -> 4000", returns "8000". +func parsePort(forwardAddr string) (string, error) { + parts := strings.Split(forwardAddr, " ") + if len(parts) != 5 { + return "", fmt.Errorf("unable to parse local port from stdout: %s", forwardAddr) + } + // parts[2] = "127.0.0.1:" + _, localPort, err := net.SplitHostPort(parts[2]) + if err != nil { + return "", fmt.Errorf("unable to parse local port: %w", err) + } + return localPort, nil +} + +type dummyPortForwarder struct { + t *testing.T +} + +func (d *dummyPortForwarder) PortForward(ctx context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { + d.t.Logf("handling port forward request for %d", port) + + req, err := http.ReadRequest(bufio.NewReader(stream)) + if err != nil { + d.t.Logf("error reading request: %v", err) + return err + } + d.t.Log(req.URL.String()) + defer req.Body.Close() //nolint:errcheck + + resp := &http.Response{ + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: io.NopCloser(bytes.NewBufferString(fmt.Sprintf("request to %d was ok", port))), + } + resp.Write(stream) //nolint:errcheck + return stream.Close() +}