Merge pull request #103177 from arkbriar/support_cancelable_exec_stream

Support cancelable SPDY executor stream

Kubernetes-commit: 3cf75a2f760b8093f7c97f26b4b2b059f3777bec
This commit is contained in:
Kubernetes Publisher 2022-11-02 19:47:36 -07:00
commit bc6266d159
2 changed files with 172 additions and 30 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
package remotecommand package remotecommand
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -27,7 +28,7 @@ import (
"k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/remotecommand" "k8s.io/apimachinery/pkg/util/remotecommand"
restclient "k8s.io/client-go/rest" restclient "k8s.io/client-go/rest"
spdy "k8s.io/client-go/transport/spdy" "k8s.io/client-go/transport/spdy"
) )
// StreamOptions holds information pertaining to the current streaming session: // StreamOptions holds information pertaining to the current streaming session:
@ -43,11 +44,16 @@ type StreamOptions struct {
// Executor is an interface for transporting shell-style streams. // Executor is an interface for transporting shell-style streams.
type Executor interface { type Executor interface {
// Stream initiates the transport of the standard shell streams. It will transport any // Deprecated: use StreamWithContext instead to avoid possible resource leaks.
// non-nil stream to a remote system, and return an error if a problem occurs. If tty // See https://github.com/kubernetes/kubernetes/pull/103177 for details.
// is set, the stderr stream is not used (raw TTY manages stdout and stderr over the
// stdout stream).
Stream(options StreamOptions) error Stream(options StreamOptions) error
// StreamWithContext initiates the transport of the standard shell streams. It will
// transport any non-nil stream to a remote system, and return an error if a problem
// occurs. If tty is set, the stderr stream is not used (raw TTY manages stdout and
// stderr over the stdout stream).
// The context controls the entire lifetime of stream execution.
StreamWithContext(ctx context.Context, options StreamOptions) error
} }
type streamCreator interface { type streamCreator interface {
@ -106,9 +112,14 @@ func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgr
// Stream opens a protocol streamer to the server and streams until a client closes // Stream opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects. // the connection or the server disconnects.
func (e *streamExecutor) Stream(options StreamOptions) error { func (e *streamExecutor) Stream(options StreamOptions) error {
req, err := http.NewRequest(e.method, e.url.String(), nil) return e.StreamWithContext(context.Background(), options)
}
// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it.
func (e *streamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) {
req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil)
if err != nil { if err != nil {
return fmt.Errorf("error creating request: %v", err) return nil, nil, fmt.Errorf("error creating request: %v", err)
} }
conn, protocol, err := spdy.Negotiate( conn, protocol, err := spdy.Negotiate(
@ -118,9 +129,8 @@ func (e *streamExecutor) Stream(options StreamOptions) error {
e.protocols..., e.protocols...,
) )
if err != nil { if err != nil {
return err return nil, nil, err
} }
defer conn.Close()
var streamer streamProtocolHandler var streamer streamProtocolHandler
@ -138,5 +148,35 @@ func (e *streamExecutor) Stream(options StreamOptions) error {
streamer = newStreamProtocolV1(options) streamer = newStreamProtocolV1(options)
} }
return streamer.stream(conn) return conn, streamer, nil
}
// StreamWithContext opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects or the context is done.
func (e *streamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error {
conn, streamer, err := e.newConnectionAndStream(ctx, options)
if err != nil {
return err
}
defer conn.Close()
panicChan := make(chan any, 1)
errorChan := make(chan error, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
errorChan <- streamer.stream(conn)
}()
select {
case p := <-panicChan:
panic(p)
case err := <-errorChan:
return err
case <-ctx.Done():
return ctx.Err()
}
} }

View File

@ -17,9 +17,17 @@ limitations under the License.
package remotecommand package remotecommand
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -28,12 +36,6 @@ import (
remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand" remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest" "k8s.io/client-go/rest"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
) )
type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error
@ -50,6 +52,17 @@ type streamAndReply struct {
replySent <-chan struct{} replySent <-chan struct{}
} }
type fakeEmptyDataPty struct {
}
func (s *fakeEmptyDataPty) Read(p []byte) (int, error) {
return len(p), nil
}
func (s *fakeEmptyDataPty) Write(p []byte) (int, error) {
return len(p), nil
}
type fakeMassiveDataPty struct{} type fakeMassiveDataPty struct{}
func (s *fakeMassiveDataPty) Read(p []byte) (int, error) { func (s *fakeMassiveDataPty) Read(p []byte) (int, error) {
@ -107,6 +120,7 @@ func writeMassiveData(stdStream io.Writer) struct{} { // write to stdin or stdou
func TestSPDYExecutorStream(t *testing.T) { func TestSPDYExecutorStream(t *testing.T) {
tests := []struct { tests := []struct {
timeout time.Duration
name string name string
options StreamOptions options StreamOptions
expectError string expectError string
@ -130,23 +144,40 @@ func TestSPDYExecutorStream(t *testing.T) {
expectError: "", expectError: "",
attacher: fakeMassiveDataAttacher, attacher: fakeMassiveDataAttacher,
}, },
{
timeout: 500 * time.Millisecond,
name: "timeoutTest",
options: StreamOptions{
Stdin: &fakeMassiveDataPty{},
Stderr: &fakeMassiveDataPty{},
},
expectError: context.DeadlineExceeded.Error(),
attacher: fakeMassiveDataAttacher,
},
} }
for _, test := range tests { for _, test := range tests {
server := newTestHTTPServer(test.attacher, &test.options) t.Run(test.name, func(t *testing.T) {
server := newTestHTTPServer(test.attacher, &test.options)
defer server.Close()
err := attach2Server(server.URL, test.options) ctx, cancel := context.Background(), func() {}
gotError := "" if test.timeout > 0 {
if err != nil { ctx, cancel = context.WithTimeout(ctx, test.timeout)
gotError = err.Error() }
} defer cancel()
if test.expectError != gotError {
t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError)
}
server.Close() err := attach2Server(ctx, server.URL, test.options)
gotError := ""
if err != nil {
gotError = err.Error()
}
if test.expectError != gotError {
t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError)
}
})
} }
} }
func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
@ -170,16 +201,16 @@ func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
return server return server
} }
func attach2Server(rawURL string, options StreamOptions) error { func attach2Server(ctx context.Context, rawURL string, options StreamOptions) error {
uri, _ := url.Parse(rawURL) uri, _ := url.Parse(rawURL)
exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri) exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
if err != nil { if err != nil {
return err return err
} }
e := make(chan error) e := make(chan error, 1)
go func(e chan error) { go func(e chan error) {
e <- exec.Stream(options) e <- exec.StreamWithContext(ctx, options)
}(e) }(e)
select { select {
case err := <-e: case err := <-e:
@ -263,3 +294,74 @@ func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) err
return err return err
} }
} }
// writeDetector provides a helper method to block until the underlying writer written.
type writeDetector struct {
written chan bool
closed bool
io.Writer
}
func newWriterDetector(w io.Writer) *writeDetector {
return &writeDetector{
written: make(chan bool),
Writer: w,
}
}
func (w *writeDetector) BlockUntilWritten() {
<-w.written
}
func (w *writeDetector) Write(p []byte) (n int, err error) {
if !w.closed {
close(w.written)
w.closed = true
}
return w.Writer.Write(p)
}
// `Executor.StreamWithContext` starts a goroutine in the background to do the streaming
// and expects the deferred close of the connection leads to the exit of the goroutine on cancellation.
// This test verifies that works.
func TestStreamExitsAfterConnectionIsClosed(t *testing.T) {
writeDetector := newWriterDetector(&fakeEmptyDataPty{})
options := StreamOptions{
Stdin: &fakeEmptyDataPty{},
Stdout: writeDetector,
}
server := newTestHTTPServer(fakeMassiveDataAttacher, &options)
ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancelFn()
uri, _ := url.Parse(server.URL)
exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
if err != nil {
t.Fatal(err)
}
streamExec := exec.(*streamExecutor)
conn, streamer, err := streamExec.newConnectionAndStream(ctx, options)
if err != nil {
t.Fatal(err)
}
errorChan := make(chan error)
go func() {
errorChan <- streamer.stream(conn)
}()
// Wait until stream goroutine starts.
writeDetector.BlockUntilWritten()
// Close the connection
conn.Close()
select {
case <-time.After(1 * time.Second):
t.Fatalf("expect stream to be closed after connection is closed.")
case <-errorChan:
return
}
}