From b7d73dd18cf2215a005f79d2a2ca2875528f1419 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Thu, 6 Jul 2023 21:22:07 -0700 Subject: [PATCH] StreamTranslator and FallbackExecutor for WebSockets Kubernetes-commit: 168998e87bfd49a1b0bc6402761fafd5ace3bb3b --- go.mod | 10 +- go.sum | 21 ++- tools/remotecommand/fallback.go | 57 ++++++ tools/remotecommand/fallback_test.go | 227 +++++++++++++++++++++++ tools/remotecommand/spdy.go | 29 ++- tools/remotecommand/spdy_test.go | 3 +- tools/remotecommand/websocket.go | 34 ++-- tools/remotecommand/websocket_test.go | 50 +++-- transport/spdy/spdy.go | 12 +- transport/websocket/roundtripper.go | 7 +- transport/websocket/roundtripper_test.go | 15 +- 11 files changed, 402 insertions(+), 63 deletions(-) create mode 100644 tools/remotecommand/fallback.go create mode 100644 tools/remotecommand/fallback_test.go diff --git a/go.mod b/go.mod index 67962fa4..a084e772 100644 --- a/go.mod +++ b/go.mod @@ -24,8 +24,8 @@ require ( golang.org/x/term v0.13.0 golang.org/x/time v0.3.0 google.golang.org/protobuf v1.31.0 - k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 - k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac + k8s.io/api v0.0.0 + k8s.io/apimachinery v0.0.0 k8s.io/klog/v2 v2.100.1 k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 k8s.io/utils v0.0.0-20230726121419-3b25d923346b @@ -49,6 +49,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sys v0.13.0 // indirect @@ -60,6 +61,7 @@ require ( ) replace ( - k8s.io/api => k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 - k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac + k8s.io/api => ../api + k8s.io/apimachinery => ../apimachinery + k8s.io/client-go => ../client-go ) diff --git a/go.sum b/go.sum index 1ded1d9f..820f7bd6 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,8 @@ +cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -17,6 +21,7 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -36,6 +41,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= +github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -53,6 +59,7 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -68,8 +75,12 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= +github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= github.com/onsi/gomega v1.28.0 h1:i2rg/p9n/UqIDAMFUJ6qIUUMcsqOuUHgbpbu235Vr1c= +github.com/onsi/gomega v1.28.0/go.mod h1:A1H2JE76sI14WIP57LMKj7FVfCHx3g3BcZVjJG8bjX8= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -77,6 +88,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -93,8 +105,10 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -126,10 +140,12 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= +golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -147,10 +163,7 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 h1:U7xcM/WBTkLV+TjNciuW7l+oXM2OHd5/TmVnPKyrmpA= -k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8/go.mod h1:mgYOiLIgrQcsuVxrBI6Pplk91r3sl5ZJ7eUx7UBMTkY= -k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac h1:x3g6c1u7CtRoraBlRP2JThB3aHz7vw4FZFXRZsvoIoc= -k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac/go.mod h1:mdlGhJWO1mhVzQXm1Lx7D1BvvBIVKlRVy0vvl1LwGjg= +k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= diff --git a/tools/remotecommand/fallback.go b/tools/remotecommand/fallback.go new file mode 100644 index 00000000..4846cdb5 --- /dev/null +++ b/tools/remotecommand/fallback.go @@ -0,0 +1,57 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" +) + +var _ Executor = &fallbackExecutor{} + +type fallbackExecutor struct { + primary Executor + secondary Executor + shouldFallback func(error) bool +} + +// NewFallbackExecutor creates an Executor that first attempts to use the +// WebSocketExecutor, falling back to the legacy SPDYExecutor if the initial +// websocket "StreamWithContext" call fails. +// func NewFallbackExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { +func NewFallbackExecutor(primary, secondary Executor, shouldFallback func(error) bool) (Executor, error) { + return &fallbackExecutor{ + primary: primary, + secondary: secondary, + shouldFallback: shouldFallback, + }, nil +} + +// Stream is deprecated. Please use "StreamWithContext". +func (f *fallbackExecutor) Stream(options StreamOptions) error { + return f.StreamWithContext(context.Background(), options) +} + +// StreamWithContext initially attempts to call "StreamWithContext" using the +// primary executor, falling back to calling the secondary executor if the +// initial primary call to upgrade to a websocket connection fails. +func (f *fallbackExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + err := f.primary.StreamWithContext(ctx, options) + if f.shouldFallback(err) { + return f.secondary.StreamWithContext(ctx, options) + } + return err +} diff --git a/tools/remotecommand/fallback_test.go b/tools/remotecommand/fallback_test.go new file mode 100644 index 00000000..70049857 --- /dev/null +++ b/tools/remotecommand/fallback_test.go @@ -0,0 +1,227 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "bytes" + "context" + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" +) + +func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + require.NoError(t, err) + })) + defer websocketServer.Close() + + // Now create the fallback client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation) + require.NoError(t, err) + // Never fallback, so always use the websocketExecutor, which succeeds against websocket server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false }) + require.NoError(t, err) + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + } + + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) { + // Create fake SPDY server. Copy received STDIN data back onto STDOUT stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var stdin, stdout bytes.Buffer + ctx, err := createHTTPStreams(w, req, &StreamOptions{ + Stdin: &stdin, + Stdout: &stdout, + }) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer ctx.conn.Close() + _, err = io.Copy(ctx.stdoutStream, ctx.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + })) + defer spdyServer.Close() + + spdyLocation, err := url.Parse(spdyServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation) + require.NoError(t, err) + // Always fallback to spdyExecutor, and spdyExecutor succeeds against fake spdy server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true }) + require.NoError(t, err) + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + } + + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + require.NoError(t, err) + })) + defer websocketServer.Close() + + // Now create the fallback client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation) + require.NoError(t, err) + // Always fallback to spdyExecutor, but spdyExecutor fails against websocket server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true }) + require.NoError(t, err) + // Update the websocket executor to request remote command v4, which is unsupported. + fallbackExec, ok := exec.(*fallbackExecutor) + assert.True(t, ok, "error casting executor as fallbackExecutor") + websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor) + assert.True(t, ok, "error casting executor as websocket executor") + // Set the attempted subprotocol version to V4; websocket server only accepts V5. + websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name} + + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Ensure secondary executor returned an error. + require.Error(t, err) + } +} diff --git a/tools/remotecommand/spdy.go b/tools/remotecommand/spdy.go index 76ea946b..c2bfcf8a 100644 --- a/tools/remotecommand/spdy.go +++ b/tools/remotecommand/spdy.go @@ -34,9 +34,10 @@ type spdyStreamExecutor struct { upgrader spdy.Upgrader transport http.RoundTripper - method string - url *url.URL - protocols []string + method string + url *url.URL + protocols []string + rejectRedirects bool // if true, receiving redirect from upstream is an error } // NewSPDYExecutor connects to the provided server and upgrades the connection to @@ -49,6 +50,20 @@ func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Ex return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url) } +// NewSPDYExecutorRejectRedirects returns an Executor that will upgrade the future +// connection to a SPDY bi-directional streaming connection when calling "Stream" (deprecated) +// or "StreamWithContext" (preferred). Additionally, if the upstream server returns a redirect +// during the attempted upgrade in these "Stream" calls, an error is returned. +func NewSPDYExecutorRejectRedirects(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { + executor, err := NewSPDYExecutorForTransports(transport, upgrader, method, url) + if err != nil { + return nil, err + } + spdyExecutor := executor.(*spdyStreamExecutor) + spdyExecutor.rejectRedirects = true + return spdyExecutor, nil +} + // NewSPDYExecutorForTransports connects to the provided server using the given transport, // upgrades the response using the given upgrader to multiplexed bidirectional streams. func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { @@ -88,9 +103,15 @@ func (e *spdyStreamExecutor) newConnectionAndStream(ctx context.Context, options return nil, nil, fmt.Errorf("error creating request: %v", err) } + client := http.Client{Transport: e.transport} + if e.rejectRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return fmt.Errorf("redirect not allowed") + } + } conn, protocol, err := spdy.Negotiate( e.upgrader, - &http.Client{Transport: e.transport}, + &client, req, e.protocols..., ) diff --git a/tools/remotecommand/spdy_test.go b/tools/remotecommand/spdy_test.go index c11177a0..1b1cf749 100644 --- a/tools/remotecommand/spdy_test.go +++ b/tools/remotecommand/spdy_test.go @@ -183,6 +183,7 @@ func TestSPDYExecutorStream(t *testing.T) { } func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { + //nolint:errcheck server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx, err := createHTTPStreams(writer, request, options) if err != nil { @@ -381,7 +382,7 @@ func TestStreamRandomData(t *testing.T) { } defer ctx.conn.Close() - io.Copy(ctx.stdoutStream, ctx.stdinStream) + io.Copy(ctx.stdoutStream, ctx.stdinStream) //nolint:errcheck })) defer server.Close() diff --git a/tools/remotecommand/websocket.go b/tools/remotecommand/websocket.go index 48e52092..a60986de 100644 --- a/tools/remotecommand/websocket.go +++ b/tools/remotecommand/websocket.go @@ -85,22 +85,26 @@ type wsStreamExecutor struct { heartbeatDeadline time.Duration } -// NewWebSocketExecutor allows to execute commands via a WebSocket connection. func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) { + // Only supports V5 protocol for correct version skew functionality. + // Previous api servers will proxy upgrade requests to legacy websocket + // servers on container runtimes which support V1-V4. These legacy + // websocket servers will not handle the new CLOSE signal. + return NewWebSocketExecutorForProtocols(config, method, url, remotecommand.StreamProtocolV5Name) +} + +// NewWebSocketExecutorForProtocols allows to execute commands via a WebSocket connection. +func NewWebSocketExecutorForProtocols(config *restclient.Config, method, url string, protocols ...string) (Executor, error) { transport, upgrader, err := websocket.RoundTripperFor(config) if err != nil { return nil, fmt.Errorf("error creating websocket transports: %v", err) } return &wsStreamExecutor{ - transport: transport, - upgrader: upgrader, - method: method, - url: url, - // Only supports V5 protocol for correct version skew functionality. - // Previous api servers will proxy upgrade requests to legacy websocket - // servers on container runtimes which support V1-V4. These legacy - // websocket servers will not handle the new CLOSE signal. - protocols: []string{remotecommand.StreamProtocolV5Name}, + transport: transport, + upgrader: upgrader, + method: method, + url: url, + protocols: protocols, heartbeatPeriod: pingPeriod, heartbeatDeadline: pingReadDeadline, }, nil @@ -177,10 +181,12 @@ func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options Stream } type wsStreamCreator struct { - conn *gwebsocket.Conn + conn *gwebsocket.Conn + // Protects writing to websocket connection; reading is lock-free connWriteLock sync.Mutex - streams map[byte]*stream - streamsMu sync.Mutex + // map of stream id to stream; multiple streams read/write the connection + streams map[byte]*stream + streamsMu sync.Mutex } func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { @@ -226,7 +232,7 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, return s, nil } -// readDemuxLoop is the reading processor for this endpoint of the websocket +// readDemuxLoop is the lock-free reading processor for this endpoint of the websocket // connection. This loop reads the connection, and demultiplexes the data // into one of the individual stream pipes (by checking the stream id). This // loop can *not* be run concurrently, because there can only be one websocket diff --git a/tools/remotecommand/websocket_test.go b/tools/remotecommand/websocket_test.go index 2895ba54..61df2b77 100644 --- a/tools/remotecommand/websocket_test.go +++ b/tools/remotecommand/websocket_test.go @@ -74,7 +74,7 @@ func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -149,7 +149,7 @@ func TestWebSocketClient_DifferentBufferSizes(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -223,7 +223,7 @@ func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -304,7 +304,7 @@ func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -377,7 +377,7 @@ func TestWebSocketClient_MultipleReadChannels(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -479,7 +479,7 @@ func TestWebSocketClient_ErrorStream(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -637,7 +637,7 @@ func TestWebSocketClient_MultipleWriteChannels(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -723,7 +723,7 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -766,11 +766,14 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) { func TestWebSocketClient_BadHandshake(t *testing.T) { // Create fake WebSocket server (supports V5 subprotocol). websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) - if err != nil { - t.Fatalf("error on webSocketServerStreams: %v", err) + // Bad handshake means websocket server will not completely initialize. + _, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err == nil { + t.Fatalf("expected error, but received none.") + } + if !strings.Contains(err.Error(), "websocket server finished before becoming ready") { + t.Errorf("expected websocket server error, but got: %v", err) } - defer conns.conn.Close() })) defer websocketServer.Close() @@ -779,7 +782,7 @@ func TestWebSocketClient_BadHandshake(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -831,7 +834,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -909,7 +912,7 @@ func TestWebSocketClient_TextMessageTypeError(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -970,7 +973,7 @@ func TestWebSocketClient_EmptyMessageHandled(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -1009,14 +1012,14 @@ func TestWebSocketClient_ExecutorErrors(t *testing.T) { ExecProvider: &clientcmdapi.ExecConfig{}, AuthProvider: &clientcmdapi.AuthProviderConfig{}, } - _, err := NewWebSocketExecutor(&config, "POST", "http://localhost") + _, err := NewWebSocketExecutor(&config, "GET", "http://localhost") if err == nil { t.Errorf("expecting executor constructor error, but received none.") } else if !strings.Contains(err.Error(), "error creating websocket transports") { t.Errorf("expecting error creating transports, got (%s)", err.Error()) } // Verify that a nil context will cause an error in StreamWithContext - exec, err := NewWebSocketExecutor(&rest.Config{}, "POST", "http://localhost") + exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost") if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -1316,7 +1319,16 @@ func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *opti resizeStream: streams[remotecommand.StreamResize], } - wsStreams.writeStatus = v4WriteStatusFunc(streams[remotecommand.StreamErr]) + wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error { + return func(status *apierrors.StatusError) error { + bs, err := json.Marshal(status.Status()) + if err != nil { + return err + } + _, err = stream.Write(bs) + return err + } + }(streams[remotecommand.StreamErr]) return wsStreams, nil } diff --git a/transport/spdy/spdy.go b/transport/spdy/spdy.go index f50b68e5..9fddc6c5 100644 --- a/transport/spdy/spdy.go +++ b/transport/spdy/spdy.go @@ -43,11 +43,15 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er if config.Proxy != nil { proxy = config.Proxy } - upgradeRoundTripper := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{ - TLS: tlsConfig, - Proxier: proxy, - PingPeriod: time.Second * 5, + upgradeRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{ + TLS: tlsConfig, + Proxier: proxy, + PingPeriod: time.Second * 5, + UpgradeTransport: nil, }) + if err != nil { + return nil, nil, err + } wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper) if err != nil { return nil, nil, err diff --git a/transport/websocket/roundtripper.go b/transport/websocket/roundtripper.go index e2a4a8ab..010f916b 100644 --- a/transport/websocket/roundtripper.go +++ b/transport/websocket/roundtripper.go @@ -108,10 +108,7 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response } wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) if err != nil { - if err != gwebsocket.ErrBadHandshake { - return nil, err - } - return nil, fmt.Errorf("unable to upgrade connection: %v", err) + return nil, &httpstream.UpgradeFailureError{Cause: err} } rt.Conn = wsConn @@ -155,7 +152,7 @@ func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http. req.Header[httpstream.HeaderProtocolVersion] = protocols resp, err := rt.RoundTrip(req) if err != nil { - return nil, fmt.Errorf("error sending request: %v", err) + return nil, err } err = resp.Body.Close() if err != nil { diff --git a/transport/websocket/roundtripper_test.go b/transport/websocket/roundtripper_test.go index 168d5d55..16bfbf57 100644 --- a/transport/websocket/roundtripper_test.go +++ b/transport/websocket/roundtripper_test.go @@ -49,7 +49,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) @@ -67,18 +67,17 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) { // Create fake WebSocket server. websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - conns, err := webSocketServerStreams(req, w) - if err != nil { - t.Fatalf("error on webSocketServerStreams: %v", err) - } - defer conns.conn.Close() + // Bad handshake means websocket server will not completely initialize. + _, err := webSocketServerStreams(req, w) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "websocket server finished before becoming ready")) })) defer websocketServer.Close() // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) @@ -105,7 +104,7 @@ func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) { // Create the websocket roundtripper and call "Negotiate" to create websocket connection. websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err)