mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-27 05:27:21 +00:00
Merge pull request #15961 from ncdc/stream-protocol-negotiation
Auto commit by PR queue bot
This commit is contained in:
commit
4f17b4b39c
@ -29,6 +29,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
"k8s.io/kubernetes/pkg/api"
|
"k8s.io/kubernetes/pkg/api"
|
||||||
|
"k8s.io/kubernetes/pkg/kubelet/portforward"
|
||||||
"k8s.io/kubernetes/pkg/util"
|
"k8s.io/kubernetes/pkg/util"
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
)
|
)
|
||||||
@ -128,7 +129,7 @@ func (pf *PortForwarder) ForwardPorts() error {
|
|||||||
defer pf.Close()
|
defer pf.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
pf.streamConn, err = pf.dialer.Dial()
|
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error upgrading connection: %s", err)
|
return fmt.Errorf("error upgrading connection: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -41,11 +41,12 @@ type fakeDialer struct {
|
|||||||
dialed bool
|
dialed bool
|
||||||
conn httpstream.Connection
|
conn httpstream.Connection
|
||||||
err error
|
err error
|
||||||
|
negotiatedProtocol string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *fakeDialer) Dial() (httpstream.Connection, error) {
|
func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||||
d.dialed = true
|
d.dialed = true
|
||||||
return d.conn, d.err
|
return d.conn, d.negotiatedProtocol, d.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParsePortsAndNew(t *testing.T) {
|
func TestParsePortsAndNew(t *testing.T) {
|
||||||
|
@ -19,14 +19,12 @@ package remotecommand
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/api"
|
"github.com/golang/glog"
|
||||||
|
|
||||||
client "k8s.io/kubernetes/pkg/client/unversioned"
|
client "k8s.io/kubernetes/pkg/client/unversioned"
|
||||||
"k8s.io/kubernetes/pkg/util"
|
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||||
)
|
)
|
||||||
@ -97,155 +95,100 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection.
|
// Dial opens a connection to a remote server and attempts to negotiate a SPDY
|
||||||
func (e *streamExecutor) Dial() (httpstream.Connection, error) {
|
// connection. Upon success, it returns the connection and the protocol
|
||||||
client := &http.Client{Transport: e.transport}
|
// selected by the server.
|
||||||
|
func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) {
|
||||||
|
transport := e.transport
|
||||||
|
// TODO consider removing this and reusing client.TransportFor above to get this for free
|
||||||
|
switch {
|
||||||
|
case bool(glog.V(9)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders)
|
||||||
|
case bool(glog.V(8)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders)
|
||||||
|
case bool(glog.V(7)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus)
|
||||||
|
case bool(glog.V(6)):
|
||||||
|
transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO the client probably shouldn't be created here, as it doesn't allow
|
||||||
|
// flexibility to allow callers to configure it.
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
req, err := http.NewRequest(e.method, e.url.String(), nil)
|
req, err := http.NewRequest(e.method, e.url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating request: %s", err)
|
return nil, "", fmt.Errorf("error creating request: %v", err)
|
||||||
|
}
|
||||||
|
for i := range protocols {
|
||||||
|
req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error sending request: %s", err)
|
return nil, "", fmt.Errorf("error sending request: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// TODO: handle protocol selection in the future
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
return e.upgrader.NewConnection(resp)
|
return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := e.upgrader.NewConnection(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The SPDY subprotocol "channel.k8s.io" is used for remote command
|
||||||
|
// attachment/execution. This represents the initial unversioned subprotocol,
|
||||||
|
// which has the known bugs http://issues.k8s.io/13394 and
|
||||||
|
// http://issues.k8s.io/13395.
|
||||||
|
StreamProtocolV1Name = "channel.k8s.io"
|
||||||
|
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
|
||||||
|
// attachment/execution. It is the second version of the subprotocol and
|
||||||
|
// resolves the issues present in the first version.
|
||||||
|
StreamProtocolV2Name = "v2.channel.k8s.io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type streamProtocolHandler interface {
|
||||||
|
stream(httpstream.Connection) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream opens a protocol streamer to the server and streams until a client closes
|
// Stream opens a protocol streamer to the server and streams until a client closes
|
||||||
// the connection or the server disconnects.
|
// the connection or the server disconnects.
|
||||||
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
|
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
|
||||||
conn, err := e.Dial()
|
supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name}
|
||||||
|
conn, protocol, err := e.Dial(supportedProtocols...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
// TODO: negotiate protocols
|
|
||||||
streamer := &streamProtocol{
|
var streamer streamProtocolHandler
|
||||||
|
|
||||||
|
switch protocol {
|
||||||
|
case StreamProtocolV2Name:
|
||||||
|
streamer = &streamProtocolV2{
|
||||||
stdin: stdin,
|
stdin: stdin,
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
stderr: stderr,
|
stderr: stderr,
|
||||||
tty: tty,
|
tty: tty,
|
||||||
}
|
}
|
||||||
|
case "":
|
||||||
|
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name)
|
||||||
|
fallthrough
|
||||||
|
case StreamProtocolV1Name:
|
||||||
|
streamer = &streamProtocolV1{
|
||||||
|
stdin: stdin,
|
||||||
|
stdout: stdout,
|
||||||
|
stderr: stderr,
|
||||||
|
tty: tty,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return streamer.stream(conn)
|
return streamer.stream(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamProtocol struct {
|
|
||||||
stdin io.Reader
|
|
||||||
stdout io.Writer
|
|
||||||
stderr io.Writer
|
|
||||||
tty bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *streamProtocol) stream(conn httpstream.Connection) error {
|
|
||||||
headers := http.Header{}
|
|
||||||
|
|
||||||
// set up error stream
|
|
||||||
errorChan := make(chan error)
|
|
||||||
headers.Set(api.StreamType, api.StreamTypeError)
|
|
||||||
errorStream, err := conn.CreateStream(headers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
message, err := ioutil.ReadAll(errorStream)
|
|
||||||
switch {
|
|
||||||
case err != nil && err != io.EOF:
|
|
||||||
errorChan <- fmt.Errorf("error reading from error stream: %s", err)
|
|
||||||
case len(message) > 0:
|
|
||||||
errorChan <- fmt.Errorf("error executing remote command: %s", message)
|
|
||||||
default:
|
|
||||||
errorChan <- nil
|
|
||||||
}
|
|
||||||
close(errorChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var once sync.Once
|
|
||||||
|
|
||||||
// set up stdin stream
|
|
||||||
if e.stdin != nil {
|
|
||||||
headers.Set(api.StreamType, api.StreamTypeStdin)
|
|
||||||
remoteStdin, err := conn.CreateStream(headers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy from client's stdin to container's stdin
|
|
||||||
go func() {
|
|
||||||
// if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure
|
|
||||||
// we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise
|
|
||||||
// the executed command will remain running.
|
|
||||||
defer once.Do(func() { remoteStdin.Close() })
|
|
||||||
|
|
||||||
if _, err := io.Copy(remoteStdin, e.stdin); err != nil {
|
|
||||||
util.HandleError(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// read from remoteStdin until the stream is closed. this is essential to
|
|
||||||
// be able to exit interactive sessions cleanly and not leak goroutines or
|
|
||||||
// hang the client's terminal.
|
|
||||||
//
|
|
||||||
// go-dockerclient's current hijack implementation
|
|
||||||
// (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564)
|
|
||||||
// waits for all three streams (stdin/stdout/stderr) to finish copying
|
|
||||||
// before returning. When hijack finishes copying stdout/stderr, it calls
|
|
||||||
// Close() on its side of remoteStdin, which allows this copy to complete.
|
|
||||||
// When that happens, we must Close() on our side of remoteStdin, to
|
|
||||||
// allow the copy in hijack to complete, and hijack to return.
|
|
||||||
go func() {
|
|
||||||
defer once.Do(func() { remoteStdin.Close() })
|
|
||||||
// this "copy" doesn't actually read anything - it's just here to wait for
|
|
||||||
// the server to close remoteStdin.
|
|
||||||
if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil {
|
|
||||||
util.HandleError(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// set up stdout stream
|
|
||||||
if e.stdout != nil {
|
|
||||||
headers.Set(api.StreamType, api.StreamTypeStdout)
|
|
||||||
remoteStdout, err := conn.CreateStream(headers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
if _, err := io.Copy(e.stdout, remoteStdout); err != nil {
|
|
||||||
util.HandleError(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// set up stderr stream
|
|
||||||
if e.stderr != nil && !e.tty {
|
|
||||||
headers.Set(api.StreamType, api.StreamTypeStderr)
|
|
||||||
remoteStderr, err := conn.CreateStream(headers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
if _, err := io.Copy(e.stderr, remoteStderr); err != nil {
|
|
||||||
util.HandleError(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// we're waiting for stdout/stderr to finish copying
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// waits for errorStream to finish reading with an error or nil
|
|
||||||
return <-errorChan
|
|
||||||
}
|
|
||||||
|
@ -42,6 +42,13 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if protocol != StreamProtocolV2Name {
|
||||||
|
t.Fatalf("unexpected protocol: %s", protocol)
|
||||||
|
}
|
||||||
streamCh := make(chan httpstream.Stream)
|
streamCh := make(chan httpstream.Stream)
|
||||||
|
|
||||||
upgrader := spdy.NewResponseUpgrader()
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
@ -184,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) {
|
|||||||
url, _ := url.ParseRequestURI(server.URL)
|
url, _ := url.ParseRequestURI(server.URL)
|
||||||
c := client.NewRESTClient(url, "x", nil, -1, -1)
|
c := client.NewRESTClient(url, "x", nil, -1, -1)
|
||||||
req := c.Post().Resource("testing")
|
req := c.Post().Resource("testing")
|
||||||
|
req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name)
|
||||||
req.Param("command", "ls")
|
req.Param("command", "ls")
|
||||||
req.Param("command", "/")
|
req.Param("command", "/")
|
||||||
conf := &client.Config{
|
conf := &client.Config{
|
||||||
@ -347,7 +355,7 @@ func TestDial(t *testing.T) {
|
|||||||
checkResponse: true,
|
checkResponse: true,
|
||||||
conn: &fakeConnection{},
|
conn: &fakeConnection{},
|
||||||
resp: &http.Response{
|
resp: &http.Response{
|
||||||
StatusCode: http.StatusOK,
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
Body: ioutil.NopCloser(&bytes.Buffer{}),
|
Body: ioutil.NopCloser(&bytes.Buffer{}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -363,7 +371,7 @@ func TestDial(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn, err := exec.Dial()
|
conn, protocol, err := exec.Dial("protocol1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -373,4 +381,5 @@ func TestDial(t *testing.T) {
|
|||||||
if !called {
|
if !called {
|
||||||
t.Errorf("wrapper not called")
|
t.Errorf("wrapper not called")
|
||||||
}
|
}
|
||||||
|
_ = protocol
|
||||||
}
|
}
|
||||||
|
130
pkg/client/unversioned/remotecommand/v1.go
Normal file
130
pkg/client/unversioned/remotecommand/v1.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package remotecommand
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/golang/glog"
|
||||||
|
"k8s.io/kubernetes/pkg/api"
|
||||||
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
|
)
|
||||||
|
|
||||||
|
// streamProtocolV1 implements the first version of the streaming exec & attach
|
||||||
|
// protocol. This version has some bugs, such as not being able to detecte when
|
||||||
|
// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and
|
||||||
|
// http://issues.k8s.io/13395 for more details.
|
||||||
|
type streamProtocolV1 struct {
|
||||||
|
stdin io.Reader
|
||||||
|
stdout io.Writer
|
||||||
|
stderr io.Writer
|
||||||
|
tty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ streamProtocolHandler = &streamProtocolV1{}
|
||||||
|
|
||||||
|
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
|
||||||
|
doneChan := make(chan struct{}, 2)
|
||||||
|
errorChan := make(chan error)
|
||||||
|
|
||||||
|
cp := func(s string, dst io.Writer, src io.Reader) {
|
||||||
|
glog.V(6).Infof("Copying %s", s)
|
||||||
|
defer glog.V(6).Infof("Done copying %s", s)
|
||||||
|
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
|
||||||
|
glog.Errorf("Error copying %s: %v", s, err)
|
||||||
|
}
|
||||||
|
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
|
||||||
|
doneChan <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeError)
|
||||||
|
errorStream, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
message, err := ioutil.ReadAll(errorStream)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(message) > 0 {
|
||||||
|
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer errorStream.Reset()
|
||||||
|
|
||||||
|
if e.stdin != nil {
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStdin)
|
||||||
|
remoteStdin, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer remoteStdin.Reset()
|
||||||
|
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
|
||||||
|
// because stdin is not closed until the process exits. If we try to call
|
||||||
|
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
|
||||||
|
// exit when the process exits, instead.
|
||||||
|
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
|
||||||
|
}
|
||||||
|
|
||||||
|
waitCount := 0
|
||||||
|
completedStreams := 0
|
||||||
|
|
||||||
|
if e.stdout != nil {
|
||||||
|
waitCount++
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStdout)
|
||||||
|
remoteStdout, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer remoteStdout.Reset()
|
||||||
|
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.stderr != nil && !e.tty {
|
||||||
|
waitCount++
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStderr)
|
||||||
|
remoteStderr, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer remoteStderr.Reset()
|
||||||
|
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
completedStreams++
|
||||||
|
if completedStreams == waitCount {
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
case err := <-errorChan:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
151
pkg/client/unversioned/remotecommand/v2.go
Normal file
151
pkg/client/unversioned/remotecommand/v2.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package remotecommand
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"k8s.io/kubernetes/pkg/api"
|
||||||
|
"k8s.io/kubernetes/pkg/util"
|
||||||
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
|
)
|
||||||
|
|
||||||
|
// streamProtocolV2 implements version 2 of the streaming protocol for attach
|
||||||
|
// and exec. The original streaming protocol was unversioned. As a result, this
|
||||||
|
// version is referred to as version 2, even though it is the first actual
|
||||||
|
// numbered version.
|
||||||
|
type streamProtocolV2 struct {
|
||||||
|
stdin io.Reader
|
||||||
|
stdout io.Writer
|
||||||
|
stderr io.Writer
|
||||||
|
tty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ streamProtocolHandler = &streamProtocolV2{}
|
||||||
|
|
||||||
|
func (e *streamProtocolV2) stream(conn httpstream.Connection) error {
|
||||||
|
headers := http.Header{}
|
||||||
|
|
||||||
|
// set up error stream
|
||||||
|
errorChan := make(chan error)
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeError)
|
||||||
|
errorStream, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
message, err := ioutil.ReadAll(errorStream)
|
||||||
|
switch {
|
||||||
|
case err != nil && err != io.EOF:
|
||||||
|
errorChan <- fmt.Errorf("error reading from error stream: %s", err)
|
||||||
|
case len(message) > 0:
|
||||||
|
errorChan <- fmt.Errorf("error executing remote command: %s", message)
|
||||||
|
default:
|
||||||
|
errorChan <- nil
|
||||||
|
}
|
||||||
|
close(errorChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
// set up stdin stream
|
||||||
|
if e.stdin != nil {
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStdin)
|
||||||
|
remoteStdin, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy from client's stdin to container's stdin
|
||||||
|
go func() {
|
||||||
|
// if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure
|
||||||
|
// we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise
|
||||||
|
// the executed command will remain running.
|
||||||
|
defer once.Do(func() { remoteStdin.Close() })
|
||||||
|
|
||||||
|
if _, err := io.Copy(remoteStdin, e.stdin); err != nil {
|
||||||
|
util.HandleError(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// read from remoteStdin until the stream is closed. this is essential to
|
||||||
|
// be able to exit interactive sessions cleanly and not leak goroutines or
|
||||||
|
// hang the client's terminal.
|
||||||
|
//
|
||||||
|
// go-dockerclient's current hijack implementation
|
||||||
|
// (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564)
|
||||||
|
// waits for all three streams (stdin/stdout/stderr) to finish copying
|
||||||
|
// before returning. When hijack finishes copying stdout/stderr, it calls
|
||||||
|
// Close() on its side of remoteStdin, which allows this copy to complete.
|
||||||
|
// When that happens, we must Close() on our side of remoteStdin, to
|
||||||
|
// allow the copy in hijack to complete, and hijack to return.
|
||||||
|
go func() {
|
||||||
|
defer once.Do(func() { remoteStdin.Close() })
|
||||||
|
// this "copy" doesn't actually read anything - it's just here to wait for
|
||||||
|
// the server to close remoteStdin.
|
||||||
|
if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil {
|
||||||
|
util.HandleError(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up stdout stream
|
||||||
|
if e.stdout != nil {
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStdout)
|
||||||
|
remoteStdout, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if _, err := io.Copy(e.stdout, remoteStdout); err != nil {
|
||||||
|
util.HandleError(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up stderr stream
|
||||||
|
if e.stderr != nil && !e.tty {
|
||||||
|
headers.Set(api.StreamType, api.StreamTypeStderr)
|
||||||
|
remoteStderr, err := conn.CreateStream(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if _, err := io.Copy(e.stderr, remoteStderr); err != nil {
|
||||||
|
util.HandleError(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// we're waiting for stdout/stderr to finish copying
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// waits for errorStream to finish reading with an error or nil
|
||||||
|
return <-errorChan
|
||||||
|
}
|
21
pkg/kubelet/portforward/constants.go
Normal file
21
pkg/kubelet/portforward/constants.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// package portforward contains server-side logic for handling port forwarding requests.
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
// The subprotocol "portforward.k8s.io" is used for port forwarding.
|
||||||
|
const PortForwardProtocolV1Name = "portforward.k8s.io"
|
@ -44,9 +44,11 @@ import (
|
|||||||
"k8s.io/kubernetes/pkg/api/validation"
|
"k8s.io/kubernetes/pkg/api/validation"
|
||||||
"k8s.io/kubernetes/pkg/auth/authenticator"
|
"k8s.io/kubernetes/pkg/auth/authenticator"
|
||||||
"k8s.io/kubernetes/pkg/auth/authorizer"
|
"k8s.io/kubernetes/pkg/auth/authorizer"
|
||||||
|
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||||
"k8s.io/kubernetes/pkg/healthz"
|
"k8s.io/kubernetes/pkg/healthz"
|
||||||
"k8s.io/kubernetes/pkg/httplog"
|
"k8s.io/kubernetes/pkg/httplog"
|
||||||
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
||||||
|
"k8s.io/kubernetes/pkg/kubelet/portforward"
|
||||||
"k8s.io/kubernetes/pkg/types"
|
"k8s.io/kubernetes/pkg/types"
|
||||||
"k8s.io/kubernetes/pkg/util"
|
"k8s.io/kubernetes/pkg/util"
|
||||||
"k8s.io/kubernetes/pkg/util/flushwriter"
|
"k8s.io/kubernetes/pkg/util/flushwriter"
|
||||||
@ -685,6 +687,13 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
|
|||||||
return streams[0], streams[1], streams[2], streams[3], conn, tty, true
|
return streams[0], streams[1], streams[2], streams[3], conn, tty, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}
|
||||||
|
_, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name)
|
||||||
|
// negotiated protocol isn't used server side at the moment, but could be in the future
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, false, false
|
||||||
|
}
|
||||||
|
|
||||||
streamCh := make(chan httpstream.Stream)
|
streamCh := make(chan httpstream.Stream)
|
||||||
|
|
||||||
upgrader := spdy.NewResponseUpgrader()
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
@ -779,6 +788,15 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
|||||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
||||||
// handled by a single invocation of ServePortForward.
|
// handled by a single invocation of ServePortForward.
|
||||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||||
|
supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name}
|
||||||
|
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name)
|
||||||
|
// negotiated protocol isn't currently used server side, but could be in the future
|
||||||
|
if err != nil {
|
||||||
|
// Handshake writes the error to the client
|
||||||
|
util.HandleError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
streamChan := make(chan httpstream.Stream, 1)
|
streamChan := make(chan httpstream.Stream, 1)
|
||||||
|
|
||||||
glog.V(5).Infof("Upgrading port forward response")
|
glog.V(5).Infof("Upgrading port forward response")
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
package httpstream
|
package httpstream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -26,6 +27,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
HeaderConnection = "Connection"
|
HeaderConnection = "Connection"
|
||||||
HeaderUpgrade = "Upgrade"
|
HeaderUpgrade = "Upgrade"
|
||||||
|
HeaderProtocolVersion = "X-Stream-Protocol-Version"
|
||||||
|
HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewStreamHandler defines a function that is called when a new Stream is
|
// NewStreamHandler defines a function that is called when a new Stream is
|
||||||
@ -39,7 +42,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil }
|
|||||||
|
|
||||||
// Dialer knows how to open a streaming connection to a server.
|
// Dialer knows how to open a streaming connection to a server.
|
||||||
type Dialer interface {
|
type Dialer interface {
|
||||||
Dial() (Connection, error)
|
|
||||||
|
// Dial opens a streaming connection to a server using one of the protocols
|
||||||
|
// specified (in order of most preferred to least preferred).
|
||||||
|
Dial(protocols ...string) (Connection, string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
|
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
|
||||||
@ -56,7 +62,7 @@ type UpgradeRoundTripper interface {
|
|||||||
// add streaming support to them.
|
// add streaming support to them.
|
||||||
type ResponseUpgrader interface {
|
type ResponseUpgrader interface {
|
||||||
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
|
||||||
// streams. newStreamHandler will be called synchronously whenever the
|
// streams. newStreamHandler will be called asynchronously whenever the
|
||||||
// other end of the upgraded connection creates a new stream.
|
// other end of the upgraded connection creates a new stream.
|
||||||
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
|
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
|
||||||
}
|
}
|
||||||
@ -96,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
|
||||||
|
for i := range clientProtocols {
|
||||||
|
for j := range serverProtocols {
|
||||||
|
if clientProtocols[i] == serverProtocols[j] {
|
||||||
|
return clientProtocols[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handshake performs a subprotocol negotiation. If the client did not request
|
||||||
|
// a specific subprotocol, defaultProtocol is used. If the client did request a
|
||||||
|
// subprotocol, Handshake will select the first common value found in
|
||||||
|
// serverProtocols. If a match is found, Handshake adds a response header
|
||||||
|
// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
|
||||||
|
// returned, along with a response header containing the list of protocols the
|
||||||
|
// server can accept.
|
||||||
|
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) {
|
||||||
|
clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
|
||||||
|
if len(clientProtocols) == 0 {
|
||||||
|
// Kube 1.0 client that didn't support subprotocol negotiation
|
||||||
|
// TODO remove this defaulting logic once Kube 1.0 is no longer supported
|
||||||
|
w.Header().Add(HeaderProtocolVersion, defaultProtocol)
|
||||||
|
return defaultProtocol, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
|
||||||
|
if len(negotiatedProtocol) == 0 {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
for i := range serverProtocols {
|
||||||
|
w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
|
||||||
|
return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
|
||||||
|
return negotiatedProtocol, nil
|
||||||
|
}
|
||||||
|
120
pkg/util/httpstream/httpstream_test.go
Normal file
120
pkg/util/httpstream/httpstream_test.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package httpstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type responseWriter struct {
|
||||||
|
header http.Header
|
||||||
|
statusCode *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResponseWriter() *responseWriter {
|
||||||
|
return &responseWriter{
|
||||||
|
header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriter) Header() http.Header {
|
||||||
|
return r.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriter) WriteHeader(code int) {
|
||||||
|
r.statusCode = &code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriter) Write([]byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshake(t *testing.T) {
|
||||||
|
defaultProtocol := "default"
|
||||||
|
|
||||||
|
tests := map[string]struct {
|
||||||
|
clientProtocols []string
|
||||||
|
serverProtocols []string
|
||||||
|
expectedProtocol string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
"no client protocols": {
|
||||||
|
clientProtocols: []string{},
|
||||||
|
serverProtocols: []string{"a", "b"},
|
||||||
|
expectedProtocol: defaultProtocol,
|
||||||
|
},
|
||||||
|
"no common protocol": {
|
||||||
|
clientProtocols: []string{"c"},
|
||||||
|
serverProtocols: []string{"a", "b"},
|
||||||
|
expectedProtocol: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
"common protocol": {
|
||||||
|
clientProtocols: []string{"b"},
|
||||||
|
serverProtocols: []string{"a", "b"},
|
||||||
|
expectedProtocol: "b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
req, err := http.NewRequest("GET", "http://www.example.com/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%s: error creating request: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range test.clientProtocols {
|
||||||
|
req.Header.Add(HeaderProtocolVersion, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := newResponseWriter()
|
||||||
|
negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol)
|
||||||
|
|
||||||
|
// verify negotiated protocol
|
||||||
|
if e, a := test.expectedProtocol, negotiated; e != a {
|
||||||
|
t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.expectError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("%s: expected error but did not get one", name)
|
||||||
|
}
|
||||||
|
if w.statusCode == nil {
|
||||||
|
t.Errorf("%s: expected w.statusCode to be set", name)
|
||||||
|
} else if e, a := http.StatusForbidden, *w.statusCode; e != a {
|
||||||
|
t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
|
||||||
|
}
|
||||||
|
if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
|
||||||
|
t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !test.expectError && err != nil {
|
||||||
|
t.Errorf("%s: unexpected error: %v", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if w.statusCode != nil {
|
||||||
|
t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify response headers
|
||||||
|
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
|
||||||
|
t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -21,7 +21,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/golang/glog"
|
"k8s.io/kubernetes/pkg/util"
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -46,15 +46,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
|
|||||||
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
|
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
|
||||||
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
|
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
|
||||||
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
|
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
|
||||||
w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header)))
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
hijacker, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
w.Write([]byte("Unable to upgrade: unable to hijack response"))
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,13 +64,13 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
|
|||||||
|
|
||||||
conn, _, err := hijacker.Hijack()
|
conn, _, err := hijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
glog.Errorf("Unable to upgrade: error hijacking response: %v", err)
|
util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
spdyConn, err := NewServerConnection(conn, newStreamHandler)
|
spdyConn, err := NewServerConnection(conn, newStreamHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err)
|
util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user