diff --git a/go.mod b/go.mod index 67962fa4..03222b56 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( golang.org/x/time v0.3.0 google.golang.org/protobuf v1.31.0 k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 - k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac + k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439 k8s.io/klog/v2 v2.100.1 k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 k8s.io/utils v0.0.0-20230726121419-3b25d923346b @@ -49,6 +49,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sys v0.13.0 // indirect @@ -61,5 +62,5 @@ require ( replace ( k8s.io/api => k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 - k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac + k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439 ) diff --git a/go.sum b/go.sum index 1ded1d9f..aaa3577e 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= github.com/onsi/gomega v1.28.0 h1:i2rg/p9n/UqIDAMFUJ6qIUUMcsqOuUHgbpbu235Vr1c= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= @@ -149,8 +151,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 h1:U7xcM/WBTkLV+TjNciuW7l+oXM2OHd5/TmVnPKyrmpA= k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8/go.mod h1:mgYOiLIgrQcsuVxrBI6Pplk91r3sl5ZJ7eUx7UBMTkY= -k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac h1:x3g6c1u7CtRoraBlRP2JThB3aHz7vw4FZFXRZsvoIoc= -k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac/go.mod h1:mdlGhJWO1mhVzQXm1Lx7D1BvvBIVKlRVy0vvl1LwGjg= +k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439 h1:/oxbLzC7mkHNdeFI8AMsTPTwudQu7sz7rnPGIxv2yqM= +k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439/go.mod h1:mdlGhJWO1mhVzQXm1Lx7D1BvvBIVKlRVy0vvl1LwGjg= k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= diff --git a/tools/remotecommand/fallback.go b/tools/remotecommand/fallback.go new file mode 100644 index 00000000..4846cdb5 --- /dev/null +++ b/tools/remotecommand/fallback.go @@ -0,0 +1,57 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" +) + +var _ Executor = &fallbackExecutor{} + +type fallbackExecutor struct { + primary Executor + secondary Executor + shouldFallback func(error) bool +} + +// NewFallbackExecutor creates an Executor that first attempts to use the +// WebSocketExecutor, falling back to the legacy SPDYExecutor if the initial +// websocket "StreamWithContext" call fails. +// func NewFallbackExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { +func NewFallbackExecutor(primary, secondary Executor, shouldFallback func(error) bool) (Executor, error) { + return &fallbackExecutor{ + primary: primary, + secondary: secondary, + shouldFallback: shouldFallback, + }, nil +} + +// Stream is deprecated. Please use "StreamWithContext". +func (f *fallbackExecutor) Stream(options StreamOptions) error { + return f.StreamWithContext(context.Background(), options) +} + +// StreamWithContext initially attempts to call "StreamWithContext" using the +// primary executor, falling back to calling the secondary executor if the +// initial primary call to upgrade to a websocket connection fails. +func (f *fallbackExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + err := f.primary.StreamWithContext(ctx, options) + if f.shouldFallback(err) { + return f.secondary.StreamWithContext(ctx, options) + } + return err +} diff --git a/tools/remotecommand/fallback_test.go b/tools/remotecommand/fallback_test.go new file mode 100644 index 00000000..70049857 --- /dev/null +++ b/tools/remotecommand/fallback_test.go @@ -0,0 +1,227 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "bytes" + "context" + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" +) + +func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + require.NoError(t, err) + })) + defer websocketServer.Close() + + // Now create the fallback client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation) + require.NoError(t, err) + // Never fallback, so always use the websocketExecutor, which succeeds against websocket server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false }) + require.NoError(t, err) + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + } + + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) { + // Create fake SPDY server. Copy received STDIN data back onto STDOUT stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var stdin, stdout bytes.Buffer + ctx, err := createHTTPStreams(w, req, &StreamOptions{ + Stdin: &stdin, + Stdout: &stdout, + }) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer ctx.conn.Close() + _, err = io.Copy(ctx.stdoutStream, ctx.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + })) + defer spdyServer.Close() + + spdyLocation, err := url.Parse(spdyServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation) + require.NoError(t, err) + // Always fallback to spdyExecutor, and spdyExecutor succeeds against fake spdy server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true }) + require.NoError(t, err) + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + } + + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + require.NoError(t, err) + })) + defer websocketServer.Close() + + // Now create the fallback client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation) + require.NoError(t, err) + // Always fallback to spdyExecutor, but spdyExecutor fails against websocket server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true }) + require.NoError(t, err) + // Update the websocket executor to request remote command v4, which is unsupported. + fallbackExec, ok := exec.(*fallbackExecutor) + assert.True(t, ok, "error casting executor as fallbackExecutor") + websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor) + assert.True(t, ok, "error casting executor as websocket executor") + // Set the attempted subprotocol version to V4; websocket server only accepts V5. + websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name} + + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Ensure secondary executor returned an error. + require.Error(t, err) + } +} diff --git a/tools/remotecommand/spdy.go b/tools/remotecommand/spdy.go index 76ea946b..c2bfcf8a 100644 --- a/tools/remotecommand/spdy.go +++ b/tools/remotecommand/spdy.go @@ -34,9 +34,10 @@ type spdyStreamExecutor struct { upgrader spdy.Upgrader transport http.RoundTripper - method string - url *url.URL - protocols []string + method string + url *url.URL + protocols []string + rejectRedirects bool // if true, receiving redirect from upstream is an error } // NewSPDYExecutor connects to the provided server and upgrades the connection to @@ -49,6 +50,20 @@ func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Ex return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url) } +// NewSPDYExecutorRejectRedirects returns an Executor that will upgrade the future +// connection to a SPDY bi-directional streaming connection when calling "Stream" (deprecated) +// or "StreamWithContext" (preferred). Additionally, if the upstream server returns a redirect +// during the attempted upgrade in these "Stream" calls, an error is returned. +func NewSPDYExecutorRejectRedirects(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { + executor, err := NewSPDYExecutorForTransports(transport, upgrader, method, url) + if err != nil { + return nil, err + } + spdyExecutor := executor.(*spdyStreamExecutor) + spdyExecutor.rejectRedirects = true + return spdyExecutor, nil +} + // NewSPDYExecutorForTransports connects to the provided server using the given transport, // upgrades the response using the given upgrader to multiplexed bidirectional streams. func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { @@ -88,9 +103,15 @@ func (e *spdyStreamExecutor) newConnectionAndStream(ctx context.Context, options return nil, nil, fmt.Errorf("error creating request: %v", err) } + client := http.Client{Transport: e.transport} + if e.rejectRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return fmt.Errorf("redirect not allowed") + } + } conn, protocol, err := spdy.Negotiate( e.upgrader, - &http.Client{Transport: e.transport}, + &client, req, e.protocols..., ) diff --git a/tools/remotecommand/spdy_test.go b/tools/remotecommand/spdy_test.go index c11177a0..1b1cf749 100644 --- a/tools/remotecommand/spdy_test.go +++ b/tools/remotecommand/spdy_test.go @@ -183,6 +183,7 @@ func TestSPDYExecutorStream(t *testing.T) { } func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { + //nolint:errcheck server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx, err := createHTTPStreams(writer, request, options) if err != nil { @@ -381,7 +382,7 @@ func TestStreamRandomData(t *testing.T) { } defer ctx.conn.Close() - io.Copy(ctx.stdoutStream, ctx.stdinStream) + io.Copy(ctx.stdoutStream, ctx.stdinStream) //nolint:errcheck })) defer server.Close() diff --git a/tools/remotecommand/websocket.go b/tools/remotecommand/websocket.go index 48e52092..a60986de 100644 --- a/tools/remotecommand/websocket.go +++ b/tools/remotecommand/websocket.go @@ -85,22 +85,26 @@ type wsStreamExecutor struct { heartbeatDeadline time.Duration } -// NewWebSocketExecutor allows to execute commands via a WebSocket connection. func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) { + // Only supports V5 protocol for correct version skew functionality. + // Previous api servers will proxy upgrade requests to legacy websocket + // servers on container runtimes which support V1-V4. These legacy + // websocket servers will not handle the new CLOSE signal. + return NewWebSocketExecutorForProtocols(config, method, url, remotecommand.StreamProtocolV5Name) +} + +// NewWebSocketExecutorForProtocols allows to execute commands via a WebSocket connection. +func NewWebSocketExecutorForProtocols(config *restclient.Config, method, url string, protocols ...string) (Executor, error) { transport, upgrader, err := websocket.RoundTripperFor(config) if err != nil { return nil, fmt.Errorf("error creating websocket transports: %v", err) } return &wsStreamExecutor{ - transport: transport, - upgrader: upgrader, - method: method, - url: url, - // Only supports V5 protocol for correct version skew functionality. - // Previous api servers will proxy upgrade requests to legacy websocket - // servers on container runtimes which support V1-V4. These legacy - // websocket servers will not handle the new CLOSE signal. - protocols: []string{remotecommand.StreamProtocolV5Name}, + transport: transport, + upgrader: upgrader, + method: method, + url: url, + protocols: protocols, heartbeatPeriod: pingPeriod, heartbeatDeadline: pingReadDeadline, }, nil @@ -177,10 +181,12 @@ func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options Stream } type wsStreamCreator struct { - conn *gwebsocket.Conn + conn *gwebsocket.Conn + // Protects writing to websocket connection; reading is lock-free connWriteLock sync.Mutex - streams map[byte]*stream - streamsMu sync.Mutex + // map of stream id to stream; multiple streams read/write the connection + streams map[byte]*stream + streamsMu sync.Mutex } func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { @@ -226,7 +232,7 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, return s, nil } -// readDemuxLoop is the reading processor for this endpoint of the websocket +// readDemuxLoop is the lock-free reading processor for this endpoint of the websocket // connection. This loop reads the connection, and demultiplexes the data // into one of the individual stream pipes (by checking the stream id). This // loop can *not* be run concurrently, because there can only be one websocket diff --git a/tools/remotecommand/websocket_test.go b/tools/remotecommand/websocket_test.go index 2895ba54..61df2b77 100644 --- a/tools/remotecommand/websocket_test.go +++ b/tools/remotecommand/websocket_test.go @@ -74,7 +74,7 @@ func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -149,7 +149,7 @@ func TestWebSocketClient_DifferentBufferSizes(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -223,7 +223,7 @@ func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -304,7 +304,7 @@ func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -377,7 +377,7 @@ func TestWebSocketClient_MultipleReadChannels(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -479,7 +479,7 @@ func TestWebSocketClient_ErrorStream(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -637,7 +637,7 @@ func TestWebSocketClient_MultipleWriteChannels(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -723,7 +723,7 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -766,11 +766,14 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) { func TestWebSocketClient_BadHandshake(t *testing.T) { // Create fake WebSocket server (supports V5 subprotocol). websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) - if err != nil { - t.Fatalf("error on webSocketServerStreams: %v", err) + // Bad handshake means websocket server will not completely initialize. + _, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err == nil { + t.Fatalf("expected error, but received none.") + } + if !strings.Contains(err.Error(), "websocket server finished before becoming ready") { + t.Errorf("expected websocket server error, but got: %v", err) } - defer conns.conn.Close() })) defer websocketServer.Close() @@ -779,7 +782,7 @@ func TestWebSocketClient_BadHandshake(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -831,7 +834,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -909,7 +912,7 @@ func TestWebSocketClient_TextMessageTypeError(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -970,7 +973,7 @@ func TestWebSocketClient_EmptyMessageHandled(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -1009,14 +1012,14 @@ func TestWebSocketClient_ExecutorErrors(t *testing.T) { ExecProvider: &clientcmdapi.ExecConfig{}, AuthProvider: &clientcmdapi.AuthProviderConfig{}, } - _, err := NewWebSocketExecutor(&config, "POST", "http://localhost") + _, err := NewWebSocketExecutor(&config, "GET", "http://localhost") if err == nil { t.Errorf("expecting executor constructor error, but received none.") } else if !strings.Contains(err.Error(), "error creating websocket transports") { t.Errorf("expecting error creating transports, got (%s)", err.Error()) } // Verify that a nil context will cause an error in StreamWithContext - exec, err := NewWebSocketExecutor(&rest.Config{}, "POST", "http://localhost") + exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost") if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -1316,7 +1319,16 @@ func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *opti resizeStream: streams[remotecommand.StreamResize], } - wsStreams.writeStatus = v4WriteStatusFunc(streams[remotecommand.StreamErr]) + wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error { + return func(status *apierrors.StatusError) error { + bs, err := json.Marshal(status.Status()) + if err != nil { + return err + } + _, err = stream.Write(bs) + return err + } + }(streams[remotecommand.StreamErr]) return wsStreams, nil } diff --git a/transport/spdy/spdy.go b/transport/spdy/spdy.go index f50b68e5..9fddc6c5 100644 --- a/transport/spdy/spdy.go +++ b/transport/spdy/spdy.go @@ -43,11 +43,15 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er if config.Proxy != nil { proxy = config.Proxy } - upgradeRoundTripper := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{ - TLS: tlsConfig, - Proxier: proxy, - PingPeriod: time.Second * 5, + upgradeRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{ + TLS: tlsConfig, + Proxier: proxy, + PingPeriod: time.Second * 5, + UpgradeTransport: nil, }) + if err != nil { + return nil, nil, err + } wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper) if err != nil { return nil, nil, err diff --git a/transport/websocket/roundtripper.go b/transport/websocket/roundtripper.go index e2a4a8ab..010f916b 100644 --- a/transport/websocket/roundtripper.go +++ b/transport/websocket/roundtripper.go @@ -108,10 +108,7 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response } wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) if err != nil { - if err != gwebsocket.ErrBadHandshake { - return nil, err - } - return nil, fmt.Errorf("unable to upgrade connection: %v", err) + return nil, &httpstream.UpgradeFailureError{Cause: err} } rt.Conn = wsConn @@ -155,7 +152,7 @@ func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http. req.Header[httpstream.HeaderProtocolVersion] = protocols resp, err := rt.RoundTrip(req) if err != nil { - return nil, fmt.Errorf("error sending request: %v", err) + return nil, err } err = resp.Body.Close() if err != nil { diff --git a/transport/websocket/roundtripper_test.go b/transport/websocket/roundtripper_test.go index 168d5d55..16bfbf57 100644 --- a/transport/websocket/roundtripper_test.go +++ b/transport/websocket/roundtripper_test.go @@ -49,7 +49,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) @@ -67,18 +67,17 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) { // Create fake WebSocket server. websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - conns, err := webSocketServerStreams(req, w) - if err != nil { - t.Fatalf("error on webSocketServerStreams: %v", err) - } - defer conns.conn.Close() + // Bad handshake means websocket server will not completely initialize. + _, err := webSocketServerStreams(req, w) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "websocket server finished before becoming ready")) })) defer websocketServer.Close() // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) @@ -105,7 +104,7 @@ func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) { // Create the websocket roundtripper and call "Negotiate" to create websocket connection. websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err)