diff --git a/pkg/client/tests/remotecommand_test.go b/pkg/client/tests/remotecommand_test.go index 910cfce38f8..0e15b2fa9a9 100644 --- a/pkg/client/tests/remotecommand_test.go +++ b/pkg/client/tests/remotecommand_test.go @@ -18,6 +18,7 @@ package tests import ( "bytes" + "context" "errors" "fmt" "io" @@ -252,7 +253,7 @@ func TestStream(t *testing.T) { if err != nil { t.Fatalf("%s: unexpected error: %v", name, err) } - err = e.Stream(remoteclient.StreamOptions{ + err = e.StreamWithContext(context.Background(), remoteclient.StreamOptions{ Stdin: streamIn, Stdout: streamOut, Stderr: streamErr, diff --git a/pkg/kubelet/cri/streaming/server_test.go b/pkg/kubelet/cri/streaming/server_test.go index 156c45d8626..9b88c3fdf07 100644 --- a/pkg/kubelet/cri/streaming/server_test.go +++ b/pkg/kubelet/cri/streaming/server_test.go @@ -17,6 +17,7 @@ limitations under the License. package streaming import ( + "context" "crypto/tls" "io" "net/http" @@ -355,7 +356,7 @@ func runRemoteCommandTest(t *testing.T, commandType string) { Stderr: stderrW, Tty: false, } - require.NoError(t, exec.Stream(opts)) + require.NoError(t, exec.StreamWithContext(context.Background(), opts)) }() go func() { diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand.go b/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand.go index cb39faf7f1a..662a3cb4ac7 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand.go @@ -17,6 +17,7 @@ limitations under the License. package remotecommand import ( + "context" "fmt" "io" "net/http" @@ -27,7 +28,7 @@ import ( "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/remotecommand" restclient "k8s.io/client-go/rest" - spdy "k8s.io/client-go/transport/spdy" + "k8s.io/client-go/transport/spdy" ) // StreamOptions holds information pertaining to the current streaming session: @@ -43,11 +44,16 @@ type StreamOptions struct { // Executor is an interface for transporting shell-style streams. type Executor interface { - // Stream initiates the transport of the standard shell streams. It will transport any - // non-nil stream to a remote system, and return an error if a problem occurs. If tty - // is set, the stderr stream is not used (raw TTY manages stdout and stderr over the - // stdout stream). + // Deprecated: use StreamWithContext instead to avoid possible resource leaks. + // See https://github.com/kubernetes/kubernetes/pull/103177 for details. Stream(options StreamOptions) error + + // StreamWithContext initiates the transport of the standard shell streams. It will + // transport any non-nil stream to a remote system, and return an error if a problem + // occurs. If tty is set, the stderr stream is not used (raw TTY manages stdout and + // stderr over the stdout stream). + // The context controls the entire lifetime of stream execution. + StreamWithContext(ctx context.Context, options StreamOptions) error } type streamCreator interface { @@ -106,9 +112,14 @@ func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgr // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(options StreamOptions) error { - req, err := http.NewRequest(e.method, e.url.String(), nil) + return e.StreamWithContext(context.Background(), options) +} + +// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it. +func (e *streamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) { + req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil) if err != nil { - return fmt.Errorf("error creating request: %v", err) + return nil, nil, fmt.Errorf("error creating request: %v", err) } conn, protocol, err := spdy.Negotiate( @@ -118,9 +129,8 @@ func (e *streamExecutor) Stream(options StreamOptions) error { e.protocols..., ) if err != nil { - return err + return nil, nil, err } - defer conn.Close() var streamer streamProtocolHandler @@ -138,5 +148,35 @@ func (e *streamExecutor) Stream(options StreamOptions) error { streamer = newStreamProtocolV1(options) } - return streamer.stream(conn) + return conn, streamer, nil +} + +// StreamWithContext opens a protocol streamer to the server and streams until a client closes +// the connection or the server disconnects or the context is done. +func (e *streamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + conn, streamer, err := e.newConnectionAndStream(ctx, options) + if err != nil { + return err + } + defer conn.Close() + + panicChan := make(chan any, 1) + errorChan := make(chan error, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + errorChan <- streamer.stream(conn) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-errorChan: + return err + case <-ctx.Done(): + return ctx.Err() + } } diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go b/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go index 7eec4565ed1..9144a14526c 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/remotecommand_test.go @@ -17,9 +17,17 @@ limitations under the License. package remotecommand import ( + "context" "encoding/json" "errors" "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -28,12 +36,6 @@ import ( remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/rest" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" ) type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error @@ -50,6 +52,17 @@ type streamAndReply struct { replySent <-chan struct{} } +type fakeEmptyDataPty struct { +} + +func (s *fakeEmptyDataPty) Read(p []byte) (int, error) { + return len(p), nil +} + +func (s *fakeEmptyDataPty) Write(p []byte) (int, error) { + return len(p), nil +} + type fakeMassiveDataPty struct{} func (s *fakeMassiveDataPty) Read(p []byte) (int, error) { @@ -107,6 +120,7 @@ func writeMassiveData(stdStream io.Writer) struct{} { // write to stdin or stdou func TestSPDYExecutorStream(t *testing.T) { tests := []struct { + timeout time.Duration name string options StreamOptions expectError string @@ -130,23 +144,40 @@ func TestSPDYExecutorStream(t *testing.T) { expectError: "", attacher: fakeMassiveDataAttacher, }, + { + timeout: 500 * time.Millisecond, + name: "timeoutTest", + options: StreamOptions{ + Stdin: &fakeMassiveDataPty{}, + Stderr: &fakeMassiveDataPty{}, + }, + expectError: context.DeadlineExceeded.Error(), + attacher: fakeMassiveDataAttacher, + }, } for _, test := range tests { - server := newTestHTTPServer(test.attacher, &test.options) + t.Run(test.name, func(t *testing.T) { + server := newTestHTTPServer(test.attacher, &test.options) + defer server.Close() - err := attach2Server(server.URL, test.options) - gotError := "" - if err != nil { - gotError = err.Error() - } - if test.expectError != gotError { - t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError) - } + ctx, cancel := context.Background(), func() {} + if test.timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, test.timeout) + } + defer cancel() - server.Close() + err := attach2Server(ctx, server.URL, test.options) + + gotError := "" + if err != nil { + gotError = err.Error() + } + if test.expectError != gotError { + t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError) + } + }) } - } func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { @@ -170,16 +201,16 @@ func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { return server } -func attach2Server(rawURL string, options StreamOptions) error { +func attach2Server(ctx context.Context, rawURL string, options StreamOptions) error { uri, _ := url.Parse(rawURL) exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri) if err != nil { return err } - e := make(chan error) + e := make(chan error, 1) go func(e chan error) { - e <- exec.Stream(options) + e <- exec.StreamWithContext(ctx, options) }(e) select { case err := <-e: @@ -263,3 +294,74 @@ func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) err return err } } + +// writeDetector provides a helper method to block until the underlying writer written. +type writeDetector struct { + written chan bool + closed bool + io.Writer +} + +func newWriterDetector(w io.Writer) *writeDetector { + return &writeDetector{ + written: make(chan bool), + Writer: w, + } +} + +func (w *writeDetector) BlockUntilWritten() { + <-w.written +} + +func (w *writeDetector) Write(p []byte) (n int, err error) { + if !w.closed { + close(w.written) + w.closed = true + } + return w.Writer.Write(p) +} + +// `Executor.StreamWithContext` starts a goroutine in the background to do the streaming +// and expects the deferred close of the connection leads to the exit of the goroutine on cancellation. +// This test verifies that works. +func TestStreamExitsAfterConnectionIsClosed(t *testing.T) { + writeDetector := newWriterDetector(&fakeEmptyDataPty{}) + options := StreamOptions{ + Stdin: &fakeEmptyDataPty{}, + Stdout: writeDetector, + } + server := newTestHTTPServer(fakeMassiveDataAttacher, &options) + + ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancelFn() + + uri, _ := url.Parse(server.URL) + exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri) + if err != nil { + t.Fatal(err) + } + streamExec := exec.(*streamExecutor) + + conn, streamer, err := streamExec.newConnectionAndStream(ctx, options) + if err != nil { + t.Fatal(err) + } + + errorChan := make(chan error) + go func() { + errorChan <- streamer.stream(conn) + }() + + // Wait until stream goroutine starts. + writeDetector.BlockUntilWritten() + + // Close the connection + conn.Close() + + select { + case <-time.After(1 * time.Second): + t.Fatalf("expect stream to be closed after connection is closed.") + case <-errorChan: + return + } +} diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go b/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go index 668524ab659..e403b5208f5 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go @@ -17,6 +17,7 @@ limitations under the License. package attach import ( + "context" "fmt" "io" "net/url" @@ -159,7 +160,7 @@ func (*DefaultRemoteAttach) Attach(method string, url *url.URL, config *restclie if err != nil { return err } - return exec.Stream(remotecommand.StreamOptions{ + return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{ Stdin: stdin, Stdout: stdout, Stderr: stderr, diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go b/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go index f3be195db98..ff7a3750d76 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go @@ -122,7 +122,7 @@ func (*DefaultRemoteExecutor) Execute(method string, url *url.URL, config *restc if err != nil { return err } - return exec.Stream(remotecommand.StreamOptions{ + return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{ Stdin: stdin, Stdout: stdout, Stderr: stderr, diff --git a/test/e2e/framework/pod/exec_util.go b/test/e2e/framework/pod/exec_util.go index 1e3438c2b6d..a88aee2d7dd 100644 --- a/test/e2e/framework/pod/exec_util.go +++ b/test/e2e/framework/pod/exec_util.go @@ -143,7 +143,7 @@ func execute(method string, url *url.URL, config *restclient.Config, stdin io.Re if err != nil { return err } - return exec.Stream(remotecommand.StreamOptions{ + return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{ Stdin: stdin, Stdout: stdout, Stderr: stderr,