Add streaming subprotocol negotiation

Add streaming subprotocol negotiation for exec, attach, and port
forwarding. Restore previous (buggy) exec functionality as an
unspecified/unversioned subprotocol so newer kubectl clients can work
against 1.0.x kubelets.
This commit is contained in:
Andy Goldstein 2015-10-20 08:21:07 -04:00
parent d3862d453f
commit 3d1cafc2c3
9 changed files with 243 additions and 43 deletions

View File

@ -122,13 +122,16 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P
}, nil
}
// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding.
const PortForwardProtocolV1Name = "portforward.k8s.io"
// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, err = pf.dialer.Dial()
pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name})
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}

View File

@ -38,14 +38,15 @@ import (
)
type fakeDialer struct {
dialed bool
conn httpstream.Connection
err error
dialed bool
conn httpstream.Connection
err error
negotiatedProtocol string
}
func (d *fakeDialer) Dial() (httpstream.Connection, error) {
func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) {
d.dialed = true
return d.conn, d.err
return d.conn, d.negotiatedProtocol, d.err
}
func TestParsePortsAndNew(t *testing.T) {

View File

@ -24,6 +24,8 @@ import (
"net/url"
"sync"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/api"
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/util"
@ -97,51 +99,207 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou
}, nil
}
// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection.
func (e *streamExecutor) Dial() (httpstream.Connection, error) {
client := &http.Client{Transport: e.transport}
// Dial opens a connection to a remote server and attempts to negotiate a SPDY
// connection. Upon success, it returns the connection and the protocol
// selected by the server.
func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) {
transport := e.transport
// TODO consider removing this and reusing client.TransportFor above to get this for free
switch {
case bool(glog.V(9)):
transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders)
case bool(glog.V(8)):
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders)
case bool(glog.V(7)):
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus)
case bool(glog.V(6)):
transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
}
client := &http.Client{Transport: transport}
req, err := http.NewRequest(e.method, e.url.String(), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %s", err)
return nil, "", fmt.Errorf("error creating request: %v", err)
}
for i := range protocols {
req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i])
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error sending request: %s", err)
return nil, "", fmt.Errorf("error sending request: %v", err)
}
defer resp.Body.Close()
// TODO: handle protocol selection in the future
return e.upgrader.NewConnection(resp)
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode))
}
conn, err := e.upgrader.NewConnection(resp)
if err != nil {
return nil, "", err
}
return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil
}
const (
// The SPDY subprotocol "channel.k8s.io" is used for remote command
// attachment/execution. This represents the initial unversioned subprotocol,
// which has the known bugs http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395.
StreamProtocolV1Name = "channel.k8s.io"
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
// attachment/execution. It is the second version of the subprotocol and
// resolves the issues present in the first version.
StreamProtocolV2Name = "v2.channel.k8s.io"
)
type streamProtocolHandler interface {
stream(httpstream.Connection) error
}
// Stream opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects.
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
conn, err := e.Dial()
conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name})
if err != nil {
return err
}
defer conn.Close()
// TODO: negotiate protocols
streamer := &streamProtocol{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
var streamer streamProtocolHandler
switch protocol {
case StreamProtocolV2Name:
streamer = &streamProtocolV2{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
}
case "":
glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned")
// TODO restore v1
streamer = &streamProtocolV1{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
}
}
return streamer.stream(conn)
}
type streamProtocol struct {
type streamProtocolV1 struct {
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
}
func (e *streamProtocol) stream(conn httpstream.Connection) error {
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
doneChan := make(chan struct{}, 2)
errorChan := make(chan error)
cp := func(s string, dst io.Writer, src io.Reader) {
glog.V(4).Infof("Copying %s", s)
defer glog.V(4).Infof("Done copying %s", s)
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
glog.Errorf("Error copying %s: %v", s, err)
}
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
doneChan <- struct{}{}
}
}
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(headers)
if err != nil {
return err
}
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
return
}
if len(message) > 0 {
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
return
}
}()
defer errorStream.Reset()
if e.stdin != nil {
headers.Set(api.StreamType, api.StreamTypeStdin)
remoteStdin, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdin.Reset()
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
// because stdin is not closed until the process exits. If we try to call
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
// exit when the process exits, instead.
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
}
waitCount := 0
completedStreams := 0
if e.stdout != nil {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStdout)
remoteStdout, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdout.Reset()
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
}
if e.stderr != nil && !e.tty {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStderr)
remoteStderr, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStderr.Reset()
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
}
Loop:
for {
select {
case <-doneChan:
completedStreams++
if completedStreams == waitCount {
break Loop
}
case err := <-errorChan:
return err
}
}
return nil
}
// streamProtocolV2 implements version 2 of the streaming protocol for attach
// and exec. The original streaming protocol was unversioned. As a result, this
// version is referred to as version 2, even though it is the first actual
// numbered version.
type streamProtocolV2 struct {
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
}
func (e *streamProtocolV2) stream(conn httpstream.Connection) error {
headers := http.Header{}
// set up error stream

View File

@ -45,7 +45,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error {
streamCh <- stream
return nil
})
@ -57,6 +57,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
return
}
defer conn.Close()
_ = protocol
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
receivedStreams := 0
@ -347,7 +348,7 @@ func TestDial(t *testing.T) {
checkResponse: true,
conn: &fakeConnection{},
resp: &http.Response{
StatusCode: http.StatusOK,
StatusCode: http.StatusSwitchingProtocols,
Body: ioutil.NopCloser(&bytes.Buffer{}),
},
}
@ -363,7 +364,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Fatal(err)
}
conn, err := exec.Dial()
conn, protocol, err := exec.Dial([]string{"a", "b"})
if err != nil {
t.Fatal(err)
}
@ -373,4 +374,5 @@ func TestDial(t *testing.T) {
if !called {
t.Errorf("wrapper not called")
}
_ = protocol
}

View File

@ -44,6 +44,8 @@ import (
"k8s.io/kubernetes/pkg/api/validation"
"k8s.io/kubernetes/pkg/auth/authenticator"
"k8s.io/kubernetes/pkg/auth/authorizer"
"k8s.io/kubernetes/pkg/client/unversioned/portforward"
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
"k8s.io/kubernetes/pkg/healthz"
"k8s.io/kubernetes/pkg/httplog"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
@ -688,7 +690,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error {
streamCh <- stream
return nil
})
@ -699,6 +701,9 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
// if we weren't successful in upgrading.
return nil, nil, nil, nil, nil, false, false
}
if len(protocol) == 0 {
protocol = remotecommand.StreamProtocolV1Name
}
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
@ -783,12 +788,14 @@ func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder Po
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan))
if conn == nil {
return
}
defer conn.Close()
_ = protocol
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)

View File

@ -24,8 +24,9 @@ import (
)
const (
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderProtocolVersion = "X-Stream-Protocol-Version"
)
// NewStreamHandler defines a function that is called when a new Stream is
@ -39,7 +40,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil }
// Dialer knows how to open a streaming connection to a server.
type Dialer interface {
Dial() (Connection, error)
// Dial opens a streaming connection to a server using one of the protocols
// specified (in order of most preferred to least preferred).
Dial(protocols []string) (Connection, string, error)
}
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
@ -58,7 +62,7 @@ type ResponseUpgrader interface {
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// other end of the upgraded connection creates a new stream.
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string)
}
// Connection represents an upgraded HTTP connection.

View File

@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) {
streamCh := make(chan httpstream.Stream)
responseUpgrader := NewResponseUpgrader()
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error {
spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error {
streamCh <- s
return nil
})

View File

@ -39,23 +39,47 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader {
return responseUpgrader{}
}
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
for i := range clientProtocols {
for j := range serverProtocols {
if clientProtocols[i] == serverProtocols[j] {
return clientProtocols[i]
}
}
}
return ""
}
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// other end of the upgraded connection creates a new stream.
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) {
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header)))
w.WriteHeader(http.StatusBadRequest)
return nil
fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
return nil, ""
}
hijacker, ok := w.(http.Hijacker)
if !ok {
w.Write([]byte("Unable to upgrade: unable to hijack response"))
w.WriteHeader(http.StatusInternalServerError)
return nil
fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
return nil, ""
}
var negotiatedProtocol string
clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)]
if len(clientProtocols) > 0 {
negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols)
if len(negotiatedProtocol) > 0 {
w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol)
} else {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols)
return nil, ""
}
}
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
@ -64,15 +88,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
conn, _, err := hijacker.Hijack()
if err != nil {
glog.Errorf("Unable to upgrade: error hijacking response: %v", err)
return nil
glog.Errorf("unable to upgrade: error hijacking response: %v", err)
return nil, ""
}
spdyConn, err := NewServerConnection(conn, newStreamHandler)
if err != nil {
glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err)
return nil
glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)
return nil, ""
}
return spdyConn
return spdyConn, negotiatedProtocol
}

View File

@ -53,7 +53,8 @@ func TestUpgradeResponse(t *testing.T) {
for i, testCase := range testCases {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
upgrader := NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, nil)
conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil)
_ = protocol
haveErr := conn == nil
if e, a := testCase.shouldError, haveErr; e != a {
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)