mirror of
https://github.com/kubernetes/client-go.git
synced 2025-07-30 22:35:10 +00:00
portforward: tunnel spdy through websockets
Kubernetes-commit: 8b447d8c97e8823b4308eb91cf7d75693e867c61
This commit is contained in:
parent
08128e0dfa
commit
271d034e86
57
tools/portforward/fallback_dialer.go
Normal file
57
tools/portforward/fallback_dialer.go
Normal file
@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright 2024 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/klog/v2"
|
||||
)
|
||||
|
||||
var _ httpstream.Dialer = &fallbackDialer{}
|
||||
|
||||
// fallbackDialer encapsulates a primary and secondary dialer, including
|
||||
// the boolean function to determine if the primary dialer failed. Implements
|
||||
// the httpstream.Dialer interface.
|
||||
type fallbackDialer struct {
|
||||
primary httpstream.Dialer
|
||||
secondary httpstream.Dialer
|
||||
shouldFallback func(error) bool
|
||||
}
|
||||
|
||||
// NewFallbackDialer creates the fallbackDialer with the primary and secondary dialers,
|
||||
// as well as the boolean function to determine if the primary dialer failed.
|
||||
func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func(error) bool) httpstream.Dialer {
|
||||
return &fallbackDialer{
|
||||
primary: primary,
|
||||
secondary: secondary,
|
||||
shouldFallback: shouldFallback,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial is the single function necessary to implement the "httpstream.Dialer" interface.
|
||||
// It takes the protocol version strings to request, returning an the upgraded
|
||||
// httstream.Connection and the negotiated protocol version accepted. If the initial
|
||||
// primary dialer fails, this function attempts the secondary dialer. Returns an error
|
||||
// if one occurs.
|
||||
func (f *fallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||
conn, version, err := f.primary.Dial(protocols...)
|
||||
if err != nil && f.shouldFallback(err) {
|
||||
klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
|
||||
return f.secondary.Dial(protocols...)
|
||||
}
|
||||
return conn, version, err
|
||||
}
|
53
tools/portforward/fallback_dialer_test.go
Normal file
53
tools/portforward/fallback_dialer_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copyright 2024 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFallbackDialer(t *testing.T) {
|
||||
protocol := "v6.fake.k8s.io"
|
||||
// If "shouldFallback" is false, then only primary should be dialed.
|
||||
primary := &fakeDialer{dialed: false}
|
||||
primary.negotiatedProtocol = protocol
|
||||
secondary := &fakeDialer{dialed: false}
|
||||
fallbackDialer := NewFallbackDialer(primary, secondary, alwaysFalse)
|
||||
_, negotiated, err := fallbackDialer.Dial(protocol)
|
||||
assert.True(t, primary.dialed, "no fallback; primary should have dialed")
|
||||
assert.Equal(t, protocol, negotiated, "")
|
||||
assert.False(t, secondary.dialed, "no fallback; secondary should *not* have dialed")
|
||||
assert.Nil(t, err, "error should be nil")
|
||||
// If "shouldFallback" is true, then primary AND secondary should be dialed.
|
||||
primary.dialed = false // reset dialed field
|
||||
primary.err = fmt.Errorf("bad handshake")
|
||||
secondary.dialed = false // reset dialed field
|
||||
secondary.negotiatedProtocol = protocol
|
||||
fallbackDialer = NewFallbackDialer(primary, secondary, alwaysTrue)
|
||||
_, negotiated, err = fallbackDialer.Dial(protocol)
|
||||
assert.True(t, primary.dialed, "fallback; primary should have dialed (first)")
|
||||
assert.True(t, secondary.dialed, "fallback; secondary should have dialed")
|
||||
assert.Equal(t, protocol, negotiated)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func alwaysTrue(err error) bool { return true }
|
||||
|
||||
func alwaysFalse(err error) bool { return false }
|
@ -191,11 +191,15 @@ func (pf *PortForwarder) ForwardPorts() error {
|
||||
defer pf.Close()
|
||||
|
||||
var err error
|
||||
pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
|
||||
var protocol string
|
||||
pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error upgrading connection: %s", err)
|
||||
}
|
||||
defer pf.streamConn.Close()
|
||||
if protocol != PortForwardProtocolV1Name {
|
||||
return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
|
||||
}
|
||||
|
||||
return pf.forward()
|
||||
}
|
||||
|
@ -430,7 +430,8 @@ func TestGetListener(t *testing.T) {
|
||||
|
||||
func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
||||
dialer := &fakeDialer{
|
||||
conn: newFakeConnection(),
|
||||
conn: newFakeConnection(),
|
||||
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
@ -570,7 +571,8 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
|
||||
|
||||
func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
||||
dialer := &fakeDialer{
|
||||
conn: newFakeConnection(),
|
||||
conn: newFakeConnection(),
|
||||
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
@ -601,7 +603,8 @@ func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
|
||||
|
||||
func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
|
||||
dialer := &fakeDialer{
|
||||
conn: newFakeConnection(),
|
||||
conn: newFakeConnection(),
|
||||
negotiatedProtocol: PortForwardProtocolV1Name,
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
153
tools/portforward/tunneling_connection.go
Normal file
153
tools/portforward/tunneling_connection.go
Normal file
@ -0,0 +1,153 @@
|
||||
/*
|
||||
Copyright 2023 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gwebsocket "github.com/gorilla/websocket"
|
||||
|
||||
"k8s.io/klog/v2"
|
||||
)
|
||||
|
||||
var _ net.Conn = &TunnelingConnection{}
|
||||
|
||||
// TunnelingConnection implements the "httpstream.Connection" interface, wrapping
|
||||
// a websocket connection that tunnels SPDY.
|
||||
type TunnelingConnection struct {
|
||||
name string
|
||||
conn *gwebsocket.Conn
|
||||
inProgressMessage io.Reader
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewTunnelingConnection wraps the passed gorilla/websockets connection
|
||||
// with the TunnelingConnection struct (implementing net.Conn).
|
||||
func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
|
||||
return &TunnelingConnection{
|
||||
name: name,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements "io.Reader" interface, reading from the stored connection
|
||||
// into the passed buffer "p". Returns the number of bytes read and an error.
|
||||
// Can keep track of the "inProgress" messsage from the tunneled connection.
|
||||
func (c *TunnelingConnection) Read(p []byte) (int, error) {
|
||||
klog.V(7).Infof("%s: tunneling connection read...", c.name)
|
||||
defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name)
|
||||
for {
|
||||
if c.inProgressMessage == nil {
|
||||
klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name)
|
||||
messageType, nextReader, err := c.conn.NextReader()
|
||||
if err != nil {
|
||||
closeError := &gwebsocket.CloseError{}
|
||||
if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
|
||||
return 0, io.EOF
|
||||
}
|
||||
klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err)
|
||||
return 0, err
|
||||
}
|
||||
if messageType != gwebsocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type received")
|
||||
}
|
||||
c.inProgressMessage = nextReader
|
||||
}
|
||||
klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name)
|
||||
i, err := c.inProgressMessage.Read(p)
|
||||
if i == 0 && err == io.EOF {
|
||||
c.inProgressMessage = nil
|
||||
} else {
|
||||
klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements "io.Writer" interface, copying the data in the passed
|
||||
// byte array "p" into the stored tunneled connection. Returns the number
|
||||
// of bytes written and an error.
|
||||
func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
|
||||
klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p)
|
||||
defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name)
|
||||
w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// close, which flushes the message
|
||||
closeErr := w.Close()
|
||||
if closeErr != nil && err == nil {
|
||||
// if closing/flushing errored and we weren't already returning an error, return the close error
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
n, err = w.Write(p)
|
||||
return
|
||||
}
|
||||
|
||||
// Close implements "io.Closer" interface, signaling the other tunneled connection
|
||||
// endpoint, and closing the tunneled connection only once.
|
||||
func (c *TunnelingConnection) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
klog.V(7).Infof("%s: tunneling connection Close()...", c.name)
|
||||
// Signal other endpoint that websocket connection is closing; ignore error.
|
||||
normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
|
||||
c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) //nolint:errcheck
|
||||
err = c.conn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// LocalAddr implements part of the "net.Conn" interface, returning the local
|
||||
// endpoint network address of the tunneled connection.
|
||||
func (c *TunnelingConnection) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalAddr implements part of the "net.Conn" interface, returning the remote
|
||||
// endpoint network address of the tunneled connection.
|
||||
func (c *TunnelingConnection) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline sets the *absolute* time in the future for both
|
||||
// read and write deadlines. Returns an error if one occurs.
|
||||
func (c *TunnelingConnection) SetDeadline(t time.Time) error {
|
||||
rerr := c.SetReadDeadline(t)
|
||||
werr := c.SetWriteDeadline(t)
|
||||
return errors.Join(rerr, werr)
|
||||
}
|
||||
|
||||
// SetDeadline sets the *absolute* time in the future for the
|
||||
// read deadlines. Returns an error if one occurs.
|
||||
func (c *TunnelingConnection) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetDeadline sets the *absolute* time in the future for the
|
||||
// write deadlines. Returns an error if one occurs.
|
||||
func (c *TunnelingConnection) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
190
tools/portforward/tunneling_connection_test.go
Normal file
190
tools/portforward/tunneling_connection_test.go
Normal file
@ -0,0 +1,190 @@
|
||||
/*
|
||||
Copyright 2024 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gwebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
"k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/transport/websocket"
|
||||
)
|
||||
|
||||
func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
|
||||
// Stream channel that will receive streams created on upstream SPDY server.
|
||||
streamChan := make(chan httpstream.Stream)
|
||||
defer close(streamChan)
|
||||
stopServerChan := make(chan struct{})
|
||||
defer close(stopServerChan)
|
||||
// Create tunneling connection server endpoint with fake upstream SPDY server.
|
||||
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
var upgrader = gwebsocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, req, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close() //nolint:errcheck
|
||||
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||
tunnelingConn := NewTunnelingConnection("server", conn)
|
||||
spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
|
||||
require.NoError(t, err)
|
||||
defer spdyConn.Close() //nolint:errcheck
|
||||
<-stopServerChan
|
||||
}))
|
||||
defer tunnelingServer.Close()
|
||||
// Dial the client tunneling connection to the tunneling server.
|
||||
url, err := url.Parse(tunnelingServer.URL)
|
||||
require.NoError(t, err)
|
||||
dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
|
||||
require.NoError(t, err)
|
||||
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, constants.PortForwardV1Name, protocol)
|
||||
defer spdyClient.Close() //nolint:errcheck
|
||||
// Create a SPDY client stream, which will queue a SPDY server stream
|
||||
// on the stream creation channel. Send data on the client stream
|
||||
// reading off the SPDY server stream, and validating it was tunneled.
|
||||
expected := "This is a test tunneling SPDY data through websockets."
|
||||
var actual []byte
|
||||
go func() {
|
||||
clientStream, err := spdyClient.CreateStream(http.Header{})
|
||||
require.NoError(t, err)
|
||||
_, err = io.Copy(clientStream, strings.NewReader(expected))
|
||||
require.NoError(t, err)
|
||||
clientStream.Close() //nolint:errcheck
|
||||
}()
|
||||
select {
|
||||
case serverStream := <-streamChan:
|
||||
actual, err = io.ReadAll(serverStream)
|
||||
require.NoError(t, err)
|
||||
defer serverStream.Close() //nolint:errcheck
|
||||
case <-time.After(wait.ForeverTestTimeout):
|
||||
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
|
||||
}
|
||||
assert.Equal(t, expected, string(actual), "error validating tunneled string")
|
||||
}
|
||||
|
||||
func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
|
||||
stopServerChan := make(chan struct{})
|
||||
defer close(stopServerChan)
|
||||
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
var upgrader = gwebsocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, req, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close() //nolint:errcheck
|
||||
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||
<-stopServerChan
|
||||
}))
|
||||
defer tunnelingServer.Close()
|
||||
// Create the client side tunneling connection.
|
||||
url, err := url.Parse(tunnelingServer.URL)
|
||||
require.NoError(t, err)
|
||||
tConn, err := dialForTunnelingConnection(url)
|
||||
require.NoError(t, err, "error creating client tunneling connection")
|
||||
defer tConn.Close() //nolint:errcheck
|
||||
// Validate "LocalAddr()" and "RemoteAddr()"
|
||||
localAddr := tConn.LocalAddr()
|
||||
remoteAddr := tConn.RemoteAddr()
|
||||
assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
|
||||
assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
|
||||
_, err = net.ResolveTCPAddr("tcp", localAddr.String())
|
||||
assert.NoError(t, err, "tunneling connection local addr should parse")
|
||||
_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
|
||||
assert.NoError(t, err, "tunneling connection remote addr should parse")
|
||||
}
|
||||
|
||||
func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
|
||||
stopServerChan := make(chan struct{})
|
||||
defer close(stopServerChan)
|
||||
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
var upgrader = gwebsocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, req, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close() //nolint:errcheck
|
||||
require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
|
||||
<-stopServerChan
|
||||
}))
|
||||
defer tunnelingServer.Close()
|
||||
// Create the client side tunneling connection.
|
||||
url, err := url.Parse(tunnelingServer.URL)
|
||||
require.NoError(t, err)
|
||||
tConn, err := dialForTunnelingConnection(url)
|
||||
require.NoError(t, err, "error creating client tunneling connection")
|
||||
defer tConn.Close() //nolint:errcheck
|
||||
// Validate the read and write deadlines.
|
||||
err = tConn.SetReadDeadline(time.Time{})
|
||||
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||
err = tConn.SetWriteDeadline(time.Time{})
|
||||
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||
err = tConn.SetDeadline(time.Time{})
|
||||
assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
|
||||
err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
|
||||
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||
err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
|
||||
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||
err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
|
||||
assert.NoError(t, err, "setting deadline 10 year from now succeeds")
|
||||
}
|
||||
|
||||
// dialForTunnelingConnection upgrades a request at the passed "url", creating
|
||||
// a websocket connection. Returns the TunnelingConnection injected with the
|
||||
// websocket connection or an error if one occurs.
|
||||
func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
|
||||
req, err := http.NewRequest("GET", url.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
|
||||
tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
|
||||
transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewTunnelingConnection("client", conn), nil
|
||||
}
|
||||
|
||||
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
streams <- stream
|
||||
return nil
|
||||
}
|
||||
}
|
93
tools/portforward/tunneling_dialer.go
Normal file
93
tools/portforward/tunneling_dialer.go
Normal file
@ -0,0 +1,93 @@
|
||||
/*
|
||||
Copyright 2023 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||
constants "k8s.io/apimachinery/pkg/util/portforward"
|
||||
restclient "k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/transport/websocket"
|
||||
"k8s.io/klog/v2"
|
||||
)
|
||||
|
||||
const PingPeriod = 10 * time.Second
|
||||
|
||||
// tunnelingDialer implements "httpstream.Dial" interface
|
||||
type tunnelingDialer struct {
|
||||
url *url.URL
|
||||
transport http.RoundTripper
|
||||
holder websocket.ConnectionHolder
|
||||
}
|
||||
|
||||
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
|
||||
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
|
||||
// returns an error if one occurs.
|
||||
func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
|
||||
transport, holder, err := websocket.RoundTripperFor(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tunnelingDialer{
|
||||
url: url,
|
||||
transport: transport,
|
||||
holder: holder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Dial upgrades to a tunneling streaming connection, returning a SPDY connection
|
||||
// containing a WebSockets connection (which implements "net.Conn"). Also
|
||||
// returns the protocol negotiated, or an error.
|
||||
func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||
// There is no passed context, so skip the context when creating request for now.
|
||||
// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
|
||||
req, err := http.NewRequest("GET", d.url.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
// Add the spdy tunneling prefix to the requested protocols. The tunneling
|
||||
// handler will know how to negotiate these protocols.
|
||||
tunnelingProtocols := []string{}
|
||||
for _, protocol := range protocols {
|
||||
tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
|
||||
tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
|
||||
}
|
||||
klog.V(4).Infoln("Before WebSocket Upgrade Connection...")
|
||||
conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, "", fmt.Errorf("negotiated websocket connection is nil")
|
||||
}
|
||||
protocol := conn.Subprotocol()
|
||||
protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
|
||||
klog.V(4).Infof("negotiated protocol: %s", protocol)
|
||||
|
||||
// Wrap the websocket connection which implements "net.Conn".
|
||||
tConn := NewTunnelingConnection("client", conn)
|
||||
// Create SPDY connection injecting the previously created tunneling connection.
|
||||
spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)
|
||||
|
||||
return spdyConn, protocol, err
|
||||
}
|
Loading…
Reference in New Issue
Block a user