From fd1e22bd8f6286a1e7265426168dcc1570a855c6 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Thu, 6 Jul 2023 21:22:07 -0700 Subject: [PATCH] WebSocket Client and V5 RemoteCommand Subprotocol Kubernetes-commit: a0d6a815fcc02cbfea1bb22d13a8e896ecbe116c --- go.mod | 10 +- go.sum | 20 +- tools/remotecommand/remotecommand.go | 124 -- tools/remotecommand/spdy.go | 150 ++ .../{remotecommand_test.go => spdy_test.go} | 2 +- tools/remotecommand/v5.go | 35 + tools/remotecommand/websocket.go | 485 ++++++ tools/remotecommand/websocket_test.go | 1303 +++++++++++++++++ transport/websocket/roundtripper.go | 166 +++ transport/websocket/roundtripper_test.go | 140 ++ 10 files changed, 2302 insertions(+), 133 deletions(-) create mode 100644 tools/remotecommand/spdy.go rename tools/remotecommand/{remotecommand_test.go => spdy_test.go} (99%) create mode 100644 tools/remotecommand/v5.go create mode 100644 tools/remotecommand/websocket.go create mode 100644 tools/remotecommand/websocket_test.go create mode 100644 transport/websocket/roundtripper.go create mode 100644 transport/websocket/roundtripper_test.go diff --git a/go.mod b/go.mod index 4f19d7a0..02abf744 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/gofuzz v1.2.0 github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.4.2 github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 github.com/imdario/mergo v0.3.6 github.com/peterbourgon/diskv v2.0.1+incompatible @@ -23,8 +24,8 @@ require ( golang.org/x/term v0.10.0 golang.org/x/time v0.3.0 google.golang.org/protobuf v1.31.0 - k8s.io/api v0.0.0-20230901043046-faec07c7cc89 - k8s.io/apimachinery v0.0.0-20230901041540-0d057e543013 + k8s.io/api v0.0.0 + k8s.io/apimachinery v0.0.0 k8s.io/klog/v2 v2.100.1 k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 @@ -60,6 +61,7 @@ require ( ) replace ( - k8s.io/api => k8s.io/api v0.0.0-20230901043046-faec07c7cc89 - k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20230901041540-0d057e543013 + k8s.io/api => ../api + k8s.io/apimachinery => ../apimachinery + k8s.io/client-go => ../client-go ) diff --git a/go.sum b/go.sum index 490d4e1a..88c5c827 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,8 @@ +cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -17,6 +21,7 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -36,8 +41,10 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= +github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= @@ -51,6 +58,7 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -66,7 +74,9 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= +github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= @@ -76,6 +86,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -92,8 +103,10 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -125,10 +138,12 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= +golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -146,10 +161,7 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.0.0-20230901043046-faec07c7cc89 h1:qgk2nx6RBhIWmGYpmsbP9nuOFK/ZHpzKS2Bk13dYzqU= -k8s.io/api v0.0.0-20230901043046-faec07c7cc89/go.mod h1:qunKCI5HG1/XacwNcIwUHAFWPw6hFmIP5fb13hHJARM= -k8s.io/apimachinery v0.0.0-20230901041540-0d057e543013 h1:SWrddEv3aWWseUfiT42ziJUc8Gck/2Iqltz0Z+ctTB8= -k8s.io/apimachinery v0.0.0-20230901041540-0d057e543013/go.mod h1:4XSLsQ98qjSQA9IF1M+KnXoc+E7kiualOL0cx3GFYZI= +k8s.io/gengo v0.0.0-20210813121822-485abfe95c7c/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 h1:LyMgNKD2P8Wn1iAwQU5OhxCKlKJy0sHc+PcDwFB24dQ= diff --git a/tools/remotecommand/remotecommand.go b/tools/remotecommand/remotecommand.go index 662a3cb4..1ae67729 100644 --- a/tools/remotecommand/remotecommand.go +++ b/tools/remotecommand/remotecommand.go @@ -18,17 +18,10 @@ package remotecommand import ( "context" - "fmt" "io" "net/http" - "net/url" - - "k8s.io/klog/v2" "k8s.io/apimachinery/pkg/util/httpstream" - "k8s.io/apimachinery/pkg/util/remotecommand" - restclient "k8s.io/client-go/rest" - "k8s.io/client-go/transport/spdy" ) // StreamOptions holds information pertaining to the current streaming session: @@ -63,120 +56,3 @@ type streamCreator interface { type streamProtocolHandler interface { stream(conn streamCreator) error } - -// streamExecutor handles transporting standard shell streams over an httpstream connection. -type streamExecutor struct { - upgrader spdy.Upgrader - transport http.RoundTripper - - method string - url *url.URL - protocols []string -} - -// NewSPDYExecutor connects to the provided server and upgrades the connection to -// multiplexed bidirectional streams. -func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { - wrapper, upgradeRoundTripper, err := spdy.RoundTripperFor(config) - if err != nil { - return nil, err - } - return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url) -} - -// NewSPDYExecutorForTransports connects to the provided server using the given transport, -// upgrades the response using the given upgrader to multiplexed bidirectional streams. -func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { - return NewSPDYExecutorForProtocols( - transport, upgrader, method, url, - remotecommand.StreamProtocolV4Name, - remotecommand.StreamProtocolV3Name, - remotecommand.StreamProtocolV2Name, - remotecommand.StreamProtocolV1Name, - ) -} - -// NewSPDYExecutorForProtocols connects to the provided server and upgrades the connection to -// multiplexed bidirectional streams using only the provided protocols. Exposed for testing, most -// callers should use NewSPDYExecutor or NewSPDYExecutorForTransports. -func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL, protocols ...string) (Executor, error) { - return &streamExecutor{ - upgrader: upgrader, - transport: transport, - method: method, - url: url, - protocols: protocols, - }, nil -} - -// Stream opens a protocol streamer to the server and streams until a client closes -// the connection or the server disconnects. -func (e *streamExecutor) Stream(options StreamOptions) error { - return e.StreamWithContext(context.Background(), options) -} - -// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it. -func (e *streamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) { - req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil) - if err != nil { - return nil, nil, fmt.Errorf("error creating request: %v", err) - } - - conn, protocol, err := spdy.Negotiate( - e.upgrader, - &http.Client{Transport: e.transport}, - req, - e.protocols..., - ) - if err != nil { - return nil, nil, err - } - - var streamer streamProtocolHandler - - switch protocol { - case remotecommand.StreamProtocolV4Name: - streamer = newStreamProtocolV4(options) - case remotecommand.StreamProtocolV3Name: - streamer = newStreamProtocolV3(options) - case remotecommand.StreamProtocolV2Name: - streamer = newStreamProtocolV2(options) - case "": - klog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name) - fallthrough - case remotecommand.StreamProtocolV1Name: - streamer = newStreamProtocolV1(options) - } - - return conn, streamer, nil -} - -// StreamWithContext opens a protocol streamer to the server and streams until a client closes -// the connection or the server disconnects or the context is done. -func (e *streamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { - conn, streamer, err := e.newConnectionAndStream(ctx, options) - if err != nil { - return err - } - defer conn.Close() - - panicChan := make(chan any, 1) - errorChan := make(chan error, 1) - go func() { - defer func() { - if p := recover(); p != nil { - panicChan <- p - } - }() - errorChan <- streamer.stream(conn) - }() - - select { - case p := <-panicChan: - panic(p) - case err := <-errorChan: - return err - case <-ctx.Done(): - return ctx.Err() - } -} diff --git a/tools/remotecommand/spdy.go b/tools/remotecommand/spdy.go new file mode 100644 index 00000000..76ea946b --- /dev/null +++ b/tools/remotecommand/spdy.go @@ -0,0 +1,150 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/remotecommand" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport/spdy" + "k8s.io/klog/v2" +) + +// spdyStreamExecutor handles transporting standard shell streams over an httpstream connection. +type spdyStreamExecutor struct { + upgrader spdy.Upgrader + transport http.RoundTripper + + method string + url *url.URL + protocols []string +} + +// NewSPDYExecutor connects to the provided server and upgrades the connection to +// multiplexed bidirectional streams. +func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { + wrapper, upgradeRoundTripper, err := spdy.RoundTripperFor(config) + if err != nil { + return nil, err + } + return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url) +} + +// NewSPDYExecutorForTransports connects to the provided server using the given transport, +// upgrades the response using the given upgrader to multiplexed bidirectional streams. +func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { + return NewSPDYExecutorForProtocols( + transport, upgrader, method, url, + remotecommand.StreamProtocolV5Name, + remotecommand.StreamProtocolV4Name, + remotecommand.StreamProtocolV3Name, + remotecommand.StreamProtocolV2Name, + remotecommand.StreamProtocolV1Name, + ) +} + +// NewSPDYExecutorForProtocols connects to the provided server and upgrades the connection to +// multiplexed bidirectional streams using only the provided protocols. Exposed for testing, most +// callers should use NewSPDYExecutor or NewSPDYExecutorForTransports. +func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL, protocols ...string) (Executor, error) { + return &spdyStreamExecutor{ + upgrader: upgrader, + transport: transport, + method: method, + url: url, + protocols: protocols, + }, nil +} + +// Stream opens a protocol streamer to the server and streams until a client closes +// the connection or the server disconnects. +func (e *spdyStreamExecutor) Stream(options StreamOptions) error { + return e.StreamWithContext(context.Background(), options) +} + +// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it. +func (e *spdyStreamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) { + req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil) + if err != nil { + return nil, nil, fmt.Errorf("error creating request: %v", err) + } + + conn, protocol, err := spdy.Negotiate( + e.upgrader, + &http.Client{Transport: e.transport}, + req, + e.protocols..., + ) + if err != nil { + return nil, nil, err + } + + var streamer streamProtocolHandler + + switch protocol { + case remotecommand.StreamProtocolV5Name: + streamer = newStreamProtocolV5(options) + case remotecommand.StreamProtocolV4Name: + streamer = newStreamProtocolV4(options) + case remotecommand.StreamProtocolV3Name: + streamer = newStreamProtocolV3(options) + case remotecommand.StreamProtocolV2Name: + streamer = newStreamProtocolV2(options) + case "": + klog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name) + fallthrough + case remotecommand.StreamProtocolV1Name: + streamer = newStreamProtocolV1(options) + } + + return conn, streamer, nil +} + +// StreamWithContext opens a protocol streamer to the server and streams until a client closes +// the connection or the server disconnects or the context is done. +func (e *spdyStreamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + conn, streamer, err := e.newConnectionAndStream(ctx, options) + if err != nil { + return err + } + defer conn.Close() + + panicChan := make(chan any, 1) + errorChan := make(chan error, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + errorChan <- streamer.stream(conn) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-errorChan: + return err + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/tools/remotecommand/remotecommand_test.go b/tools/remotecommand/spdy_test.go similarity index 99% rename from tools/remotecommand/remotecommand_test.go rename to tools/remotecommand/spdy_test.go index c59b8270..c11177a0 100644 --- a/tools/remotecommand/remotecommand_test.go +++ b/tools/remotecommand/spdy_test.go @@ -342,7 +342,7 @@ func TestStreamExitsAfterConnectionIsClosed(t *testing.T) { if err != nil { t.Fatal(err) } - streamExec := exec.(*streamExecutor) + streamExec := exec.(*spdyStreamExecutor) conn, streamer, err := streamExec.newConnectionAndStream(ctx, options) if err != nil { diff --git a/tools/remotecommand/v5.go b/tools/remotecommand/v5.go new file mode 100644 index 00000000..4da7bfb1 --- /dev/null +++ b/tools/remotecommand/v5.go @@ -0,0 +1,35 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +// streamProtocolV5 add support for V5 of the remote command subprotocol. +// For the streamProtocolHandler, this version is the same as V4. +type streamProtocolV5 struct { + *streamProtocolV4 +} + +var _ streamProtocolHandler = &streamProtocolV5{} + +func newStreamProtocolV5(options StreamOptions) streamProtocolHandler { + return &streamProtocolV5{ + streamProtocolV4: newStreamProtocolV4(options).(*streamProtocolV4), + } +} + +func (p *streamProtocolV5) stream(conn streamCreator) error { + return p.streamProtocolV4.stream(conn) +} diff --git a/tools/remotecommand/websocket.go b/tools/remotecommand/websocket.go new file mode 100644 index 00000000..9230027c --- /dev/null +++ b/tools/remotecommand/websocket.go @@ -0,0 +1,485 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" + "fmt" + "io" + "net/http" + "sync" + "time" + + gwebsocket "github.com/gorilla/websocket" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/remotecommand" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport/websocket" + "k8s.io/klog/v2" +) + +// writeDeadline defines the time that a write to the websocket connection +// must complete by, otherwise an i/o timeout occurs. The writeDeadline +// has nothing to do with a response from the other websocket connection +// endpoint; only that the message was successfully processed by the +// local websocket connection. The typical write deadline within the websocket +// library is one second. +const writeDeadline = 2 * time.Second + +var ( + _ Executor = &wsStreamExecutor{} + _ streamCreator = &wsStreamCreator{} + _ httpstream.Stream = &stream{} + + streamType2streamID = map[string]byte{ + v1.StreamTypeStdin: remotecommand.StreamStdIn, + v1.StreamTypeStdout: remotecommand.StreamStdOut, + v1.StreamTypeStderr: remotecommand.StreamStdErr, + v1.StreamTypeError: remotecommand.StreamErr, + v1.StreamTypeResize: remotecommand.StreamResize, + } +) + +const ( + // pingPeriod defines how often a heartbeat "ping" message is sent. + pingPeriod = 5 * time.Second + // pingReadDeadline defines the time waiting for a response heartbeat + // "pong" message before a timeout error occurs for websocket reading. + // This duration must always be greater than the "pingPeriod". By defining + // this deadline in terms of the ping period, we are essentially saying + // we can drop "X-1" (e.g. 3-1=2) pings before firing the timeout. + pingReadDeadline = (pingPeriod * 3) + (1 * time.Second) +) + +// wsStreamExecutor handles transporting standard shell streams over an httpstream connection. +type wsStreamExecutor struct { + transport http.RoundTripper + upgrader websocket.ConnectionHolder + method string + url string + // requested protocols in priority order (e.g. v5.channel.k8s.io before v4.channel.k8s.io). + protocols []string + // selected protocol from the handshake process; could be empty string if handshake fails. + negotiated string + // period defines how often a "ping" heartbeat message is sent to the other endpoint. + heartbeatPeriod time.Duration + // deadline defines the amount of time before "pong" response must be received. + heartbeatDeadline time.Duration +} + +// NewWebSocketExecutor allows to execute commands via a WebSocket connection. +func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) { + transport, upgrader, err := websocket.RoundTripperFor(config) + if err != nil { + return nil, fmt.Errorf("error creating websocket transports: %v", err) + } + return &wsStreamExecutor{ + transport: transport, + upgrader: upgrader, + method: method, + url: url, + // Only supports V5 protocol for correct version skew functionality. + // Previous api servers will proxy upgrade requests to legacy websocket + // servers on container runtimes which support V1-V4. These legacy + // websocket servers will not handle the new CLOSE signal. + protocols: []string{remotecommand.StreamProtocolV5Name}, + heartbeatPeriod: pingPeriod, + heartbeatDeadline: pingReadDeadline, + }, nil +} + +// Deprecated: use StreamWithContext instead to avoid possible resource leaks. +// See https://github.com/kubernetes/kubernetes/pull/103177 for details. +func (e *wsStreamExecutor) Stream(options StreamOptions) error { + return e.StreamWithContext(context.Background(), options) +} + +// StreamWithContext upgrades an HTTPRequest to a WebSocket connection, and starts the various +// goroutines to implement the necessary streams over the connection. The "options" parameter +// defines which streams are requested. Returns an error if one occurred. This method is NOT +// safe to run concurrently with the same executor (because of the state stored in the upgrader). +func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + req, err := http.NewRequestWithContext(ctx, e.method, e.url, nil) + if err != nil { + return err + } + conn, err := websocket.Negotiate(e.transport, e.upgrader, req, e.protocols...) + if err != nil { + return err + } + if conn == nil { + panic(fmt.Errorf("websocket connection is nil")) + } + defer conn.Close() + e.negotiated = conn.Subprotocol() + klog.V(4).Infof("The subprotocol is %s", e.negotiated) + + var streamer streamProtocolHandler + switch e.negotiated { + case remotecommand.StreamProtocolV5Name: + streamer = newStreamProtocolV5(options) + case remotecommand.StreamProtocolV4Name: + streamer = newStreamProtocolV4(options) + case remotecommand.StreamProtocolV3Name: + streamer = newStreamProtocolV3(options) + case remotecommand.StreamProtocolV2Name: + streamer = newStreamProtocolV2(options) + case "": + klog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name) + fallthrough + case remotecommand.StreamProtocolV1Name: + streamer = newStreamProtocolV1(options) + } + + panicChan := make(chan any, 1) + errorChan := make(chan error, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + creator := newWSStreamCreator(conn) + go creator.readDemuxLoop( + e.upgrader.DataBufferSize(), + e.heartbeatPeriod, + e.heartbeatDeadline, + ) + errorChan <- streamer.stream(creator) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-errorChan: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +type wsStreamCreator struct { + conn *gwebsocket.Conn + connWriteLock sync.Mutex + streams map[byte]*stream + streamsMu sync.Mutex +} + +func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { + return &wsStreamCreator{ + conn: conn, + streams: map[byte]*stream{}, + } +} + +func (c *wsStreamCreator) getStream(id byte) *stream { + c.streamsMu.Lock() + defer c.streamsMu.Unlock() + return c.streams[id] +} + +func (c *wsStreamCreator) setStream(id byte, s *stream) { + c.streamsMu.Lock() + defer c.streamsMu.Unlock() + c.streams[id] = s +} + +// CreateStream uses id from passed headers to create a stream over "c.conn" connection. +// Returns a Stream structure or nil and an error if one occurred. +func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) { + streamType := headers.Get(v1.StreamType) + id, ok := streamType2streamID[streamType] + if !ok { + return nil, fmt.Errorf("unknown stream type: %s", streamType) + } + if s := c.getStream(id); s != nil { + return nil, fmt.Errorf("duplicate stream for type %s", streamType) + } + reader, writer := io.Pipe() + s := &stream{ + headers: headers, + readPipe: reader, + writePipe: writer, + conn: c.conn, + connWriteLock: &c.connWriteLock, + id: id, + } + c.setStream(id, s) + return s, nil +} + +// readDemuxLoop is the reading processor for this endpoint of the websocket +// connection. This loop reads the connection, and demultiplexes the data +// into one of the individual stream pipes (by checking the stream id). This +// loop can *not* be run concurrently, because there can only be one websocket +// connection reader at a time (a read mutex would provide no benefit). +func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, deadline time.Duration) { + // Initialize and start the ping/pong heartbeat. + h := newHeartbeat(c.conn, period, deadline) + // Set initial timeout for websocket connection reading. + if err := c.conn.SetReadDeadline(time.Now().Add(deadline)); err != nil { + klog.Errorf("Websocket initial setting read deadline failed %v", err) + return + } + go h.start() + // Buffer size must correspond to the same size allocated + // for the read buffer during websocket client creation. A + // difference can cause incomplete connection reads. + readBuffer := make([]byte, bufferSize) + for { + // NextReader() only returns data messages (BinaryMessage or Text + // Message). Even though this call will never return control frames + // such as ping, pong, or close, this call is necessary for these + // message types to be processed. There can only be one reader + // at a time, so this reader loop must *not* be run concurrently; + // there is no lock for reading. Calling "NextReader()" before the + // current reader has been processed will close the current reader. + // If the heartbeat read deadline times out, this "NextReader()" will + // return an i/o error, and error handling will clean up. + messageType, r, err := c.conn.NextReader() + if err != nil { + websocketErr, ok := err.(*gwebsocket.CloseError) + if ok && websocketErr.Code == gwebsocket.CloseNormalClosure { + err = nil // readers will get io.EOF as it's a normal closure + } else { + err = fmt.Errorf("next reader: %w", err) + } + c.closeAllStreamReaders(err) + return + } + // All remote command protocols send/receive only binary data messages. + if messageType != gwebsocket.BinaryMessage { + c.closeAllStreamReaders(fmt.Errorf("unexpected message type: %d", messageType)) + return + } + // It's ok to read just a single byte because the underlying library wraps the actual + // connection with a buffered reader anyway. + _, err = io.ReadFull(r, readBuffer[:1]) + if err != nil { + c.closeAllStreamReaders(fmt.Errorf("read stream id: %w", err)) + return + } + streamID := readBuffer[0] + s := c.getStream(streamID) + if s == nil { + klog.Errorf("Unknown stream id %d, discarding message", streamID) + continue + } + for { + nr, errRead := r.Read(readBuffer) + if nr > 0 { + // Write the data to the stream's pipe. This can block. + _, errWrite := s.writePipe.Write(readBuffer[:nr]) + if errWrite != nil { + // Pipe must have been closed by the stream user. + // Nothing to do, discard the message. + break + } + } + if errRead != nil { + if errRead == io.EOF { + break + } + c.closeAllStreamReaders(fmt.Errorf("read message: %w", err)) + return + } + } + } +} + +// closeAllStreamReaders closes readers in all streams. +// This unblocks all stream.Read() calls. +func (c *wsStreamCreator) closeAllStreamReaders(err error) { + c.streamsMu.Lock() + defer c.streamsMu.Unlock() + for _, s := range c.streams { + // Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. + _ = s.writePipe.CloseWithError(err) + } +} + +type stream struct { + headers http.Header + readPipe *io.PipeReader + writePipe *io.PipeWriter + // conn is used for writing directly into the connection. + // Is nil after Close() / Reset() to prevent future writes. + conn *gwebsocket.Conn + // connWriteLock protects conn against concurrent write operations. There must be a single writer and a single reader only. + // The mutex is shared across all streams because the underlying connection is shared. + connWriteLock *sync.Mutex + id byte +} + +func (s *stream) Read(p []byte) (n int, err error) { + return s.readPipe.Read(p) +} + +// Write writes directly to the underlying WebSocket connection. +func (s *stream) Write(p []byte) (n int, err error) { + klog.V(4).Infof("Write() on stream %d", s.id) + defer klog.V(4).Infof("Write() done on stream %d", s.id) + s.connWriteLock.Lock() + defer s.connWriteLock.Unlock() + if s.conn == nil { + return 0, fmt.Errorf("write on closed stream %d", s.id) + } + err = s.conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if err != nil { + klog.V(7).Infof("Websocket setting write deadline failed %v", err) + return 0, err + } + // Message writer buffers the message data, so we don't need to do that ourselves. + // Just write id and the data as two separate writes to avoid allocating an intermediate buffer. + w, err := s.conn.NextWriter(gwebsocket.BinaryMessage) + if err != nil { + return 0, err + } + defer func() { + if w != nil { + w.Close() + } + }() + _, err = w.Write([]byte{s.id}) + if err != nil { + return 0, err + } + n, err = w.Write(p) + if err != nil { + return n, err + } + err = w.Close() + w = nil + return n, err +} + +// Close half-closes the stream, indicating this side is finished with the stream. +func (s *stream) Close() error { + klog.V(4).Infof("Close() on stream %d", s.id) + defer klog.V(4).Infof("Close() done on stream %d", s.id) + s.connWriteLock.Lock() + defer s.connWriteLock.Unlock() + if s.conn == nil { + return fmt.Errorf("Close() on already closed stream %d", s.id) + } + // Communicate the CLOSE stream signal to the other websocket endpoint. + err := s.conn.WriteMessage(gwebsocket.BinaryMessage, []byte{remotecommand.StreamClose, s.id}) + s.conn = nil + return err +} + +func (s *stream) Reset() error { + klog.V(4).Infof("Reset() on stream %d", s.id) + defer klog.V(4).Infof("Reset() done on stream %d", s.id) + s.Close() + return s.writePipe.Close() +} + +func (s *stream) Headers() http.Header { + return s.headers +} + +func (s *stream) Identifier() uint32 { + return uint32(s.id) +} + +// heartbeat encasulates data necessary for the websocket ping/pong heartbeat. This +// heartbeat works by setting a read deadline on the websocket connection, then +// pushing this deadline into the future for every successful heartbeat. If the +// heartbeat "pong" fails to respond within the deadline, then the "NextReader()" call +// inside the "readDemuxLoop" will return an i/o error prompting a connection close +// and cleanup. +type heartbeat struct { + conn *gwebsocket.Conn + // period defines how often a "ping" heartbeat message is sent to the other endpoint + period time.Duration + // closing the "closer" channel will clean up the heartbeat timers + closer chan struct{} + // optional data to send with "ping" message + message []byte + // optionally received data message with "pong" message, same as sent with ping + pongMessage []byte +} + +// newHeartbeat creates heartbeat structure encapsulating fields necessary to +// run the websocket connection ping/pong mechanism and sets up handlers on +// the websocket connection. +func newHeartbeat(conn *gwebsocket.Conn, period time.Duration, deadline time.Duration) *heartbeat { + h := &heartbeat{ + conn: conn, + period: period, + closer: make(chan struct{}), + } + // Set up handler for receiving returned "pong" message from other endpoint + // by pushing the read deadline into the future. The "msg" received could + // be empty. + h.conn.SetPongHandler(func(msg string) error { + // Push the read deadline into the future. + klog.V(8).Infof("Pong message received (%s)--resetting read deadline", msg) + err := h.conn.SetReadDeadline(time.Now().Add(deadline)) + if err != nil { + klog.Errorf("Websocket setting read deadline failed %v", err) + return err + } + if len(msg) > 0 { + h.pongMessage = []byte(msg) + } + return nil + }) + // Set up handler to cleanup timers when this endpoint receives "Close" message. + closeHandler := h.conn.CloseHandler() + h.conn.SetCloseHandler(func(code int, text string) error { + close(h.closer) + return closeHandler(code, text) + }) + return h +} + +// setMessage is optional data sent with "ping" heartbeat. According to the websocket RFC +// this data sent with "ping" message should be returned in "pong" message. +func (h *heartbeat) setMessage(msg string) { + h.message = []byte(msg) +} + +// start the heartbeat by setting up necesssary handlers and looping by sending "ping" +// message every "period" until the "closer" channel is closed. +func (h *heartbeat) start() { + // Loop to continually send "ping" message through websocket connection every "period". + t := time.NewTicker(h.period) + defer t.Stop() + for { + select { + case <-h.closer: + klog.V(8).Infof("closed channel--returning") + return + case <-t.C: + // "WriteControl" does not need to be protected by a mutex. According to + // gorilla/websockets library docs: "The Close and WriteControl methods can + // be called concurrently with all other methods." + if err := h.conn.WriteControl(gwebsocket.PingMessage, h.message, time.Now().Add(writeDeadline)); err == nil { + klog.V(8).Infof("Websocket Ping succeeeded") + } else { + klog.Errorf("Websocket Ping failed: %v", err) + // Continue, in case this is a transient failure. + // c.conn.CloseChan above will tell us when the connection is + // actually closed. + } + } + } +} diff --git a/tools/remotecommand/websocket_test.go b/tools/remotecommand/websocket_test.go new file mode 100644 index 00000000..2b0be67c --- /dev/null +++ b/tools/remotecommand/websocket_test.go @@ -0,0 +1,1303 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/json" + "fmt" + "io" + "math" + mrand "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "sync" + "testing" + "time" + + gwebsocket "github.com/gorilla/websocket" + + "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" + "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// TestWebSocketClient_LoopbackStdinToStdout returns random data sent on the STDIN channel +// back down the STDOUT channel. A subsequent comparison checks if the data +// sent on the STDIN channel is the same as the data returned on the STDOUT +// channel. This test can be run many times by the "stress" tool to check +// if there is any data which would cause problems with the WebSocket streams. +func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + })) + defer websocketServer.Close() + + // Now create the WebSocket client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the WebSocket client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Fatalf("error reading the stream: %v", err) + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +// TestWebSocketClient_DifferentBufferSizes runs the previous loopback (STDIN -> STDOUT) test with different +// buffer sizes for reading from the opposite end of the websocket connection (in the websocket server). +func TestWebSocketClient_DifferentBufferSizes(t *testing.T) { + // 1k, 4k, 64k, and 128k buffer sizes for reading STDIN at websocket server endpoint. + // The standard buffer size for io.Copy is 32k. + bufferSizes := []int{1 * 1024, 4 * 1024, 64 * 1024, 128 * 1024} + for _, bufferSize := range bufferSizes { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream, using buffer with size. + buffer := make([]byte, bufferSize) + _, err = io.CopyBuffer(conns.stdoutStream, conns.stdinStream, buffer) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + })) + defer websocketServer.Close() + + // Now create the WebSocket client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the WebSocket client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check all the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } + } +} + +// TestWebSocketClient_LoopbackStdinAsPipe uses a pipe to send random data on the STDIN +// channel, then closes the pipe. The fake server simply returns all STDIN data back +// onto the STDOUT channel, and the received data on STDOUT is validated against the +// random data initially sent. +func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + })) + defer websocketServer.Close() + + // Now create the WebSocket client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the WebSocket client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and it will be written on the STDIN pipe. The same + // data will be returned on the STDOUT channel. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + reader, writer := io.Pipe() + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: reader, + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + // Write the random data onto the pipe connected to STDIN, then close the pipe. + _, err = writer.Write(randomData) + if err != nil { + t.Fatalf("unable to write random data to STDIN pipe: %v", err) + } + writer.Close() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +// TestWebSocketClient_LoopbackStdinToStderr returns random data sent on the STDIN channel +// back down the STDERR channel. A subsequent comparison checks if the data +// sent on the STDIN channel is the same as the data returned on the STDERR +// channel. This test can be run many times by the "stress" tool to check +// if there is any data which would cause problems with the WebSocket streams. +func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDERR stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDERR stream. + _, err = io.Copy(conns.stderrStream, conns.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDERR: %v", err) + } + })) + defer websocketServer.Close() + + // Now create the WebSocket client (executor), and point it to the "websocketServer". + // Must add STDIN and STDERR query params for the WebSocket client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDERR buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stderr bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stderr: &stderr, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } + data, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDERR. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +// TestWebSocketClient_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from +// the websocket connection at the same time. +func TestWebSocketClient_MultipleReadChannels(t *testing.T) { + // Create fake WebSocket server, which uses a TeeReader to copy the same data + // onto the STDOUT stream onto the STDERR stream as well. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // TeeReader copies data read on STDIN onto STDERR. + stdinReader := io.TeeReader(conns.stdinStream, conns.stderrStream) + // Also copy STDIN to STDOUT. + _, err = io.Copy(conns.stdoutStream, stdinReader) + if err != nil { + t.Errorf("error copying STDIN to STDOUT: %v", err) + } + })) + defer websocketServer.Close() + // Now create the WebSocket client (executor), and point it to the "websocketServer". + // Must add stdin, stdout, and stderr query param for the WebSocket client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + "&" + "stderr=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate 1MB of random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT and STDERR buffers. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout, stderr bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + Stderr: &stderr, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } + // Validate the data read from the STDOUT stream is the same as sent on the STDIN stream. + stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Fatalf("error reading the stream: %v", err) + } + if !bytes.Equal(stdoutBytes, randomData) { + t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(randomData)) + } + // Validate the data read from the STDERR stream is the same as sent on the STDIN stream. + stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) + if err != nil { + t.Fatalf("error reading the stream: %v", err) + } + if !bytes.Equal(stderrBytes, randomData) { + t.Errorf("unexpected data received (%d) sent (%d)", len(stderrBytes), len(randomData)) + } +} + +// Returns a random exit code in the range(1-127). +func randomExitCode() int { + errorCode := mrand.Intn(128) + if errorCode == 0 { + errorCode = 1 + } + return errorCode +} + +// TestWebSocketClient_ErrorStream tests the websocket error stream by hard-coding a +// structured non-zero exit code error from the websocket server to the websocket client. +func TestWebSocketClient_ErrorStream(t *testing.T) { + expectedExitCode := randomExitCode() + // Create fake WebSocket server. Returns structured exit code error on error stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + _, err = io.Copy(conns.stderrStream, conns.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDERR: %v", err) + } + // Force an non-zero exit code error returned on the error stream. + err = conns.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{ + Status: metav1.StatusFailure, + Reason: remotecommand.NonZeroExitCodeReason, + Details: &metav1.StatusDetails{ + Causes: []metav1.StatusCause{ + { + Type: remotecommand.ExitCodeCauseType, + Message: fmt.Sprintf("%d", expectedExitCode), + }, + }, + }, + }}) + if err != nil { + t.Fatalf("error writing status: %v", err) + } + })) + defer websocketServer.Close() + + // Now create the WebSocket client (executor), and point it to the "websocketServer". + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + randomData := make([]byte, 256) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stderr bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stderr: &stderr, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + // Expect exit code error on error stream. + if err == nil { + t.Errorf("expected error, but received none") + } + expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode) + // Compare expected error with exit code to actual error. + if expectedError != err.Error() { + t.Errorf("expected error (%s), got (%s)", expectedError, err) + } + } +} + +// fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of +// "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice. +type fakeTerminalSizeQueue struct { + maxSizes int + terminalSizes []TerminalSize +} + +// newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing +// "max" number of random TerminalSizes created. +func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue { + return &fakeTerminalSizeQueue{ + maxSizes: max, + terminalSizes: make([]TerminalSize, 0, max), + } +} + +// Next returns a pointer to the next random TerminalSize, or nil if we have +// already returned "maxSizes" TerminalSizes already. Stores the randomly +// created TerminalSize in "terminalSizes" field for later validation. +func (f *fakeTerminalSizeQueue) Next() *TerminalSize { + if len(f.terminalSizes) >= f.maxSizes { + return nil + } + size := randomTerminalSize() + f.terminalSizes = append(f.terminalSizes, size) + return &size +} + +// randomTerminalSize returns a TerminalSize with random values in the +// range (0-65535) for the fields Width and Height. +func randomTerminalSize() TerminalSize { + randWidth := uint16(mrand.Intn(int(math.Pow(2, 16)))) + randHeight := uint16(mrand.Intn(int(math.Pow(2, 16)))) + return TerminalSize{ + Width: randWidth, + Height: randHeight, + } +} + +// randReader implements the ReadCloser interface, and it continuously +// returns random data until it is closed. Stores number of random +// bytes generated and returned. +type randReader struct { + randBytes []byte + closed bool + lock sync.Mutex +} + +// Read implements the Reader interface filling the passed buffer with +// random data, returning the number of bytes filled and an error +// if one occurs. Return 0 and EOF if the randReader has been closed. +func (r *randReader) Read(b []byte) (int, error) { + r.lock.Lock() + defer r.lock.Unlock() + if r.closed { + return 0, io.EOF + } + n, err := rand.Read(b) + c := bytes.Clone(b) + r.randBytes = append(r.randBytes, c...) + return n, err +} + +// Close implements the Closer interface, setting the close field true. +// Further calls to Read() after Close() will return 0, EOF. Returns +// nil error. +func (r *randReader) Close() (err error) { + r.lock.Lock() + defer r.lock.Unlock() + r.closed = true + return nil +} + +// TestWebSocketClient_MultipleWriteChannels tests two streams (STDIN, TTY resize) writing to the +// websocket connection at the same time to exercise the connection write lock. +func TestWebSocketClient_MultipleWriteChannels(t *testing.T) { + // Create the fake terminal size queue and the actualTerminalSizes which + // will be received at the opposite websocket endpoint. + numSizeQueue := 10000 + sizeQueue := newTerminalSizeQueue(numSizeQueue) + actualTerminalSizes := make([]TerminalSize, 0, numSizeQueue) + // Create ReadCloser sending random data on STDIN stream over websocket connection. + stdinReader := randReader{randBytes: []byte{}, closed: false} + // Create fake WebSocket server, which will receive concurrently the STDIN stream as + // well as the resize stream (TerminalSizes). Store the TerminalSize data from the resize + // stream for subsequent validation. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var wg sync.WaitGroup + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // Create goroutine to loopback the STDIN stream onto the STDOUT stream. + wg.Add(1) + go func() { + _, err := io.Copy(conns.stdoutStream, conns.stdinStream) + if err != nil { + t.Errorf("error copying STDIN to STDOUT: %v", err) + } + wg.Done() + }() + // Read the terminal resize requests, storing them in actualTerminalSizes + for i := 0; i < numSizeQueue; i++ { + actualTerminalSize := <-conns.resizeChan + actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize) + } + stdinReader.Close() // Stops the random STDIN stream generation + wg.Wait() // Wait for all bytes copied from STDIN to STDOUT + })) + defer websocketServer.Close() + // Now create the WebSocket client (executor), and point it to the "websocketServer". + // Must add stdin, stdout, and TTY query param for the WebSocket client request. + websocketServer.URL = websocketServer.URL + "?" + "tty=true" + "&" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: &stdinReader, + Stdout: &stdout, + Tty: true, + TerminalSizeQueue: sizeQueue, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } + // Check the random data sent on STDIN was the same returned on STDOUT *and* + // that a minimum amount of random data was sent and received, ensuring concurrency. + stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Fatalf("error reading the stream: %v", err) + } + if len(stdoutBytes) == 0 { + t.Errorf("No STDOUT bytes processed before resize stream finished: %d", len(stdoutBytes)) + } + if !bytes.Equal(stdoutBytes, stdinReader.randBytes) { + t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(stdinReader.randBytes)) + } + // Validate the random TerminalSizes sent on the resize stream are the same + // as the actual TerminalSizes received at the websocket server. + if len(actualTerminalSizes) != numSizeQueue { + t.Errorf("expected received terminal size window (%d), got (%d)", + numSizeQueue, len(actualTerminalSizes)) + } + for i, actual := range actualTerminalSizes { + expected := sizeQueue.terminalSizes[i] + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected terminal resize window %v, got %v", expected, actual) + } + } +} + +// TestWebSocketClient_ProtocolVersions validates that remote command subprotocol versions V2-V4 +// (V5 is already tested elsewhere) can be negotiated. +func TestWebSocketClient_ProtocolVersions(t *testing.T) { + // Create a raw websocket server that accepts V2-V4 versions of + // the remote command subprotocol. + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + Subprotocols: []string{ + remotecommand.StreamProtocolV4Name, + remotecommand.StreamProtocolV3Name, + remotecommand.StreamProtocolV2Name, + }, + } + // Upgrade a raw websocket server connection. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + t.Fatalf("unable to upgrade to create websocket connection: %v", err) + } + defer conn.Close() + })) + defer websocketServer.Close() + + // Set up the websocket client with the STDOUT stream. + websocketServer.URL = websocketServer.URL + "?" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Iterate through previous remote command protocol versions, validating the + // requested protocol version is the one that is negotiated. + versions := []string{ + remotecommand.StreamProtocolV4Name, + remotecommand.StreamProtocolV3Name, + remotecommand.StreamProtocolV2Name, + } + for _, requestedVersion := range versions { + streamExec := exec.(*wsStreamExecutor) + streamExec.protocols = []string{requestedVersion} + var stdout bytes.Buffer + options := &StreamOptions{ + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case <-errorChan: + // Validate remote command protocol requestedVersion was negotiated. + streamExec := exec.(*wsStreamExecutor) + if requestedVersion != streamExec.negotiated { + t.Fatalf("expected protocol version (%s), got (%s)", requestedVersion, streamExec.negotiated) + } + } + } +} + +// TestWebSocketClient_BadHandshake tests that a "bad handshake" error occurs when +// the WebSocketExecutor attempts to upgrade the connection to a subprotocol version +// (V4) that is not supported by the websocket server (only supports V5). +func TestWebSocketClient_BadHandshake(t *testing.T) { + // Create fake WebSocket server (supports V5 subprotocol). + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + })) + defer websocketServer.Close() + + websocketServer.URL = websocketServer.URL + "?" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + streamExec := exec.(*wsStreamExecutor) + // Set the attempted subprotocol version to V4; websocket server only accepts V5. + streamExec.protocols = []string{remotecommand.StreamProtocolV4Name} + + var stdout bytes.Buffer + options := &StreamOptions{ + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- streamExec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Expecting unable to upgrade connection -- "bad handshake" error. + if err == nil { + t.Errorf("expected error but received none") + } + if !strings.Contains(err.Error(), "bad handshake") { + t.Errorf("expected bad handshake error, got (%s)", err) + } + } +} + +// TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a +// timeout by setting the ping period greater than the deadline. +func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { + // Create fake WebSocket server which blocks. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + // Block server; heartbeat timeout (or test timeout) will fire before this returns. + time.Sleep(1 * time.Second) + })) + defer websocketServer.Close() + // Create websocket client connecting to fake server. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + streamExec := exec.(*wsStreamExecutor) + // Ping period is greater than the ping deadline, forcing the timeout to fire. + pingPeriod := 20 * time.Millisecond + pingDeadline := 5 * time.Millisecond + streamExec.heartbeatPeriod = pingPeriod + streamExec.heartbeatDeadline = pingDeadline + // Send some random data to the websocket server through STDIN. + randomData := make([]byte, 128) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- streamExec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(pingPeriod * 5): + // Give up after about five ping attempts + t.Fatalf("expected heartbeat timeout, got none.") + case err := <-errorChan: + // Expecting heartbeat timeout error. + if err == nil { + t.Fatalf("expected error but received none") + } + if !strings.Contains(err.Error(), "i/o timeout") { + t.Errorf("expected heartbeat timeout error, got (%s)", err) + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } +} + +// TestWebSocketClient_TextMessageTypeError tests when the wrong message type is returned +// from the other websocket endpoint. Remote command protocols use "BinaryMessage", but +// this test hard-codes returning a "TextMessage". +func TestWebSocketClient_TextMessageTypeError(t *testing.T) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + Subprotocols: []string{remotecommand.StreamProtocolV5Name}, + } + // Upgrade a raw websocket server connection. Returns wrong message type "TextMessage". + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + t.Fatalf("unable to upgrade to create websocket connection: %v", err) + } + defer conn.Close() + msg := []byte("test message with wrong message type.") + stdOutMsg := append([]byte{remotecommand.StreamStdOut}, msg...) + // Wrong message type "TextMessage". + err = conn.WriteMessage(gwebsocket.TextMessage, stdOutMsg) + if err != nil { + t.Fatalf("error writing text message to websocket: %v", err) + } + + })) + defer websocketServer.Close() + + // Set up the websocket client with the STDOUT stream. + websocketServer.URL = websocketServer.URL + "?" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Expecting bad message type error. + if err == nil { + t.Fatalf("expected error but received none") + } + if !strings.Contains(err.Error(), "unexpected message type") { + t.Errorf("expected bad message type error, got (%s)", err) + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } +} + +// TestWebSocketClient_EmptyMessageHandled tests that the error of a completely empty message +// is handled correctly. If the message is completely empty, the initial read of the stream id +// should fail (followed by cleanup). +func TestWebSocketClient_EmptyMessageHandled(t *testing.T) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + Subprotocols: []string{remotecommand.StreamProtocolV5Name}, + } + // Upgrade a raw websocket server connection. Returns wrong message type "TextMessage". + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + t.Fatalf("unable to upgrade to create websocket connection: %v", err) + } + defer conn.Close() + // Send completely empty message, including missing initial stream id. + conn.WriteMessage(gwebsocket.BinaryMessage, []byte{}) //nolint:errcheck + })) + defer websocketServer.Close() + + // Set up the websocket client with the STDOUT stream. + websocketServer.URL = websocketServer.URL + "?" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + if err != nil { + t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) + } + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Expecting error reading initial stream id. + if err == nil { + t.Fatalf("expected error but received none") + } + if !strings.Contains(err.Error(), "read stream id") { + t.Errorf("expected error reading stream id, got (%s)", err) + } + // Validate remote command v5 protocol was negotiated. + streamExec := exec.(*wsStreamExecutor) + if remotecommand.StreamProtocolV5Name != streamExec.negotiated { + t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated) + } + } +} + +func TestWebSocketClient_ExecutorErrors(t *testing.T) { + // Invalid config causes transport creation error in websocket executor constructor. + config := rest.Config{ + ExecProvider: &clientcmdapi.ExecConfig{}, + AuthProvider: &clientcmdapi.AuthProviderConfig{}, + } + _, err := NewWebSocketExecutor(&config, "POST", "http://localhost") + if err == nil { + t.Errorf("expecting executor constructor error, but received none.") + } else if !strings.Contains(err.Error(), "error creating websocket transports") { + t.Errorf("expecting error creating transports, got (%s)", err.Error()) + } + // Verify that a nil context will cause an error in StreamWithContext + exec, err := NewWebSocketExecutor(&rest.Config{}, "POST", "http://localhost") + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + var ctx context.Context + errorChan <- exec.StreamWithContext(ctx, StreamOptions{}) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Expecting error with nil context. + if err == nil { + t.Fatalf("expected error but received none") + } + if !strings.Contains(err.Error(), "nil Context") { + t.Errorf("expected nil context error, got (%s)", err) + } + } +} + +func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { + var upgrader = gwebsocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + } + // Upgrade a raw websocket server connection, which automatically responds to Ping. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + t.Fatalf("unable to upgrade to create websocket connection: %v", err) + } + defer conn.Close() + conn.ReadMessage() //nolint:errcheck + })) + defer websocketServer.Close() + // Create a raw websocket client, connecting to the websocket server. + url := strings.ReplaceAll(websocketServer.URL, "http", "ws") + client, _, err := gwebsocket.DefaultDialer.Dial(url, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer client.Close() + // Create a heartbeat using the client websocket connection, and start it. + // "period" is less than "deadline", so ping/pong heartbeat will succceed. + var expectedMsg = "test heartbeat message" + var period = 10 * time.Millisecond + var deadline = 20 * time.Millisecond + heartbeat := newHeartbeat(client, period, deadline) + heartbeat.setMessage(expectedMsg) + // Add a channel to the handler to retrieve the "pong" message. + pongMsgCh := make(chan string) + pongHandler := heartbeat.conn.PongHandler() + heartbeat.conn.SetPongHandler(func(msg string) error { + pongMsgCh <- msg + return pongHandler(msg) + }) + go heartbeat.start() + go client.ReadMessage() //nolint:errcheck + select { + case actualMsg := <-pongMsgCh: + close(heartbeat.closer) + // Validate the received pong message is the same as sent in ping. + if expectedMsg != actualMsg { + t.Errorf("expected received pong message (%s), got (%s)", expectedMsg, actualMsg) + } + case <-time.After(period * 4): + // This case should not happen. + close(heartbeat.closer) + t.Errorf("unexpected heartbeat timeout") + } +} + +func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { + // Validate Stream functions. + c := newWSStreamCreator(nil) + headers := http.Header{} + headers.Set(v1.StreamType, v1.StreamTypeStdin) + s, err := c.CreateStream(headers) + if err != nil { + t.Errorf("unexpected stream creation error: %v", err) + } + expectedStreamID := uint32(remotecommand.StreamStdIn) + actualStreamID := s.Identifier() + if expectedStreamID != actualStreamID { + t.Errorf("expecting stream id (%d), got (%d)", expectedStreamID, actualStreamID) + } + actualHeaders := s.Headers() + if !reflect.DeepEqual(headers, actualHeaders) { + t.Errorf("expecting stream headers (%v), got (%v)", headers, actualHeaders) + } + // Validate stream reset does not return error. + err = s.Reset() + if err != nil { + t.Errorf("unexpected error in stream reset: %v", err) + } + // Validate close with nil connection is an error. + err = s.Close() + if err == nil { + t.Errorf("expecting stream Close error, but received none") + } + if !strings.Contains(err.Error(), "Close() on already closed stream") { + t.Errorf("expected stream close error, got (%s)", err) + } + // Validate write with nil connection is an error. + n, err := s.Write([]byte("not written")) + if n != 0 { + t.Errorf("expected zero bytes written, wrote (%d) instead", n) + } + if err == nil { + t.Errorf("expecting stream Write error, but received none") + } + if !strings.Contains(err.Error(), "write on closed stream") { + t.Errorf("expected stream write error, got (%s)", err) + } + // Validate CreateStream errors -- unknown stream + headers = http.Header{} + headers.Set(v1.StreamType, "UNKNOWN") + _, err = c.CreateStream(headers) + if err == nil { + t.Errorf("expecting CreateStream error, but received none") + } else if !strings.Contains(err.Error(), "unknown stream type") { + t.Errorf("expecting unknown stream type error, got (%s)", err.Error()) + } + // Validate CreateStream errors -- duplicate stream + headers.Set(v1.StreamType, v1.StreamTypeError) + c.streams[remotecommand.StreamErr] = &stream{} + _, err = c.CreateStream(headers) + if err == nil { + t.Errorf("expecting CreateStream error, but received none") + } else if !strings.Contains(err.Error(), "duplicate stream") { + t.Errorf("expecting duplicate stream error, got (%s)", err.Error()) + } +} + +// options contains details about which streams are required for +// remote command execution. +type options struct { + stdin bool + stdout bool + stderr bool + tty bool +} + +// Translates query params in request into options struct. +func streamOptionsFromRequest(req *http.Request) *options { + query := req.URL.Query() + tty := query.Get("tty") == "true" + stdin := query.Get("stdin") == "true" + stdout := query.Get("stdout") == "true" + stderr := query.Get("stderr") == "true" + return &options{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } +} + +// websocketStreams contains the WebSocket connection and streams from a server. +type websocketStreams struct { + conn io.Closer + stdinStream io.ReadCloser + stdoutStream io.WriteCloser + stderrStream io.WriteCloser + writeStatus func(status *apierrors.StatusError) error + resizeStream io.ReadCloser + resizeChan chan TerminalSize + tty bool +} + +// Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed +// in the stream options. +func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) { + conn, err := createWebSocketStreams(req, w, opts) + if err != nil { + return nil, err + } + + if conn.resizeStream != nil { + conn.resizeChan = make(chan TerminalSize) + go handleResizeEvents(req.Context(), conn.resizeStream, conn.resizeChan) + } + + return conn, nil +} + +// Read terminal resize events off of passed stream and queue into passed channel. +func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- TerminalSize) { + defer close(channel) + + decoder := json.NewDecoder(stream) + for { + size := TerminalSize{} + if err := decoder.Decode(&size); err != nil { + break + } + + select { + case channel <- size: + case <-ctx.Done(): + // To avoid leaking this routine, exit if the http request finishes. This path + // would generally be hit if starting the process fails and nothing is started to + // ingest these resize events. + return + } + } +} + +// createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2) +// along with the approximate duplex value. It also creates the error (3) and resize (4) channels. +func createChannels(opts *options) []wsstream.ChannelType { + // open the requested channels, and always open the error channel + channels := make([]wsstream.ChannelType, 5) + channels[remotecommand.StreamStdIn] = readChannel(opts.stdin) + channels[remotecommand.StreamStdOut] = writeChannel(opts.stdout) + channels[remotecommand.StreamStdErr] = writeChannel(opts.stderr) + channels[remotecommand.StreamErr] = wsstream.WriteChannel + channels[remotecommand.StreamResize] = wsstream.ReadChannel + return channels +} + +// readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel. +func readChannel(real bool) wsstream.ChannelType { + if real { + return wsstream.ReadChannel + } + return wsstream.IgnoreChannel +} + +// writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel. +func writeChannel(real bool) wsstream.ChannelType { + if real { + return wsstream.WriteChannel + } + return wsstream.IgnoreChannel +} + +// createWebSocketStreams returns a "channels" struct containing the websocket connection and +// streams needed to perform an exec or an attach. +func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) { + channels := createChannels(opts) + conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ + remotecommand.StreamProtocolV5Name: { + Binary: true, + Channels: channels, + }, + }) + conn.SetIdleTimeout(4 * time.Hour) + // Opening the connection responds to WebSocket client, negotiating + // the WebSocket upgrade connection and the subprotocol. + _, streams, err := conn.Open(w, req) + if err != nil { + return nil, err + } + + // Send an empty message to the lowest writable channel to notify the client the connection is established + //nolint:errcheck + switch { + case opts.stdout: + streams[remotecommand.StreamStdOut].Write([]byte{}) + case opts.stderr: + streams[remotecommand.StreamStdErr].Write([]byte{}) + default: + streams[remotecommand.StreamErr].Write([]byte{}) + } + + wsStreams := &websocketStreams{ + conn: conn, + stdinStream: streams[remotecommand.StreamStdIn], + stdoutStream: streams[remotecommand.StreamStdOut], + stderrStream: streams[remotecommand.StreamStdErr], + tty: opts.tty, + resizeStream: streams[remotecommand.StreamResize], + } + + wsStreams.writeStatus = v4WriteStatusFunc(streams[remotecommand.StreamErr]) + + return wsStreams, nil +} diff --git a/transport/websocket/roundtripper.go b/transport/websocket/roundtripper.go new file mode 100644 index 00000000..e2a4a8ab --- /dev/null +++ b/transport/websocket/roundtripper.go @@ -0,0 +1,166 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package websocket + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + + gwebsocket "github.com/gorilla/websocket" + + "k8s.io/apimachinery/pkg/util/httpstream" + utilnet "k8s.io/apimachinery/pkg/util/net" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport" +) + +var ( + _ utilnet.TLSClientConfigHolder = &RoundTripper{} + _ http.RoundTripper = &RoundTripper{} +) + +// ConnectionHolder defines functions for structure providing +// access to the websocket connection. +type ConnectionHolder interface { + DataBufferSize() int + Connection() *gwebsocket.Conn +} + +// RoundTripper knows how to establish a connection to a remote WebSocket endpoint and make it available for use. +// RoundTripper must not be reused. +type RoundTripper struct { + // TLSConfig holds the TLS configuration settings to use when connecting + // to the remote server. + TLSConfig *tls.Config + + // Proxier specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxier func(req *http.Request) (*url.URL, error) + + // Conn holds the WebSocket connection after a round trip. + Conn *gwebsocket.Conn +} + +// Connection returns the stored websocket connection. +func (rt *RoundTripper) Connection() *gwebsocket.Conn { + return rt.Conn +} + +// DataBufferSize returns the size of buffers for the +// websocket connection. +func (rt *RoundTripper) DataBufferSize() int { + return 32 * 1024 +} + +// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder. +func (rt *RoundTripper) TLSClientConfig() *tls.Config { + return rt.TLSConfig +} + +// RoundTrip connects to the remote websocket using the headers in the request and the TLS +// configuration from the config +func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response, retErr error) { + defer func() { + if request.Body != nil { + err := request.Body.Close() + if retErr == nil { + retErr = err + } + } + }() + + // set the protocol version directly on the dialer from the header + protocolVersions := request.Header[httpstream.HeaderProtocolVersion] + delete(request.Header, httpstream.HeaderProtocolVersion) + + dialer := gwebsocket.Dialer{ + Proxy: rt.Proxier, + TLSClientConfig: rt.TLSConfig, + Subprotocols: protocolVersions, + ReadBufferSize: rt.DataBufferSize() + 1024, // add space for the protocol byte indicating which channel the data is for + WriteBufferSize: rt.DataBufferSize() + 1024, // add space for the protocol byte indicating which channel the data is for + } + switch request.URL.Scheme { + case "https": + request.URL.Scheme = "wss" + case "http": + request.URL.Scheme = "ws" + default: + return nil, fmt.Errorf("unknown url scheme: %s", request.URL.Scheme) + } + wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) + if err != nil { + if err != gwebsocket.ErrBadHandshake { + return nil, err + } + return nil, fmt.Errorf("unable to upgrade connection: %v", err) + } + + rt.Conn = wsConn + + return resp, nil +} + +// RoundTripperFor transforms the passed rest config into a wrapped roundtripper, as well +// as a pointer to the websocket RoundTripper. The websocket RoundTripper contains the +// websocket connection after RoundTrip() on the wrapper. Returns an error if there is +// a problem creating the round trippers. +func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHolder, error) { + transportCfg, err := config.TransportConfig() + if err != nil { + return nil, nil, err + } + tlsConfig, err := transport.TLSConfigFor(transportCfg) + if err != nil { + return nil, nil, err + } + proxy := config.Proxy + if proxy == nil { + proxy = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment) + } + + upgradeRoundTripper := &RoundTripper{ + TLSConfig: tlsConfig, + Proxier: proxy, + } + wrapper, err := transport.HTTPWrappersForConfig(transportCfg, upgradeRoundTripper) + if err != nil { + return nil, nil, err + } + return wrapper, upgradeRoundTripper, nil +} + +// Negotiate opens a connection to a remote server and attempts to negotiate +// a WebSocket connection. Upon success, it returns the negotiated connection. +// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor. +func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) { + req.Header[httpstream.HeaderProtocolVersion] = protocols + resp, err := rt.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("error sending request: %v", err) + } + err = resp.Body.Close() + if err != nil { + connectionInfo.Connection().Close() + return nil, fmt.Errorf("error closing response body: %v", err) + } + return connectionInfo.Connection(), nil +} diff --git a/transport/websocket/roundtripper_test.go b/transport/websocket/roundtripper_test.go new file mode 100644 index 00000000..168d5d55 --- /dev/null +++ b/transport/websocket/roundtripper_test.go @@ -0,0 +1,140 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package websocket + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" + "k8s.io/apimachinery/pkg/util/remotecommand" + restclient "k8s.io/client-go/rest" +) + +func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { + // Create fake WebSocket server. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + })) + defer websocketServer.Close() + + // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + require.NoError(t, err) + rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) + require.NoError(t, err) + requestedProtocol := remotecommand.StreamProtocolV5Name + req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol} + _, err = rt.RoundTrip(req) + require.NoError(t, err) + // WebSocket Connection is stored in websocket RoundTripper. + // Compare the expected negotiated subprotocol with the actual subprotocol. + actualProtocol := wsRt.Connection().Subprotocol() + assert.Equal(t, requestedProtocol, actualProtocol) + +} + +func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) { + // Create fake WebSocket server. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + })) + defer websocketServer.Close() + + // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + require.NoError(t, err) + rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) + require.NoError(t, err) + // Requested subprotocol version 1 is not supported by test websocket server. + requestedProtocol := remotecommand.StreamProtocolV1Name + req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol} + _, err = rt.RoundTrip(req) + // Ensure a "bad handshake" error is returned, since requested protocol is not supported. + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "bad handshake")) +} + +func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) { + // Create fake WebSocket server. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w) + if err != nil { + t.Fatalf("error on webSocketServerStreams: %v", err) + } + defer conns.conn.Close() + })) + defer websocketServer.Close() + + // Create the websocket roundtripper and call "Negotiate" to create websocket connection. + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + require.NoError(t, err) + rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) + require.NoError(t, err) + requestedProtocol := remotecommand.StreamProtocolV5Name + conn, err := Negotiate(rt, wsRt, req, requestedProtocol) + require.NoError(t, err) + // Compare the expected negotiated subprotocol with the actual subprotocol. + actualProtocol := conn.Subprotocol() + assert.Equal(t, requestedProtocol, actualProtocol) +} + +// websocketStreams contains the WebSocket connection and streams from a server. +type websocketStreams struct { + conn io.Closer +} + +func webSocketServerStreams(req *http.Request, w http.ResponseWriter) (*websocketStreams, error) { + conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ + remotecommand.StreamProtocolV5Name: { + Binary: true, + Channels: []wsstream.ChannelType{}, + }, + }) + conn.SetIdleTimeout(4 * time.Hour) + // Opening the connection responds to WebSocket client, negotiating + // the WebSocket upgrade connection and the subprotocol. + _, _, err := conn.Open(w, req) + if err != nil { + return nil, err + } + return &websocketStreams{conn: conn}, nil +}