diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/BUILD b/staging/src/k8s.io/client-go/tools/remotecommand/BUILD index b3eb433c01d..7a3c20f4365 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/BUILD +++ b/staging/src/k8s.io/client-go/tools/remotecommand/BUILD @@ -9,14 +9,20 @@ load( go_test( name = "go_default_test", srcs = [ + "remotecommand_test.go", "v2_test.go", "v4_test.go", ], embed = [":go_default_library"], deps = [ "//staging/src/k8s.io/api/core/v1:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/api/errors:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/httpstream:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/remotecommand:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/wait:go_default_library", + "//staging/src/k8s.io/client-go/rest:go_default_library", ], ) diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go b/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go new file mode 100644 index 00000000000..77e309d6199 --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "encoding/json" + "errors" + "io" + "io/ioutil" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error +type streamContext struct { + conn io.Closer + stdinStream io.ReadCloser + stdoutStream io.WriteCloser + stderrStream io.WriteCloser + writeStatus func(status *apierrors.StatusError) error +} + +type streamAndReply struct { + httpstream.Stream + replySent <-chan struct{} +} + +type fakeMassiveDataPty struct{} + +func (s *fakeMassiveDataPty) Read(p []byte) (int, error) { + time.Sleep(time.Duration(1) * time.Second) + return copy(p, []byte{}), errors.New("client crashed after 1 second") +} + +func (s *fakeMassiveDataPty) Write(p []byte) (int, error) { + time.Sleep(time.Duration(1) * time.Second) + return len(p), errors.New("return err") +} + +func fakeMassiveDataAttacher(stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan TerminalSize) error { + + copyDone := make(chan struct{}, 3) + + if stdin == nil { + return errors.New("stdin is requested") // we need stdin to notice the conn break + } + + go func() { + io.Copy(ioutil.Discard, stdin) + copyDone <- struct{}{} + }() + + go func() { + if stdout == nil { + return + } + copyDone <- writeMassiveData(stdout) + }() + + go func() { + if stderr == nil { + return + } + copyDone <- writeMassiveData(stderr) + }() + + select { + case <-copyDone: + return nil + } +} + +func writeMassiveData(stdStream io.Writer) struct{} { // write to stdin or stdout + for { + _, err := io.Copy(stdStream, strings.NewReader("something")) + if err != nil && err.Error() != "EOF" { + break + } + } + return struct{}{} +} + +func TestSPDYExecutorStream(t *testing.T) { + tests := []struct { + name string + options StreamOptions + expectError string + attacher AttachFunc + }{ + { + name: "stdoutBlockTest", + options: StreamOptions{ + Stdin: &fakeMassiveDataPty{}, + Stdout: &fakeMassiveDataPty{}, + }, + expectError: "", + attacher: fakeMassiveDataAttacher, + }, + { + name: "stderrBlockTest", + options: StreamOptions{ + Stdin: &fakeMassiveDataPty{}, + Stderr: &fakeMassiveDataPty{}, + }, + expectError: "", + attacher: fakeMassiveDataAttacher, + }, + } + + for _, test := range tests { + server := newTestHTTPServer(test.attacher, &test.options) + + err := attach2Server(server.URL, test.options) + gotError := "" + if err != nil { + gotError = err.Error() + } + if test.expectError != gotError { + t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError) + } + + server.Close() + } + +} + +func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + ctx, err := createHTTPStreams(writer, request, options) + if err != nil { + return + } + defer ctx.conn.Close() + + // handle input output + err = f(ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, false, nil) + if err != nil { + ctx.writeStatus(apierrors.NewInternalError(err)) + } else { + ctx.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{ + Status: metav1.StatusSuccess, + }}) + } + })) + return server +} + +func attach2Server(rawURL string, options StreamOptions) error { + uri, _ := url.Parse(rawURL) + exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri) + if err != nil { + return err + } + + e := make(chan error) + go func(e chan error) { + e <- exec.Stream(options) + }(e) + select { + case err := <-e: + return err + case <-time.After(wait.ForeverTestTimeout): + return errors.New("execute timeout") + } +} + +// simplify createHttpStreams , only support StreamProtocolV4Name +func createHTTPStreams(w http.ResponseWriter, req *http.Request, opts *StreamOptions) (*streamContext, error) { + _, err := httpstream.Handshake(req, w, []string{remotecommandconsts.StreamProtocolV4Name}) + if err != nil { + return nil, err + } + + upgrader := spdy.NewResponseUpgrader() + streamCh := make(chan streamAndReply) + conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- streamAndReply{Stream: stream, replySent: replySent} + return nil + }) + ctx := &streamContext{ + conn: conn, + } + + // wait for stream + replyChan := make(chan struct{}, 1) + receivedStreams := 0 + expectedStreams := 1 + if opts.Stdout != nil { + expectedStreams++ + } + if opts.Stdin != nil { + expectedStreams++ + } + if opts.Stderr != nil { + expectedStreams++ + } +WaitForStreams: + for { + select { + case stream := <-streamCh: + streamType := stream.Headers().Get(v1.StreamType) + switch streamType { + case v1.StreamTypeError: + replyChan <- struct{}{} + ctx.writeStatus = v4WriteStatusFunc(stream) + case v1.StreamTypeStdout: + replyChan <- struct{}{} + ctx.stdoutStream = stream + case v1.StreamTypeStdin: + replyChan <- struct{}{} + ctx.stdinStream = stream + case v1.StreamTypeStderr: + replyChan <- struct{}{} + ctx.stderrStream = stream + default: + // add other stream ... + return nil, errors.New("unimplemented stream type") + } + case <-replyChan: + receivedStreams++ + if receivedStreams == expectedStreams { + break WaitForStreams + } + } + } + + return ctx, nil +} + +func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error { + return func(status *apierrors.StatusError) error { + bs, err := json.Marshal(status.Status()) + if err != nil { + return err + } + _, err = stream.Write(bs) + return err + } +} diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/v2.go b/staging/src/k8s.io/client-go/tools/remotecommand/v2.go index 4b0001502a1..2f5561c942e 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/v2.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/v2.go @@ -142,6 +142,10 @@ func (p *streamProtocolV2) copyStdout(wg *sync.WaitGroup) { go func() { defer runtime.HandleCrash() defer wg.Done() + // make sure, packet in queue can be consumed. + // block in queue may lead to deadlock in conn.server + // issue: https://github.com/kubernetes/kubernetes/issues/96339 + defer io.Copy(ioutil.Discard, p.remoteStdout) if _, err := io.Copy(p.Stdout, p.remoteStdout); err != nil { runtime.HandleError(err) @@ -158,6 +162,7 @@ func (p *streamProtocolV2) copyStderr(wg *sync.WaitGroup) { go func() { defer runtime.HandleCrash() defer wg.Done() + defer io.Copy(ioutil.Discard, p.remoteStderr) if _, err := io.Copy(p.Stderr, p.remoteStderr); err != nil { runtime.HandleError(err)