mirror of
https://github.com/kubernetes/client-go.git
synced 2025-08-01 07:20:59 +00:00
portforward: tunnel spdy through websockets
Kubernetes-commit: 8b447d8c97e8823b4308eb91cf7d75693e867c61
This commit is contained in:
parent
08128e0dfa
commit
271d034e86
57
tools/portforward/fallback_dialer.go
Normal file
57
tools/portforward/fallback_dialer.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ httpstream.Dialer = &fallbackDialer{}
|
||||||
|
|
||||||
|
// fallbackDialer encapsulates a primary and secondary dialer, including
|
||||||
|
// the boolean function to determine if the primary dialer failed. Implements
|
||||||
|
// the httpstream.Dialer interface.
|
||||||
|
type fallbackDialer struct {
|
||||||
|
primary httpstream.Dialer
|
||||||
|
secondary httpstream.Dialer
|
||||||
|
shouldFallback func(error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers,
|
||||||
|
// as well as the boolean function to determine if the primary dialer failed.
|
||||||
|
func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer {
|
||||||
|
return &fallbackDialer{
|
||||||
|
primary: primary,
|
||||||
|
secondary: secondary,
|
||||||
|
shouldFallback: shouldFallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial is the single function necessary to implement the "httpstream.Dialer" interface.
|
||||||
|
// It takes the protocol version strings to request, returning an the upgraded
|
||||||
|
// httstream.Connection and the negotiated protocol version accepted. If the initial
|
||||||
|
// primary dialer fails, this function attempts the secondary dialer. Returns an error
|
||||||
|
// if one occurs.
|
||||||
|
func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||||
|
conn, version, err := f.primary.Dial(protocols...)
|
||||||
|
if err != nil && f.shouldFallback(err) {
|
||||||
|
klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
|
||||||
|
return f.secondary.Dial(protocols...)
|
||||||
|
}
|
||||||
|
return conn, version, err
|
||||||
|
}
|
53
tools/portforward/fallback_dialer_test.go
Normal file
53
tools/portforward/fallback_dialer_test.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFallbackDialer(t *testing.T) {
|
||||||
|
protocol := "v6.fake.k8s.io"
|
||||||
|
// If "shouldFallback" is false, then only primary should be dialed.
|
||||||
|
primary := &fakeDialer{dialed: false}
|
||||||
|
primary.negotiatedProtocol = protocol
|
||||||
|
secondary := &fakeDialer{dialed: false}
|
||||||
|
fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse)
|
||||||
|
_, negotiated, err := fallbackDialer.Dial(protocol)
|
||||||
|
assert.True(t, primary.dialed, "no fallback; primary should have dialed")
|
||||||
|
assert.Equal(t, protocol, negotiated, "")
|
||||||
|
assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
|
||||||
|
assert.Nil(t, err, "error should be nil")
|
||||||
|
// If "shouldFallback" is true, then primary AND secondary should be dialed.
|
||||||
|
primary.dialed = false // reset dialed field
|
||||||
|
primary.err = fmt.Errorf("bad handshake")
|
||||||
|
secondary.dialed = false // reset dialed field
|
||||||
|
secondary.negotiatedProtocol = protocol
|
||||||
|
fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue)
|
||||||
|
_, negotiated, err = fallbackDialer.Dial(protocol)
|
||||||
|
assert.True(t, primary.dialed, "fallback; primary should have dialed (first)")
|
||||||
|
assert.True(t, secondary.dialed, "fallback; secondary should have dialed")
|
||||||
|
assert.Equal(t, protocol, negotiated)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func alwaysTrue(err error) bool { return true }
|
||||||
|
|
||||||
|
func alwaysFalse(err error) bool { return false }
|
@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error {
|
|||||||
defer pf.Close()
|
defer pf.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
|
var protocol string
|
||||||
|
pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error upgrading connection: %s", err)
|
return fmt.Errorf("error upgrading connection: %s", err)
|
||||||
}
|
}
|
||||||
defer pf.streamConn.Close()
|
defer pf.streamConn.Close()
|
||||||
|
if protocol != PortForwardProtocolV1Name {
|
||||||
|
return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
|
||||||
|
}
|
||||||
|
|
||||||
return pf.forward()
|
return pf.forward()
|
||||||
}
|
}
|
||||||
|
@ -431,6 +431,7 @@ func TestGetListener(t *testing.T) {
|
|||||||
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
||||||
dialer := &fakeDialer{
|
dialer := &fakeDialer{
|
||||||
conn: newFakeConnection(),
|
conn: newFakeConnection(),
|
||||||
|
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
@ -571,6 +572,7 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
|
|||||||
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
||||||
dialer := &fakeDialer{
|
dialer := &fakeDialer{
|
||||||
conn: newFakeConnection(),
|
conn: newFakeConnection(),
|
||||||
|
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
@ -602,6 +604,7 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
|||||||
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
|
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
|
||||||
dialer := &fakeDialer{
|
dialer := &fakeDialer{
|
||||||
conn: newFakeConnection(),
|
conn: newFakeConnection(),
|
||||||
|
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
|
153
tools/portforward/tunneling_connection.go
Normal file
153
tools/portforward/tunneling_connection.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gwebsocket "github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ net.Conn = &TunnelingConnection{}
|
||||||
|
|
||||||
|
// TunnelingConnection implements the "httpstream.Connection" interface, wrapping
|
||||||
|
// a websocket connection that tunnels SPDY.
|
||||||
|
type TunnelingConnection struct {
|
||||||
|
name string
|
||||||
|
conn *gwebsocket.Conn
|
||||||
|
inProgressMessage io.Reader
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelingConnection wraps the passed gorilla/websockets connection
|
||||||
|
// with the TunnelingConnection struct (implementing net.Conn).
|
||||||
|
func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
|
||||||
|
return &TunnelingConnection{
|
||||||
|
name: name,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements "io.Reader" interface, reading from the stored connection
|
||||||
|
// into the passed buffer "p". Returns the number of bytes read and an error.
|
||||||
|
// Can keep track of the "inProgress" messsage from the tunneled connection.
|
||||||
|
func (c *TunnelingConnection) Read(p []byte) (int, error) {
|
||||||
|
klog.V(7).Infof("%s: tunneling connection read...", c.name)
|
||||||
|
defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name)
|
||||||
|
for {
|
||||||
|
if c.inProgressMessage == nil {
|
||||||
|
klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name)
|
||||||
|
messageType, nextReader, err := c.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
closeError := &gwebsocket.CloseError{}
|
||||||
|
if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if messageType != gwebsocket.BinaryMessage {
|
||||||
|
return 0, fmt.Errorf("invalid message type received")
|
||||||
|
}
|
||||||
|
c.inProgressMessage = nextReader
|
||||||
|
}
|
||||||
|
klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name)
|
||||||
|
i, err := c.inProgressMessage.Read(p)
|
||||||
|
if i == 0 && err == io.EOF {
|
||||||
|
c.inProgressMessage = nil
|
||||||
|
} else {
|
||||||
|
klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements "io.Writer" interface, copying the data in the passed
|
||||||
|
// byte array "p" into the stored tunneled connection. Returns the number
|
||||||
|
// of bytes written and an error.
|
||||||
|
func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
|
||||||
|
klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p)
|
||||||
|
defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name)
|
||||||
|
w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// close, which flushes the message
|
||||||
|
closeErr := w.Close()
|
||||||
|
if closeErr != nil && err == nil {
|
||||||
|
// if closing/flushing errored and we weren't already returning an error, return the close error
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
n, err = w.Write(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements "io.Closer" interface, signaling the other tunneled connection
|
||||||
|
// endpoint, and closing the tunneled connection only once.
|
||||||
|
func (c *TunnelingConnection) Close() error {
|
||||||
|
var err error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
klog.V(7).Infof("%s: tunneling connection Close()...", c.name)
|
||||||
|
// Signal other endpoint that websocket connection is closing; ignore error.
|
||||||
|
normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
|
||||||
|
c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck
|
||||||
|
err = c.conn.Close()
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements part of the "net.Conn" interface, returning the local
|
||||||
|
// endpoint network address of the tunneled connection.
|
||||||
|
func (c *TunnelingConnection) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements part of the "net.Conn" interface, returning the remote
|
||||||
|
// endpoint network address of the tunneled connection.
|
||||||
|
func (c *TunnelingConnection) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the *absolute* time in the future for both
|
||||||
|
// read and write deadlines. Returns an error if one occurs.
|
||||||
|
func (c *TunnelingConnection) SetDeadline(t time.Time) error {
|
||||||
|
rerr := c.SetReadDeadline(t)
|
||||||
|
werr := c.SetWriteDeadline(t)
|
||||||
|
return errors.Join(rerr, werr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the *absolute* time in the future for the
|
||||||
|
// read deadlines. Returns an error if one occurs.
|
||||||
|
func (c *TunnelingConnection) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the *absolute* time in the future for the
|
||||||
|
// write deadlines. Returns an error if one occurs.
|
||||||
|
func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
190
tools/portforward/tunneling_connection_test.go
Normal file
190
tools/portforward/tunneling_connection_test.go
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2024 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gwebsocket "github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||||
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||||
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/transport/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
|
||||||
|
// Stream channel that will receive streams created on upstream SPDY server.
|
||||||
|
streamChan := make(chan httpstream.Stream)
|
||||||
|
defer close(streamChan)
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
// Create tunneling connection server endpoint with fake upstream SPDY server.
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||||
|
tunnelingConn := NewTunnelingConnection("server", conn)
|
||||||
|
spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer spdyConn.Close() //nolint:errcheck
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Dial the client tunneling connection to the tunneling server.
|
||||||
|
url, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
|
||||||
|
require.NoError(t, err)
|
||||||
|
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, constants.PortForwardV1Name, protocol)
|
||||||
|
defer spdyClient.Close() //nolint:errcheck
|
||||||
|
// Create a SPDY client stream, which will queue a SPDY server stream
|
||||||
|
// on the stream creation channel. Send data on the client stream
|
||||||
|
// reading off the SPDY server stream, and validating it was tunneled.
|
||||||
|
expected := "This is a test tunneling SPDY data through websockets."
|
||||||
|
var actual []byte
|
||||||
|
go func() {
|
||||||
|
clientStream, err := spdyClient.CreateStream(http.Header{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = io.Copy(clientStream, strings.NewReader(expected))
|
||||||
|
require.NoError(t, err)
|
||||||
|
clientStream.Close() //nolint:errcheck
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case serverStream := <-streamChan:
|
||||||
|
actual, err = io.ReadAll(serverStream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer serverStream.Close() //nolint:errcheck
|
||||||
|
case <-time.After(wait.ForeverTestTimeout):
|
||||||
|
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, string(actual), "error validating tunneled string")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Create the client side tunneling connection.
|
||||||
|
url, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tConn, err := dialForTunnelingConnection(url)
|
||||||
|
require.NoError(t, err, "error creating client tunneling connection")
|
||||||
|
defer tConn.Close() //nolint:errcheck
|
||||||
|
// Validate "LocalAddr()" and "RemoteAddr()"
|
||||||
|
localAddr := tConn.LocalAddr()
|
||||||
|
remoteAddr := tConn.RemoteAddr()
|
||||||
|
assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
|
||||||
|
assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
|
||||||
|
_, err = net.ResolveTCPAddr("tcp", localAddr.String())
|
||||||
|
assert.NoError(t, err, "tunneling connection local addr should parse")
|
||||||
|
_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
|
||||||
|
assert.NoError(t, err, "tunneling connection remote addr should parse")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
|
||||||
|
stopServerChan := make(chan struct{})
|
||||||
|
defer close(stopServerChan)
|
||||||
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
var upgrader = gwebsocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close() //nolint:errcheck
|
||||||
|
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||||
|
<-stopServerChan
|
||||||
|
}))
|
||||||
|
defer tunnelingServer.Close()
|
||||||
|
// Create the client side tunneling connection.
|
||||||
|
url, err := url.Parse(tunnelingServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tConn, err := dialForTunnelingConnection(url)
|
||||||
|
require.NoError(t, err, "error creating client tunneling connection")
|
||||||
|
defer tConn.Close() //nolint:errcheck
|
||||||
|
// Validate the read and write deadlines.
|
||||||
|
err = tConn.SetReadDeadline(time.Time{})
|
||||||
|
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||||
|
err = tConn.SetWriteDeadline(time.Time{})
|
||||||
|
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||||
|
err = tConn.SetDeadline(time.Time{})
|
||||||
|
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||||
|
err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
|
||||||
|
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||||
|
err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
|
||||||
|
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||||
|
err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
|
||||||
|
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialForTunnelingConnection upgrades a request at the passed "url", creating
|
||||||
|
// a websocket connection. Returns the TunnelingConnection injected with the
|
||||||
|
// websocket connection or an error if one occurs.
|
||||||
|
func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
|
||||||
|
req, err := http.NewRequest("GET", url.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
|
||||||
|
tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
|
||||||
|
transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewTunnelingConnection("client", conn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||||
|
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||||
|
streams <- stream
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
93
tools/portforward/tunneling_dialer.go
Normal file
93
tools/portforward/tunneling_dialer.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||||
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||||
|
restclient "k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/transport/websocket"
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PingPeriod = 10 * time.Second
|
||||||
|
|
||||||
|
// tunnelingDialer implements "httpstream.Dial" interface
|
||||||
|
type tunnelingDialer struct {
|
||||||
|
url *url.URL
|
||||||
|
transport http.RoundTripper
|
||||||
|
holder websocket.ConnectionHolder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
|
||||||
|
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
|
||||||
|
// returns an error if one occurs.
|
||||||
|
func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
|
||||||
|
transport, holder, err := websocket.RoundTripperFor(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tunnelingDialer{
|
||||||
|
url: url,
|
||||||
|
transport: transport,
|
||||||
|
holder: holder,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial upgrades to a tunneling streaming connection, returning a SPDY connection
|
||||||
|
// containing a WebSockets connection (which implements "net.Conn"). Also
|
||||||
|
// returns the protocol negotiated, or an error.
|
||||||
|
func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||||
|
// There is no passed context, so skip the context when creating request for now.
|
||||||
|
// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
|
||||||
|
req, err := http.NewRequest("GET", d.url.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
// Add the spdy tunneling prefix to the requested protocols. The tunneling
|
||||||
|
// handler will know how to negotiate these protocols.
|
||||||
|
tunnelingProtocols := []string{}
|
||||||
|
for _, protocol := range protocols {
|
||||||
|
tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
|
||||||
|
tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
|
||||||
|
}
|
||||||
|
klog.V(4).Infoln("Before WebSocket Upgrade Connection...")
|
||||||
|
conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return nil, "", fmt.Errorf("negotiated websocket connection is nil")
|
||||||
|
}
|
||||||
|
protocol := conn.Subprotocol()
|
||||||
|
protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
|
||||||
|
klog.V(4).Infof("negotiated protocol: %s", protocol)
|
||||||
|
|
||||||
|
// Wrap the websocket connection which implements "net.Conn".
|
||||||
|
tConn := NewTunnelingConnection("client", conn)
|
||||||
|
// Create SPDY connection injecting the previously created tunneling connection.
|
||||||
|
spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)
|
||||||
|
|
||||||
|
return spdyConn, protocol, err
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user