mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-12 21:36:24 +00:00
portforward: tunnel spdy through websockets
This commit is contained in:
parent
89cbd94e68
commit
8b447d8c97
@ -619,6 +619,13 @@ const (
|
|||||||
// Enable users to specify when a Pod is ready for scheduling.
|
// Enable users to specify when a Pod is ready for scheduling.
|
||||||
PodSchedulingReadiness featuregate.Feature = "PodSchedulingReadiness"
|
PodSchedulingReadiness featuregate.Feature = "PodSchedulingReadiness"
|
||||||
|
|
||||||
|
// owner: @seans3
|
||||||
|
// kep: http://kep.k8s.io/4006
|
||||||
|
// alpha: v1.30
|
||||||
|
//
|
||||||
|
// Enables PortForward to be proxied with a websocket client
|
||||||
|
PortForwardWebsockets featuregate.Feature = "PortForwardWebsockets"
|
||||||
|
|
||||||
// owner: @jessfraz
|
// owner: @jessfraz
|
||||||
// alpha: v1.12
|
// alpha: v1.12
|
||||||
//
|
//
|
||||||
@ -1101,6 +1108,8 @@ var defaultKubernetesFeatureGates = map[featuregate.Feature]featuregate.FeatureS
|
|||||||
|
|
||||||
PodSchedulingReadiness: {Default: true, PreRelease: featuregate.GA, LockToDefault: true}, // GA in 1.30; remove in 1.32
|
PodSchedulingReadiness: {Default: true, PreRelease: featuregate.GA, LockToDefault: true}, // GA in 1.30; remove in 1.32
|
||||||
|
|
||||||
|
PortForwardWebsockets: {Default: false, PreRelease: featuregate.Alpha},
|
||||||
|
|
||||||
ProcMountType: {Default: false, PreRelease: featuregate.Alpha},
|
ProcMountType: {Default: false, PreRelease: featuregate.Alpha},
|
||||||
|
|
||||||
QOSReserved: {Default: false, PreRelease: featuregate.Alpha},
|
QOSReserved: {Default: false, PreRelease: featuregate.Alpha},
|
||||||
|
@ -242,7 +242,12 @@ func (r *PortForwardREST) Connect(ctx context.Context, name string, opts runtime
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil
|
handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder)
|
||||||
|
if utilfeature.DefaultFeatureGate.Enabled(features.PortForwardWebsockets) {
|
||||||
|
tunnelingHandler := translator.NewTunnelingHandler(handler)
|
||||||
|
handler = translator.NewTranslatingHandler(handler, tunnelingHandler, wsstream.IsWebSocketRequestWithTunnelingProtocol)
|
||||||
|
}
|
||||||
|
return handler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler {
|
func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler {
|
||||||
|
@ -27,6 +27,7 @@ import (
|
|||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
|
|
||||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/portforward"
|
||||||
"k8s.io/apimachinery/pkg/util/remotecommand"
|
"k8s.io/apimachinery/pkg/util/remotecommand"
|
||||||
"k8s.io/apimachinery/pkg/util/runtime"
|
"k8s.io/apimachinery/pkg/util/runtime"
|
||||||
"k8s.io/klog/v2"
|
"k8s.io/klog/v2"
|
||||||
@ -106,6 +107,23 @@ func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsWebSocketRequestWithTunnelingProtocol returns true if the request contains headers
|
||||||
|
// identifying that it is requesting a websocket upgrade with a tunneling protocol;
|
||||||
|
// false otherwise.
|
||||||
|
func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool {
|
||||||
|
if !IsWebSocketRequest(req) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
|
||||||
|
for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
|
||||||
|
if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
|
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
|
||||||
// read and write deadlines are pushed every time a new message is received.
|
// read and write deadlines are pushed every time a new message is received.
|
||||||
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
|
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
|
||||||
@ -301,6 +319,12 @@ func protocolSupportsStreamClose(protocol string) bool {
|
|||||||
return protocol == remotecommand.StreamProtocolV5Name
|
return protocol == remotecommand.StreamProtocolV5Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// protocolSupportsWebsocketTunneling returns true if the passed protocol
|
||||||
|
// is a tunneled Kubernetes spdy protocol; false otherwise.
|
||||||
|
func protocolSupportsWebsocketTunneling(protocol string) bool {
|
||||||
|
return strings.HasPrefix(protocol, portforward.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, portforward.KubernetesSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
// handle implements a websocket handler.
|
// handle implements a websocket handler.
|
||||||
func (conn *Conn) handle(ws *websocket.Conn) {
|
func (conn *Conn) handle(ws *websocket.Conn) {
|
||||||
conn.initialize(ws)
|
conn.initialize(ws)
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2016 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
const (
|
||||||
|
PortForwardV1Name = "portforward.k8s.io"
|
||||||
|
WebsocketsSPDYTunnelingPrefix = "SPDY/3.1+"
|
||||||
|
KubernetesSuffix = ".k8s.io"
|
||||||
|
WebsocketsSPDYTunnelingPortForwardV1 = WebsocketsSPDYTunnelingPrefix + PortForwardV1Name
|
||||||
|
)
|
@ -36,6 +36,7 @@ import (
|
|||||||
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
||||||
|
|
||||||
"github.com/mxk/go-flowrate/flowrate"
|
"github.com/mxk/go-flowrate/flowrate"
|
||||||
|
|
||||||
"k8s.io/klog/v2"
|
"k8s.io/klog/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -336,6 +337,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
|
|||||||
clone.Host = h.Location.Host
|
clone.Host = h.Location.Host
|
||||||
}
|
}
|
||||||
clone.URL = &location
|
clone.URL = &location
|
||||||
|
klog.V(6).Infof("UpgradeAwareProxy: dialing for SPDY upgrade with headers: %v", clone.Header)
|
||||||
backendConn, err = h.DialForUpgrade(clone)
|
backendConn, err = h.DialForUpgrade(clone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
klog.V(6).Infof("Proxy connection error: %v", err)
|
klog.V(6).Infof("Proxy connection error: %v", err)
|
||||||
@ -370,13 +372,13 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
|
|||||||
// hijacking should be the last step in the upgrade.
|
// hijacking should be the last step in the upgrade.
|
||||||
requestHijacker, ok := w.(http.Hijacker)
|
requestHijacker, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
klog.V(6).Infof("Unable to hijack response writer: %T", w)
|
klog.Errorf("Unable to hijack response writer: %T", w)
|
||||||
h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
|
h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
requestHijackedConn, _, err := requestHijacker.Hijack()
|
requestHijackedConn, _, err := requestHijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
klog.V(6).Infof("Unable to hijack response: %v", err)
|
klog.Errorf("Unable to hijack response: %v", err)
|
||||||
h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
|
h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -420,7 +422,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
|
|||||||
} else {
|
} else {
|
||||||
writer = backendConn
|
writer = backendConn
|
||||||
}
|
}
|
||||||
_, err := io.Copy(writer, requestHijackedConn)
|
_, err := io.Copy(writer, &loggingReader{name: "client->backend", delegate: requestHijackedConn})
|
||||||
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
klog.Errorf("Error proxying data from client to backend: %v", err)
|
klog.Errorf("Error proxying data from client to backend: %v", err)
|
||||||
}
|
}
|
||||||
@ -434,7 +436,7 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
|
|||||||
} else {
|
} else {
|
||||||
reader = backendConn
|
reader = backendConn
|
||||||
}
|
}
|
||||||
_, err := io.Copy(requestHijackedConn, reader)
|
_, err := io.Copy(requestHijackedConn, &loggingReader{name: "backend->client", delegate: reader})
|
||||||
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
klog.Errorf("Error proxying data from backend to client: %v", err)
|
klog.Errorf("Error proxying data from backend to client: %v", err)
|
||||||
}
|
}
|
||||||
@ -452,6 +454,18 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loggingReader logs the bytes read from the "delegate" with a "name" prefix.
|
||||||
|
type loggingReader struct {
|
||||||
|
name string
|
||||||
|
delegate io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *loggingReader) Read(p []byte) (int, error) {
|
||||||
|
n, err := l.delegate.Read(p)
|
||||||
|
klog.V(8).Infof("%s: %d bytes, err=%v, bytes=% X", l.name, n, err, p[:n])
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME: Taken from net/http/httputil/reverseproxy.go as singleJoiningSlash is not exported to be re-used.
|
// FIXME: Taken from net/http/httputil/reverseproxy.go as singleJoiningSlash is not exported to be re-used.
|
||||||
// See-also: https://github.com/golang/go/issues/44290
|
// See-also: https://github.com/golang/go/issues/44290
|
||||||
func singleJoiningSlash(a, b string) string {
|
func singleJoiningSlash(a, b string) string {
|
||||||
|
@ -16,6 +16,7 @@ require (
|
|||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.6.0
|
||||||
github.com/google/gofuzz v1.2.0
|
github.com/google/gofuzz v1.2.0
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
|
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
|
||||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
|
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
|
||||||
@ -77,7 +78,6 @@ require (
|
|||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
github.com/google/btree v1.0.1 // indirect
|
github.com/google/btree v1.0.1 // indirect
|
||||||
github.com/gorilla/websocket v1.5.0 // indirect
|
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
|
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
|
||||||
|
433
staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go
Normal file
433
staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel.go
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gwebsocket "github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
|
||||||
|
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||||
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||||
|
"k8s.io/client-go/tools/portforward"
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TunnelingHandler is a handler which tunnels SPDY through WebSockets.
|
||||||
|
type TunnelingHandler struct {
|
||||||
|
// Used to communicate between upstream SPDY and downstream tunnel.
|
||||||
|
upgradeHandler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelingHandler is used to create the tunnel between an upstream
|
||||||
|
// SPDY connection and a downstream tunneling connection through the stored
|
||||||
|
// UpgradeAwareProxy.
|
||||||
|
func NewTunnelingHandler(upgradeHandler http.Handler) *TunnelingHandler {
|
||||||
|
return &TunnelingHandler{upgradeHandler: upgradeHandler}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP uses the upgradeHandler to tunnel between a downstream tunneling
|
||||||
|
// connection and an upstream SPDY connection. The tunneling connection is
|
||||||
|
// a wrapped WebSockets connection which communicates SPDY framed data.
|
||||||
|
func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
klog.V(4).Infoln("TunnelingHandler ServeHTTP")
|
||||||
|
|
||||||
|
spdyProtocols := spdyProtocolsFromWebsocketProtocols(req)
|
||||||
|
if len(spdyProtocols) == 0 {
|
||||||
|
http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
spdyRequest := createSPDYRequest(req, spdyProtocols...)
|
||||||
|
|
||||||
|
writer := &tunnelingResponseWriter{
|
||||||
|
w: w,
|
||||||
|
conn: &headerInterceptingConn{
|
||||||
|
initializableConn: &tunnelingWebsocketUpgraderConn{
|
||||||
|
w: w,
|
||||||
|
req: req,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
klog.V(4).Infoln("Tunnel spdy through websockets using the UpgradeAwareProxy")
|
||||||
|
h.upgradeHandler.ServeHTTP(writer, spdyRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSPDYRequest modifies the passed request to remove
|
||||||
|
// WebSockets headers and add SPDY upgrade information, including
|
||||||
|
// spdy protocols acceptable to the client.
|
||||||
|
func createSPDYRequest(req *http.Request, spdyProtocols ...string) *http.Request {
|
||||||
|
clone := utilnet.CloneRequest(req)
|
||||||
|
// Clean up the websocket headers from the http request.
|
||||||
|
clone.Header.Del(wsstream.WebSocketProtocolHeader)
|
||||||
|
clone.Header.Del("Sec-Websocket-Key")
|
||||||
|
clone.Header.Del("Sec-Websocket-Version")
|
||||||
|
clone.Header.Del(httpstream.HeaderUpgrade)
|
||||||
|
// Update the http request for an upstream SPDY upgrade.
|
||||||
|
clone.Method = "POST"
|
||||||
|
clone.Body = nil // Remove the request body which is unused.
|
||||||
|
clone.Header.Set(httpstream.HeaderUpgrade, spdy.HeaderSpdy31)
|
||||||
|
clone.Header.Del(httpstream.HeaderProtocolVersion)
|
||||||
|
for i := range spdyProtocols {
|
||||||
|
clone.Header.Add(httpstream.HeaderProtocolVersion, spdyProtocols[i])
|
||||||
|
}
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// spdyProtocolsFromWebsocketProtocols returns a list of spdy protocols by filtering
|
||||||
|
// to Kubernetes websocket subprotocols prefixed with "SPDY/3.1+", then removing the prefix
|
||||||
|
func spdyProtocolsFromWebsocketProtocols(req *http.Request) []string {
|
||||||
|
var spdyProtocols []string
|
||||||
|
for _, protocol := range gwebsocket.Subprotocols(req) {
|
||||||
|
if strings.HasPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, constants.KubernetesSuffix) {
|
||||||
|
spdyProtocols = append(spdyProtocols, strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return spdyProtocols
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.ResponseWriter = &tunnelingResponseWriter{}
|
||||||
|
var _ http.Hijacker = &tunnelingResponseWriter{}
|
||||||
|
|
||||||
|
// tunnelingResponseWriter implements the http.ResponseWriter and http.Hijacker interfaces.
|
||||||
|
// Only non-upgrade responses can be written using WriteHeader() and Write().
|
||||||
|
// Once Write or WriteHeader is called, Hijack returns an error.
|
||||||
|
// Once Hijack is called, Write, WriteHeader, and Hijack return errors.
|
||||||
|
type tunnelingResponseWriter struct {
|
||||||
|
// w is used to delegate Header(), WriteHeader(), and Write() calls
|
||||||
|
w http.ResponseWriter
|
||||||
|
// conn is returned from Hijack()
|
||||||
|
conn net.Conn
|
||||||
|
// mu guards writes
|
||||||
|
mu sync.Mutex
|
||||||
|
// wrote tracks whether WriteHeader or Write has been called
|
||||||
|
written bool
|
||||||
|
// hijacked tracks whether Hijack has been called
|
||||||
|
hijacked bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack returns a delegate "net.Conn".
|
||||||
|
// An error is returned if Write(), WriteHeader(), or Hijack() was previously called.
|
||||||
|
// The returned bufio.ReadWriter is always nil.
|
||||||
|
func (w *tunnelingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.written {
|
||||||
|
klog.Errorf("Hijack called after write")
|
||||||
|
return nil, nil, errors.New("connection has already been written to")
|
||||||
|
}
|
||||||
|
if w.hijacked {
|
||||||
|
klog.Errorf("Hijack called after hijack")
|
||||||
|
return nil, nil, errors.New("connection has already been hijacked")
|
||||||
|
}
|
||||||
|
w.hijacked = true
|
||||||
|
klog.V(6).Infof("Hijack returning websocket tunneling net.Conn")
|
||||||
|
return w.conn, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header is delegated to the stored "http.ResponseWriter".
|
||||||
|
func (w *tunnelingResponseWriter) Header() http.Header {
|
||||||
|
return w.w.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is delegated to the stored "http.ResponseWriter".
|
||||||
|
func (w *tunnelingResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.hijacked {
|
||||||
|
klog.Errorf("Write called after hijack")
|
||||||
|
return 0, http.ErrHijacked
|
||||||
|
}
|
||||||
|
w.written = true
|
||||||
|
return w.w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader is delegated to the stored "http.ResponseWriter".
|
||||||
|
func (w *tunnelingResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.written {
|
||||||
|
klog.Errorf("WriteHeader called after write")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if w.hijacked {
|
||||||
|
klog.Errorf("WriteHeader called after hijack")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.written = true
|
||||||
|
|
||||||
|
if statusCode == http.StatusSwitchingProtocols {
|
||||||
|
// 101 upgrade responses must come via the hijacked connection, not WriteHeader
|
||||||
|
klog.Errorf("WriteHeader called with 101 upgrade")
|
||||||
|
http.Error(w.w, "unexpected upgrade", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// pass through non-upgrade responses we don't need to translate
|
||||||
|
w.w.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerInterceptingConn wraps the tunneling "net.Conn" to drain the
|
||||||
|
// HTTP response status/headers from the upstream SPDY connection, then use
|
||||||
|
// that to decide how to initialize the delegate connection for writes.
|
||||||
|
type headerInterceptingConn struct {
|
||||||
|
// initializableConn is delegated to for all net.Conn methods.
|
||||||
|
// initializableConn.Write() is not called until response headers have been read
|
||||||
|
// and initializableConn#InitializeWrite() has been called with the result.
|
||||||
|
initializableConn
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
headerBuffer []byte
|
||||||
|
initialized bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializableConn is a connection that will be initialized before any calls to Write are made
|
||||||
|
type initializableConn interface {
|
||||||
|
net.Conn
|
||||||
|
InitializeWrite(backendResponse *http.Response) error
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxHeaderBytes = 1 << 20
|
||||||
|
|
||||||
|
// Write intercepts to initially swallow the HTTP response, then
|
||||||
|
// delegate to the tunneling "net.Conn" once the response has been
|
||||||
|
// seen and processed.
|
||||||
|
func (h *headerInterceptingConn) Write(b []byte) (int, error) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
if h.initialized {
|
||||||
|
return h.initializableConn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write into the headerBuffer, then attempt to parse the bytes
|
||||||
|
// as an http response.
|
||||||
|
if len(h.headerBuffer)+len(b) > maxHeaderBytes {
|
||||||
|
return 0, fmt.Errorf("header size limit exceeded")
|
||||||
|
}
|
||||||
|
h.headerBuffer = append(h.headerBuffer, b...)
|
||||||
|
bufferedReader := bufio.NewReader(bytes.NewReader(h.headerBuffer))
|
||||||
|
resp, err := http.ReadResponse(bufferedReader, nil)
|
||||||
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
// don't yet have a complete set of headers
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
klog.Errorf("invalid headers: %v", err)
|
||||||
|
return len(b), err
|
||||||
|
}
|
||||||
|
resp.Body.Close() //nolint:errcheck
|
||||||
|
|
||||||
|
h.headerBuffer = nil
|
||||||
|
err = h.initializableConn.InitializeWrite(resp)
|
||||||
|
h.initialized = true
|
||||||
|
if err != nil {
|
||||||
|
return len(b), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy any remaining buffered data to the underlying conn
|
||||||
|
remainingBuffer, _ := io.ReadAll(bufferedReader)
|
||||||
|
if len(remainingBuffer) > 0 {
|
||||||
|
_, err = h.initializableConn.Write(remainingBuffer)
|
||||||
|
}
|
||||||
|
return len(b), err
|
||||||
|
}
|
||||||
|
|
||||||
|
type tunnelingWebsocketUpgraderConn struct {
|
||||||
|
// req is the websocket request, used for upgrading
|
||||||
|
req *http.Request
|
||||||
|
// w is the websocket writer, used for upgrading and writing error responses
|
||||||
|
w http.ResponseWriter
|
||||||
|
|
||||||
|
// lock guards conn and err
|
||||||
|
lock sync.RWMutex
|
||||||
|
// if conn is non-nil, InitializeWrite succeeded
|
||||||
|
conn net.Conn
|
||||||
|
// if err is non-nil, InitializeWrite failed or Close was called before InitializeWrite
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.Response) (err error) {
|
||||||
|
// make sure we close a connection we open in error cases
|
||||||
|
var conn net.Conn
|
||||||
|
defer func() {
|
||||||
|
if err != nil && conn != nil {
|
||||||
|
conn.Close() //nolint:errcheck
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
u.lock.Lock()
|
||||||
|
defer u.lock.Unlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return fmt.Errorf("InitializeWrite already called")
|
||||||
|
}
|
||||||
|
if u.err != nil {
|
||||||
|
return u.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if backendResponse.StatusCode == http.StatusSwitchingProtocols {
|
||||||
|
connectionHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderConnection))
|
||||||
|
upgradeHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderUpgrade))
|
||||||
|
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) {
|
||||||
|
klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header)
|
||||||
|
u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response")
|
||||||
|
http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
|
||||||
|
return u.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Translate the server's chosen SPDY protocol into the tunneled websocket protocol for the handshake
|
||||||
|
var serverWebsocketProtocols []string
|
||||||
|
if backendSPDYProtocol := strings.TrimSpace(backendResponse.Header.Get(httpstream.HeaderProtocolVersion)); backendSPDYProtocol != "" {
|
||||||
|
serverWebsocketProtocols = []string{constants.WebsocketsSPDYTunnelingPrefix + backendSPDYProtocol}
|
||||||
|
} else {
|
||||||
|
serverWebsocketProtocols = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to upgrade the websocket connection.
|
||||||
|
// Beyond this point, we don't need to write errors to the response.
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: serverWebsocketProtocols,
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(u.w, u.req, nil)
|
||||||
|
if err != nil {
|
||||||
|
klog.Errorf("error upgrading websocket connection: %v", err)
|
||||||
|
u.err = err
|
||||||
|
return u.err
|
||||||
|
}
|
||||||
|
|
||||||
|
klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol())
|
||||||
|
u.conn = portforward.NewTunnelingConnection("server", conn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// anything other than an upgrade should pass through the backend response
|
||||||
|
|
||||||
|
// try to hijack
|
||||||
|
conn, _, err = u.w.(http.Hijacker).Hijack()
|
||||||
|
if err != nil {
|
||||||
|
klog.Errorf("Unable to hijack response: %v", err)
|
||||||
|
u.err = err
|
||||||
|
return u.err
|
||||||
|
}
|
||||||
|
// replay the backend response to the hijacked conn
|
||||||
|
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) //nolint:errcheck
|
||||||
|
err = backendResponse.Write(conn)
|
||||||
|
if err != nil {
|
||||||
|
u.err = err
|
||||||
|
return u.err
|
||||||
|
}
|
||||||
|
u.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) Read(b []byte) (n int, err error) {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.Read(b)
|
||||||
|
}
|
||||||
|
if u.err != nil {
|
||||||
|
return 0, u.err
|
||||||
|
}
|
||||||
|
// return empty read without blocking until we are initialized
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) Write(b []byte) (n int, err error) {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.Write(b)
|
||||||
|
}
|
||||||
|
if u.err != nil {
|
||||||
|
return 0, u.err
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("Write called before Initialize")
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) Close() error {
|
||||||
|
u.lock.Lock()
|
||||||
|
defer u.lock.Unlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.Close()
|
||||||
|
}
|
||||||
|
if u.err != nil {
|
||||||
|
return u.err
|
||||||
|
}
|
||||||
|
// record that we closed so we don't write again or try to initialize
|
||||||
|
u.err = fmt.Errorf("connection closed")
|
||||||
|
// write a response
|
||||||
|
http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) LocalAddr() net.Addr {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
return noopAddr{}
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) RemoteAddr() net.Addr {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
return noopAddr{}
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) SetDeadline(t time.Time) error {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) SetReadDeadline(t time.Time) error {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *tunnelingWebsocketUpgraderConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
u.lock.RLock()
|
||||||
|
defer u.lock.RUnlock()
|
||||||
|
if u.conn != nil {
|
||||||
|
return u.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopAddr struct{}
|
||||||
|
|
||||||
|
func (n noopAddr) Network() string { return "" }
|
||||||
|
func (n noopAddr) String() string { return "" }
|
197
staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go
Normal file
197
staging/src/k8s.io/apiserver/pkg/util/proxy/streamtunnel_test.go
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/runtime"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||||
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||||
|
"k8s.io/apimachinery/pkg/util/proxy"
|
||||||
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
"k8s.io/apiserver/pkg/registry/rest"
|
||||||
|
restconfig "k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/tools/portforward"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
|
||||||
|
// Create fake upstream SPDY server, with channel receiving SPDY streams.
|
||||||
|
streamChan := make(chan httpstream.Stream)
|
||||||
|
defer close(streamChan)
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
// Create fake upstream SPDY server.
|
||||||
|
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
|
||||||
|
require.NoError(t, err)
|
||||||
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
|
conn := upgrader.UpgradeResponse(w, req, justQueueStream(streamChan))
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer spdyServer.Close()
|
||||||
|
// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then
|
||||||
|
// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer.
|
||||||
|
url, err := url.Parse(spdyServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
transport, err := fakeTransport()
|
||||||
|
require.NoError(t, err)
|
||||||
|
upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
|
||||||
|
tunnelingHandler := NewTunnelingHandler(upgradeHandler)
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
tunnelingHandler.ServeHTTP(w, req)
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Create SPDY client connection containing a TunnelingConnection by upgrading
|
||||||
|
// a request to TunnelingHandler using new portforward version 2.
|
||||||
|
tunnelingURL, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
|
||||||
|
require.NoError(t, err)
|
||||||
|
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, constants.PortForwardV1Name, protocol)
|
||||||
|
defer spdyClient.Close() //nolint:errcheck
|
||||||
|
// Create a SPDY client stream, which will queue a SPDY server stream
|
||||||
|
// on the stream creation channel. Send random data on the client stream
|
||||||
|
// reading off the SPDY server stream, and validating it was tunneled.
|
||||||
|
randomSize := 1024 * 1024
|
||||||
|
randomData := make([]byte, randomSize)
|
||||||
|
_, err = rand.Read(randomData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var actual []byte
|
||||||
|
go func() {
|
||||||
|
clientStream, err := spdyClient.CreateStream(http.Header{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = io.Copy(clientStream, bytes.NewReader(randomData))
|
||||||
|
require.NoError(t, err)
|
||||||
|
clientStream.Close() //nolint:errcheck
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case serverStream := <-streamChan:
|
||||||
|
actual, err = io.ReadAll(serverStream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer serverStream.Close() //nolint:errcheck
|
||||||
|
case <-time.After(wait.ForeverTestTimeout):
|
||||||
|
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
|
||||||
|
}
|
||||||
|
assert.Equal(t, randomData, actual, "error validating tunneled random data")
|
||||||
|
}
|
||||||
|
|
||||||
|
const responseStr = `HTTP/1.1 101 Switching Protocols
|
||||||
|
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
||||||
|
X-App-Protocol: portforward.k8s.io
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
const responseWithExtraStr = `HTTP/1.1 101 Switching Protocols
|
||||||
|
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
||||||
|
X-App-Protocol: portforward.k8s.io
|
||||||
|
|
||||||
|
This is extra data.
|
||||||
|
`
|
||||||
|
|
||||||
|
const invalidResponseStr = `INVALID/1.1 101 Switching Protocols
|
||||||
|
Date: Sun, 25 Feb 2024 08:09:25 GMT
|
||||||
|
X-App-Protocol: portforward.k8s.io
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
|
||||||
|
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
|
||||||
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}}
|
||||||
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err := hic.Write([]byte(responseStr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
||||||
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
|
||||||
|
assert.Equal(t, 0, len(testConnConstructor.mockConn.written), "no extra data written to net.Conn")
|
||||||
|
// Extra data after response headers should be sent to net.Conn.
|
||||||
|
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err = hic.Write([]byte(responseWithExtraStr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, hic.initialized)
|
||||||
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
||||||
|
assert.Equal(t, "This is extra data.\n", string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
||||||
|
// Invalid response returns error.
|
||||||
|
hic = &headerInterceptingConn{initializableConn: testConnConstructor}
|
||||||
|
_, err = hic.Write([]byte(invalidResponseStr))
|
||||||
|
assert.Error(t, err, "expected error from invalid http response")
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockConnInitializer struct {
|
||||||
|
resp *http.Response
|
||||||
|
*mockConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response) error {
|
||||||
|
m.resp = backendResponse
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockConn implements "net.Conn" interface.
|
||||||
|
var _ net.Conn = &mockConn{}
|
||||||
|
|
||||||
|
type mockConn struct {
|
||||||
|
written []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockConn) Write(p []byte) (int, error) {
|
||||||
|
mc.written = make([]byte, len(p))
|
||||||
|
copy(mc.written, p)
|
||||||
|
return len(mc.written), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
|
||||||
|
func (mc *mockConn) Close() error { return nil }
|
||||||
|
func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||||
|
func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||||
|
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
// fakeResponder implements "rest.Responder" interface.
|
||||||
|
var _ rest.Responder = &fakeResponder{}
|
||||||
|
|
||||||
|
type fakeResponder struct{}
|
||||||
|
|
||||||
|
func (fr *fakeResponder) Object(statusCode int, obj runtime.Object) {}
|
||||||
|
func (fr *fakeResponder) Error(err error) {}
|
||||||
|
|
||||||
|
// justQueueStream skips the usual stream validation before
|
||||||
|
// queueing the stream on the stream channel.
|
||||||
|
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||||
|
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||||
|
streams <- stream
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ httpstream.Dialer = &fallbackDialer{}
|
||||||
|
|
||||||
|
// fallbackDialer encapsulates a primary and secondary dialer, including
|
||||||
|
// the boolean function to determine if the primary dialer failed. Implements
|
||||||
|
// the httpstream.Dialer interface.
|
||||||
|
type fallbackDialer struct {
|
||||||
|
primary httpstream.Dialer
|
||||||
|
secondary httpstream.Dialer
|
||||||
|
shouldFallback func(error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers,
|
||||||
|
// as well as the boolean function to determine if the primary dialer failed.
|
||||||
|
func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer {
|
||||||
|
return &fallbackDialer{
|
||||||
|
primary: primary,
|
||||||
|
secondary: secondary,
|
||||||
|
shouldFallback: shouldFallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial is the single function necessary to implement the "httpstream.Dialer" interface.
|
||||||
|
// It takes the protocol version strings to request, returning an the upgraded
|
||||||
|
// httstream.Connection and the negotiated protocol version accepted. If the initial
|
||||||
|
// primary dialer fails, this function attempts the secondary dialer. Returns an error
|
||||||
|
// if one occurs.
|
||||||
|
func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||||
|
conn, version, err := f.primary.Dial(protocols...)
|
||||||
|
if err != nil && f.shouldFallback(err) {
|
||||||
|
klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
|
||||||
|
return f.secondary.Dial(protocols...)
|
||||||
|
}
|
||||||
|
return conn, version, err
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFallbackDialer(t *testing.T) {
|
||||||
|
protocol := "v6.fake.k8s.io"
|
||||||
|
// If "shouldFallback" is false, then only primary should be dialed.
|
||||||
|
primary := &fakeDialer{dialed: false}
|
||||||
|
primary.negotiatedProtocol = protocol
|
||||||
|
secondary := &fakeDialer{dialed: false}
|
||||||
|
fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse)
|
||||||
|
_, negotiated, err := fallbackDialer.Dial(protocol)
|
||||||
|
assert.True(t, primary.dialed, "no fallback; primary should have dialed")
|
||||||
|
assert.Equal(t, protocol, negotiated, "")
|
||||||
|
assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
|
||||||
|
assert.Nil(t, err, "error should be nil")
|
||||||
|
// If "shouldFallback" is true, then primary AND secondary should be dialed.
|
||||||
|
primary.dialed = false // reset dialed field
|
||||||
|
primary.err = fmt.Errorf("bad handshake")
|
||||||
|
secondary.dialed = false // reset dialed field
|
||||||
|
secondary.negotiatedProtocol = protocol
|
||||||
|
fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue)
|
||||||
|
_, negotiated, err = fallbackDialer.Dial(protocol)
|
||||||
|
assert.True(t, primary.dialed, "fallback; primary should have dialed (first)")
|
||||||
|
assert.True(t, secondary.dialed, "fallback; secondary should have dialed")
|
||||||
|
assert.Equal(t, protocol, negotiated)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func alwaysTrue(err error) bool { return true }
|
||||||
|
|
||||||
|
func alwaysFalse(err error) bool { return false }
|
@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error {
|
|||||||
defer pf.Close()
|
defer pf.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
|
var protocol string
|
||||||
|
pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error upgrading connection: %s", err)
|
return fmt.Errorf("error upgrading connection: %s", err)
|
||||||
}
|
}
|
||||||
defer pf.streamConn.Close()
|
defer pf.streamConn.Close()
|
||||||
|
if protocol != PortForwardProtocolV1Name {
|
||||||
|
return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
|
||||||
|
}
|
||||||
|
|
||||||
return pf.forward()
|
return pf.forward()
|
||||||
}
|
}
|
||||||
|
@ -430,7 +430,8 @@ func TestGetListener(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
||||||
dialer := &fakeDialer{
|
dialer := &fakeDialer{
|
||||||
conn: newFakeConnection(),
|
conn: newFakeConnection(),
|
||||||
|
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
@ -570,7 +571,8 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
|
|||||||
|
|
||||||
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
||||||
dialer := &fakeDialer{
|
dialer := &fakeDialer{
|
||||||
conn: newFakeConnection(),
|
conn: newFakeConnection(),
|
||||||
|
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
@ -601,7 +603,8 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
|||||||
|
|
||||||
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
|
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
|
||||||
dialer := &fakeDialer{
|
dialer := &fakeDialer{
|
||||||
conn: newFakeConnection(),
|
conn: newFakeConnection(),
|
||||||
|
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
|
@ -0,0 +1,153 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gwebsocket "github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ net.Conn = &TunnelingConnection{}
|
||||||
|
|
||||||
|
// TunnelingConnection implements the "httpstream.Connection" interface, wrapping
|
||||||
|
// a websocket connection that tunnels SPDY.
|
||||||
|
type TunnelingConnection struct {
|
||||||
|
name string
|
||||||
|
conn *gwebsocket.Conn
|
||||||
|
inProgressMessage io.Reader
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelingConnection wraps the passed gorilla/websockets connection
|
||||||
|
// with the TunnelingConnection struct (implementing net.Conn).
|
||||||
|
func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
|
||||||
|
return &TunnelingConnection{
|
||||||
|
name: name,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements "io.Reader" interface, reading from the stored connection
|
||||||
|
// into the passed buffer "p". Returns the number of bytes read and an error.
|
||||||
|
// Can keep track of the "inProgress" messsage from the tunneled connection.
|
||||||
|
func (c *TunnelingConnection) Read(p []byte) (int, error) {
|
||||||
|
klog.V(7).Infof("%s: tunneling connection read...", c.name)
|
||||||
|
defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name)
|
||||||
|
for {
|
||||||
|
if c.inProgressMessage == nil {
|
||||||
|
klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name)
|
||||||
|
messageType, nextReader, err := c.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
closeError := &gwebsocket.CloseError{}
|
||||||
|
if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if messageType != gwebsocket.BinaryMessage {
|
||||||
|
return 0, fmt.Errorf("invalid message type received")
|
||||||
|
}
|
||||||
|
c.inProgressMessage = nextReader
|
||||||
|
}
|
||||||
|
klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name)
|
||||||
|
i, err := c.inProgressMessage.Read(p)
|
||||||
|
if i == 0 && err == io.EOF {
|
||||||
|
c.inProgressMessage = nil
|
||||||
|
} else {
|
||||||
|
klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements "io.Writer" interface, copying the data in the passed
|
||||||
|
// byte array "p" into the stored tunneled connection. Returns the number
|
||||||
|
// of bytes written and an error.
|
||||||
|
func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
|
||||||
|
klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p)
|
||||||
|
defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name)
|
||||||
|
w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// close, which flushes the message
|
||||||
|
closeErr := w.Close()
|
||||||
|
if closeErr != nil && err == nil {
|
||||||
|
// if closing/flushing errored and we weren't already returning an error, return the close error
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
n, err = w.Write(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements "io.Closer" interface, signaling the other tunneled connection
|
||||||
|
// endpoint, and closing the tunneled connection only once.
|
||||||
|
func (c *TunnelingConnection) Close() error {
|
||||||
|
var err error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
klog.V(7).Infof("%s: tunneling connection Close()...", c.name)
|
||||||
|
// Signal other endpoint that websocket connection is closing; ignore error.
|
||||||
|
normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
|
||||||
|
c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck
|
||||||
|
err = c.conn.Close()
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements part of the "net.Conn" interface, returning the local
|
||||||
|
// endpoint network address of the tunneled connection.
|
||||||
|
func (c *TunnelingConnection) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements part of the "net.Conn" interface, returning the remote
|
||||||
|
// endpoint network address of the tunneled connection.
|
||||||
|
func (c *TunnelingConnection) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the *absolute* time in the future for both
|
||||||
|
// read and write deadlines. Returns an error if one occurs.
|
||||||
|
func (c *TunnelingConnection) SetDeadline(t time.Time) error {
|
||||||
|
rerr := c.SetReadDeadline(t)
|
||||||
|
werr := c.SetWriteDeadline(t)
|
||||||
|
return errors.Join(rerr, werr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the *absolute* time in the future for the
|
||||||
|
// read deadlines. Returns an error if one occurs.
|
||||||
|
func (c *TunnelingConnection) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the *absolute* time in the future for the
|
||||||
|
// write deadlines. Returns an error if one occurs.
|
||||||
|
func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
@ -0,0 +1,190 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gwebsocket "github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||||
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||||
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/transport/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
|
||||||
|
// Stream channel that will receive streams created on upstream SPDY server.
|
||||||
|
streamChan := make(chan httpstream.Stream)
|
||||||
|
defer close(streamChan)
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
// Create tunneling connection server endpoint with fake upstream SPDY server.
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||||
|
tunnelingConn := NewTunnelingConnection("server", conn)
|
||||||
|
spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer spdyConn.Close() //nolint:errcheck
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Dial the client tunneling connection to the tunneling server.
|
||||||
|
url, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
|
||||||
|
require.NoError(t, err)
|
||||||
|
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, constants.PortForwardV1Name, protocol)
|
||||||
|
defer spdyClient.Close() //nolint:errcheck
|
||||||
|
// Create a SPDY client stream, which will queue a SPDY server stream
|
||||||
|
// on the stream creation channel. Send data on the client stream
|
||||||
|
// reading off the SPDY server stream, and validating it was tunneled.
|
||||||
|
expected := "This is a test tunneling SPDY data through websockets."
|
||||||
|
var actual []byte
|
||||||
|
go func() {
|
||||||
|
clientStream, err := spdyClient.CreateStream(http.Header{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = io.Copy(clientStream, strings.NewReader(expected))
|
||||||
|
require.NoError(t, err)
|
||||||
|
clientStream.Close() //nolint:errcheck
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case serverStream := <-streamChan:
|
||||||
|
actual, err = io.ReadAll(serverStream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer serverStream.Close() //nolint:errcheck
|
||||||
|
case <-time.After(wait.ForeverTestTimeout):
|
||||||
|
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, string(actual), "error validating tunneled string")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Create the client side tunneling connection.
|
||||||
|
url, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tConn, err := dialForTunnelingConnection(url)
|
||||||
|
require.NoError(t, err, "error creating client tunneling connection")
|
||||||
|
defer tConn.Close() //nolint:errcheck
|
||||||
|
// Validate "LocalAddr()" and "RemoteAddr()"
|
||||||
|
localAddr := tConn.LocalAddr()
|
||||||
|
remoteAddr := tConn.RemoteAddr()
|
||||||
|
assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
|
||||||
|
assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
|
||||||
|
_, err = net.ResolveTCPAddr("tcp", localAddr.String())
|
||||||
|
assert.NoError(t, err, "tunneling connection local addr should parse")
|
||||||
|
_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
|
||||||
|
assert.NoError(t, err, "tunneling connection remote addr should parse")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Create the client side tunneling connection.
|
||||||
|
url, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tConn, err := dialForTunnelingConnection(url)
|
||||||
|
require.NoError(t, err, "error creating client tunneling connection")
|
||||||
|
defer tConn.Close() //nolint:errcheck
|
||||||
|
// Validate the read and write deadlines.
|
||||||
|
err = tConn.SetReadDeadline(time.Time{})
|
||||||
|
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||||
|
err = tConn.SetWriteDeadline(time.Time{})
|
||||||
|
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||||
|
err = tConn.SetDeadline(time.Time{})
|
||||||
|
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||||
|
err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
|
||||||
|
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||||
|
err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
|
||||||
|
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||||
|
err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
|
||||||
|
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialForTunnelingConnection upgrades a request at the passed "url", creating
|
||||||
|
// a websocket connection. Returns the TunnelingConnection injected with the
|
||||||
|
// websocket connection or an error if one occurs.
|
||||||
|
func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
|
||||||
|
req, err := http.NewRequest("GET", url.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
|
||||||
|
tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
|
||||||
|
transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewTunnelingConnection("client", conn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||||
|
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||||
|
streams <- stream
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,93 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||||
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||||
|
restclient "k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/transport/websocket"
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PingPeriod = 10 * time.Second
|
||||||
|
|
||||||
|
// tunnelingDialer implements "httpstream.Dial" interface
|
||||||
|
type tunnelingDialer struct {
|
||||||
|
url *url.URL
|
||||||
|
transport http.RoundTripper
|
||||||
|
holder websocket.ConnectionHolder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
|
||||||
|
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
|
||||||
|
// returns an error if one occurs.
|
||||||
|
func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
|
||||||
|
transport, holder, err := websocket.RoundTripperFor(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tunnelingDialer{
|
||||||
|
url: url,
|
||||||
|
transport: transport,
|
||||||
|
holder: holder,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial upgrades to a tunneling streaming connection, returning a SPDY connection
|
||||||
|
// containing a WebSockets connection (which implements "net.Conn"). Also
|
||||||
|
// returns the protocol negotiated, or an error.
|
||||||
|
func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||||
|
// There is no passed context, so skip the context when creating request for now.
|
||||||
|
// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
|
||||||
|
req, err := http.NewRequest("GET", d.url.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
// Add the spdy tunneling prefix to the requested protocols. The tunneling
|
||||||
|
// handler will know how to negotiate these protocols.
|
||||||
|
tunnelingProtocols := []string{}
|
||||||
|
for _, protocol := range protocols {
|
||||||
|
tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
|
||||||
|
tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
|
||||||
|
}
|
||||||
|
klog.V(4).Infoln("Before WebSocket Upgrade Connection...")
|
||||||
|
conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return nil, "", fmt.Errorf("negotiated websocket connection is nil")
|
||||||
|
}
|
||||||
|
protocol := conn.Subprotocol()
|
||||||
|
protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
|
||||||
|
klog.V(4).Infof("negotiated protocol: %s", protocol)
|
||||||
|
|
||||||
|
// Wrap the websocket connection which implements "net.Conn".
|
||||||
|
tConn := NewTunnelingConnection("client", conn)
|
||||||
|
// Create SPDY connection injecting the previously created tunneling connection.
|
||||||
|
spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)
|
||||||
|
|
||||||
|
return spdyConn, protocol, err
|
||||||
|
}
|
@ -31,6 +31,7 @@ import (
|
|||||||
|
|
||||||
corev1 "k8s.io/api/core/v1"
|
corev1 "k8s.io/api/core/v1"
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
"k8s.io/apimachinery/pkg/util/sets"
|
"k8s.io/apimachinery/pkg/util/sets"
|
||||||
"k8s.io/cli-runtime/pkg/genericiooptions"
|
"k8s.io/cli-runtime/pkg/genericiooptions"
|
||||||
"k8s.io/client-go/kubernetes/scheme"
|
"k8s.io/client-go/kubernetes/scheme"
|
||||||
@ -50,7 +51,7 @@ import (
|
|||||||
type PortForwardOptions struct {
|
type PortForwardOptions struct {
|
||||||
Namespace string
|
Namespace string
|
||||||
PodName string
|
PodName string
|
||||||
RESTClient *restclient.RESTClient
|
RESTClient restclient.Interface
|
||||||
Config *restclient.Config
|
Config *restclient.Config
|
||||||
PodClient corev1client.PodsGetter
|
PodClient corev1client.PodsGetter
|
||||||
Address []string
|
Address []string
|
||||||
@ -99,11 +100,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command {
|
func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command {
|
||||||
opts := &PortForwardOptions{
|
opts := NewDefaultPortForwardOptions(streams)
|
||||||
PortForwarder: &defaultPortForwarder{
|
|
||||||
IOStreams: streams,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "port-forward TYPE/NAME [options] [LOCAL_PORT:]REMOTE_PORT [...[LOCAL_PORT_N:]REMOTE_PORT_N]",
|
Use: "port-forward TYPE/NAME [options] [LOCAL_PORT:]REMOTE_PORT [...[LOCAL_PORT_N:]REMOTE_PORT_N]",
|
||||||
DisableFlagsInUseLine: true,
|
DisableFlagsInUseLine: true,
|
||||||
@ -123,6 +120,14 @@ func NewCmdPortForward(f cmdutil.Factory, streams genericiooptions.IOStreams) *c
|
|||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewDefaultPortForwardOptions(streams genericiooptions.IOStreams) *PortForwardOptions {
|
||||||
|
return &PortForwardOptions{
|
||||||
|
PortForwarder: &defaultPortForwarder{
|
||||||
|
IOStreams: streams,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type portForwarder interface {
|
type portForwarder interface {
|
||||||
ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error
|
ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error
|
||||||
}
|
}
|
||||||
@ -137,6 +142,14 @@ func (f *defaultPortForwarder) ForwardPorts(method string, url *url.URL, opts Po
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url)
|
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url)
|
||||||
|
if cmdutil.PortForwardWebsockets.IsEnabled() {
|
||||||
|
tunnelingDialer, err := portforward.NewSPDYOverWebsocketDialer(url, opts.Config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// First attempt tunneling (websocket) dialer, then fallback to spdy dialer.
|
||||||
|
dialer = portforward.NewFallbackDialer(tunnelingDialer, dialer, httpstream.IsUpgradeFailure)
|
||||||
|
}
|
||||||
fw, err := portforward.NewOnAddresses(dialer, opts.Address, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.Out, f.ErrOut)
|
fw, err := portforward.NewOnAddresses(dialer, opts.Address, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.Out, f.ErrOut)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -387,7 +400,14 @@ func (o PortForwardOptions) Validate() error {
|
|||||||
|
|
||||||
// RunPortForward implements all the necessary functionality for port-forward cmd.
|
// RunPortForward implements all the necessary functionality for port-forward cmd.
|
||||||
func (o PortForwardOptions) RunPortForward() error {
|
func (o PortForwardOptions) RunPortForward() error {
|
||||||
pod, err := o.PodClient.Pods(o.Namespace).Get(context.TODO(), o.PodName, metav1.GetOptions{})
|
return o.RunPortForwardContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunPortForwardContext implements all the necessary functionality for port-forward cmd.
|
||||||
|
// It ends portforwarding when an error is received from the backend, or an os.Interrupt
|
||||||
|
// signal is received, or the provided context is done.
|
||||||
|
func (o PortForwardOptions) RunPortForwardContext(ctx context.Context) error {
|
||||||
|
pod, err := o.PodClient.Pods(o.Namespace).Get(ctx, o.PodName, metav1.GetOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -401,7 +421,10 @@ func (o PortForwardOptions) RunPortForward() error {
|
|||||||
defer signal.Stop(signals)
|
defer signal.Stop(signals)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-signals
|
select {
|
||||||
|
case <-signals:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
if o.StopChannel != nil {
|
if o.StopChannel != nil {
|
||||||
close(o.StopChannel)
|
close(o.StopChannel)
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
package portforward
|
package portforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -101,6 +102,8 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
opts := &PortForwardOptions{}
|
opts := &PortForwardOptions{}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
cmd := NewCmdPortForward(tf, genericiooptions.NewTestIOStreamsDiscard())
|
cmd := NewCmdPortForward(tf, genericiooptions.NewTestIOStreamsDiscard())
|
||||||
cmd.Run = func(cmd *cobra.Command, args []string) {
|
cmd.Run = func(cmd *cobra.Command, args []string) {
|
||||||
if err = opts.Complete(tf, cmd, args); err != nil {
|
if err = opts.Complete(tf, cmd, args); err != nil {
|
||||||
@ -110,7 +113,7 @@ func testPortForward(t *testing.T, flags map[string]string, args []string) {
|
|||||||
if err = opts.Validate(); err != nil {
|
if err = opts.Validate(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = opts.RunPortForward()
|
err = opts.RunPortForwardContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, value := range flags {
|
for name, value := range flags {
|
||||||
|
@ -430,6 +430,7 @@ const (
|
|||||||
InteractiveDelete FeatureGate = "KUBECTL_INTERACTIVE_DELETE"
|
InteractiveDelete FeatureGate = "KUBECTL_INTERACTIVE_DELETE"
|
||||||
OpenAPIV3Patch FeatureGate = "KUBECTL_OPENAPIV3_PATCH"
|
OpenAPIV3Patch FeatureGate = "KUBECTL_OPENAPIV3_PATCH"
|
||||||
RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS"
|
RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS"
|
||||||
|
PortForwardWebsockets FeatureGate = "KUBECTL_PORT_FORWARD_WEBSOCKETS"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IsEnabled returns true iff environment variable is set to true.
|
// IsEnabled returns true iff environment variable is set to true.
|
||||||
|
27
test/integration/apiserver/portforward/main_test.go
Normal file
27
test/integration/apiserver/portforward/main_test.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"k8s.io/kubernetes/test/integration/framework"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
framework.EtcdMain(m.Run)
|
||||||
|
}
|
228
test/integration/apiserver/portforward/portforward_test.go
Normal file
228
test/integration/apiserver/portforward/portforward_test.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
corev1 "k8s.io/api/core/v1"
|
||||||
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
"k8s.io/apimachinery/pkg/types"
|
||||||
|
"k8s.io/apimachinery/pkg/util/remotecommand"
|
||||||
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
utilfeature "k8s.io/apiserver/pkg/util/feature"
|
||||||
|
"k8s.io/cli-runtime/pkg/genericiooptions"
|
||||||
|
"k8s.io/client-go/kubernetes"
|
||||||
|
featuregatetesting "k8s.io/component-base/featuregate/testing"
|
||||||
|
"k8s.io/kubectl/pkg/cmd/portforward"
|
||||||
|
kubeletportforward "k8s.io/kubelet/pkg/cri/streaming/portforward"
|
||||||
|
kastesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing"
|
||||||
|
kubefeatures "k8s.io/kubernetes/pkg/features"
|
||||||
|
|
||||||
|
"k8s.io/kubernetes/test/integration/framework"
|
||||||
|
)
|
||||||
|
|
||||||
|
const remotePort = "8765"
|
||||||
|
|
||||||
|
func TestPortforward(t *testing.T) {
|
||||||
|
defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, kubefeatures.PortForwardWebsockets, true)()
|
||||||
|
t.Setenv("KUBECTL_PORT_FORWARD_WEBSOCKETS", "true")
|
||||||
|
|
||||||
|
var podName string
|
||||||
|
var podUID types.UID
|
||||||
|
backendServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
t.Logf("backend saw request: %v", req.URL.String())
|
||||||
|
kubeletportforward.ServePortForward(
|
||||||
|
w,
|
||||||
|
req,
|
||||||
|
&dummyPortForwarder{t: t},
|
||||||
|
podName,
|
||||||
|
podUID,
|
||||||
|
&kubeletportforward.V4Options{},
|
||||||
|
wait.ForeverTestTimeout, // idle timeout
|
||||||
|
remotecommand.DefaultStreamCreationTimeout, // stream creation timeout
|
||||||
|
[]string{kubeletportforward.ProtocolV1Name},
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
defer backendServer.Close()
|
||||||
|
backendURL, _ := url.Parse(backendServer.URL)
|
||||||
|
backendHost := backendURL.Hostname()
|
||||||
|
backendPort, _ := strconv.Atoi(backendURL.Port())
|
||||||
|
|
||||||
|
etcd := framework.SharedEtcd()
|
||||||
|
server := kastesting.StartTestServerOrDie(t, nil, []string{"--disable-admission-plugins=ServiceAccount"}, etcd)
|
||||||
|
defer server.TearDownFn()
|
||||||
|
|
||||||
|
adminClient, err := kubernetes.NewForConfig(server.ClientConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
node := &corev1.Node{
|
||||||
|
ObjectMeta: metav1.ObjectMeta{Name: "mynode"},
|
||||||
|
Status: corev1.NodeStatus{
|
||||||
|
DaemonEndpoints: corev1.NodeDaemonEndpoints{KubeletEndpoint: corev1.DaemonEndpoint{Port: int32(backendPort)}},
|
||||||
|
Addresses: []corev1.NodeAddress{{Type: corev1.NodeInternalIP, Address: backendHost}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := adminClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pod := &corev1.Pod{
|
||||||
|
ObjectMeta: metav1.ObjectMeta{Namespace: "default", Name: "mypod"},
|
||||||
|
Spec: corev1.PodSpec{
|
||||||
|
NodeName: "mynode",
|
||||||
|
Containers: []corev1.Container{{Name: "test", Image: "test"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := adminClient.CoreV1().Pods("default").Create(context.Background(), pod, metav1.CreateOptions{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := adminClient.CoreV1().Pods("default").Patch(context.Background(), "mypod", types.MergePatchType, []byte(`{"status":{"phase":"Running"}}`), metav1.PatchOptions{}, "status"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// local port missing asks os to find random open port.
|
||||||
|
// Example: ":8000" (local = random, remote = 8000)
|
||||||
|
localRemotePort := fmt.Sprintf(":%s", remotePort)
|
||||||
|
streams, _, out, errOut := genericiooptions.NewTestIOStreams()
|
||||||
|
portForwardOptions := portforward.NewDefaultPortForwardOptions(streams)
|
||||||
|
portForwardOptions.Namespace = "default"
|
||||||
|
portForwardOptions.PodName = "mypod"
|
||||||
|
portForwardOptions.RESTClient = adminClient.CoreV1().RESTClient()
|
||||||
|
portForwardOptions.Config = server.ClientConfig
|
||||||
|
portForwardOptions.PodClient = adminClient.CoreV1()
|
||||||
|
portForwardOptions.Address = []string{"127.0.0.1"}
|
||||||
|
portForwardOptions.Ports = []string{localRemotePort}
|
||||||
|
portForwardOptions.StopChannel = make(chan struct{}, 1)
|
||||||
|
portForwardOptions.ReadyChannel = make(chan struct{})
|
||||||
|
|
||||||
|
if err := portForwardOptions.Validate(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := portForwardOptions.RunPortForwardContext(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Log("waiting for port forward to be ready")
|
||||||
|
select {
|
||||||
|
case <-portForwardOptions.ReadyChannel:
|
||||||
|
t.Log("port forward was ready")
|
||||||
|
case <-time.After(wait.ForeverTestTimeout):
|
||||||
|
t.Error("port forward was never ready")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse out the randomly selected local port from "out" stream.
|
||||||
|
localPort, err := parsePort(out.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("Local Port: %s", localPort)
|
||||||
|
|
||||||
|
timeoutContext, cleanupTimeoutContext := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
||||||
|
defer cleanupTimeoutContext()
|
||||||
|
testReq, _ := http.NewRequest("GET", fmt.Sprintf("http://127.0.0.1:%s/test", localPort), nil)
|
||||||
|
testReq = testReq.WithContext(timeoutContext)
|
||||||
|
testResp, err := http.DefaultClient.Do(testReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
t.Log(testResp.StatusCode)
|
||||||
|
data, err := io.ReadAll(testResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
t.Log("client saw response:", string(data))
|
||||||
|
}
|
||||||
|
if string(data) != fmt.Sprintf("request to %s was ok", remotePort) {
|
||||||
|
t.Errorf("unexpected data")
|
||||||
|
}
|
||||||
|
if testResp.StatusCode != 200 {
|
||||||
|
t.Error("expected success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
t.Logf("stdout: %s", out.String())
|
||||||
|
t.Logf("stderr: %s", errOut.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePort parses out the local port from the port-forward output string.
|
||||||
|
// This should work for both IP4 and IP6 addresses.
|
||||||
|
//
|
||||||
|
// Example: "Forwarding from 127.0.0.1:8000 -> 4000", returns "8000".
|
||||||
|
func parsePort(forwardAddr string) (string, error) {
|
||||||
|
parts := strings.Split(forwardAddr, " ")
|
||||||
|
if len(parts) != 5 {
|
||||||
|
return "", fmt.Errorf("unable to parse local port from stdout: %s", forwardAddr)
|
||||||
|
}
|
||||||
|
// parts[2] = "127.0.0.1:<LOCAL_PORT>"
|
||||||
|
_, localPort, err := net.SplitHostPort(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("unable to parse local port: %w", err)
|
||||||
|
}
|
||||||
|
return localPort, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dummyPortForwarder struct {
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dummyPortForwarder) PortForward(ctx context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
|
||||||
|
d.t.Logf("handling port forward request for %d", port)
|
||||||
|
|
||||||
|
req, err := http.ReadRequest(bufio.NewReader(stream))
|
||||||
|
if err != nil {
|
||||||
|
d.t.Logf("error reading request: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.t.Log(req.URL.String())
|
||||||
|
defer req.Body.Close() //nolint:errcheck
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString(fmt.Sprintf("request to %d was ok", port))),
|
||||||
|
}
|
||||||
|
resp.Write(stream) //nolint:errcheck
|
||||||
|
return stream.Close()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user