diff --git a/tools/portforward/fallback_dialer.go b/tools/portforward/fallback_dialer.go new file mode 100644 index 00000000..8fb74a41 --- /dev/null +++ b/tools/portforward/fallback_dialer.go @@ -0,0 +1,57 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/klog/v2" +) + +var _ httpstream.Dialer = &fallbackDialer{} + +// fallbackDialer encapsulates a primary and secondary dialer, including +// the boolean function to determine if the primary dialer failed. Implements +// the httpstream.Dialer interface. +type fallbackDialer struct { + primary httpstream.Dialer + secondary httpstream.Dialer + shouldFallback func(error) bool +} + +// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers, +// as well as the boolean function to determine if the primary dialer failed. +func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer { + return &fallbackDialer{ + primary: primary, + secondary: secondary, + shouldFallback: shouldFallback, + } +} + +// Dial is the single function necessary to implement the "httpstream.Dialer" interface. +// It takes the protocol version strings to request, returning an the upgraded +// httstream.Connection and the negotiated protocol version accepted. If the initial +// primary dialer fails, this function attempts the secondary dialer. Returns an error +// if one occurs. +func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { + conn, version, err := f.primary.Dial(protocols...) + if err != nil && f.shouldFallback(err) { + klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err) + return f.secondary.Dial(protocols...) + } + return conn, version, err +} diff --git a/tools/portforward/fallback_dialer_test.go b/tools/portforward/fallback_dialer_test.go new file mode 100644 index 00000000..4680fa29 --- /dev/null +++ b/tools/portforward/fallback_dialer_test.go @@ -0,0 +1,53 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFallbackDialer(t *testing.T) { + protocol := "v6.fake.k8s.io" + // If "shouldFallback" is false, then only primary should be dialed. + primary := &fakeDialer{dialed: false} + primary.negotiatedProtocol = protocol + secondary := &fakeDialer{dialed: false} + fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse) + _, negotiated, err := fallbackDialer.Dial(protocol) + assert.True(t, primary.dialed, "no fallback; primary should have dialed") + assert.Equal(t, protocol, negotiated, "") + assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed") + assert.Nil(t, err, "error should be nil") + // If "shouldFallback" is true, then primary AND secondary should be dialed. + primary.dialed = false // reset dialed field + primary.err = fmt.Errorf("bad handshake") + secondary.dialed = false // reset dialed field + secondary.negotiatedProtocol = protocol + fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue) + _, negotiated, err = fallbackDialer.Dial(protocol) + assert.True(t, primary.dialed, "fallback; primary should have dialed (first)") + assert.True(t, secondary.dialed, "fallback; secondary should have dialed") + assert.Equal(t, protocol, negotiated) + assert.Nil(t, err) +} + +func alwaysTrue(err error) bool { return true } + +func alwaysFalse(err error) bool { return false } diff --git a/tools/portforward/portforward.go b/tools/portforward/portforward.go index b581043f..83ef3e92 100644 --- a/tools/portforward/portforward.go +++ b/tools/portforward/portforward.go @@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name) + var protocol string + pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } defer pf.streamConn.Close() + if protocol != PortForwardProtocolV1Name { + return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol) + } return pf.forward() } diff --git a/tools/portforward/portforward_test.go b/tools/portforward/portforward_test.go index 3c90a3fd..075a22e6 100644 --- a/tools/portforward/portforward_test.go +++ b/tools/portforward/portforward_test.go @@ -430,7 +430,8 @@ func TestGetListener(t *testing.T) { func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) { dialer := &fakeDialer{ - conn: newFakeConnection(), + conn: newFakeConnection(), + negotiatedProtocol: PortForwardProtocolV1Name, } stopChan := make(chan struct{}) @@ -570,7 +571,8 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) { func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) { dialer := &fakeDialer{ - conn: newFakeConnection(), + conn: newFakeConnection(), + negotiatedProtocol: PortForwardProtocolV1Name, } stopChan := make(chan struct{}) @@ -601,7 +603,8 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) { func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) { dialer := &fakeDialer{ - conn: newFakeConnection(), + conn: newFakeConnection(), + negotiatedProtocol: PortForwardProtocolV1Name, } stopChan := make(chan struct{}) diff --git a/tools/portforward/tunneling_connection.go b/tools/portforward/tunneling_connection.go new file mode 100644 index 00000000..4c04531b --- /dev/null +++ b/tools/portforward/tunneling_connection.go @@ -0,0 +1,153 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + gwebsocket "github.com/gorilla/websocket" + + "k8s.io/klog/v2" +) + +var _ net.Conn = &TunnelingConnection{} + +// TunnelingConnection implements the "httpstream.Connection" interface, wrapping +// a websocket connection that tunnels SPDY. +type TunnelingConnection struct { + name string + conn *gwebsocket.Conn + inProgressMessage io.Reader + closeOnce sync.Once +} + +// NewTunnelingConnection wraps the passed gorilla/websockets connection +// with the TunnelingConnection struct (implementing net.Conn). +func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection { + return &TunnelingConnection{ + name: name, + conn: conn, + } +} + +// Read implements "io.Reader" interface, reading from the stored connection +// into the passed buffer "p". Returns the number of bytes read and an error. +// Can keep track of the "inProgress" messsage from the tunneled connection. +func (c *TunnelingConnection) Read(p []byte) (int, error) { + klog.V(7).Infof("%s: tunneling connection read...", c.name) + defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name) + for { + if c.inProgressMessage == nil { + klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name) + messageType, nextReader, err := c.conn.NextReader() + if err != nil { + closeError := &gwebsocket.CloseError{} + if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure { + return 0, io.EOF + } + klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err) + return 0, err + } + if messageType != gwebsocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type received") + } + c.inProgressMessage = nextReader + } + klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name) + i, err := c.inProgressMessage.Read(p) + if i == 0 && err == io.EOF { + c.inProgressMessage = nil + } else { + klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i]) + return i, err + } + } +} + +// Write implements "io.Writer" interface, copying the data in the passed +// byte array "p" into the stored tunneled connection. Returns the number +// of bytes written and an error. +func (c *TunnelingConnection) Write(p []byte) (n int, err error) { + klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p) + defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name) + w, err := c.conn.NextWriter(gwebsocket.BinaryMessage) + if err != nil { + return 0, err + } + defer func() { + // close, which flushes the message + closeErr := w.Close() + if closeErr != nil && err == nil { + // if closing/flushing errored and we weren't already returning an error, return the close error + err = closeErr + } + }() + + n, err = w.Write(p) + return +} + +// Close implements "io.Closer" interface, signaling the other tunneled connection +// endpoint, and closing the tunneled connection only once. +func (c *TunnelingConnection) Close() error { + var err error + c.closeOnce.Do(func() { + klog.V(7).Infof("%s: tunneling connection Close()...", c.name) + // Signal other endpoint that websocket connection is closing; ignore error. + normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "") + c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck + err = c.conn.Close() + }) + return err +} + +// LocalAddr implements part of the "net.Conn" interface, returning the local +// endpoint network address of the tunneled connection. +func (c *TunnelingConnection) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// LocalAddr implements part of the "net.Conn" interface, returning the remote +// endpoint network address of the tunneled connection. +func (c *TunnelingConnection) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the *absolute* time in the future for both +// read and write deadlines. Returns an error if one occurs. +func (c *TunnelingConnection) SetDeadline(t time.Time) error { + rerr := c.SetReadDeadline(t) + werr := c.SetWriteDeadline(t) + return errors.Join(rerr, werr) +} + +// SetDeadline sets the *absolute* time in the future for the +// read deadlines. Returns an error if one occurs. +func (c *TunnelingConnection) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetDeadline sets the *absolute* time in the future for the +// write deadlines. Returns an error if one occurs. +func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/tools/portforward/tunneling_connection_test.go b/tools/portforward/tunneling_connection_test.go new file mode 100644 index 00000000..4127f49d --- /dev/null +++ b/tools/portforward/tunneling_connection_test.go @@ -0,0 +1,190 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + gwebsocket "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + constants "k8s.io/apimachinery/pkg/util/portforward" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + "k8s.io/client-go/transport/websocket" +) + +func TestTunnelingConnection_ReadWriteClose(t *testing.T) { + // Stream channel that will receive streams created on upstream SPDY server. + streamChan := make(chan httpstream.Stream) + defer close(streamChan) + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + // Create tunneling connection server endpoint with fake upstream SPDY server. + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, + } + conn, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) + tunnelingConn := NewTunnelingConnection("server", conn) + spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan)) + require.NoError(t, err) + defer spdyConn.Close() //nolint:errcheck + <-stopServerChan + })) + defer tunnelingServer.Close() + // Dial the client tunneling connection to the tunneling server. + url, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host}) + require.NoError(t, err) + spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name) + require.NoError(t, err) + assert.Equal(t, constants.PortForwardV1Name, protocol) + defer spdyClient.Close() //nolint:errcheck + // Create a SPDY client stream, which will queue a SPDY server stream + // on the stream creation channel. Send data on the client stream + // reading off the SPDY server stream, and validating it was tunneled. + expected := "This is a test tunneling SPDY data through websockets." + var actual []byte + go func() { + clientStream, err := spdyClient.CreateStream(http.Header{}) + require.NoError(t, err) + _, err = io.Copy(clientStream, strings.NewReader(expected)) + require.NoError(t, err) + clientStream.Close() //nolint:errcheck + }() + select { + case serverStream := <-streamChan: + actual, err = io.ReadAll(serverStream) + require.NoError(t, err) + defer serverStream.Close() //nolint:errcheck + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("timeout waiting for spdy stream to arrive on channel.") + } + assert.Equal(t, expected, string(actual), "error validating tunneled string") +} + +func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) { + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, + } + conn, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) + <-stopServerChan + })) + defer tunnelingServer.Close() + // Create the client side tunneling connection. + url, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + tConn, err := dialForTunnelingConnection(url) + require.NoError(t, err, "error creating client tunneling connection") + defer tConn.Close() //nolint:errcheck + // Validate "LocalAddr()" and "RemoteAddr()" + localAddr := tConn.LocalAddr() + remoteAddr := tConn.RemoteAddr() + assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP") + assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP") + _, err = net.ResolveTCPAddr("tcp", localAddr.String()) + assert.NoError(t, err, "tunneling connection local addr should parse") + _, err = net.ResolveTCPAddr("tcp", remoteAddr.String()) + assert.NoError(t, err, "tunneling connection remote addr should parse") +} + +func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) { + stopServerChan := make(chan struct{}) + defer close(stopServerChan) + tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1}, + } + conn, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol()) + <-stopServerChan + })) + defer tunnelingServer.Close() + // Create the client side tunneling connection. + url, err := url.Parse(tunnelingServer.URL) + require.NoError(t, err) + tConn, err := dialForTunnelingConnection(url) + require.NoError(t, err, "error creating client tunneling connection") + defer tConn.Close() //nolint:errcheck + // Validate the read and write deadlines. + err = tConn.SetReadDeadline(time.Time{}) + assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") + err = tConn.SetWriteDeadline(time.Time{}) + assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") + err = tConn.SetDeadline(time.Time{}) + assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline") + err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0)) + assert.NoError(t, err, "setting deadline 10 year from now succeeds") + err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0)) + assert.NoError(t, err, "setting deadline 10 year from now succeeds") + err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0)) + assert.NoError(t, err, "setting deadline 10 year from now succeeds") +} + +// dialForTunnelingConnection upgrades a request at the passed "url", creating +// a websocket connection. Returns the TunnelingConnection injected with the +// websocket connection or an error if one occurs. +func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) { + req, err := http.NewRequest("GET", url.String(), nil) + if err != nil { + return nil, err + } + // Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol. + tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1} + transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host}) + if err != nil { + return nil, err + } + conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...) + if err != nil { + return nil, err + } + return NewTunnelingConnection("client", conn), nil +} + +func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { + streams <- stream + return nil + } +} diff --git a/tools/portforward/tunneling_dialer.go b/tools/portforward/tunneling_dialer.go new file mode 100644 index 00000000..2bef5ecd --- /dev/null +++ b/tools/portforward/tunneling_dialer.go @@ -0,0 +1,93 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + constants "k8s.io/apimachinery/pkg/util/portforward" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport/websocket" + "k8s.io/klog/v2" +) + +const PingPeriod = 10 * time.Second + +// tunnelingDialer implements "httpstream.Dial" interface +type tunnelingDialer struct { + url *url.URL + transport http.RoundTripper + holder websocket.ConnectionHolder +} + +// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer" +// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function +// returns an error if one occurs. +func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) { + transport, holder, err := websocket.RoundTripperFor(config) + if err != nil { + return nil, err + } + return &tunnelingDialer{ + url: url, + transport: transport, + holder: holder, + }, nil +} + +// Dial upgrades to a tunneling streaming connection, returning a SPDY connection +// containing a WebSockets connection (which implements "net.Conn"). Also +// returns the protocol negotiated, or an error. +func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { + // There is no passed context, so skip the context when creating request for now. + // Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17). + req, err := http.NewRequest("GET", d.url.String(), nil) + if err != nil { + return nil, "", err + } + // Add the spdy tunneling prefix to the requested protocols. The tunneling + // handler will know how to negotiate these protocols. + tunnelingProtocols := []string{} + for _, protocol := range protocols { + tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol + tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol) + } + klog.V(4).Infoln("Before WebSocket Upgrade Connection...") + conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...) + if err != nil { + return nil, "", err + } + if conn == nil { + return nil, "", fmt.Errorf("negotiated websocket connection is nil") + } + protocol := conn.Subprotocol() + protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) + klog.V(4).Infof("negotiated protocol: %s", protocol) + + // Wrap the websocket connection which implements "net.Conn". + tConn := NewTunnelingConnection("client", conn) + // Create SPDY connection injecting the previously created tunneling connection. + spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod) + + return spdyConn, protocol, err +}