diff --git a/pkg/kubelet/server/server.go b/pkg/kubelet/server/server.go index 9d225b7d1a0..9b39ddff0b9 100644 --- a/pkg/kubelet/server/server.go +++ b/pkg/kubelet/server/server.go @@ -606,7 +606,7 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response) podFullName := kubecontainer.GetPodFullName(pod) redirect, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts) if err != nil { - response.WriteError(streaming.HTTPStatus(err), err) + streaming.WriteError(err, response.ResponseWriter) return } if redirect != nil { @@ -644,7 +644,7 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) { podFullName := kubecontainer.GetPodFullName(pod) redirect, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts) if err != nil { - response.WriteError(streaming.HTTPStatus(err), err) + streaming.WriteError(err, response.ResponseWriter) return } if redirect != nil { @@ -714,7 +714,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID) if err != nil { - response.WriteError(streaming.HTTPStatus(err), err) + streaming.WriteError(err, response.ResponseWriter) return } if redirect != nil { diff --git a/pkg/kubelet/server/streaming/BUILD b/pkg/kubelet/server/streaming/BUILD index 1c1cfe7784d..6444113c102 100644 --- a/pkg/kubelet/server/streaming/BUILD +++ b/pkg/kubelet/server/streaming/BUILD @@ -12,14 +12,15 @@ go_library( name = "go_default_library", srcs = [ "errors.go", + "request_cache.go", "server.go", ], tags = ["automanaged"], deps = [ - "//pkg/api:go_default_library", "//pkg/kubelet/api/v1alpha1/runtime:go_default_library", "//pkg/kubelet/server/portforward:go_default_library", "//pkg/kubelet/server/remotecommand:go_default_library", + "//pkg/util/clock:go_default_library", "//pkg/util/term:go_default_library", "//vendor:github.com/emicklei/go-restful", "//vendor:google.golang.org/grpc", @@ -30,7 +31,10 @@ go_library( go_test( name = "go_default_test", - srcs = ["server_test.go"], + srcs = [ + "request_cache_test.go", + "server_test.go", + ], library = ":go_default_library", tags = ["automanaged"], deps = [ @@ -43,6 +47,7 @@ go_test( "//vendor:github.com/stretchr/testify/assert", "//vendor:github.com/stretchr/testify/require", "//vendor:k8s.io/client-go/pkg/api", + "//vendor:k8s.io/client-go/pkg/util/clock", ], ) diff --git a/pkg/kubelet/server/streaming/errors.go b/pkg/kubelet/server/streaming/errors.go index 3d957bb1edc..d440ec498db 100644 --- a/pkg/kubelet/server/streaming/errors.go +++ b/pkg/kubelet/server/streaming/errors.go @@ -19,6 +19,7 @@ package streaming import ( "fmt" "net/http" + "strconv" "time" "google.golang.org/grpc" @@ -33,12 +34,27 @@ func ErrorTimeout(op string, timeout time.Duration) error { return grpc.Errorf(codes.DeadlineExceeded, fmt.Sprintf("%s timed out after %s", op, timeout.String())) } -// Translates a CRI streaming error into an HTTP status code. -func HTTPStatus(err error) int { +// The error returned when the maximum number of in-flight requests is exceeded. +func ErrorTooManyInFlight() error { + return grpc.Errorf(codes.ResourceExhausted, "maximum number of in-flight requests exceeded") +} + +// Translates a CRI streaming error into an appropriate HTTP response. +func WriteError(err error, w http.ResponseWriter) error { + var status int switch grpc.Code(err) { case codes.NotFound: - return http.StatusNotFound + status = http.StatusNotFound + case codes.ResourceExhausted: + // We only expect to hit this if there is a DoS, so we just wait the full TTL. + // If this is ever hit in steady-state operations, consider increasing the MaxInFlight requests, + // or plumbing through the time to next expiration. + w.Header().Set("Retry-After", strconv.Itoa(int(CacheTTL.Seconds()))) + status = http.StatusTooManyRequests default: - return http.StatusInternalServerError + status = http.StatusInternalServerError } + w.WriteHeader(status) + _, writeErr := w.Write([]byte(err.Error())) + return writeErr } diff --git a/pkg/kubelet/server/streaming/request_cache.go b/pkg/kubelet/server/streaming/request_cache.go new file mode 100644 index 00000000000..c8b68a464f1 --- /dev/null +++ b/pkg/kubelet/server/streaming/request_cache.go @@ -0,0 +1,146 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package streaming + +import ( + "container/list" + "crypto/rand" + "encoding/base64" + "fmt" + "math" + "sync" + "time" + + "k8s.io/kubernetes/pkg/util/clock" +) + +var ( + // Timeout after which tokens become invalid. + CacheTTL = 1 * time.Minute + // The maximum number of in-flight requests to allow. + MaxInFlight = 1000 + // Length of the random base64 encoded token identifying the request. + TokenLen = 8 +) + +// requestCache caches streaming (exec/attach/port-forward) requests and generates a single-use +// random token for their retrieval. The requestCache is used for building streaming URLs without +// the need to encode every request parameter in the URL. +type requestCache struct { + // clock is used to obtain the current time + clock clock.Clock + + // tokens maps the generate token to the request for fast retrieval. + tokens map[string]*list.Element + // ll maintains an age-ordered request list for faster garbage collection of expired requests. + ll *list.List + + lock sync.Mutex +} + +// Type representing an *ExecRequest, *AttachRequest, or *PortForwardRequest. +type request interface{} + +type cacheEntry struct { + token string + req request + expireTime time.Time +} + +func newRequestCache() *requestCache { + return &requestCache{ + clock: clock.RealClock{}, + ll: list.New(), + tokens: make(map[string]*list.Element), + } +} + +// Insert the given request into the cache and returns the token used for fetching it out. +func (c *requestCache) Insert(req request) (token string, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + // Remove expired entries. + c.gc() + // If the cache is full, reject the request. + if c.ll.Len() == MaxInFlight { + return "", ErrorTooManyInFlight() + } + token, err = c.uniqueToken() + if err != nil { + return "", err + } + ele := c.ll.PushFront(&cacheEntry{token, req, c.clock.Now().Add(CacheTTL)}) + + c.tokens[token] = ele + return token, nil +} + +// Consume the token (remove it from the cache) and return the cached request, if found. +func (c *requestCache) Consume(token string) (req request, found bool) { + c.lock.Lock() + defer c.lock.Unlock() + ele, ok := c.tokens[token] + if !ok { + return nil, false + } + c.ll.Remove(ele) + delete(c.tokens, token) + + entry := ele.Value.(*cacheEntry) + if c.clock.Now().After(entry.expireTime) { + // Entry already expired. + return nil, false + } + return entry.req, true +} + +// uniqueToken generates a random URL-safe token and ensures uniqueness. +func (c *requestCache) uniqueToken() (string, error) { + const maxTries = 10 + // Number of bytes to be TokenLen when base64 encoded. + tokenSize := math.Ceil(float64(TokenLen) * 6 / 8) + rawToken := make([]byte, int(tokenSize)) + for i := 0; i < maxTries; i++ { + if _, err := rand.Read(rawToken); err != nil { + return "", err + } + encoded := base64.RawURLEncoding.EncodeToString(rawToken) + token := encoded[:TokenLen] + // If it's unique, return it. Otherwise retry. + if _, exists := c.tokens[encoded]; !exists { + return token, nil + } + } + return "", fmt.Errorf("failed to generate unique token") +} + +// Must be write-locked prior to calling. +func (c *requestCache) gc() { + now := c.clock.Now() + for c.ll.Len() > 0 { + oldest := c.ll.Back() + entry := oldest.Value.(*cacheEntry) + if !now.After(entry.expireTime) { + return + } + + // Oldest value is expired; remove it. + c.ll.Remove(oldest) + delete(c.tokens, entry.token) + } +} diff --git a/pkg/kubelet/server/streaming/request_cache_test.go b/pkg/kubelet/server/streaming/request_cache_test.go new file mode 100644 index 00000000000..a714a7149ae --- /dev/null +++ b/pkg/kubelet/server/streaming/request_cache_test.go @@ -0,0 +1,221 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package streaming + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "k8s.io/client-go/pkg/util/clock" +) + +func TestInsert(t *testing.T) { + c, _ := newTestCache() + + // Insert normal + oldestTok, err := c.Insert(nextRequest()) + require.NoError(t, err) + assert.Len(t, oldestTok, TokenLen) + assertCacheSize(t, c, 1) + + // Insert until full + for i := 0; i < MaxInFlight-2; i++ { + tok, err := c.Insert(nextRequest()) + require.NoError(t, err) + assert.Len(t, tok, TokenLen) + } + assertCacheSize(t, c, MaxInFlight-1) + + newestReq := nextRequest() + newestTok, err := c.Insert(newestReq) + require.NoError(t, err) + assert.Len(t, newestTok, TokenLen) + assertCacheSize(t, c, MaxInFlight) + require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached") + + // Consume newest token. + req, ok := c.Consume(newestTok) + assert.True(t, ok, "newest request should still be cached") + assert.Equal(t, newestReq, req) + require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached") + + // Insert again (still full) + tok, err := c.Insert(nextRequest()) + require.NoError(t, err) + assert.Len(t, tok, TokenLen) + assertCacheSize(t, c, MaxInFlight) + + // Insert again (should evict) + _, err = c.Insert(nextRequest()) + assert.Error(t, err, "should reject further requests") + errResponse := httptest.NewRecorder() + require.NoError(t, WriteError(err, errResponse)) + assert.Equal(t, errResponse.Code, http.StatusTooManyRequests) + assert.Equal(t, strconv.Itoa(int(CacheTTL.Seconds())), errResponse.HeaderMap.Get("Retry-After")) + + assertCacheSize(t, c, MaxInFlight) + _, ok = c.Consume(oldestTok) + assert.True(t, ok, "oldest request should be valid") +} + +func TestConsume(t *testing.T) { + c, clock := newTestCache() + + { // Insert & consume. + req := nextRequest() + tok, err := c.Insert(req) + require.NoError(t, err) + assertCacheSize(t, c, 1) + + cachedReq, ok := c.Consume(tok) + assert.True(t, ok) + assert.Equal(t, req, cachedReq) + assertCacheSize(t, c, 0) + } + + { // Insert & consume out of order + req1 := nextRequest() + tok1, err := c.Insert(req1) + require.NoError(t, err) + assertCacheSize(t, c, 1) + + req2 := nextRequest() + tok2, err := c.Insert(req2) + require.NoError(t, err) + assertCacheSize(t, c, 2) + + cachedReq2, ok := c.Consume(tok2) + assert.True(t, ok) + assert.Equal(t, req2, cachedReq2) + assertCacheSize(t, c, 1) + + cachedReq1, ok := c.Consume(tok1) + assert.True(t, ok) + assert.Equal(t, req1, cachedReq1) + assertCacheSize(t, c, 0) + } + + { // Consume a second time + req := nextRequest() + tok, err := c.Insert(req) + require.NoError(t, err) + assertCacheSize(t, c, 1) + + cachedReq, ok := c.Consume(tok) + assert.True(t, ok) + assert.Equal(t, req, cachedReq) + assertCacheSize(t, c, 0) + + _, ok = c.Consume(tok) + assert.False(t, ok) + assertCacheSize(t, c, 0) + } + + { // Consume without insert + _, ok := c.Consume("fooBAR") + assert.False(t, ok) + assertCacheSize(t, c, 0) + } + + { // Consume expired + tok, err := c.Insert(nextRequest()) + require.NoError(t, err) + assertCacheSize(t, c, 1) + + clock.Step(2 * CacheTTL) + + _, ok := c.Consume(tok) + assert.False(t, ok) + assertCacheSize(t, c, 0) + } +} + +func TestGC(t *testing.T) { + c, clock := newTestCache() + + // When empty + c.gc() + assertCacheSize(t, c, 0) + + tok1, err := c.Insert(nextRequest()) + require.NoError(t, err) + assertCacheSize(t, c, 1) + clock.Step(10 * time.Second) + tok2, err := c.Insert(nextRequest()) + require.NoError(t, err) + assertCacheSize(t, c, 2) + + // expired: tok1, tok2 + // non-expired: tok3, tok4 + clock.Step(2 * CacheTTL) + tok3, err := c.Insert(nextRequest()) + require.NoError(t, err) + assertCacheSize(t, c, 1) + clock.Step(10 * time.Second) + tok4, err := c.Insert(nextRequest()) + require.NoError(t, err) + assertCacheSize(t, c, 2) + + _, ok := c.Consume(tok1) + assert.False(t, ok) + _, ok = c.Consume(tok2) + assert.False(t, ok) + _, ok = c.Consume(tok3) + assert.True(t, ok) + _, ok = c.Consume(tok4) + assert.True(t, ok) + + // When full, nothing is expired. + for i := 0; i < MaxInFlight; i++ { + _, err := c.Insert(nextRequest()) + require.NoError(t, err) + } + assertCacheSize(t, c, MaxInFlight) + + // When everything is expired + clock.Step(2 * CacheTTL) + _, err = c.Insert(nextRequest()) + require.NoError(t, err) + assertCacheSize(t, c, 1) +} + +func newTestCache() (*requestCache, *clock.FakeClock) { + c := newRequestCache() + fakeClock := clock.NewFakeClock(time.Now()) + c.clock = fakeClock + return c, fakeClock +} + +func assertCacheSize(t *testing.T, cache *requestCache, expectedSize int) { + tokenLen := len(cache.tokens) + llLen := cache.ll.Len() + assert.Equal(t, tokenLen, llLen, "inconsistent cache size! len(tokens)=%d; len(ll)=%d", tokenLen, llLen) + assert.Equal(t, expectedSize, tokenLen, "unexpected cache size!") +} + +var requestUID = 0 + +func nextRequest() interface{} { + requestUID++ + return requestUID +} diff --git a/pkg/kubelet/server/streaming/server.go b/pkg/kubelet/server/streaming/server.go index 4c7b3189076..95e01c492f2 100644 --- a/pkg/kubelet/server/streaming/server.go +++ b/pkg/kubelet/server/streaming/server.go @@ -25,10 +25,12 @@ import ( "path" "time" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + restful "github.com/emicklei/go-restful" "k8s.io/apimachinery/pkg/types" - "k8s.io/kubernetes/pkg/api" runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" "k8s.io/kubernetes/pkg/kubelet/server/portforward" "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" @@ -97,6 +99,7 @@ func NewServer(config Config, runtime Runtime) (Server, error) { s := &server{ config: config, runtime: &criAdapter{runtime}, + cache: newRequestCache(), } if s.config.BaseURL == nil { @@ -114,9 +117,9 @@ func NewServer(config Config, runtime Runtime) (Server, error) { path string handler restful.RouteFunction }{ - {"/exec/{containerID}", s.serveExec}, - {"/attach/{containerID}", s.serveAttach}, - {"/portforward/{podSandboxID}", s.servePortForward}, + {"/exec/{token}", s.serveExec}, + {"/attach/{token}", s.serveAttach}, + {"/portforward/{token}", s.servePortForward}, } // If serving relative to a base path, set that here. pathPrefix := path.Dir(s.config.BaseURL.Path) @@ -139,37 +142,45 @@ type server struct { config Config runtime *criAdapter handler http.Handler + cache *requestCache } func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) { - url := s.buildURL("exec", req.GetContainerId(), streamOpts{ - stdin: req.GetStdin(), - stdout: true, - stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout. - tty: req.GetTty(), - command: req.GetCmd(), - }) + if req.GetContainerId() == "" { + return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id") + } + token, err := s.cache.Insert(req) + if err != nil { + return nil, err + } return &runtimeapi.ExecResponse{ - Url: &url, + Url: s.buildURL("exec", token), }, nil } func (s *server) GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error) { - url := s.buildURL("attach", req.GetContainerId(), streamOpts{ - stdin: req.GetStdin(), - stdout: true, - stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout. - tty: req.GetTty(), - }) + if req.GetContainerId() == "" { + return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id") + } + token, err := s.cache.Insert(req) + if err != nil { + return nil, err + } return &runtimeapi.AttachResponse{ - Url: &url, + Url: s.buildURL("attach", token), }, nil } func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) { - url := s.buildURL("portforward", req.GetPodSandboxId(), streamOpts{}) + if req.GetPodSandboxId() == "" { + return nil, grpc.Errorf(codes.InvalidArgument, "missing required pod_sandbox_id") + } + token, err := s.cache.Insert(req) + if err != nil { + return nil, err + } return &runtimeapi.PortForwardResponse{ - Url: &url, + Url: s.buildURL("portforward", token), }, nil } @@ -200,63 +211,32 @@ func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.handler.ServeHTTP(w, r) } -type streamOpts struct { - stdin bool - stdout bool - stderr bool - tty bool - - command []string - port []int32 -} - -const ( - urlParamStdin = api.ExecStdinParam - urlParamStdout = api.ExecStdoutParam - urlParamStderr = api.ExecStderrParam - urlParamTTY = api.ExecTTYParam - urlParamCommand = api.ExecCommandParamm -) - -func (s *server) buildURL(method, id string, opts streamOpts) string { - loc := &url.URL{ - Path: path.Join(method, id), - } - - query := url.Values{} - if opts.stdin { - query.Add(urlParamStdin, "1") - } - if opts.stdout { - query.Add(urlParamStdout, "1") - } - if opts.stderr { - query.Add(urlParamStderr, "1") - } - if opts.tty { - query.Add(urlParamTTY, "1") - } - for _, c := range opts.command { - query.Add(urlParamCommand, c) - } - loc.RawQuery = query.Encode() - - return s.config.BaseURL.ResolveReference(loc).String() +func (s *server) buildURL(method, token string) *string { + loc := s.config.BaseURL.ResolveReference(&url.URL{ + Path: path.Join(method, token), + }).String() + return &loc } func (s *server) serveExec(req *restful.Request, resp *restful.Response) { - containerID := req.PathParameter("containerID") - if containerID == "" { - resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) + token := req.PathParameter("token") + cachedRequest, ok := s.cache.Consume(token) + if !ok { + http.NotFound(resp.ResponseWriter, req.Request) + return + } + exec, ok := cachedRequest.(*runtimeapi.ExecRequest) + if !ok { + http.NotFound(resp.ResponseWriter, req.Request) return } - streamOpts, err := remotecommand.NewOptions(req.Request) - if err != nil { - resp.WriteError(http.StatusBadRequest, err) - return + streamOpts := &remotecommand.Options{ + Stdin: exec.GetStdin(), + Stdout: true, + Stderr: !exec.GetTty(), + TTY: exec.GetTty(), } - cmd := req.Request.URL.Query()[api.ExecCommandParamm] remotecommand.ServeExec( resp.ResponseWriter, @@ -264,8 +244,8 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { s.runtime, "", // unused: podName "", // unusued: podUID - containerID, - cmd, + exec.GetContainerId(), + exec.GetCmd(), streamOpts, s.config.StreamIdleTimeout, s.config.StreamCreationTimeout, @@ -273,25 +253,31 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { } func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { - containerID := req.PathParameter("containerID") - if containerID == "" { - resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) + token := req.PathParameter("token") + cachedRequest, ok := s.cache.Consume(token) + if !ok { + http.NotFound(resp.ResponseWriter, req.Request) + return + } + attach, ok := cachedRequest.(*runtimeapi.AttachRequest) + if !ok { + http.NotFound(resp.ResponseWriter, req.Request) return } - streamOpts, err := remotecommand.NewOptions(req.Request) - if err != nil { - resp.WriteError(http.StatusBadRequest, err) - return + streamOpts := &remotecommand.Options{ + Stdin: attach.GetStdin(), + Stdout: true, + Stderr: !attach.GetTty(), + TTY: attach.GetTty(), } - remotecommand.ServeAttach( resp.ResponseWriter, req.Request, s.runtime, "", // unused: podName "", // unusued: podUID - containerID, + attach.GetContainerId(), streamOpts, s.config.StreamIdleTimeout, s.config.StreamCreationTimeout, @@ -299,9 +285,15 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { } func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { - podSandboxID := req.PathParameter("podSandboxID") - if podSandboxID == "" { - resp.WriteError(http.StatusBadRequest, errors.New("missing required podSandboxID path parameter")) + token := req.PathParameter("token") + cachedRequest, ok := s.cache.Consume(token) + if !ok { + http.NotFound(resp.ResponseWriter, req.Request) + return + } + pf, ok := cachedRequest.(*runtimeapi.PortForwardRequest) + if !ok { + http.NotFound(resp.ResponseWriter, req.Request) return } @@ -309,7 +301,7 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response) resp.ResponseWriter, req.Request, s.runtime, - podSandboxID, + pf.GetPodSandboxId(), "", // unused: podUID s.config.StreamIdleTimeout, s.config.StreamCreationTimeout) diff --git a/pkg/kubelet/server/streaming/server_test.go b/pkg/kubelet/server/streaming/server_test.go index c76f458aabf..e6170a9222c 100644 --- a/pkg/kubelet/server/streaming/server_test.go +++ b/pkg/kubelet/server/streaming/server_test.go @@ -18,12 +18,12 @@ package streaming import ( "crypto/tls" - "fmt" "io" "net/http" "net/http/httptest" "net/url" "strconv" + "strings" "sync" "testing" @@ -46,18 +46,18 @@ const ( ) func TestGetExec(t *testing.T) { - testcases := []struct { - cmd []string - tty bool - stdin bool - expectedQuery string - }{ - {[]string{"echo", "foo"}, false, false, "?command=echo&command=foo&error=1&output=1"}, - {[]string{"date"}, true, false, "?command=date&output=1&tty=1"}, - {[]string{"date"}, false, true, "?command=date&error=1&input=1&output=1"}, - {[]string{"date"}, true, true, "?command=date&input=1&output=1&tty=1"}, + type testcase struct { + cmd []string + tty bool + stdin bool } - server, err := NewServer(Config{ + testcases := []testcase{ + {[]string{"echo", "foo"}, false, false}, + {[]string{"date"}, true, false}, + {[]string{"date"}, false, true}, + {[]string{"date"}, true, true}, + } + serv, err := NewServer(Config{ Addr: testAddr, }, nil) assert.NoError(t, err) @@ -79,6 +79,14 @@ func TestGetExec(t *testing.T) { }, nil) assert.NoError(t, err) + assertRequestToken := func(test testcase, cache *requestCache, token string) { + req, ok := cache.Consume(token) + require.True(t, ok, "token %s not found! testcase=%+v", token, test) + assert.Equal(t, testContainerID, req.(*runtimeapi.ExecRequest).GetContainerId(), "testcase=%+v", test) + assert.Equal(t, test.cmd, req.(*runtimeapi.ExecRequest).GetCmd(), "testcase=%+v", test) + assert.Equal(t, test.tty, req.(*runtimeapi.ExecRequest).GetTty(), "testcase=%+v", test) + assert.Equal(t, test.stdin, req.(*runtimeapi.ExecRequest).GetStdin(), "testcase=%+v", test) + } containerID := testContainerID for _, test := range testcases { request := &runtimeapi.ExecRequest{ @@ -87,38 +95,47 @@ func TestGetExec(t *testing.T) { Tty: &test.tty, Stdin: &test.stdin, } - // Non-TLS - resp, err := server.GetExec(request) - assert.NoError(t, err, "testcase=%+v", test) - expectedURL := "http://" + testAddr + "/exec/" + testContainerID + test.expectedQuery - assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + { // Non-TLS + resp, err := serv.GetExec(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "http://" + testAddr + "/exec/" + assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + assertRequestToken(test, serv.(*server).cache, token) + } - // TLS - resp, err = tlsServer.GetExec(request) - assert.NoError(t, err, "testcase=%+v", test) - expectedURL = "https://" + testAddr + "/exec/" + testContainerID + test.expectedQuery - assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + { // TLS + resp, err := tlsServer.GetExec(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "https://" + testAddr + "/exec/" + assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + assertRequestToken(test, tlsServer.(*server).cache, token) + } - // Path prefix - resp, err = prefixServer.GetExec(request) - assert.NoError(t, err, "testcase=%+v", test) - expectedURL = "http://" + testAddr + "/" + pathPrefix + "/exec/" + testContainerID + test.expectedQuery - assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + { // Path prefix + resp, err := prefixServer.GetExec(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/" + assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + assertRequestToken(test, prefixServer.(*server).cache, token) + } } } func TestGetAttach(t *testing.T) { - testcases := []struct { - tty bool - stdin bool - expectedQuery string - }{ - {false, false, "?error=1&output=1"}, - {true, false, "?output=1&tty=1"}, - {false, true, "?error=1&input=1&output=1"}, - {true, true, "?input=1&output=1&tty=1"}, + type testcase struct { + tty bool + stdin bool } - server, err := NewServer(Config{ + testcases := []testcase{ + {false, false}, + {true, false}, + {false, true}, + {true, true}, + } + serv, err := NewServer(Config{ Addr: testAddr, }, nil) assert.NoError(t, err) @@ -129,6 +146,13 @@ func TestGetAttach(t *testing.T) { }, nil) assert.NoError(t, err) + assertRequestToken := func(test testcase, cache *requestCache, token string) { + req, ok := cache.Consume(token) + require.True(t, ok, "token %s not found! testcase=%+v", token, test) + assert.Equal(t, testContainerID, req.(*runtimeapi.AttachRequest).GetContainerId(), "testcase=%+v", test) + assert.Equal(t, test.tty, req.(*runtimeapi.AttachRequest).GetTty(), "testcase=%+v", test) + assert.Equal(t, test.stdin, req.(*runtimeapi.AttachRequest).GetStdin(), "testcase=%+v", test) + } containerID := testContainerID for _, test := range testcases { request := &runtimeapi.AttachRequest{ @@ -136,17 +160,23 @@ func TestGetAttach(t *testing.T) { Stdin: &test.stdin, Tty: &test.tty, } - // Non-TLS - resp, err := server.GetAttach(request) - assert.NoError(t, err, "testcase=%+v", test) - expectedURL := "http://" + testAddr + "/attach/" + testContainerID + test.expectedQuery - assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + { // Non-TLS + resp, err := serv.GetAttach(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "http://" + testAddr + "/attach/" + assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + assertRequestToken(test, serv.(*server).cache, token) + } - // TLS - resp, err = tlsServer.GetAttach(request) - assert.NoError(t, err, "testcase=%+v", test) - expectedURL = "https://" + testAddr + "/attach/" + testContainerID + test.expectedQuery - assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + { // TLS + resp, err := tlsServer.GetAttach(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "https://" + testAddr + "/attach/" + assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + assertRequestToken(test, tlsServer.(*server).cache, token) + } } } @@ -157,26 +187,36 @@ func TestGetPortForward(t *testing.T) { Port: []int32{1, 2, 3, 4}, } - // Non-TLS - server, err := NewServer(Config{ - Addr: testAddr, - }, nil) - assert.NoError(t, err) - resp, err := server.GetPortForward(request) - assert.NoError(t, err) - expectedURL := "http://" + testAddr + "/portforward/" + testPodSandboxID - assert.Equal(t, expectedURL, resp.GetUrl()) + { // Non-TLS + serv, err := NewServer(Config{ + Addr: testAddr, + }, nil) + assert.NoError(t, err) + resp, err := serv.GetPortForward(request) + assert.NoError(t, err) + expectedURL := "http://" + testAddr + "/portforward/" + assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL)) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + req, ok := serv.(*server).cache.Consume(token) + require.True(t, ok, "token %s not found!", token) + assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId()) + } - // TLS - tlsServer, err := NewServer(Config{ - Addr: testAddr, - TLSConfig: &tls.Config{}, - }, nil) - assert.NoError(t, err) - resp, err = tlsServer.GetPortForward(request) - assert.NoError(t, err) - expectedURL = "https://" + testAddr + "/portforward/" + testPodSandboxID - assert.Equal(t, expectedURL, resp.GetUrl()) + { // TLS + tlsServer, err := NewServer(Config{ + Addr: testAddr, + TLSConfig: &tls.Config{}, + }, nil) + assert.NoError(t, err) + resp, err := tlsServer.GetPortForward(request) + assert.NoError(t, err) + expectedURL := "https://" + testAddr + "/portforward/" + assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL)) + token := strings.TrimPrefix(resp.GetUrl(), expectedURL) + req, ok := tlsServer.(*server).cache.Consume(token) + require.True(t, ok, "token %s not found!", token) + assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId()) + } } func TestServeExec(t *testing.T) { @@ -188,21 +228,18 @@ func TestServeAttach(t *testing.T) { } func TestServePortForward(t *testing.T) { - rt := newFakeRuntime(t) - s, err := NewServer(DefaultConfig, rt) - require.NoError(t, err) - testServer := httptest.NewServer(s) + s, testServer := startTestServer(t) defer testServer.Close() - testURL, err := url.Parse(testServer.URL) + podSandboxID := testPodSandboxID + resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{ + PodSandboxId: &podSandboxID, + }) + require.NoError(t, err) + reqURL, err := url.Parse(resp.GetUrl()) require.NoError(t, err) - loc := &url.URL{ - Scheme: testURL.Scheme, - Host: testURL.Host, - } - loc.Path = fmt.Sprintf("/%s/%s", "portforward", testPodSandboxID) - exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) + exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) require.NoError(t, err) streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) require.NoError(t, err) @@ -227,22 +264,30 @@ func TestServePortForward(t *testing.T) { // Run the remote command test. // commandType is either "exec" or "attach". func runRemoteCommandTest(t *testing.T, commandType string) { - rt := newFakeRuntime(t) - s, err := NewServer(DefaultConfig, rt) - require.NoError(t, err) - testServer := httptest.NewServer(s) + s, testServer := startTestServer(t) defer testServer.Close() - testURL, err := url.Parse(testServer.URL) - require.NoError(t, err) - query := url.Values{} - query.Add(urlParamStdin, "1") - query.Add(urlParamStdout, "1") - query.Add(urlParamStderr, "1") - loc := &url.URL{ - Scheme: testURL.Scheme, - Host: testURL.Host, - RawQuery: query.Encode(), + var reqURL *url.URL + stdin := true + containerID := testContainerID + switch commandType { + case "exec": + resp, err := s.GetExec(&runtimeapi.ExecRequest{ + ContainerId: &containerID, + Cmd: []string{"echo"}, + Stdin: &stdin, + }) + require.NoError(t, err) + reqURL, err = url.Parse(resp.GetUrl()) + require.NoError(t, err) + case "attach": + resp, err := s.GetAttach(&runtimeapi.AttachRequest{ + ContainerId: &containerID, + Stdin: &stdin, + }) + require.NoError(t, err) + reqURL, err = url.Parse(resp.GetUrl()) + require.NoError(t, err) } wg := sync.WaitGroup{} @@ -254,8 +299,7 @@ func runRemoteCommandTest(t *testing.T, commandType string) { go func() { defer wg.Done() - loc.Path = fmt.Sprintf("/%s/%s", commandType, testContainerID) - exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) + exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) require.NoError(t, err) opts := remotecommand.StreamOptions{ @@ -275,6 +319,36 @@ func runRemoteCommandTest(t *testing.T, commandType string) { }() wg.Wait() + + // Repeat request with the same URL should be a 404. + resp, err := http.Get(reqURL.String()) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func startTestServer(t *testing.T) (Server, *httptest.Server) { + var s Server + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.ServeHTTP(w, r) + })) + cleanup := true + defer func() { + if cleanup { + testServer.Close() + } + }() + + testURL, err := url.Parse(testServer.URL) + require.NoError(t, err) + + rt := newFakeRuntime(t) + config := DefaultConfig + config.BaseURL = testURL + s, err = NewServer(config, rt) + require.NoError(t, err) + + cleanup = false // Caller must close the test server. + return s, testServer } const ( diff --git a/test/e2e/kubectl.go b/test/e2e/kubectl.go index 2c5275152cf..462e1bbc520 100644 --- a/test/e2e/kubectl.go +++ b/test/e2e/kubectl.go @@ -391,6 +391,14 @@ var _ = framework.KubeDescribe("Kubectl client", func() { framework.Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) } + By("executing a very long command in the container") + veryLongData := make([]rune, 20000) + for i := 0; i < len(veryLongData); i++ { + veryLongData[i] = 'a' + } + execOutput = framework.RunKubectlOrDie("exec", fmt.Sprintf("--namespace=%v", ns), simplePodName, "echo", string(veryLongData)) + Expect(string(veryLongData)).To(Equal(strings.TrimSpace(execOutput)), "Unexpected kubectl exec output") + By("executing a command in the container with noninteractive stdin") execOutput = framework.NewKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "cat"). WithStdinData("abcd1234").