mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-19 00:31:00 +00:00
Refactor PortForward server methods into the portforward package
This commit is contained in:
parent
85207190f5
commit
52ed57ec3b
@ -33,7 +33,7 @@ go_test(
|
||||
deps = [
|
||||
"//pkg/client/restclient:go_default_library",
|
||||
"//pkg/client/unversioned/remotecommand:go_default_library",
|
||||
"//pkg/kubelet/server:go_default_library",
|
||||
"//pkg/kubelet/server/portforward:go_default_library",
|
||||
"//pkg/types:go_default_library",
|
||||
"//pkg/util/httpstream:go_default_library",
|
||||
],
|
||||
|
@ -33,7 +33,7 @@ import (
|
||||
|
||||
"k8s.io/kubernetes/pkg/client/restclient"
|
||||
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||
kubeletserver "k8s.io/kubernetes/pkg/kubelet/server"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
@ -206,7 +206,7 @@ func TestGetListener(t *testing.T) {
|
||||
}
|
||||
|
||||
// fakePortForwarder simulates port forwarding for testing. It implements
|
||||
// kubeletserver.PortForwarder.
|
||||
// portforward.PortForwarder.
|
||||
type fakePortForwarder struct {
|
||||
lock sync.Mutex
|
||||
// stores data expected from the stream per port
|
||||
@ -217,7 +217,7 @@ type fakePortForwarder struct {
|
||||
send map[uint16]string
|
||||
}
|
||||
|
||||
var _ kubeletserver.PortForwarder = &fakePortForwarder{}
|
||||
var _ portforward.PortForwarder = &fakePortForwarder{}
|
||||
|
||||
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
defer stream.Close()
|
||||
@ -252,7 +252,7 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF
|
||||
received: make(map[uint16]string),
|
||||
send: serverSends,
|
||||
}
|
||||
kubeletserver.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
|
||||
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
|
||||
|
||||
for port, expected := range expectedFromClient {
|
||||
actual, ok := pf.received[port]
|
||||
|
@ -37,10 +37,7 @@ go_library(
|
||||
"//pkg/types:go_default_library",
|
||||
"//pkg/util/configz:go_default_library",
|
||||
"//pkg/util/flushwriter:go_default_library",
|
||||
"//pkg/util/httpstream:go_default_library",
|
||||
"//pkg/util/httpstream/spdy:go_default_library",
|
||||
"//pkg/util/limitwriter:go_default_library",
|
||||
"//pkg/util/runtime:go_default_library",
|
||||
"//pkg/util/term:go_default_library",
|
||||
"//pkg/volume:go_default_library",
|
||||
"//vendor:github.com/emicklei/go-restful",
|
||||
|
@ -12,6 +12,28 @@ load(
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["constants.go"],
|
||||
srcs = [
|
||||
"constants.go",
|
||||
"portforward.go",
|
||||
],
|
||||
tags = ["automanaged"],
|
||||
deps = [
|
||||
"//pkg/api:go_default_library",
|
||||
"//pkg/types:go_default_library",
|
||||
"//pkg/util/httpstream:go_default_library",
|
||||
"//pkg/util/httpstream/spdy:go_default_library",
|
||||
"//pkg/util/runtime:go_default_library",
|
||||
"//vendor:github.com/golang/glog",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = ["portforward_test.go"],
|
||||
library = "go_default_library",
|
||||
tags = ["automanaged"],
|
||||
deps = [
|
||||
"//pkg/api:go_default_library",
|
||||
"//pkg/util/httpstream:go_default_library",
|
||||
],
|
||||
)
|
||||
|
323
pkg/kubelet/server/portforward/portforward.go
Normal file
323
pkg/kubelet/server/portforward/portforward.go
Normal file
@ -0,0 +1,323 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/glog"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
utilruntime "k8s.io/kubernetes/pkg/util/runtime"
|
||||
)
|
||||
|
||||
// PortForwarder knows how to forward content from a data stream to/from a port
|
||||
// in a pod.
|
||||
type PortForwarder interface {
|
||||
// PortForwarder copies data between a data stream and a port in a pod.
|
||||
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
||||
}
|
||||
|
||||
// ServePortForward handles a port forwarding request. A single request is
|
||||
// kept alive as long as the client is still alive and the connection has not
|
||||
// been timed out due to idleness. This function handles multiple forwarded
|
||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
||||
// handled by a single invocation of ServePortForward.
|
||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||
supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
|
||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
||||
// negotiated protocol isn't currently used server side, but could be in the future
|
||||
if err != nil {
|
||||
// Handshake writes the error to the client
|
||||
utilruntime.HandleError(err)
|
||||
return
|
||||
}
|
||||
|
||||
streamChan := make(chan httpstream.Stream, 1)
|
||||
|
||||
glog.V(5).Infof("Upgrading port forward response")
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
conn: conn,
|
||||
streamChan: streamChan,
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
streamCreationTimeout: streamCreationTimeout,
|
||||
pod: podName,
|
||||
uid: uid,
|
||||
forwarder: portForwarder,
|
||||
}
|
||||
h.run()
|
||||
}
|
||||
|
||||
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
|
||||
// forward streams. It checks each stream's port and stream type headers,
|
||||
// rejecting any streams that with missing or invalid values. Each valid
|
||||
// stream is sent to the streams channel.
|
||||
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
// make sure it has a valid port header
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
if len(portString) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.PortHeader)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("port %q must be > 0", portString)
|
||||
}
|
||||
|
||||
// make sure it has a valid stream type header
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
if len(streamType) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.StreamType)
|
||||
}
|
||||
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
||||
return fmt.Errorf("invalid stream type %q", streamType)
|
||||
}
|
||||
|
||||
streams <- stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamHandler is capable of processing multiple port forward
|
||||
// requests over a single httpstream.Connection.
|
||||
type portForwardStreamHandler struct {
|
||||
conn httpstream.Connection
|
||||
streamChan chan httpstream.Stream
|
||||
streamPairsLock sync.RWMutex
|
||||
streamPairs map[string]*portForwardStreamPair
|
||||
streamCreationTimeout time.Duration
|
||||
pod string
|
||||
uid types.UID
|
||||
forwarder PortForwarder
|
||||
}
|
||||
|
||||
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
|
||||
// new pair if one does not yet exist for the requestID. The returned bool is
|
||||
// true if the pair was created.
|
||||
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
if p, ok := h.streamPairs[requestID]; ok {
|
||||
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
||||
return p, false
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
||||
|
||||
p := newPortForwardPair(requestID)
|
||||
h.streamPairs[requestID] = p
|
||||
|
||||
return p, true
|
||||
}
|
||||
|
||||
// monitorStreamPair waits for the pair to receive both its error and data
|
||||
// streams, or for the timeout to expire (whichever happens first), and then
|
||||
// removes the pair.
|
||||
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
|
||||
select {
|
||||
case <-timeout:
|
||||
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
||||
utilruntime.HandleError(err)
|
||||
p.printError(err.Error())
|
||||
case <-p.complete:
|
||||
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
||||
}
|
||||
h.removeStreamPair(p.requestID)
|
||||
}
|
||||
|
||||
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
||||
// exists.
|
||||
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
|
||||
h.streamPairsLock.RLock()
|
||||
defer h.streamPairsLock.RUnlock()
|
||||
|
||||
_, ok := h.streamPairs[requestID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
||||
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
delete(h.streamPairs, requestID)
|
||||
}
|
||||
|
||||
// requestID returns the request id for stream.
|
||||
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
|
||||
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
||||
if len(requestID) == 0 {
|
||||
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
||||
// If we get here, it's because the connection came from an older client
|
||||
// that isn't generating the request id header
|
||||
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
||||
//
|
||||
// This is a best-effort attempt at supporting older clients.
|
||||
//
|
||||
// When there aren't concurrent new forwarded connections, each connection
|
||||
// will have a pair of streams (data, error), and the stream IDs will be
|
||||
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
||||
// the stream ID into a pseudo-request id by taking the stream type and
|
||||
// using id = stream.Identifier() when the stream type is error,
|
||||
// and id = stream.Identifier() - 2 when it's data.
|
||||
//
|
||||
// NOTE: this only works when there are not concurrent new streams from
|
||||
// multiple forwarded connections; it's a best-effort attempt at supporting
|
||||
// old clients that don't generate request ids. If there are concurrent
|
||||
// new connections, it's possible that 1 connection gets streams whose IDs
|
||||
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()))
|
||||
case api.StreamTypeData:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
||||
}
|
||||
return requestID
|
||||
}
|
||||
|
||||
// run is the main loop for the portForwardStreamHandler. It processes new
|
||||
// streams, invoking portForward for each complete stream pair. The loop exits
|
||||
// when the httpstream.Connection is closed.
|
||||
func (h *portForwardStreamHandler) run() {
|
||||
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-h.conn.CloseChan():
|
||||
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
||||
break Loop
|
||||
case stream := <-h.streamChan:
|
||||
requestID := h.requestID(stream)
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
||||
|
||||
p, created := h.getStreamPair(requestID)
|
||||
if created {
|
||||
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
||||
}
|
||||
if complete, err := p.add(stream); err != nil {
|
||||
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
||||
utilruntime.HandleError(errors.New(msg))
|
||||
p.printError(msg)
|
||||
} else if complete {
|
||||
go h.portForward(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
|
||||
// function for the given stream pair.
|
||||
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
|
||||
defer p.dataStream.Close()
|
||||
defer p.errorStream.Close()
|
||||
|
||||
portString := p.dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
||||
utilruntime.HandleError(msg)
|
||||
fmt.Fprint(p.errorStream, msg.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamPair represents the error and data streams for a port
|
||||
// forwarding request.
|
||||
type portForwardStreamPair struct {
|
||||
lock sync.RWMutex
|
||||
requestID string
|
||||
dataStream httpstream.Stream
|
||||
errorStream httpstream.Stream
|
||||
complete chan struct{}
|
||||
}
|
||||
|
||||
// newPortForwardPair creates a new portForwardStreamPair.
|
||||
func newPortForwardPair(requestID string) *portForwardStreamPair {
|
||||
return &portForwardStreamPair{
|
||||
requestID: requestID,
|
||||
complete: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// add adds the stream to the portForwardStreamPair. If the pair already
|
||||
// contains a stream for the new stream's type, an error is returned. add
|
||||
// returns true if both the data and error streams for this pair have been
|
||||
// received.
|
||||
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
switch stream.Headers().Get(api.StreamType) {
|
||||
case api.StreamTypeError:
|
||||
if p.errorStream != nil {
|
||||
return false, errors.New("error stream already assigned")
|
||||
}
|
||||
p.errorStream = stream
|
||||
case api.StreamTypeData:
|
||||
if p.dataStream != nil {
|
||||
return false, errors.New("data stream already assigned")
|
||||
}
|
||||
p.dataStream = stream
|
||||
}
|
||||
|
||||
complete := p.errorStream != nil && p.dataStream != nil
|
||||
if complete {
|
||||
close(p.complete)
|
||||
}
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
// printError writes s to p.errorStream if p.errorStream has been set.
|
||||
func (p *portForwardStreamPair) printError(s string) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
if p.errorStream != nil {
|
||||
fmt.Fprint(p.errorStream, s)
|
||||
}
|
||||
}
|
246
pkg/kubelet/server/portforward/portforward_test.go
Normal file
246
pkg/kubelet/server/portforward/portforward_test.go
Normal file
@ -0,0 +1,246 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
func TestPortForwardStreamReceived(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
port string
|
||||
streamType string
|
||||
expectedError string
|
||||
}{
|
||||
"missing port": {
|
||||
expectedError: `"port" header is required`,
|
||||
},
|
||||
"unable to parse port": {
|
||||
port: "abc",
|
||||
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
|
||||
},
|
||||
"negative port": {
|
||||
port: "-1",
|
||||
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
|
||||
},
|
||||
"missing stream type": {
|
||||
port: "80",
|
||||
expectedError: `"streamType" header is required`,
|
||||
},
|
||||
"valid port with error stream": {
|
||||
port: "80",
|
||||
streamType: "error",
|
||||
},
|
||||
"valid port with data stream": {
|
||||
port: "80",
|
||||
streamType: "data",
|
||||
},
|
||||
"invalid stream type": {
|
||||
port: "80",
|
||||
streamType: "foo",
|
||||
expectedError: `invalid stream type "foo"`,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
streams := make(chan httpstream.Stream, 1)
|
||||
f := portForwardStreamReceived(streams)
|
||||
stream := newFakeHttpStream()
|
||||
if len(test.port) > 0 {
|
||||
stream.headers.Set("port", test.port)
|
||||
}
|
||||
if len(test.streamType) > 0 {
|
||||
stream.headers.Set("streamType", test.streamType)
|
||||
}
|
||||
replySent := make(chan struct{})
|
||||
err := f(stream, replySent)
|
||||
close(replySent)
|
||||
if len(test.expectedError) > 0 {
|
||||
if err == nil {
|
||||
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
|
||||
}
|
||||
if e, a := test.expectedError, err.Error(); e != a {
|
||||
t.Errorf("%s: expected err=%q, got %q", name, e, a)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("%s: unexpected error %v", name, err)
|
||||
continue
|
||||
}
|
||||
if s := <-streams; s != stream {
|
||||
t.Errorf("%s: expected stream %#v, got %#v", name, stream, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHttpStream struct {
|
||||
headers http.Header
|
||||
id uint32
|
||||
}
|
||||
|
||||
func newFakeHttpStream() *fakeHttpStream {
|
||||
return &fakeHttpStream{
|
||||
headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
var _ httpstream.Stream = &fakeHttpStream{}
|
||||
|
||||
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Headers() http.Header {
|
||||
return s.headers
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Identifier() uint32 {
|
||||
return s.id
|
||||
}
|
||||
|
||||
func TestGetStreamPair(t *testing.T) {
|
||||
timeout := make(chan time.Time)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
}
|
||||
|
||||
// test adding a new entry
|
||||
p, created := h.getStreamPair("1")
|
||||
if p == nil {
|
||||
t.Fatalf("unexpected nil pair")
|
||||
}
|
||||
if !created {
|
||||
t.Fatal("expected created=true")
|
||||
}
|
||||
if p.dataStream != nil {
|
||||
t.Errorf("unexpected non-nil data stream")
|
||||
}
|
||||
if p.errorStream != nil {
|
||||
t.Errorf("unexpected non-nil error stream")
|
||||
}
|
||||
|
||||
// start the monitor for this pair
|
||||
monitorDone := make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
close(monitorDone)
|
||||
}()
|
||||
|
||||
if !h.hasStreamPair("1") {
|
||||
t.Fatal("This should still be true")
|
||||
}
|
||||
|
||||
// make sure we can retrieve an existing entry
|
||||
p2, created := h.getStreamPair("1")
|
||||
if created {
|
||||
t.Fatal("expected created=false")
|
||||
}
|
||||
if p != p2 {
|
||||
t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2)
|
||||
}
|
||||
|
||||
// removed via complete
|
||||
dataStream := newFakeHttpStream()
|
||||
dataStream.headers.Set(api.StreamType, api.StreamTypeData)
|
||||
complete, err := p.add(dataStream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding data stream to pair: %v", err)
|
||||
}
|
||||
if complete {
|
||||
t.Fatalf("unexpected complete")
|
||||
}
|
||||
|
||||
errorStream := newFakeHttpStream()
|
||||
errorStream.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
complete, err = p.add(errorStream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding error stream to pair: %v", err)
|
||||
}
|
||||
if !complete {
|
||||
t.Fatal("unexpected incomplete")
|
||||
}
|
||||
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
|
||||
// make sure the pair was removed
|
||||
if h.hasStreamPair("1") {
|
||||
t.Fatal("expected removal of pair after both data and error streams received")
|
||||
}
|
||||
|
||||
// removed via timeout
|
||||
p, created = h.getStreamPair("2")
|
||||
if !created {
|
||||
t.Fatal("expected created=true")
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatal("expected p not to be nil")
|
||||
}
|
||||
monitorDone = make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
close(monitorDone)
|
||||
}()
|
||||
// cause the timeout
|
||||
close(timeout)
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
if h.hasStreamPair("2") {
|
||||
t.Fatal("expected stream pair to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
h := &portForwardStreamHandler{}
|
||||
|
||||
s := newFakeHttpStream()
|
||||
s.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
s.id = 1
|
||||
if e, a := "1", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
|
||||
s.headers.Set(api.StreamType, api.StreamTypeData)
|
||||
s.id = 3
|
||||
if e, a := "1", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
|
||||
s.id = 7
|
||||
s.headers.Set(api.PortForwardRequestIDHeader, "2")
|
||||
if e, a := "2", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
}
|
@ -18,7 +18,6 @@ package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -27,7 +26,6 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
restful "github.com/emicklei/go-restful"
|
||||
@ -54,10 +52,7 @@ import (
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util/configz"
|
||||
"k8s.io/kubernetes/pkg/util/flushwriter"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
"k8s.io/kubernetes/pkg/util/limitwriter"
|
||||
utilruntime "k8s.io/kubernetes/pkg/util/runtime"
|
||||
"k8s.io/kubernetes/pkg/util/term"
|
||||
"k8s.io/kubernetes/pkg/volume"
|
||||
)
|
||||
@ -650,13 +645,6 @@ func writeJsonResponse(response *restful.Response, data []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// PortForwarder knows how to forward content from a data stream to/from a port
|
||||
// in a pod.
|
||||
type PortForwarder interface {
|
||||
// PortForwarder copies data between a data stream and a port in a pod.
|
||||
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
||||
}
|
||||
|
||||
// getPortForward handles a new restful port forward request. It determines the
|
||||
// pod name and uid and then calls ServePortForward.
|
||||
func (s *Server) getPortForward(request *restful.Request, response *restful.Response) {
|
||||
@ -669,288 +657,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
||||
|
||||
podName := kubecontainer.GetPodFullName(pod)
|
||||
|
||||
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout)
|
||||
}
|
||||
|
||||
// ServePortForward handles a port forwarding request. A single request is
|
||||
// kept alive as long as the client is still alive and the connection has not
|
||||
// been timed out due to idleness. This function handles multiple forwarded
|
||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
||||
// handled by a single invocation of ServePortForward.
|
||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||
supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name}
|
||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
||||
// negotiated protocol isn't currently used server side, but could be in the future
|
||||
if err != nil {
|
||||
// Handshake writes the error to the client
|
||||
utilruntime.HandleError(err)
|
||||
return
|
||||
}
|
||||
|
||||
streamChan := make(chan httpstream.Stream, 1)
|
||||
|
||||
glog.V(5).Infof("Upgrading port forward response")
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
conn: conn,
|
||||
streamChan: streamChan,
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
streamCreationTimeout: streamCreationTimeout,
|
||||
pod: podName,
|
||||
uid: uid,
|
||||
forwarder: portForwarder,
|
||||
}
|
||||
h.run()
|
||||
}
|
||||
|
||||
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
|
||||
// forward streams. It checks each stream's port and stream type headers,
|
||||
// rejecting any streams that with missing or invalid values. Each valid
|
||||
// stream is sent to the streams channel.
|
||||
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
// make sure it has a valid port header
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
if len(portString) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.PortHeader)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("port %q must be > 0", portString)
|
||||
}
|
||||
|
||||
// make sure it has a valid stream type header
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
if len(streamType) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.StreamType)
|
||||
}
|
||||
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
||||
return fmt.Errorf("invalid stream type %q", streamType)
|
||||
}
|
||||
|
||||
streams <- stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamHandler is capable of processing multiple port forward
|
||||
// requests over a single httpstream.Connection.
|
||||
type portForwardStreamHandler struct {
|
||||
conn httpstream.Connection
|
||||
streamChan chan httpstream.Stream
|
||||
streamPairsLock sync.RWMutex
|
||||
streamPairs map[string]*portForwardStreamPair
|
||||
streamCreationTimeout time.Duration
|
||||
pod string
|
||||
uid types.UID
|
||||
forwarder PortForwarder
|
||||
}
|
||||
|
||||
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
|
||||
// new pair if one does not yet exist for the requestID. The returned bool is
|
||||
// true if the pair was created.
|
||||
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
if p, ok := h.streamPairs[requestID]; ok {
|
||||
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
||||
return p, false
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
||||
|
||||
p := newPortForwardPair(requestID)
|
||||
h.streamPairs[requestID] = p
|
||||
|
||||
return p, true
|
||||
}
|
||||
|
||||
// monitorStreamPair waits for the pair to receive both its error and data
|
||||
// streams, or for the timeout to expire (whichever happens first), and then
|
||||
// removes the pair.
|
||||
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
|
||||
select {
|
||||
case <-timeout:
|
||||
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
||||
utilruntime.HandleError(err)
|
||||
p.printError(err.Error())
|
||||
case <-p.complete:
|
||||
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
||||
}
|
||||
h.removeStreamPair(p.requestID)
|
||||
}
|
||||
|
||||
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
||||
// exists.
|
||||
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
|
||||
h.streamPairsLock.RLock()
|
||||
defer h.streamPairsLock.RUnlock()
|
||||
|
||||
_, ok := h.streamPairs[requestID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
||||
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
delete(h.streamPairs, requestID)
|
||||
}
|
||||
|
||||
// requestID returns the request id for stream.
|
||||
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
|
||||
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
||||
if len(requestID) == 0 {
|
||||
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
||||
// If we get here, it's because the connection came from an older client
|
||||
// that isn't generating the request id header
|
||||
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
||||
//
|
||||
// This is a best-effort attempt at supporting older clients.
|
||||
//
|
||||
// When there aren't concurrent new forwarded connections, each connection
|
||||
// will have a pair of streams (data, error), and the stream IDs will be
|
||||
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
||||
// the stream ID into a pseudo-request id by taking the stream type and
|
||||
// using id = stream.Identifier() when the stream type is error,
|
||||
// and id = stream.Identifier() - 2 when it's data.
|
||||
//
|
||||
// NOTE: this only works when there are not concurrent new streams from
|
||||
// multiple forwarded connections; it's a best-effort attempt at supporting
|
||||
// old clients that don't generate request ids. If there are concurrent
|
||||
// new connections, it's possible that 1 connection gets streams whose IDs
|
||||
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()))
|
||||
case api.StreamTypeData:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
||||
}
|
||||
return requestID
|
||||
}
|
||||
|
||||
// run is the main loop for the portForwardStreamHandler. It processes new
|
||||
// streams, invoking portForward for each complete stream pair. The loop exits
|
||||
// when the httpstream.Connection is closed.
|
||||
func (h *portForwardStreamHandler) run() {
|
||||
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-h.conn.CloseChan():
|
||||
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
||||
break Loop
|
||||
case stream := <-h.streamChan:
|
||||
requestID := h.requestID(stream)
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
||||
|
||||
p, created := h.getStreamPair(requestID)
|
||||
if created {
|
||||
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
||||
}
|
||||
if complete, err := p.add(stream); err != nil {
|
||||
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
||||
utilruntime.HandleError(errors.New(msg))
|
||||
p.printError(msg)
|
||||
} else if complete {
|
||||
go h.portForward(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
|
||||
// function for the given stream pair.
|
||||
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
|
||||
defer p.dataStream.Close()
|
||||
defer p.errorStream.Close()
|
||||
|
||||
portString := p.dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
||||
utilruntime.HandleError(msg)
|
||||
fmt.Fprint(p.errorStream, msg.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamPair represents the error and data streams for a port
|
||||
// forwarding request.
|
||||
type portForwardStreamPair struct {
|
||||
lock sync.RWMutex
|
||||
requestID string
|
||||
dataStream httpstream.Stream
|
||||
errorStream httpstream.Stream
|
||||
complete chan struct{}
|
||||
}
|
||||
|
||||
// newPortForwardPair creates a new portForwardStreamPair.
|
||||
func newPortForwardPair(requestID string) *portForwardStreamPair {
|
||||
return &portForwardStreamPair{
|
||||
requestID: requestID,
|
||||
complete: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// add adds the stream to the portForwardStreamPair. If the pair already
|
||||
// contains a stream for the new stream's type, an error is returned. add
|
||||
// returns true if both the data and error streams for this pair have been
|
||||
// received.
|
||||
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
switch stream.Headers().Get(api.StreamType) {
|
||||
case api.StreamTypeError:
|
||||
if p.errorStream != nil {
|
||||
return false, errors.New("error stream already assigned")
|
||||
}
|
||||
p.errorStream = stream
|
||||
case api.StreamTypeData:
|
||||
if p.dataStream != nil {
|
||||
return false, errors.New("data stream already assigned")
|
||||
}
|
||||
p.dataStream = stream
|
||||
}
|
||||
|
||||
complete := p.errorStream != nil && p.dataStream != nil
|
||||
if complete {
|
||||
close(p.complete)
|
||||
}
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
// printError writes s to p.errorStream if p.errorStream has been set.
|
||||
func (p *portForwardStreamPair) printError(s string) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
if p.errorStream != nil {
|
||||
fmt.Fprint(p.errorStream, s)
|
||||
}
|
||||
portforward.ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout)
|
||||
}
|
||||
|
||||
// ServeHTTP responds to HTTP requests on the Kubelet.
|
||||
|
@ -1530,223 +1530,3 @@ func TestServePortForward(t *testing.T) {
|
||||
<-portForwardFuncDone
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHttpStream struct {
|
||||
headers http.Header
|
||||
id uint32
|
||||
}
|
||||
|
||||
func newFakeHttpStream() *fakeHttpStream {
|
||||
return &fakeHttpStream{
|
||||
headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
var _ httpstream.Stream = &fakeHttpStream{}
|
||||
|
||||
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Headers() http.Header {
|
||||
return s.headers
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Identifier() uint32 {
|
||||
return s.id
|
||||
}
|
||||
|
||||
func TestPortForwardStreamReceived(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
port string
|
||||
streamType string
|
||||
expectedError string
|
||||
}{
|
||||
"missing port": {
|
||||
expectedError: `"port" header is required`,
|
||||
},
|
||||
"unable to parse port": {
|
||||
port: "abc",
|
||||
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
|
||||
},
|
||||
"negative port": {
|
||||
port: "-1",
|
||||
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
|
||||
},
|
||||
"missing stream type": {
|
||||
port: "80",
|
||||
expectedError: `"streamType" header is required`,
|
||||
},
|
||||
"valid port with error stream": {
|
||||
port: "80",
|
||||
streamType: "error",
|
||||
},
|
||||
"valid port with data stream": {
|
||||
port: "80",
|
||||
streamType: "data",
|
||||
},
|
||||
"invalid stream type": {
|
||||
port: "80",
|
||||
streamType: "foo",
|
||||
expectedError: `invalid stream type "foo"`,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
streams := make(chan httpstream.Stream, 1)
|
||||
f := portForwardStreamReceived(streams)
|
||||
stream := newFakeHttpStream()
|
||||
if len(test.port) > 0 {
|
||||
stream.headers.Set("port", test.port)
|
||||
}
|
||||
if len(test.streamType) > 0 {
|
||||
stream.headers.Set("streamType", test.streamType)
|
||||
}
|
||||
replySent := make(chan struct{})
|
||||
err := f(stream, replySent)
|
||||
close(replySent)
|
||||
if len(test.expectedError) > 0 {
|
||||
if err == nil {
|
||||
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
|
||||
}
|
||||
if e, a := test.expectedError, err.Error(); e != a {
|
||||
t.Errorf("%s: expected err=%q, got %q", name, e, a)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("%s: unexpected error %v", name, err)
|
||||
continue
|
||||
}
|
||||
if s := <-streams; s != stream {
|
||||
t.Errorf("%s: expected stream %#v, got %#v", name, stream, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamPair(t *testing.T) {
|
||||
timeout := make(chan time.Time)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
}
|
||||
|
||||
// test adding a new entry
|
||||
p, created := h.getStreamPair("1")
|
||||
if p == nil {
|
||||
t.Fatalf("unexpected nil pair")
|
||||
}
|
||||
if !created {
|
||||
t.Fatal("expected created=true")
|
||||
}
|
||||
if p.dataStream != nil {
|
||||
t.Errorf("unexpected non-nil data stream")
|
||||
}
|
||||
if p.errorStream != nil {
|
||||
t.Errorf("unexpected non-nil error stream")
|
||||
}
|
||||
|
||||
// start the monitor for this pair
|
||||
monitorDone := make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
close(monitorDone)
|
||||
}()
|
||||
|
||||
if !h.hasStreamPair("1") {
|
||||
t.Fatal("This should still be true")
|
||||
}
|
||||
|
||||
// make sure we can retrieve an existing entry
|
||||
p2, created := h.getStreamPair("1")
|
||||
if created {
|
||||
t.Fatal("expected created=false")
|
||||
}
|
||||
if p != p2 {
|
||||
t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2)
|
||||
}
|
||||
|
||||
// removed via complete
|
||||
dataStream := newFakeHttpStream()
|
||||
dataStream.headers.Set(api.StreamType, api.StreamTypeData)
|
||||
complete, err := p.add(dataStream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding data stream to pair: %v", err)
|
||||
}
|
||||
if complete {
|
||||
t.Fatalf("unexpected complete")
|
||||
}
|
||||
|
||||
errorStream := newFakeHttpStream()
|
||||
errorStream.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
complete, err = p.add(errorStream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding error stream to pair: %v", err)
|
||||
}
|
||||
if !complete {
|
||||
t.Fatal("unexpected incomplete")
|
||||
}
|
||||
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
|
||||
// make sure the pair was removed
|
||||
if h.hasStreamPair("1") {
|
||||
t.Fatal("expected removal of pair after both data and error streams received")
|
||||
}
|
||||
|
||||
// removed via timeout
|
||||
p, created = h.getStreamPair("2")
|
||||
if !created {
|
||||
t.Fatal("expected created=true")
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatal("expected p not to be nil")
|
||||
}
|
||||
monitorDone = make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
close(monitorDone)
|
||||
}()
|
||||
// cause the timeout
|
||||
close(timeout)
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
if h.hasStreamPair("2") {
|
||||
t.Fatal("expected stream pair to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
h := &portForwardStreamHandler{}
|
||||
|
||||
s := newFakeHttpStream()
|
||||
s.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
s.id = 1
|
||||
if e, a := "1", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
|
||||
s.headers.Set(api.StreamType, api.StreamTypeData)
|
||||
s.id = 3
|
||||
if e, a := "1", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
|
||||
s.id = 7
|
||||
s.headers.Set(api.PortForwardRequestIDHeader, "2")
|
||||
if e, a := "2", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user