portforward: tunnel spdy through websockets

Kubernetes-commit: 8b447d8c97e8823b4308eb91cf7d75693e867c61
This commit is contained in:
Sean Sullivan 2024-02-21 08:56:07 +00:00 committed by Kubernetes Publisher
parent 08128e0dfa
commit 271d034e86
7 changed files with 557 additions and 4 deletions

View File

@ -0,0 +1,57 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/klog/v2"
)
var _ httpstream.Dialer = &fallbackDialer{}
// fallbackDialer encapsulates a primary and secondary dialer, including
// the boolean function to determine if the primary dialer failed. Implements
// the httpstream.Dialer interface.
type fallbackDialer struct {
primary httpstream.Dialer
secondary httpstream.Dialer
shouldFallback func(error) bool
}
// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers,
// as well as the boolean function to determine if the primary dialer failed.
func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer {
return &fallbackDialer{
primary: primary,
secondary: secondary,
shouldFallback: shouldFallback,
}
}
// Dial is the single function necessary to implement the "httpstream.Dialer" interface.
// It takes the protocol version strings to request, returning an the upgraded
// httstream.Connection and the negotiated protocol version accepted. If the initial
// primary dialer fails, this function attempts the secondary dialer. Returns an error
// if one occurs.
func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
conn, version, err := f.primary.Dial(protocols...)
if err != nil && f.shouldFallback(err) {
klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
return f.secondary.Dial(protocols...)
}
return conn, version, err
}

View File

@ -0,0 +1,53 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFallbackDialer(t *testing.T) {
protocol := "v6.fake.k8s.io"
// If "shouldFallback" is false, then only primary should be dialed.
primary := &fakeDialer{dialed: false}
primary.negotiatedProtocol = protocol
secondary := &fakeDialer{dialed: false}
fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse)
_, negotiated, err := fallbackDialer.Dial(protocol)
assert.True(t, primary.dialed, "no fallback; primary should have dialed")
assert.Equal(t, protocol, negotiated, "")
assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
assert.Nil(t, err, "error should be nil")
// If "shouldFallback" is true, then primary AND secondary should be dialed.
primary.dialed = false // reset dialed field
primary.err = fmt.Errorf("bad handshake")
secondary.dialed = false // reset dialed field
secondary.negotiatedProtocol = protocol
fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue)
_, negotiated, err = fallbackDialer.Dial(protocol)
assert.True(t, primary.dialed, "fallback; primary should have dialed (first)")
assert.True(t, secondary.dialed, "fallback; secondary should have dialed")
assert.Equal(t, protocol, negotiated)
assert.Nil(t, err)
}
func alwaysTrue(err error) bool { return true }
func alwaysFalse(err error) bool { return false }

View File

@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
var protocol string
pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
if protocol != PortForwardProtocolV1Name {
return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
}
return pf.forward()
}

View File

@ -430,7 +430,8 @@ func TestGetListener(t *testing.T) {
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
dialer := &fakeDialer{
conn: newFakeConnection(),
conn: newFakeConnection(),
negotiatedProtocol: PortForwardProtocolV1Name,
}
stopChan := make(chan struct{})
@ -570,7 +571,8 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
dialer := &fakeDialer{
conn: newFakeConnection(),
conn: newFakeConnection(),
negotiatedProtocol: PortForwardProtocolV1Name,
}
stopChan := make(chan struct{})
@ -601,7 +603,8 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
dialer := &fakeDialer{
conn: newFakeConnection(),
conn: newFakeConnection(),
negotiatedProtocol: PortForwardProtocolV1Name,
}
stopChan := make(chan struct{})

View File

@ -0,0 +1,153 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"io"
"net"
"sync"
"time"
gwebsocket "github.com/gorilla/websocket"
"k8s.io/klog/v2"
)
var _ net.Conn = &TunnelingConnection{}
// TunnelingConnection implements the "httpstream.Connection" interface, wrapping
// a websocket connection that tunnels SPDY.
type TunnelingConnection struct {
name string
conn *gwebsocket.Conn
inProgressMessage io.Reader
closeOnce sync.Once
}
// NewTunnelingConnection wraps the passed gorilla/websockets connection
// with the TunnelingConnection struct (implementing net.Conn).
func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
return &TunnelingConnection{
name: name,
conn: conn,
}
}
// Read implements "io.Reader" interface, reading from the stored connection
// into the passed buffer "p". Returns the number of bytes read and an error.
// Can keep track of the "inProgress" messsage from the tunneled connection.
func (c *TunnelingConnection) Read(p []byte) (int, error) {
klog.V(7).Infof("%s: tunneling connection read...", c.name)
defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name)
for {
if c.inProgressMessage == nil {
klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name)
messageType, nextReader, err := c.conn.NextReader()
if err != nil {
closeError := &gwebsocket.CloseError{}
if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
return 0, io.EOF
}
klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err)
return 0, err
}
if messageType != gwebsocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type received")
}
c.inProgressMessage = nextReader
}
klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name)
i, err := c.inProgressMessage.Read(p)
if i == 0 && err == io.EOF {
c.inProgressMessage = nil
} else {
klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
return i, err
}
}
}
// Write implements "io.Writer" interface, copying the data in the passed
// byte array "p" into the stored tunneled connection. Returns the number
// of bytes written and an error.
func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p)
defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name)
w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
if err != nil {
return 0, err
}
defer func() {
// close, which flushes the message
closeErr := w.Close()
if closeErr != nil && err == nil {
// if closing/flushing errored and we weren't already returning an error, return the close error
err = closeErr
}
}()
n, err = w.Write(p)
return
}
// Close implements "io.Closer" interface, signaling the other tunneled connection
// endpoint, and closing the tunneled connection only once.
func (c *TunnelingConnection) Close() error {
var err error
c.closeOnce.Do(func() {
klog.V(7).Infof("%s: tunneling connection Close()...", c.name)
// Signal other endpoint that websocket connection is closing; ignore error.
normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck
err = c.conn.Close()
})
return err
}
// LocalAddr implements part of the "net.Conn" interface, returning the local
// endpoint network address of the tunneled connection.
func (c *TunnelingConnection) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// LocalAddr implements part of the "net.Conn" interface, returning the remote
// endpoint network address of the tunneled connection.
func (c *TunnelingConnection) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the *absolute* time in the future for both
// read and write deadlines. Returns an error if one occurs.
func (c *TunnelingConnection) SetDeadline(t time.Time) error {
rerr := c.SetReadDeadline(t)
werr := c.SetWriteDeadline(t)
return errors.Join(rerr, werr)
}
// SetDeadline sets the *absolute* time in the future for the
// read deadlines. Returns an error if one occurs.
func (c *TunnelingConnection) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetDeadline sets the *absolute* time in the future for the
// write deadlines. Returns an error if one occurs.
func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@ -0,0 +1,190 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
gwebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/portforward"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
"k8s.io/client-go/transport/websocket"
)
func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
// Stream channel that will receive streams created on upstream SPDY server.
streamChan := make(chan httpstream.Stream)
defer close(streamChan)
stopServerChan := make(chan struct{})
defer close(stopServerChan)
// Create tunneling connection server endpoint with fake upstream SPDY server.
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
}
conn, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
defer conn.Close() //nolint:errcheck
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
tunnelingConn := NewTunnelingConnection("server", conn)
spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
require.NoError(t, err)
defer spdyConn.Close() //nolint:errcheck
<-stopServerChan
}))
defer tunnelingServer.Close()
// Dial the client tunneling connection to the tunneling server.
url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
require.NoError(t, err)
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
require.NoError(t, err)
assert.Equal(t, constants.PortForwardV1Name, protocol)
defer spdyClient.Close() //nolint:errcheck
// Create a SPDY client stream, which will queue a SPDY server stream
// on the stream creation channel. Send data on the client stream
// reading off the SPDY server stream, and validating it was tunneled.
expected := "This is a test tunneling SPDY data through websockets."
var actual []byte
go func() {
clientStream, err := spdyClient.CreateStream(http.Header{})
require.NoError(t, err)
_, err = io.Copy(clientStream, strings.NewReader(expected))
require.NoError(t, err)
clientStream.Close() //nolint:errcheck
}()
select {
case serverStream := <-streamChan:
actual, err = io.ReadAll(serverStream)
require.NoError(t, err)
defer serverStream.Close() //nolint:errcheck
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
}
assert.Equal(t, expected, string(actual), "error validating tunneled string")
}
func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
stopServerChan := make(chan struct{})
defer close(stopServerChan)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
}
conn, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
defer conn.Close() //nolint:errcheck
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
<-stopServerChan
}))
defer tunnelingServer.Close()
// Create the client side tunneling connection.
url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
tConn, err := dialForTunnelingConnection(url)
require.NoError(t, err, "error creating client tunneling connection")
defer tConn.Close() //nolint:errcheck
// Validate "LocalAddr()" and "RemoteAddr()"
localAddr := tConn.LocalAddr()
remoteAddr := tConn.RemoteAddr()
assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
_, err = net.ResolveTCPAddr("tcp", localAddr.String())
assert.NoError(t, err, "tunneling connection local addr should parse")
_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
assert.NoError(t, err, "tunneling connection remote addr should parse")
}
func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
stopServerChan := make(chan struct{})
defer close(stopServerChan)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var upgrader = gwebsocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
}
conn, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
defer conn.Close() //nolint:errcheck
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
<-stopServerChan
}))
defer tunnelingServer.Close()
// Create the client side tunneling connection.
url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
tConn, err := dialForTunnelingConnection(url)
require.NoError(t, err, "error creating client tunneling connection")
defer tConn.Close() //nolint:errcheck
// Validate the read and write deadlines.
err = tConn.SetReadDeadline(time.Time{})
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
err = tConn.SetWriteDeadline(time.Time{})
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
err = tConn.SetDeadline(time.Time{})
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
}
// dialForTunnelingConnection upgrades a request at the passed "url", creating
// a websocket connection. Returns the TunnelingConnection injected with the
// websocket connection or an error if one occurs.
func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return nil, err
}
// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
if err != nil {
return nil, err
}
conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
if err != nil {
return nil, err
}
return NewTunnelingConnection("client", conn), nil
}
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
streams <- stream
return nil
}
}

View File

@ -0,0 +1,93 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/portforward"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/transport/websocket"
"k8s.io/klog/v2"
)
const PingPeriod = 10 * time.Second
// tunnelingDialer implements "httpstream.Dial" interface
type tunnelingDialer struct {
url *url.URL
transport http.RoundTripper
holder websocket.ConnectionHolder
}
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
// returns an error if one occurs.
func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
transport, holder, err := websocket.RoundTripperFor(config)
if err != nil {
return nil, err
}
return &tunnelingDialer{
url: url,
transport: transport,
holder: holder,
}, nil
}
// Dial upgrades to a tunneling streaming connection, returning a SPDY connection
// containing a WebSockets connection (which implements "net.Conn"). Also
// returns the protocol negotiated, or an error.
func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
// There is no passed context, so skip the context when creating request for now.
// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
req, err := http.NewRequest("GET", d.url.String(), nil)
if err != nil {
return nil, "", err
}
// Add the spdy tunneling prefix to the requested protocols. The tunneling
// handler will know how to negotiate these protocols.
tunnelingProtocols := []string{}
for _, protocol := range protocols {
tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
}
klog.V(4).Infoln("Before WebSocket Upgrade Connection...")
conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
if err != nil {
return nil, "", err
}
if conn == nil {
return nil, "", fmt.Errorf("negotiated websocket connection is nil")
}
protocol := conn.Subprotocol()
protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
klog.V(4).Infof("negotiated protocol: %s", protocol)
// Wrap the websocket connection which implements "net.Conn".
tConn := NewTunnelingConnection("client", conn)
// Create SPDY connection injecting the previously created tunneling connection.
spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)
return spdyConn, protocol, err
}