mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-20 09:05:26 +00:00
Add streaming subprotocol negotiation
Add streaming subprotocol negotiation for exec, attach, and port forwarding. Restore previous (buggy) exec functionality as an unspecified/unversioned subprotocol so newer kubectl clients can work against 1.0.x kubelets.
This commit is contained in:
parent
d3862d453f
commit
3d1cafc2c3
@ -122,13 +122,16 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding.
|
||||||
|
const PortForwardProtocolV1Name = "portforward.k8s.io"
|
||||||
|
|
||||||
// ForwardPorts formats and executes a port forwarding request. The connection will remain
|
// ForwardPorts formats and executes a port forwarding request. The connection will remain
|
||||||
// open until stopChan is closed.
|
// open until stopChan is closed.
|
||||||
func (pf *PortForwarder) ForwardPorts() error {
|
func (pf *PortForwarder) ForwardPorts() error {
|
||||||
defer pf.Close()
|
defer pf.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pf.streamConn, err = pf.dialer.Dial()
|
pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error upgrading connection: %s", err)
|
return fmt.Errorf("error upgrading connection: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -38,14 +38,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type fakeDialer struct {
|
type fakeDialer struct {
|
||||||
dialed bool
|
dialed bool
|
||||||
conn httpstream.Connection
|
conn httpstream.Connection
|
||||||
err error
|
err error
|
||||||
|
negotiatedProtocol string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *fakeDialer) Dial() (httpstream.Connection, error) {
|
func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) {
|
||||||
d.dialed = true
|
d.dialed = true
|
||||||
return d.conn, d.err
|
return d.conn, d.negotiatedProtocol, d.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParsePortsAndNew(t *testing.T) {
|
func TestParsePortsAndNew(t *testing.T) {
|
||||||
|
@ -24,6 +24,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/golang/glog"
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/api"
|
"k8s.io/kubernetes/pkg/api"
|
||||||
client "k8s.io/kubernetes/pkg/client/unversioned"
|
client "k8s.io/kubernetes/pkg/client/unversioned"
|
||||||
"k8s.io/kubernetes/pkg/util"
|
"k8s.io/kubernetes/pkg/util"
|
||||||
@ -97,51 +99,207 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection.
|
// Dial opens a connection to a remote server and attempts to negotiate a SPDY
|
||||||
func (e *streamExecutor) Dial() (httpstream.Connection, error) {
|
// connection. Upon success, it returns the connection and the protocol
|
||||||
client := &http.Client{Transport: e.transport}
|
// selected by the server.
|
||||||
|
func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) {
|
||||||
|
transport := e.transport
|
||||||
|
// TODO consider removing this and reusing client.TransportFor above to get this for free
|
||||||
|
switch {
|
||||||
|
case bool(glog.V(9)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders)
|
||||||
|
case bool(glog.V(8)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders)
|
||||||
|
case bool(glog.V(7)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus)
|
||||||
|
case bool(glog.V(6)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
req, err := http.NewRequest(e.method, e.url.String(), nil)
|
req, err := http.NewRequest(e.method, e.url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating request: %s", err)
|
return nil, "", fmt.Errorf("error creating request: %v", err)
|
||||||
|
}
|
||||||
|
for i := range protocols {
|
||||||
|
req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error sending request: %s", err)
|
return nil, "", fmt.Errorf("error sending request: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// TODO: handle protocol selection in the future
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
return e.upgrader.NewConnection(resp)
|
return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := e.upgrader.NewConnection(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The SPDY subprotocol "channel.k8s.io" is used for remote command
|
||||||
|
// attachment/execution. This represents the initial unversioned subprotocol,
|
||||||
|
// which has the known bugs http://issues.k8s.io/13394 and
|
||||||
|
// http://issues.k8s.io/13395.
|
||||||
|
StreamProtocolV1Name = "channel.k8s.io"
|
||||||
|
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
|
||||||
|
// attachment/execution. It is the second version of the subprotocol and
|
||||||
|
// resolves the issues present in the first version.
|
||||||
|
StreamProtocolV2Name = "v2.channel.k8s.io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type streamProtocolHandler interface {
|
||||||
|
stream(httpstream.Connection) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream opens a protocol streamer to the server and streams until a client closes
|
// Stream opens a protocol streamer to the server and streams until a client closes
|
||||||
// the connection or the server disconnects.
|
// the connection or the server disconnects.
|
||||||
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
|
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
|
||||||
conn, err := e.Dial()
|
conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
// TODO: negotiate protocols
|
|
||||||
streamer := &streamProtocol{
|
var streamer streamProtocolHandler
|
||||||
stdin: stdin,
|
|
||||||
stdout: stdout,
|
switch protocol {
|
||||||
stderr: stderr,
|
case StreamProtocolV2Name:
|
||||||
tty: tty,
|
streamer = &streamProtocolV2{
|
||||||
|
stdin: stdin,
|
||||||
|
stdout: stdout,
|
||||||
|
stderr: stderr,
|
||||||
|
tty: tty,
|
||||||
|
}
|
||||||
|
case "":
|
||||||
|
glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned")
|
||||||
|
// TODO restore v1
|
||||||
|
streamer = &streamProtocolV1{
|
||||||
|
stdin: stdin,
|
||||||
|
stdout: stdout,
|
||||||
|
stderr: stderr,
|
||||||
|
tty: tty,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return streamer.stream(conn)
|
return streamer.stream(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamProtocol struct {
|
type streamProtocolV1 struct {
|
||||||
stdin io.Reader
|
stdin io.Reader
|
||||||
stdout io.Writer
|
stdout io.Writer
|
||||||
stderr io.Writer
|
stderr io.Writer
|
||||||
tty bool
|
tty bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *streamProtocol) stream(conn httpstream.Connection) error {
|
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
|
||||||
|
doneChan := make(chan struct{}, 2)
|
||||||
|
errorChan := make(chan error)
|
||||||
|
|
||||||
|
cp := func(s string, dst io.Writer, src io.Reader) {
|
||||||
|
glog.V(4).Infof("Copying %s", s)
|
||||||
|
defer glog.V(4).Infof("Done copying %s", s)
|
||||||
|
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
|
||||||
|
glog.Errorf("Error copying %s: %v", s, err)
|
||||||
|
}
|
||||||
|
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
|
||||||
|
doneChan <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeError)
|
||||||
|
errorStream, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
message, err := ioutil.ReadAll(errorStream)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(message) > 0 {
|
||||||
|
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer errorStream.Reset()
|
||||||
|
|
||||||
|
if e.stdin != nil {
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStdin)
|
||||||
|
remoteStdin, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer remoteStdin.Reset()
|
||||||
|
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
|
||||||
|
// because stdin is not closed until the process exits. If we try to call
|
||||||
|
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
|
||||||
|
// exit when the process exits, instead.
|
||||||
|
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
|
||||||
|
}
|
||||||
|
|
||||||
|
waitCount := 0
|
||||||
|
completedStreams := 0
|
||||||
|
|
||||||
|
if e.stdout != nil {
|
||||||
|
waitCount++
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStdout)
|
||||||
|
remoteStdout, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer remoteStdout.Reset()
|
||||||
|
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.stderr != nil && !e.tty {
|
||||||
|
waitCount++
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStderr)
|
||||||
|
remoteStderr, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer remoteStderr.Reset()
|
||||||
|
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
completedStreams++
|
||||||
|
if completedStreams == waitCount {
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
case err := <-errorChan:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamProtocolV2 implements version 2 of the streaming protocol for attach
|
||||||
|
// and exec. The original streaming protocol was unversioned. As a result, this
|
||||||
|
// version is referred to as version 2, even though it is the first actual
|
||||||
|
// numbered version.
|
||||||
|
type streamProtocolV2 struct {
|
||||||
|
stdin io.Reader
|
||||||
|
stdout io.Writer
|
||||||
|
stderr io.Writer
|
||||||
|
tty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *streamProtocolV2) stream(conn httpstream.Connection) error {
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
|
|
||||||
// set up error stream
|
// set up error stream
|
||||||
|
@ -45,7 +45,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
|
|||||||
streamCh := make(chan httpstream.Stream)
|
streamCh := make(chan httpstream.Stream)
|
||||||
|
|
||||||
upgrader := spdy.NewResponseUpgrader()
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
|
conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error {
|
||||||
streamCh <- stream
|
streamCh <- stream
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -57,6 +57,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
_ = protocol
|
||||||
|
|
||||||
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
|
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
|
||||||
receivedStreams := 0
|
receivedStreams := 0
|
||||||
@ -347,7 +348,7 @@ func TestDial(t *testing.T) {
|
|||||||
checkResponse: true,
|
checkResponse: true,
|
||||||
conn: &fakeConnection{},
|
conn: &fakeConnection{},
|
||||||
resp: &http.Response{
|
resp: &http.Response{
|
||||||
StatusCode: http.StatusOK,
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
Body: ioutil.NopCloser(&bytes.Buffer{}),
|
Body: ioutil.NopCloser(&bytes.Buffer{}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -363,7 +364,7 @@ func TestDial(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn, err := exec.Dial()
|
conn, protocol, err := exec.Dial([]string{"a", "b"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -373,4 +374,5 @@ func TestDial(t *testing.T) {
|
|||||||
if !called {
|
if !called {
|
||||||
t.Errorf("wrapper not called")
|
t.Errorf("wrapper not called")
|
||||||
}
|
}
|
||||||
|
_ = protocol
|
||||||
}
|
}
|
||||||
|
@ -44,6 +44,8 @@ import (
|
|||||||
"k8s.io/kubernetes/pkg/api/validation"
|
"k8s.io/kubernetes/pkg/api/validation"
|
||||||
"k8s.io/kubernetes/pkg/auth/authenticator"
|
"k8s.io/kubernetes/pkg/auth/authenticator"
|
||||||
"k8s.io/kubernetes/pkg/auth/authorizer"
|
"k8s.io/kubernetes/pkg/auth/authorizer"
|
||||||
|
"k8s.io/kubernetes/pkg/client/unversioned/portforward"
|
||||||
|
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||||
"k8s.io/kubernetes/pkg/healthz"
|
"k8s.io/kubernetes/pkg/healthz"
|
||||||
"k8s.io/kubernetes/pkg/httplog"
|
"k8s.io/kubernetes/pkg/httplog"
|
||||||
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
||||||
@ -688,7 +690,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
|
|||||||
streamCh := make(chan httpstream.Stream)
|
streamCh := make(chan httpstream.Stream)
|
||||||
|
|
||||||
upgrader := spdy.NewResponseUpgrader()
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
|
conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error {
|
||||||
streamCh <- stream
|
streamCh <- stream
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -699,6 +701,9 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
|
|||||||
// if we weren't successful in upgrading.
|
// if we weren't successful in upgrading.
|
||||||
return nil, nil, nil, nil, nil, false, false
|
return nil, nil, nil, nil, nil, false, false
|
||||||
}
|
}
|
||||||
|
if len(protocol) == 0 {
|
||||||
|
protocol = remotecommand.StreamProtocolV1Name
|
||||||
|
}
|
||||||
|
|
||||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||||
|
|
||||||
@ -783,12 +788,14 @@ func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder Po
|
|||||||
|
|
||||||
glog.V(5).Infof("Upgrading port forward response")
|
glog.V(5).Infof("Upgrading port forward response")
|
||||||
upgrader := spdy.NewResponseUpgrader()
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan))
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
_ = protocol
|
||||||
|
|
||||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||||
conn.SetIdleTimeout(idleTimeout)
|
conn.SetIdleTimeout(idleTimeout)
|
||||||
|
|
||||||
|
@ -24,8 +24,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HeaderConnection = "Connection"
|
HeaderConnection = "Connection"
|
||||||
HeaderUpgrade = "Upgrade"
|
HeaderUpgrade = "Upgrade"
|
||||||
|
HeaderProtocolVersion = "X-Stream-Protocol-Version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewStreamHandler defines a function that is called when a new Stream is
|
// NewStreamHandler defines a function that is called when a new Stream is
|
||||||
@ -39,7 +40,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil }
|
|||||||
|
|
||||||
// Dialer knows how to open a streaming connection to a server.
|
// Dialer knows how to open a streaming connection to a server.
|
||||||
type Dialer interface {
|
type Dialer interface {
|
||||||
Dial() (Connection, error)
|
|
||||||
|
// Dial opens a streaming connection to a server using one of the protocols
|
||||||
|
// specified (in order of most preferred to least preferred).
|
||||||
|
Dial(protocols []string) (Connection, string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
|
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
|
||||||
@ -58,7 +62,7 @@ type ResponseUpgrader interface {
|
|||||||
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
||||||
// streams. newStreamHandler will be called synchronously whenever the
|
// streams. newStreamHandler will be called synchronously whenever the
|
||||||
// other end of the upgraded connection creates a new stream.
|
// other end of the upgraded connection creates a new stream.
|
||||||
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
|
UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection represents an upgraded HTTP connection.
|
// Connection represents an upgraded HTTP connection.
|
||||||
|
@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) {
|
|||||||
streamCh := make(chan httpstream.Stream)
|
streamCh := make(chan httpstream.Stream)
|
||||||
|
|
||||||
responseUpgrader := NewResponseUpgrader()
|
responseUpgrader := NewResponseUpgrader()
|
||||||
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error {
|
spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error {
|
||||||
streamCh <- s
|
streamCh <- s
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -39,23 +39,47 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader {
|
|||||||
return responseUpgrader{}
|
return responseUpgrader{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
|
||||||
|
for i := range clientProtocols {
|
||||||
|
for j := range serverProtocols {
|
||||||
|
if clientProtocols[i] == serverProtocols[j] {
|
||||||
|
return clientProtocols[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
||||||
// streams. newStreamHandler will be called synchronously whenever the
|
// streams. newStreamHandler will be called synchronously whenever the
|
||||||
// other end of the upgraded connection creates a new stream.
|
// other end of the upgraded connection creates a new stream.
|
||||||
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
|
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) {
|
||||||
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
|
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
|
||||||
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
|
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
|
||||||
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
|
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
|
||||||
w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header)))
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return nil
|
fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
|
||||||
|
return nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
hijacker, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
w.Write([]byte("Unable to upgrade: unable to hijack response"))
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return nil
|
fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var negotiatedProtocol string
|
||||||
|
clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)]
|
||||||
|
if len(clientProtocols) > 0 {
|
||||||
|
negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols)
|
||||||
|
if len(negotiatedProtocol) > 0 {
|
||||||
|
w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols)
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
|
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
|
||||||
@ -64,15 +88,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
|
|||||||
|
|
||||||
conn, _, err := hijacker.Hijack()
|
conn, _, err := hijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
glog.Errorf("Unable to upgrade: error hijacking response: %v", err)
|
glog.Errorf("unable to upgrade: error hijacking response: %v", err)
|
||||||
return nil
|
return nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
spdyConn, err := NewServerConnection(conn, newStreamHandler)
|
spdyConn, err := NewServerConnection(conn, newStreamHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err)
|
glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)
|
||||||
return nil
|
return nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return spdyConn
|
return spdyConn, negotiatedProtocol
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,8 @@ func TestUpgradeResponse(t *testing.T) {
|
|||||||
for i, testCase := range testCases {
|
for i, testCase := range testCases {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
upgrader := NewResponseUpgrader()
|
upgrader := NewResponseUpgrader()
|
||||||
conn := upgrader.UpgradeResponse(w, req, nil)
|
conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil)
|
||||||
|
_ = protocol
|
||||||
haveErr := conn == nil
|
haveErr := conn == nil
|
||||||
if e, a := testCase.shouldError, haveErr; e != a {
|
if e, a := testCase.shouldError, haveErr; e != a {
|
||||||
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
|
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)
|
||||||
|
Loading…
Reference in New Issue
Block a user