Address code review comments

This commit is contained in:
Andy Goldstein 2015-10-21 20:42:40 -04:00
parent 6c7b519619
commit ff9883d9ec
10 changed files with 107 additions and 64 deletions

View File

@ -29,6 +29,7 @@ import (
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/kubelet"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
)
@ -122,16 +123,13 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P
}, nil
}
// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding.
const PortForwardProtocolV1Name = "portforward.k8s.io"
// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name})
pf.streamConn, _, err = pf.dialer.Dial(kubelet.PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}

View File

@ -44,7 +44,7 @@ type fakeDialer struct {
negotiatedProtocol string
}
func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) {
func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
d.dialed = true
return d.conn, d.negotiatedProtocol, d.err
}

View File

@ -98,7 +98,7 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou
// Dial opens a connection to a remote server and attempts to negotiate a SPDY
// connection. Upon success, it returns the connection and the protocol
// selected by the server.
func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) {
func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) {
transport := e.transport
// TODO consider removing this and reusing client.TransportFor above to get this for free
switch {
@ -111,6 +111,9 @@ func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string
case bool(glog.V(6)):
transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
}
// TODO the client probably shouldn't be created here, as it doesn't allow
// flexibility to allow callers to configure it.
client := &http.Client{Transport: transport}
req, err := http.NewRequest(e.method, e.url.String(), nil)
@ -158,7 +161,8 @@ type streamProtocolHandler interface {
// Stream opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects.
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name})
supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name}
conn, protocol, err := e.Dial(supportedProtocols...)
if err != nil {
return err
}
@ -175,8 +179,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b
tty: tty,
}
case "":
glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned")
// TODO restore v1
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to unversioned")
streamer = &streamProtocolV1{
stdin: stdin,
stdout: stdout,

View File

@ -42,10 +42,17 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
if err != nil {
t.Fatal(err)
}
if protocol != StreamProtocolV2Name {
t.Fatalf("unexpected protocol: %s", protocol)
}
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error {
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error {
streamCh <- stream
return nil
})
@ -57,7 +64,6 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
return
}
defer conn.Close()
_ = protocol
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
receivedStreams := 0
@ -185,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) {
url, _ := url.ParseRequestURI(server.URL)
c := client.NewRESTClient(url, "x", nil, -1, -1)
req := c.Post().Resource("testing")
req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name)
req.Param("command", "ls")
req.Param("command", "/")
conf := &client.Config{
@ -364,7 +371,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Fatal(err)
}
conn, protocol, err := exec.Dial([]string{"a", "b"})
conn, protocol, err := exec.Dial("protocol1")
if err != nil {
t.Fatal(err)
}

View File

@ -27,6 +27,10 @@ import (
"k8s.io/kubernetes/pkg/util/httpstream"
)
// streamProtocolV1 implements the first version of the streaming exec & attach
// protocol. This version has some bugs, such as not being able to detecte when
// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395 for more details.
type streamProtocolV1 struct {
stdin io.Reader
stdout io.Writer
@ -41,8 +45,8 @@ func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
errorChan := make(chan error)
cp := func(s string, dst io.Writer, src io.Reader) {
glog.V(4).Infof("Copying %s", s)
defer glog.V(4).Infof("Done copying %s", s)
glog.V(6).Infof("Copying %s", s)
defer glog.V(6).Infof("Done copying %s", s)
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
glog.Errorf("Error copying %s: %v", s, err)
}

View File

@ -44,7 +44,6 @@ import (
"k8s.io/kubernetes/pkg/api/validation"
"k8s.io/kubernetes/pkg/auth/authenticator"
"k8s.io/kubernetes/pkg/auth/authorizer"
"k8s.io/kubernetes/pkg/client/unversioned/portforward"
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
"k8s.io/kubernetes/pkg/healthz"
"k8s.io/kubernetes/pkg/httplog"
@ -687,10 +686,17 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
return streams[0], streams[1], streams[2], streams[3], conn, tty, true
}
supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}
_, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name)
// negotiated protocol isn't used server side at the moment, but could be in the future
if err != nil {
return nil, nil, nil, nil, nil, false, false
}
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error {
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
streamCh <- stream
return nil
})
@ -701,9 +707,6 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
// if we weren't successful in upgrading.
return nil, nil, nil, nil, nil, false, false
}
if len(protocol) == 0 {
protocol = remotecommand.StreamProtocolV1Name
}
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
@ -778,24 +781,34 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
}
// The subprotocol "portforward.k8s.io" is used for port forwarding.
const PortForwardProtocolV1Name = "portforward.k8s.io"
// ServePortForward handles a port forwarding request. A single request is
// kept alive as long as the client is still alive and the connection has not
// been timed out due to idleness. This function handles multiple forwarded
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, PortForwardProtocolV1Name)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
util.HandleError(err)
return
}
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan))
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
if conn == nil {
return
}
defer conn.Close()
_ = protocol
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)

View File

@ -17,6 +17,7 @@ limitations under the License.
package httpstream
import (
"fmt"
"io"
"net/http"
"strings"
@ -24,9 +25,10 @@ import (
)
const (
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderProtocolVersion = "X-Stream-Protocol-Version"
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderProtocolVersion = "X-Stream-Protocol-Version"
HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
)
// NewStreamHandler defines a function that is called when a new Stream is
@ -43,7 +45,7 @@ type Dialer interface {
// Dial opens a streaming connection to a server using one of the protocols
// specified (in order of most preferred to least preferred).
Dial(protocols []string) (Connection, string, error)
Dial(protocols ...string) (Connection, string, error)
}
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
@ -60,9 +62,9 @@ type UpgradeRoundTripper interface {
// add streaming support to them.
type ResponseUpgrader interface {
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// streams. newStreamHandler will be called asynchronously whenever the
// other end of the upgraded connection creates a new stream.
UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string)
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
}
// Connection represents an upgraded HTTP connection.
@ -100,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool {
}
return false
}
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
for i := range clientProtocols {
for j := range serverProtocols {
if clientProtocols[i] == serverProtocols[j] {
return clientProtocols[i]
}
}
}
return ""
}
// Handshake performs a subprotocol negotiation. If the client did not request
// a specific subprotocol, defaultProtocol is used. If the client did request a
// subprotocol, Handshake will select the first common value found in
// serverProtocols. If a match is found, Handshake adds a response header
// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
// returned, along with a response header containing the list of protocols the
// server can accept.
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) {
clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
if len(clientProtocols) == 0 {
// Kube 1.0 client that didn't support subprotocol negotiation
// TODO remove this defaulting logic once Kube 1.0 is no longer supported
w.Header().Add(HeaderProtocolVersion, defaultProtocol)
return defaultProtocol, nil
}
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
if len(negotiatedProtocol) == 0 {
w.WriteHeader(http.StatusForbidden)
for i := range serverProtocols {
w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
}
fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols)
}
w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
return negotiatedProtocol, nil
}

View File

@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) {
streamCh := make(chan httpstream.Stream)
responseUpgrader := NewResponseUpgrader()
spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error {
spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error {
streamCh <- s
return nil
})

View File

@ -21,7 +21,7 @@ import (
"net/http"
"strings"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
)
@ -39,47 +39,23 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader {
return responseUpgrader{}
}
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
for i := range clientProtocols {
for j := range serverProtocols {
if clientProtocols[i] == serverProtocols[j] {
return clientProtocols[i]
}
}
}
return ""
}
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// other end of the upgraded connection creates a new stream.
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) {
func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection {
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
return nil, ""
return nil
}
hijacker, ok := w.(http.Hijacker)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
return nil, ""
}
var negotiatedProtocol string
clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)]
if len(clientProtocols) > 0 {
negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols)
if len(negotiatedProtocol) > 0 {
w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol)
} else {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols)
return nil, ""
}
return nil
}
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
@ -88,15 +64,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
conn, _, err := hijacker.Hijack()
if err != nil {
glog.Errorf("unable to upgrade: error hijacking response: %v", err)
return nil, ""
util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err))
return nil
}
spdyConn, err := NewServerConnection(conn, newStreamHandler)
if err != nil {
glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)
return nil, ""
util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err))
return nil
}
return spdyConn, negotiatedProtocol
return spdyConn
}

View File

@ -53,8 +53,7 @@ func TestUpgradeResponse(t *testing.T) {
for i, testCase := range testCases {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
upgrader := NewResponseUpgrader()
conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil)
_ = protocol
conn := upgrader.UpgradeResponse(w, req, nil)
haveErr := conn == nil
if e, a := testCase.shouldError, haveErr; e != a {
t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)