mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-19 08:40:42 +00:00
Refactor PortForward server methods into the portforward package
This commit is contained in:
parent
85207190f5
commit
52ed57ec3b
@ -33,7 +33,7 @@ go_test(
|
|||||||
deps = [
|
deps = [
|
||||||
"//pkg/client/restclient:go_default_library",
|
"//pkg/client/restclient:go_default_library",
|
||||||
"//pkg/client/unversioned/remotecommand:go_default_library",
|
"//pkg/client/unversioned/remotecommand:go_default_library",
|
||||||
"//pkg/kubelet/server:go_default_library",
|
"//pkg/kubelet/server/portforward:go_default_library",
|
||||||
"//pkg/types:go_default_library",
|
"//pkg/types:go_default_library",
|
||||||
"//pkg/util/httpstream:go_default_library",
|
"//pkg/util/httpstream:go_default_library",
|
||||||
],
|
],
|
||||||
|
@ -33,7 +33,7 @@ import (
|
|||||||
|
|
||||||
"k8s.io/kubernetes/pkg/client/restclient"
|
"k8s.io/kubernetes/pkg/client/restclient"
|
||||||
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||||
kubeletserver "k8s.io/kubernetes/pkg/kubelet/server"
|
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
|
||||||
"k8s.io/kubernetes/pkg/types"
|
"k8s.io/kubernetes/pkg/types"
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
)
|
)
|
||||||
@ -206,7 +206,7 @@ func TestGetListener(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// fakePortForwarder simulates port forwarding for testing. It implements
|
// fakePortForwarder simulates port forwarding for testing. It implements
|
||||||
// kubeletserver.PortForwarder.
|
// portforward.PortForwarder.
|
||||||
type fakePortForwarder struct {
|
type fakePortForwarder struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
// stores data expected from the stream per port
|
// stores data expected from the stream per port
|
||||||
@ -217,7 +217,7 @@ type fakePortForwarder struct {
|
|||||||
send map[uint16]string
|
send map[uint16]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ kubeletserver.PortForwarder = &fakePortForwarder{}
|
var _ portforward.PortForwarder = &fakePortForwarder{}
|
||||||
|
|
||||||
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
@ -252,7 +252,7 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF
|
|||||||
received: make(map[uint16]string),
|
received: make(map[uint16]string),
|
||||||
send: serverSends,
|
send: serverSends,
|
||||||
}
|
}
|
||||||
kubeletserver.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
|
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
|
||||||
|
|
||||||
for port, expected := range expectedFromClient {
|
for port, expected := range expectedFromClient {
|
||||||
actual, ok := pf.received[port]
|
actual, ok := pf.received[port]
|
||||||
|
@ -37,10 +37,7 @@ go_library(
|
|||||||
"//pkg/types:go_default_library",
|
"//pkg/types:go_default_library",
|
||||||
"//pkg/util/configz:go_default_library",
|
"//pkg/util/configz:go_default_library",
|
||||||
"//pkg/util/flushwriter:go_default_library",
|
"//pkg/util/flushwriter:go_default_library",
|
||||||
"//pkg/util/httpstream:go_default_library",
|
|
||||||
"//pkg/util/httpstream/spdy:go_default_library",
|
|
||||||
"//pkg/util/limitwriter:go_default_library",
|
"//pkg/util/limitwriter:go_default_library",
|
||||||
"//pkg/util/runtime:go_default_library",
|
|
||||||
"//pkg/util/term:go_default_library",
|
"//pkg/util/term:go_default_library",
|
||||||
"//pkg/volume:go_default_library",
|
"//pkg/volume:go_default_library",
|
||||||
"//vendor:github.com/emicklei/go-restful",
|
"//vendor:github.com/emicklei/go-restful",
|
||||||
|
@ -12,6 +12,28 @@ load(
|
|||||||
|
|
||||||
go_library(
|
go_library(
|
||||||
name = "go_default_library",
|
name = "go_default_library",
|
||||||
srcs = ["constants.go"],
|
srcs = [
|
||||||
|
"constants.go",
|
||||||
|
"portforward.go",
|
||||||
|
],
|
||||||
tags = ["automanaged"],
|
tags = ["automanaged"],
|
||||||
|
deps = [
|
||||||
|
"//pkg/api:go_default_library",
|
||||||
|
"//pkg/types:go_default_library",
|
||||||
|
"//pkg/util/httpstream:go_default_library",
|
||||||
|
"//pkg/util/httpstream/spdy:go_default_library",
|
||||||
|
"//pkg/util/runtime:go_default_library",
|
||||||
|
"//vendor:github.com/golang/glog",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
go_test(
|
||||||
|
name = "go_default_test",
|
||||||
|
srcs = ["portforward_test.go"],
|
||||||
|
library = "go_default_library",
|
||||||
|
tags = ["automanaged"],
|
||||||
|
deps = [
|
||||||
|
"//pkg/api:go_default_library",
|
||||||
|
"//pkg/util/httpstream:go_default_library",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
323
pkg/kubelet/server/portforward/portforward.go
Normal file
323
pkg/kubelet/server/portforward/portforward.go
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2016 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/glog"
|
||||||
|
|
||||||
|
"k8s.io/kubernetes/pkg/api"
|
||||||
|
"k8s.io/kubernetes/pkg/types"
|
||||||
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
|
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||||
|
utilruntime "k8s.io/kubernetes/pkg/util/runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PortForwarder knows how to forward content from a data stream to/from a port
|
||||||
|
// in a pod.
|
||||||
|
type PortForwarder interface {
|
||||||
|
// PortForwarder copies data between a data stream and a port in a pod.
|
||||||
|
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServePortForward handles a port forwarding request. A single request is
|
||||||
|
// kept alive as long as the client is still alive and the connection has not
|
||||||
|
// been timed out due to idleness. This function handles multiple forwarded
|
||||||
|
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
||||||
|
// handled by a single invocation of ServePortForward.
|
||||||
|
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||||
|
supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
|
||||||
|
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
||||||
|
// negotiated protocol isn't currently used server side, but could be in the future
|
||||||
|
if err != nil {
|
||||||
|
// Handshake writes the error to the client
|
||||||
|
utilruntime.HandleError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamChan := make(chan httpstream.Stream, 1)
|
||||||
|
|
||||||
|
glog.V(5).Infof("Upgrading port forward response")
|
||||||
|
upgrader := spdy.NewResponseUpgrader()
|
||||||
|
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
||||||
|
if conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||||
|
conn.SetIdleTimeout(idleTimeout)
|
||||||
|
|
||||||
|
h := &portForwardStreamHandler{
|
||||||
|
conn: conn,
|
||||||
|
streamChan: streamChan,
|
||||||
|
streamPairs: make(map[string]*portForwardStreamPair),
|
||||||
|
streamCreationTimeout: streamCreationTimeout,
|
||||||
|
pod: podName,
|
||||||
|
uid: uid,
|
||||||
|
forwarder: portForwarder,
|
||||||
|
}
|
||||||
|
h.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
|
||||||
|
// forward streams. It checks each stream's port and stream type headers,
|
||||||
|
// rejecting any streams that with missing or invalid values. Each valid
|
||||||
|
// stream is sent to the streams channel.
|
||||||
|
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||||
|
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||||
|
// make sure it has a valid port header
|
||||||
|
portString := stream.Headers().Get(api.PortHeader)
|
||||||
|
if len(portString) == 0 {
|
||||||
|
return fmt.Errorf("%q header is required", api.PortHeader)
|
||||||
|
}
|
||||||
|
port, err := strconv.ParseUint(portString, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||||
|
}
|
||||||
|
if port < 1 {
|
||||||
|
return fmt.Errorf("port %q must be > 0", portString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure it has a valid stream type header
|
||||||
|
streamType := stream.Headers().Get(api.StreamType)
|
||||||
|
if len(streamType) == 0 {
|
||||||
|
return fmt.Errorf("%q header is required", api.StreamType)
|
||||||
|
}
|
||||||
|
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
||||||
|
return fmt.Errorf("invalid stream type %q", streamType)
|
||||||
|
}
|
||||||
|
|
||||||
|
streams <- stream
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// portForwardStreamHandler is capable of processing multiple port forward
|
||||||
|
// requests over a single httpstream.Connection.
|
||||||
|
type portForwardStreamHandler struct {
|
||||||
|
conn httpstream.Connection
|
||||||
|
streamChan chan httpstream.Stream
|
||||||
|
streamPairsLock sync.RWMutex
|
||||||
|
streamPairs map[string]*portForwardStreamPair
|
||||||
|
streamCreationTimeout time.Duration
|
||||||
|
pod string
|
||||||
|
uid types.UID
|
||||||
|
forwarder PortForwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
|
||||||
|
// new pair if one does not yet exist for the requestID. The returned bool is
|
||||||
|
// true if the pair was created.
|
||||||
|
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
|
||||||
|
h.streamPairsLock.Lock()
|
||||||
|
defer h.streamPairsLock.Unlock()
|
||||||
|
|
||||||
|
if p, ok := h.streamPairs[requestID]; ok {
|
||||||
|
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
|
||||||
|
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
||||||
|
|
||||||
|
p := newPortForwardPair(requestID)
|
||||||
|
h.streamPairs[requestID] = p
|
||||||
|
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// monitorStreamPair waits for the pair to receive both its error and data
|
||||||
|
// streams, or for the timeout to expire (whichever happens first), and then
|
||||||
|
// removes the pair.
|
||||||
|
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
||||||
|
utilruntime.HandleError(err)
|
||||||
|
p.printError(err.Error())
|
||||||
|
case <-p.complete:
|
||||||
|
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
||||||
|
}
|
||||||
|
h.removeStreamPair(p.requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
||||||
|
// exists.
|
||||||
|
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
|
||||||
|
h.streamPairsLock.RLock()
|
||||||
|
defer h.streamPairsLock.RUnlock()
|
||||||
|
|
||||||
|
_, ok := h.streamPairs[requestID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
||||||
|
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
|
||||||
|
h.streamPairsLock.Lock()
|
||||||
|
defer h.streamPairsLock.Unlock()
|
||||||
|
|
||||||
|
delete(h.streamPairs, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestID returns the request id for stream.
|
||||||
|
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
|
||||||
|
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
||||||
|
if len(requestID) == 0 {
|
||||||
|
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
||||||
|
// If we get here, it's because the connection came from an older client
|
||||||
|
// that isn't generating the request id header
|
||||||
|
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
||||||
|
//
|
||||||
|
// This is a best-effort attempt at supporting older clients.
|
||||||
|
//
|
||||||
|
// When there aren't concurrent new forwarded connections, each connection
|
||||||
|
// will have a pair of streams (data, error), and the stream IDs will be
|
||||||
|
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
||||||
|
// the stream ID into a pseudo-request id by taking the stream type and
|
||||||
|
// using id = stream.Identifier() when the stream type is error,
|
||||||
|
// and id = stream.Identifier() - 2 when it's data.
|
||||||
|
//
|
||||||
|
// NOTE: this only works when there are not concurrent new streams from
|
||||||
|
// multiple forwarded connections; it's a best-effort attempt at supporting
|
||||||
|
// old clients that don't generate request ids. If there are concurrent
|
||||||
|
// new connections, it's possible that 1 connection gets streams whose IDs
|
||||||
|
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
||||||
|
streamType := stream.Headers().Get(api.StreamType)
|
||||||
|
switch streamType {
|
||||||
|
case api.StreamTypeError:
|
||||||
|
requestID = strconv.Itoa(int(stream.Identifier()))
|
||||||
|
case api.StreamTypeData:
|
||||||
|
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
||||||
|
}
|
||||||
|
return requestID
|
||||||
|
}
|
||||||
|
|
||||||
|
// run is the main loop for the portForwardStreamHandler. It processes new
|
||||||
|
// streams, invoking portForward for each complete stream pair. The loop exits
|
||||||
|
// when the httpstream.Connection is closed.
|
||||||
|
func (h *portForwardStreamHandler) run() {
|
||||||
|
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.conn.CloseChan():
|
||||||
|
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
||||||
|
break Loop
|
||||||
|
case stream := <-h.streamChan:
|
||||||
|
requestID := h.requestID(stream)
|
||||||
|
streamType := stream.Headers().Get(api.StreamType)
|
||||||
|
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
||||||
|
|
||||||
|
p, created := h.getStreamPair(requestID)
|
||||||
|
if created {
|
||||||
|
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
||||||
|
}
|
||||||
|
if complete, err := p.add(stream); err != nil {
|
||||||
|
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
||||||
|
utilruntime.HandleError(errors.New(msg))
|
||||||
|
p.printError(msg)
|
||||||
|
} else if complete {
|
||||||
|
go h.portForward(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
|
||||||
|
// function for the given stream pair.
|
||||||
|
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
|
||||||
|
defer p.dataStream.Close()
|
||||||
|
defer p.errorStream.Close()
|
||||||
|
|
||||||
|
portString := p.dataStream.Headers().Get(api.PortHeader)
|
||||||
|
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||||
|
|
||||||
|
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||||
|
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
||||||
|
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
||||||
|
utilruntime.HandleError(msg)
|
||||||
|
fmt.Fprint(p.errorStream, msg.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// portForwardStreamPair represents the error and data streams for a port
|
||||||
|
// forwarding request.
|
||||||
|
type portForwardStreamPair struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
requestID string
|
||||||
|
dataStream httpstream.Stream
|
||||||
|
errorStream httpstream.Stream
|
||||||
|
complete chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newPortForwardPair creates a new portForwardStreamPair.
|
||||||
|
func newPortForwardPair(requestID string) *portForwardStreamPair {
|
||||||
|
return &portForwardStreamPair{
|
||||||
|
requestID: requestID,
|
||||||
|
complete: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds the stream to the portForwardStreamPair. If the pair already
|
||||||
|
// contains a stream for the new stream's type, an error is returned. add
|
||||||
|
// returns true if both the data and error streams for this pair have been
|
||||||
|
// received.
|
||||||
|
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
|
||||||
|
p.lock.Lock()
|
||||||
|
defer p.lock.Unlock()
|
||||||
|
|
||||||
|
switch stream.Headers().Get(api.StreamType) {
|
||||||
|
case api.StreamTypeError:
|
||||||
|
if p.errorStream != nil {
|
||||||
|
return false, errors.New("error stream already assigned")
|
||||||
|
}
|
||||||
|
p.errorStream = stream
|
||||||
|
case api.StreamTypeData:
|
||||||
|
if p.dataStream != nil {
|
||||||
|
return false, errors.New("data stream already assigned")
|
||||||
|
}
|
||||||
|
p.dataStream = stream
|
||||||
|
}
|
||||||
|
|
||||||
|
complete := p.errorStream != nil && p.dataStream != nil
|
||||||
|
if complete {
|
||||||
|
close(p.complete)
|
||||||
|
}
|
||||||
|
return complete, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// printError writes s to p.errorStream if p.errorStream has been set.
|
||||||
|
func (p *portForwardStreamPair) printError(s string) {
|
||||||
|
p.lock.RLock()
|
||||||
|
defer p.lock.RUnlock()
|
||||||
|
if p.errorStream != nil {
|
||||||
|
fmt.Fprint(p.errorStream, s)
|
||||||
|
}
|
||||||
|
}
|
246
pkg/kubelet/server/portforward/portforward_test.go
Normal file
246
pkg/kubelet/server/portforward/portforward_test.go
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2016 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/kubernetes/pkg/api"
|
||||||
|
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPortForwardStreamReceived(t *testing.T) {
|
||||||
|
tests := map[string]struct {
|
||||||
|
port string
|
||||||
|
streamType string
|
||||||
|
expectedError string
|
||||||
|
}{
|
||||||
|
"missing port": {
|
||||||
|
expectedError: `"port" header is required`,
|
||||||
|
},
|
||||||
|
"unable to parse port": {
|
||||||
|
port: "abc",
|
||||||
|
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
|
||||||
|
},
|
||||||
|
"negative port": {
|
||||||
|
port: "-1",
|
||||||
|
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
|
||||||
|
},
|
||||||
|
"missing stream type": {
|
||||||
|
port: "80",
|
||||||
|
expectedError: `"streamType" header is required`,
|
||||||
|
},
|
||||||
|
"valid port with error stream": {
|
||||||
|
port: "80",
|
||||||
|
streamType: "error",
|
||||||
|
},
|
||||||
|
"valid port with data stream": {
|
||||||
|
port: "80",
|
||||||
|
streamType: "data",
|
||||||
|
},
|
||||||
|
"invalid stream type": {
|
||||||
|
port: "80",
|
||||||
|
streamType: "foo",
|
||||||
|
expectedError: `invalid stream type "foo"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, test := range tests {
|
||||||
|
streams := make(chan httpstream.Stream, 1)
|
||||||
|
f := portForwardStreamReceived(streams)
|
||||||
|
stream := newFakeHttpStream()
|
||||||
|
if len(test.port) > 0 {
|
||||||
|
stream.headers.Set("port", test.port)
|
||||||
|
}
|
||||||
|
if len(test.streamType) > 0 {
|
||||||
|
stream.headers.Set("streamType", test.streamType)
|
||||||
|
}
|
||||||
|
replySent := make(chan struct{})
|
||||||
|
err := f(stream, replySent)
|
||||||
|
close(replySent)
|
||||||
|
if len(test.expectedError) > 0 {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
|
||||||
|
}
|
||||||
|
if e, a := test.expectedError, err.Error(); e != a {
|
||||||
|
t.Errorf("%s: expected err=%q, got %q", name, e, a)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s: unexpected error %v", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s := <-streams; s != stream {
|
||||||
|
t.Errorf("%s: expected stream %#v, got %#v", name, stream, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeHttpStream struct {
|
||||||
|
headers http.Header
|
||||||
|
id uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeHttpStream() *fakeHttpStream {
|
||||||
|
return &fakeHttpStream{
|
||||||
|
headers: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ httpstream.Stream = &fakeHttpStream{}
|
||||||
|
|
||||||
|
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeHttpStream) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeHttpStream) Reset() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeHttpStream) Headers() http.Header {
|
||||||
|
return s.headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeHttpStream) Identifier() uint32 {
|
||||||
|
return s.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStreamPair(t *testing.T) {
|
||||||
|
timeout := make(chan time.Time)
|
||||||
|
|
||||||
|
h := &portForwardStreamHandler{
|
||||||
|
streamPairs: make(map[string]*portForwardStreamPair),
|
||||||
|
}
|
||||||
|
|
||||||
|
// test adding a new entry
|
||||||
|
p, created := h.getStreamPair("1")
|
||||||
|
if p == nil {
|
||||||
|
t.Fatalf("unexpected nil pair")
|
||||||
|
}
|
||||||
|
if !created {
|
||||||
|
t.Fatal("expected created=true")
|
||||||
|
}
|
||||||
|
if p.dataStream != nil {
|
||||||
|
t.Errorf("unexpected non-nil data stream")
|
||||||
|
}
|
||||||
|
if p.errorStream != nil {
|
||||||
|
t.Errorf("unexpected non-nil error stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
// start the monitor for this pair
|
||||||
|
monitorDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
h.monitorStreamPair(p, timeout)
|
||||||
|
close(monitorDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !h.hasStreamPair("1") {
|
||||||
|
t.Fatal("This should still be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure we can retrieve an existing entry
|
||||||
|
p2, created := h.getStreamPair("1")
|
||||||
|
if created {
|
||||||
|
t.Fatal("expected created=false")
|
||||||
|
}
|
||||||
|
if p != p2 {
|
||||||
|
t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removed via complete
|
||||||
|
dataStream := newFakeHttpStream()
|
||||||
|
dataStream.headers.Set(api.StreamType, api.StreamTypeData)
|
||||||
|
complete, err := p.add(dataStream)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error adding data stream to pair: %v", err)
|
||||||
|
}
|
||||||
|
if complete {
|
||||||
|
t.Fatalf("unexpected complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
errorStream := newFakeHttpStream()
|
||||||
|
errorStream.headers.Set(api.StreamType, api.StreamTypeError)
|
||||||
|
complete, err = p.add(errorStream)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error adding error stream to pair: %v", err)
|
||||||
|
}
|
||||||
|
if !complete {
|
||||||
|
t.Fatal("unexpected incomplete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure monitorStreamPair completed
|
||||||
|
<-monitorDone
|
||||||
|
|
||||||
|
// make sure the pair was removed
|
||||||
|
if h.hasStreamPair("1") {
|
||||||
|
t.Fatal("expected removal of pair after both data and error streams received")
|
||||||
|
}
|
||||||
|
|
||||||
|
// removed via timeout
|
||||||
|
p, created = h.getStreamPair("2")
|
||||||
|
if !created {
|
||||||
|
t.Fatal("expected created=true")
|
||||||
|
}
|
||||||
|
if p == nil {
|
||||||
|
t.Fatal("expected p not to be nil")
|
||||||
|
}
|
||||||
|
monitorDone = make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
h.monitorStreamPair(p, timeout)
|
||||||
|
close(monitorDone)
|
||||||
|
}()
|
||||||
|
// cause the timeout
|
||||||
|
close(timeout)
|
||||||
|
// make sure monitorStreamPair completed
|
||||||
|
<-monitorDone
|
||||||
|
if h.hasStreamPair("2") {
|
||||||
|
t.Fatal("expected stream pair to be removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestID(t *testing.T) {
|
||||||
|
h := &portForwardStreamHandler{}
|
||||||
|
|
||||||
|
s := newFakeHttpStream()
|
||||||
|
s.headers.Set(api.StreamType, api.StreamTypeError)
|
||||||
|
s.id = 1
|
||||||
|
if e, a := "1", h.requestID(s); e != a {
|
||||||
|
t.Errorf("expected %q, got %q", e, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.headers.Set(api.StreamType, api.StreamTypeData)
|
||||||
|
s.id = 3
|
||||||
|
if e, a := "1", h.requestID(s); e != a {
|
||||||
|
t.Errorf("expected %q, got %q", e, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.id = 7
|
||||||
|
s.headers.Set(api.PortForwardRequestIDHeader, "2")
|
||||||
|
if e, a := "2", h.requestID(s); e != a {
|
||||||
|
t.Errorf("expected %q, got %q", e, a)
|
||||||
|
}
|
||||||
|
}
|
@ -18,7 +18,6 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -27,7 +26,6 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
restful "github.com/emicklei/go-restful"
|
restful "github.com/emicklei/go-restful"
|
||||||
@ -54,10 +52,7 @@ import (
|
|||||||
"k8s.io/kubernetes/pkg/types"
|
"k8s.io/kubernetes/pkg/types"
|
||||||
"k8s.io/kubernetes/pkg/util/configz"
|
"k8s.io/kubernetes/pkg/util/configz"
|
||||||
"k8s.io/kubernetes/pkg/util/flushwriter"
|
"k8s.io/kubernetes/pkg/util/flushwriter"
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
|
||||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
|
||||||
"k8s.io/kubernetes/pkg/util/limitwriter"
|
"k8s.io/kubernetes/pkg/util/limitwriter"
|
||||||
utilruntime "k8s.io/kubernetes/pkg/util/runtime"
|
|
||||||
"k8s.io/kubernetes/pkg/util/term"
|
"k8s.io/kubernetes/pkg/util/term"
|
||||||
"k8s.io/kubernetes/pkg/volume"
|
"k8s.io/kubernetes/pkg/volume"
|
||||||
)
|
)
|
||||||
@ -650,13 +645,6 @@ func writeJsonResponse(response *restful.Response, data []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// PortForwarder knows how to forward content from a data stream to/from a port
|
|
||||||
// in a pod.
|
|
||||||
type PortForwarder interface {
|
|
||||||
// PortForwarder copies data between a data stream and a port in a pod.
|
|
||||||
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPortForward handles a new restful port forward request. It determines the
|
// getPortForward handles a new restful port forward request. It determines the
|
||||||
// pod name and uid and then calls ServePortForward.
|
// pod name and uid and then calls ServePortForward.
|
||||||
func (s *Server) getPortForward(request *restful.Request, response *restful.Response) {
|
func (s *Server) getPortForward(request *restful.Request, response *restful.Response) {
|
||||||
@ -669,288 +657,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
|||||||
|
|
||||||
podName := kubecontainer.GetPodFullName(pod)
|
podName := kubecontainer.GetPodFullName(pod)
|
||||||
|
|
||||||
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout)
|
portforward.ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout)
|
||||||
}
|
|
||||||
|
|
||||||
// ServePortForward handles a port forwarding request. A single request is
|
|
||||||
// kept alive as long as the client is still alive and the connection has not
|
|
||||||
// been timed out due to idleness. This function handles multiple forwarded
|
|
||||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
|
||||||
// handled by a single invocation of ServePortForward.
|
|
||||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
|
||||||
supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name}
|
|
||||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
|
||||||
// negotiated protocol isn't currently used server side, but could be in the future
|
|
||||||
if err != nil {
|
|
||||||
// Handshake writes the error to the client
|
|
||||||
utilruntime.HandleError(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
streamChan := make(chan httpstream.Stream, 1)
|
|
||||||
|
|
||||||
glog.V(5).Infof("Upgrading port forward response")
|
|
||||||
upgrader := spdy.NewResponseUpgrader()
|
|
||||||
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
|
||||||
if conn == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
|
||||||
conn.SetIdleTimeout(idleTimeout)
|
|
||||||
|
|
||||||
h := &portForwardStreamHandler{
|
|
||||||
conn: conn,
|
|
||||||
streamChan: streamChan,
|
|
||||||
streamPairs: make(map[string]*portForwardStreamPair),
|
|
||||||
streamCreationTimeout: streamCreationTimeout,
|
|
||||||
pod: podName,
|
|
||||||
uid: uid,
|
|
||||||
forwarder: portForwarder,
|
|
||||||
}
|
|
||||||
h.run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
|
|
||||||
// forward streams. It checks each stream's port and stream type headers,
|
|
||||||
// rejecting any streams that with missing or invalid values. Each valid
|
|
||||||
// stream is sent to the streams channel.
|
|
||||||
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
|
||||||
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
|
||||||
// make sure it has a valid port header
|
|
||||||
portString := stream.Headers().Get(api.PortHeader)
|
|
||||||
if len(portString) == 0 {
|
|
||||||
return fmt.Errorf("%q header is required", api.PortHeader)
|
|
||||||
}
|
|
||||||
port, err := strconv.ParseUint(portString, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
|
||||||
}
|
|
||||||
if port < 1 {
|
|
||||||
return fmt.Errorf("port %q must be > 0", portString)
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure it has a valid stream type header
|
|
||||||
streamType := stream.Headers().Get(api.StreamType)
|
|
||||||
if len(streamType) == 0 {
|
|
||||||
return fmt.Errorf("%q header is required", api.StreamType)
|
|
||||||
}
|
|
||||||
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
|
||||||
return fmt.Errorf("invalid stream type %q", streamType)
|
|
||||||
}
|
|
||||||
|
|
||||||
streams <- stream
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// portForwardStreamHandler is capable of processing multiple port forward
|
|
||||||
// requests over a single httpstream.Connection.
|
|
||||||
type portForwardStreamHandler struct {
|
|
||||||
conn httpstream.Connection
|
|
||||||
streamChan chan httpstream.Stream
|
|
||||||
streamPairsLock sync.RWMutex
|
|
||||||
streamPairs map[string]*portForwardStreamPair
|
|
||||||
streamCreationTimeout time.Duration
|
|
||||||
pod string
|
|
||||||
uid types.UID
|
|
||||||
forwarder PortForwarder
|
|
||||||
}
|
|
||||||
|
|
||||||
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
|
|
||||||
// new pair if one does not yet exist for the requestID. The returned bool is
|
|
||||||
// true if the pair was created.
|
|
||||||
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
|
|
||||||
h.streamPairsLock.Lock()
|
|
||||||
defer h.streamPairsLock.Unlock()
|
|
||||||
|
|
||||||
if p, ok := h.streamPairs[requestID]; ok {
|
|
||||||
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
|
||||||
return p, false
|
|
||||||
}
|
|
||||||
|
|
||||||
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
|
||||||
|
|
||||||
p := newPortForwardPair(requestID)
|
|
||||||
h.streamPairs[requestID] = p
|
|
||||||
|
|
||||||
return p, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// monitorStreamPair waits for the pair to receive both its error and data
|
|
||||||
// streams, or for the timeout to expire (whichever happens first), and then
|
|
||||||
// removes the pair.
|
|
||||||
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
|
||||||
utilruntime.HandleError(err)
|
|
||||||
p.printError(err.Error())
|
|
||||||
case <-p.complete:
|
|
||||||
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
|
||||||
}
|
|
||||||
h.removeStreamPair(p.requestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
|
||||||
// exists.
|
|
||||||
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
|
|
||||||
h.streamPairsLock.RLock()
|
|
||||||
defer h.streamPairsLock.RUnlock()
|
|
||||||
|
|
||||||
_, ok := h.streamPairs[requestID]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
|
||||||
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
|
|
||||||
h.streamPairsLock.Lock()
|
|
||||||
defer h.streamPairsLock.Unlock()
|
|
||||||
|
|
||||||
delete(h.streamPairs, requestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestID returns the request id for stream.
|
|
||||||
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
|
|
||||||
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
|
||||||
if len(requestID) == 0 {
|
|
||||||
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
|
||||||
// If we get here, it's because the connection came from an older client
|
|
||||||
// that isn't generating the request id header
|
|
||||||
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
|
||||||
//
|
|
||||||
// This is a best-effort attempt at supporting older clients.
|
|
||||||
//
|
|
||||||
// When there aren't concurrent new forwarded connections, each connection
|
|
||||||
// will have a pair of streams (data, error), and the stream IDs will be
|
|
||||||
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
|
||||||
// the stream ID into a pseudo-request id by taking the stream type and
|
|
||||||
// using id = stream.Identifier() when the stream type is error,
|
|
||||||
// and id = stream.Identifier() - 2 when it's data.
|
|
||||||
//
|
|
||||||
// NOTE: this only works when there are not concurrent new streams from
|
|
||||||
// multiple forwarded connections; it's a best-effort attempt at supporting
|
|
||||||
// old clients that don't generate request ids. If there are concurrent
|
|
||||||
// new connections, it's possible that 1 connection gets streams whose IDs
|
|
||||||
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
|
||||||
streamType := stream.Headers().Get(api.StreamType)
|
|
||||||
switch streamType {
|
|
||||||
case api.StreamTypeError:
|
|
||||||
requestID = strconv.Itoa(int(stream.Identifier()))
|
|
||||||
case api.StreamTypeData:
|
|
||||||
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
|
||||||
}
|
|
||||||
return requestID
|
|
||||||
}
|
|
||||||
|
|
||||||
// run is the main loop for the portForwardStreamHandler. It processes new
|
|
||||||
// streams, invoking portForward for each complete stream pair. The loop exits
|
|
||||||
// when the httpstream.Connection is closed.
|
|
||||||
func (h *portForwardStreamHandler) run() {
|
|
||||||
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
|
||||||
Loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-h.conn.CloseChan():
|
|
||||||
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
|
||||||
break Loop
|
|
||||||
case stream := <-h.streamChan:
|
|
||||||
requestID := h.requestID(stream)
|
|
||||||
streamType := stream.Headers().Get(api.StreamType)
|
|
||||||
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
|
||||||
|
|
||||||
p, created := h.getStreamPair(requestID)
|
|
||||||
if created {
|
|
||||||
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
|
||||||
}
|
|
||||||
if complete, err := p.add(stream); err != nil {
|
|
||||||
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
|
||||||
utilruntime.HandleError(errors.New(msg))
|
|
||||||
p.printError(msg)
|
|
||||||
} else if complete {
|
|
||||||
go h.portForward(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
|
|
||||||
// function for the given stream pair.
|
|
||||||
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
|
|
||||||
defer p.dataStream.Close()
|
|
||||||
defer p.errorStream.Close()
|
|
||||||
|
|
||||||
portString := p.dataStream.Headers().Get(api.PortHeader)
|
|
||||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
|
||||||
|
|
||||||
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
|
||||||
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
|
||||||
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
|
||||||
utilruntime.HandleError(msg)
|
|
||||||
fmt.Fprint(p.errorStream, msg.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// portForwardStreamPair represents the error and data streams for a port
|
|
||||||
// forwarding request.
|
|
||||||
type portForwardStreamPair struct {
|
|
||||||
lock sync.RWMutex
|
|
||||||
requestID string
|
|
||||||
dataStream httpstream.Stream
|
|
||||||
errorStream httpstream.Stream
|
|
||||||
complete chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// newPortForwardPair creates a new portForwardStreamPair.
|
|
||||||
func newPortForwardPair(requestID string) *portForwardStreamPair {
|
|
||||||
return &portForwardStreamPair{
|
|
||||||
requestID: requestID,
|
|
||||||
complete: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// add adds the stream to the portForwardStreamPair. If the pair already
|
|
||||||
// contains a stream for the new stream's type, an error is returned. add
|
|
||||||
// returns true if both the data and error streams for this pair have been
|
|
||||||
// received.
|
|
||||||
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
|
|
||||||
p.lock.Lock()
|
|
||||||
defer p.lock.Unlock()
|
|
||||||
|
|
||||||
switch stream.Headers().Get(api.StreamType) {
|
|
||||||
case api.StreamTypeError:
|
|
||||||
if p.errorStream != nil {
|
|
||||||
return false, errors.New("error stream already assigned")
|
|
||||||
}
|
|
||||||
p.errorStream = stream
|
|
||||||
case api.StreamTypeData:
|
|
||||||
if p.dataStream != nil {
|
|
||||||
return false, errors.New("data stream already assigned")
|
|
||||||
}
|
|
||||||
p.dataStream = stream
|
|
||||||
}
|
|
||||||
|
|
||||||
complete := p.errorStream != nil && p.dataStream != nil
|
|
||||||
if complete {
|
|
||||||
close(p.complete)
|
|
||||||
}
|
|
||||||
return complete, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// printError writes s to p.errorStream if p.errorStream has been set.
|
|
||||||
func (p *portForwardStreamPair) printError(s string) {
|
|
||||||
p.lock.RLock()
|
|
||||||
defer p.lock.RUnlock()
|
|
||||||
if p.errorStream != nil {
|
|
||||||
fmt.Fprint(p.errorStream, s)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP responds to HTTP requests on the Kubelet.
|
// ServeHTTP responds to HTTP requests on the Kubelet.
|
||||||
|
@ -1530,223 +1530,3 @@ func TestServePortForward(t *testing.T) {
|
|||||||
<-portForwardFuncDone
|
<-portForwardFuncDone
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeHttpStream struct {
|
|
||||||
headers http.Header
|
|
||||||
id uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeHttpStream() *fakeHttpStream {
|
|
||||||
return &fakeHttpStream{
|
|
||||||
headers: make(http.Header),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ httpstream.Stream = &fakeHttpStream{}
|
|
||||||
|
|
||||||
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeHttpStream) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeHttpStream) Reset() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeHttpStream) Headers() http.Header {
|
|
||||||
return s.headers
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeHttpStream) Identifier() uint32 {
|
|
||||||
return s.id
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPortForwardStreamReceived(t *testing.T) {
|
|
||||||
tests := map[string]struct {
|
|
||||||
port string
|
|
||||||
streamType string
|
|
||||||
expectedError string
|
|
||||||
}{
|
|
||||||
"missing port": {
|
|
||||||
expectedError: `"port" header is required`,
|
|
||||||
},
|
|
||||||
"unable to parse port": {
|
|
||||||
port: "abc",
|
|
||||||
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
|
|
||||||
},
|
|
||||||
"negative port": {
|
|
||||||
port: "-1",
|
|
||||||
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
|
|
||||||
},
|
|
||||||
"missing stream type": {
|
|
||||||
port: "80",
|
|
||||||
expectedError: `"streamType" header is required`,
|
|
||||||
},
|
|
||||||
"valid port with error stream": {
|
|
||||||
port: "80",
|
|
||||||
streamType: "error",
|
|
||||||
},
|
|
||||||
"valid port with data stream": {
|
|
||||||
port: "80",
|
|
||||||
streamType: "data",
|
|
||||||
},
|
|
||||||
"invalid stream type": {
|
|
||||||
port: "80",
|
|
||||||
streamType: "foo",
|
|
||||||
expectedError: `invalid stream type "foo"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, test := range tests {
|
|
||||||
streams := make(chan httpstream.Stream, 1)
|
|
||||||
f := portForwardStreamReceived(streams)
|
|
||||||
stream := newFakeHttpStream()
|
|
||||||
if len(test.port) > 0 {
|
|
||||||
stream.headers.Set("port", test.port)
|
|
||||||
}
|
|
||||||
if len(test.streamType) > 0 {
|
|
||||||
stream.headers.Set("streamType", test.streamType)
|
|
||||||
}
|
|
||||||
replySent := make(chan struct{})
|
|
||||||
err := f(stream, replySent)
|
|
||||||
close(replySent)
|
|
||||||
if len(test.expectedError) > 0 {
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
|
|
||||||
}
|
|
||||||
if e, a := test.expectedError, err.Error(); e != a {
|
|
||||||
t.Errorf("%s: expected err=%q, got %q", name, e, a)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s: unexpected error %v", name, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if s := <-streams; s != stream {
|
|
||||||
t.Errorf("%s: expected stream %#v, got %#v", name, stream, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStreamPair(t *testing.T) {
|
|
||||||
timeout := make(chan time.Time)
|
|
||||||
|
|
||||||
h := &portForwardStreamHandler{
|
|
||||||
streamPairs: make(map[string]*portForwardStreamPair),
|
|
||||||
}
|
|
||||||
|
|
||||||
// test adding a new entry
|
|
||||||
p, created := h.getStreamPair("1")
|
|
||||||
if p == nil {
|
|
||||||
t.Fatalf("unexpected nil pair")
|
|
||||||
}
|
|
||||||
if !created {
|
|
||||||
t.Fatal("expected created=true")
|
|
||||||
}
|
|
||||||
if p.dataStream != nil {
|
|
||||||
t.Errorf("unexpected non-nil data stream")
|
|
||||||
}
|
|
||||||
if p.errorStream != nil {
|
|
||||||
t.Errorf("unexpected non-nil error stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
// start the monitor for this pair
|
|
||||||
monitorDone := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
h.monitorStreamPair(p, timeout)
|
|
||||||
close(monitorDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !h.hasStreamPair("1") {
|
|
||||||
t.Fatal("This should still be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure we can retrieve an existing entry
|
|
||||||
p2, created := h.getStreamPair("1")
|
|
||||||
if created {
|
|
||||||
t.Fatal("expected created=false")
|
|
||||||
}
|
|
||||||
if p != p2 {
|
|
||||||
t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removed via complete
|
|
||||||
dataStream := newFakeHttpStream()
|
|
||||||
dataStream.headers.Set(api.StreamType, api.StreamTypeData)
|
|
||||||
complete, err := p.add(dataStream)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error adding data stream to pair: %v", err)
|
|
||||||
}
|
|
||||||
if complete {
|
|
||||||
t.Fatalf("unexpected complete")
|
|
||||||
}
|
|
||||||
|
|
||||||
errorStream := newFakeHttpStream()
|
|
||||||
errorStream.headers.Set(api.StreamType, api.StreamTypeError)
|
|
||||||
complete, err = p.add(errorStream)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error adding error stream to pair: %v", err)
|
|
||||||
}
|
|
||||||
if !complete {
|
|
||||||
t.Fatal("unexpected incomplete")
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure monitorStreamPair completed
|
|
||||||
<-monitorDone
|
|
||||||
|
|
||||||
// make sure the pair was removed
|
|
||||||
if h.hasStreamPair("1") {
|
|
||||||
t.Fatal("expected removal of pair after both data and error streams received")
|
|
||||||
}
|
|
||||||
|
|
||||||
// removed via timeout
|
|
||||||
p, created = h.getStreamPair("2")
|
|
||||||
if !created {
|
|
||||||
t.Fatal("expected created=true")
|
|
||||||
}
|
|
||||||
if p == nil {
|
|
||||||
t.Fatal("expected p not to be nil")
|
|
||||||
}
|
|
||||||
monitorDone = make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
h.monitorStreamPair(p, timeout)
|
|
||||||
close(monitorDone)
|
|
||||||
}()
|
|
||||||
// cause the timeout
|
|
||||||
close(timeout)
|
|
||||||
// make sure monitorStreamPair completed
|
|
||||||
<-monitorDone
|
|
||||||
if h.hasStreamPair("2") {
|
|
||||||
t.Fatal("expected stream pair to be removed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestID(t *testing.T) {
|
|
||||||
h := &portForwardStreamHandler{}
|
|
||||||
|
|
||||||
s := newFakeHttpStream()
|
|
||||||
s.headers.Set(api.StreamType, api.StreamTypeError)
|
|
||||||
s.id = 1
|
|
||||||
if e, a := "1", h.requestID(s); e != a {
|
|
||||||
t.Errorf("expected %q, got %q", e, a)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.headers.Set(api.StreamType, api.StreamTypeData)
|
|
||||||
s.id = 3
|
|
||||||
if e, a := "1", h.requestID(s); e != a {
|
|
||||||
t.Errorf("expected %q, got %q", e, a)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.id = 7
|
|
||||||
s.headers.Set(api.PortForwardRequestIDHeader, "2")
|
|
||||||
if e, a := "2", h.requestID(s); e != a {
|
|
||||||
t.Errorf("expected %q, got %q", e, a)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user