From 6fddb0e83ac582f7a33c43ea7ac2a073f9ff98c5 Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Fri, 23 Oct 2015 14:09:41 -0400 Subject: [PATCH] Add httpstream.Handshake unit test --- .../remotecommand/remotecommand.go | 4 +- pkg/util/httpstream/httpstream_test.go | 120 ++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 pkg/util/httpstream/httpstream_test.go diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 99e914abebf..4feb953d95f 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -179,7 +179,9 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b tty: tty, } case "": - glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to unversioned") + glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name) + fallthrough + case StreamProtocolV1Name: streamer = &streamProtocolV1{ stdin: stdin, stdout: stdout, diff --git a/pkg/util/httpstream/httpstream_test.go b/pkg/util/httpstream/httpstream_test.go new file mode 100644 index 00000000000..7a1bbaefb89 --- /dev/null +++ b/pkg/util/httpstream/httpstream_test.go @@ -0,0 +1,120 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package httpstream + +import ( + "net/http" + "reflect" + "testing" +) + +type responseWriter struct { + header http.Header + statusCode *int +} + +func newResponseWriter() *responseWriter { + return &responseWriter{ + header: make(http.Header), + } +} + +func (r *responseWriter) Header() http.Header { + return r.header +} + +func (r *responseWriter) WriteHeader(code int) { + r.statusCode = &code +} + +func (r *responseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func TestHandshake(t *testing.T) { + defaultProtocol := "default" + + tests := map[string]struct { + clientProtocols []string + serverProtocols []string + expectedProtocol string + expectError bool + }{ + "no client protocols": { + clientProtocols: []string{}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: defaultProtocol, + }, + "no common protocol": { + clientProtocols: []string{"c"}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: "", + expectError: true, + }, + "common protocol": { + clientProtocols: []string{"b"}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: "b", + }, + } + + for name, test := range tests { + req, err := http.NewRequest("GET", "http://www.example.com/", nil) + if err != nil { + t.Fatalf("%s: error creating request: %v", name, err) + } + + for _, p := range test.clientProtocols { + req.Header.Add(HeaderProtocolVersion, p) + } + + w := newResponseWriter() + negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol) + + // verify negotiated protocol + if e, a := test.expectedProtocol, negotiated; e != a { + t.Errorf("%s: protocol: expected %q, got %q", name, e, a) + } + + if test.expectError { + if err == nil { + t.Errorf("%s: expected error but did not get one", name) + } + if w.statusCode == nil { + t.Errorf("%s: expected w.statusCode to be set", name) + } else if e, a := http.StatusForbidden, *w.statusCode; e != a { + t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a) + } + if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) { + t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a) + } + continue + } + if !test.expectError && err != nil { + t.Errorf("%s: unexpected error: %v", name, err) + continue + } + if w.statusCode != nil { + t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode) + } + + // verify response headers + if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) { + t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a) + } + } +}