diff --git a/pkg/kubelet/server/streaming/BUILD b/pkg/kubelet/server/streaming/BUILD new file mode 100644 index 00000000000..a5a8469926a --- /dev/null +++ b/pkg/kubelet/server/streaming/BUILD @@ -0,0 +1,44 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +load( + "@io_bazel_rules_go//go:def.bzl", + "go_binary", + "go_library", + "go_test", + "cgo_library", +) + +go_library( + name = "go_default_library", + srcs = ["server.go"], + tags = ["automanaged"], + deps = [ + "//pkg/kubelet/api/v1alpha1/runtime:go_default_library", + "//pkg/kubelet/server/portforward:go_default_library", + "//pkg/kubelet/server/remotecommand:go_default_library", + "//pkg/types:go_default_library", + "//pkg/util/term:go_default_library", + "//vendor:github.com/emicklei/go-restful", + "//vendor:k8s.io/client-go/pkg/api", + ], +) + +go_test( + name = "go_default_test", + srcs = ["server_test.go"], + library = "go_default_library", + tags = ["automanaged"], + deps = [ + "//pkg/client/restclient:go_default_library", + "//pkg/client/unversioned/remotecommand:go_default_library", + "//pkg/kubelet/api/v1alpha1/runtime:go_default_library", + "//pkg/kubelet/server/portforward:go_default_library", + "//pkg/kubelet/server/remotecommand:go_default_library", + "//pkg/util/term:go_default_library", + "//vendor:github.com/stretchr/testify/assert", + "//vendor:github.com/stretchr/testify/require", + "//vendor:k8s.io/client-go/pkg/api", + ], +) diff --git a/pkg/kubelet/server/streaming/server.go b/pkg/kubelet/server/streaming/server.go new file mode 100644 index 00000000000..994e6be1a5d --- /dev/null +++ b/pkg/kubelet/server/streaming/server.go @@ -0,0 +1,312 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package streaming + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" + + restful "github.com/emicklei/go-restful" + + "k8s.io/client-go/pkg/api" + runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" + "k8s.io/kubernetes/pkg/kubelet/server/portforward" + "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" + "k8s.io/kubernetes/pkg/types" + "k8s.io/kubernetes/pkg/util/term" +) + +// The library interface to serve the stream requests. +type Server interface { + http.Handler + + // Get the serving URL for the requests. Server must be started before these are called. + // Requests must not be nil. Responses may be nil iff an error is returned. + GetExec(*runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) + GetAttach(req *runtimeapi.AttachRequest, tty bool) (*runtimeapi.AttachResponse, error) + GetPortForward(*runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) + + // Start the server. + // addr is the address to serve on (address:port) stayUp indicates whether the server should + // listen until Stop() is called, or automatically stop after all expected connections are + // closed. Calling Get{Exec,Attach,PortForward} increments the expected connection count. + // Function does not return until the server is stopped. + Start(stayUp bool) error + // Stop the server, and terminate any open connections. + Stop() error +} + +// The interface to execute the commands and provide the streams. +type Runtime interface { + Exec(containerID string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error + Attach(containerID string, in io.Reader, out, err io.WriteCloser, resize <-chan term.Size) error + PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error +} + +// Config defines the options used for running the stream server. +type Config struct { + // The host:port address the server will listen on. + Addr string + + // How long to leave idle connections open for. + StreamIdleTimeout time.Duration + // How long to wait for clients to create streams. Only used for SPDY streaming. + StreamCreationTimeout time.Duration + + // The streaming protocols the server supports (understands and permits). See + // k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols. + // Only used for SPDY streaming. + SupportedProtocols []string + + // The config for serving over TLS. If nil, TLS will not be used. + TLSConfig *tls.Config +} + +// DefaultConfig provides default values for server Config. The DefaultConfig is partial, so +// some fields like Addr must still be provided. +var DefaultConfig = Config{ + StreamIdleTimeout: 4 * time.Hour, + StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout, + SupportedProtocols: remotecommand.SupportedStreamingProtocols, +} + +// TODO(timstclair): Add auth(n/z) interface & handling. +func NewServer(config Config, runtime Runtime) (Server, error) { + s := &server{ + config: config, + runtime: &criAdapter{runtime}, + } + + ws := &restful.WebService{} + endpoints := []struct { + path string + handler restful.RouteFunction + }{ + {"/exec/{containerID}", s.serveExec}, + {"/attach/{containerID}", s.serveAttach}, + {"/portforward/{podSandboxID}", s.servePortForward}, + } + for _, e := range endpoints { + for _, method := range []string{"GET", "POST"} { + ws.Route(ws. + Method(method). + Path(e.path). + To(e.handler)) + } + } + handler := restful.NewContainer() + handler.Add(ws) + s.handler = handler + + return s, nil +} + +type server struct { + config Config + runtime *criAdapter + handler http.Handler +} + +func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) { + url := s.buildURL("exec", req.GetContainerId(), streamOpts{ + stdin: req.GetStdin(), + stdout: true, + stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout. + tty: req.GetTty(), + command: req.GetCmd(), + }) + return &runtimeapi.ExecResponse{ + Url: &url, + }, nil +} + +func (s *server) GetAttach(req *runtimeapi.AttachRequest, tty bool) (*runtimeapi.AttachResponse, error) { + url := s.buildURL("attach", req.GetContainerId(), streamOpts{ + stdin: req.GetStdin(), + stdout: true, + stderr: !tty, // For TTY connections, both stderr is combined with stdout. + tty: tty, + }) + return &runtimeapi.AttachResponse{ + Url: &url, + }, nil +} + +func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) { + url := s.buildURL("portforward", req.GetPodSandboxId(), streamOpts{}) + return &runtimeapi.PortForwardResponse{ + Url: &url, + }, nil +} + +func (s *server) Start(stayUp bool) error { + if !stayUp { + // TODO(timstclair): Implement this. + return errors.New("stayUp=false is not yet implemented") + } + + server := &http.Server{ + Addr: s.config.Addr, + Handler: s.handler, + TLSConfig: s.config.TLSConfig, + } + if s.config.TLSConfig != nil { + return server.ListenAndServeTLS("", "") // Use certs from TLSConfig. + } else { + return server.ListenAndServe() + } +} + +func (s *server) Stop() error { + // TODO(timstclair): Implement this. + return errors.New("not yet implemented") +} + +func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) +} + +type streamOpts struct { + stdin bool + stdout bool + stderr bool + tty bool + + command []string + port []int32 +} + +const ( + urlParamStdin = api.ExecStdinParam + urlParamStdout = api.ExecStdoutParam + urlParamStderr = api.ExecStderrParam + urlParamTTY = api.ExecTTYParam + urlParamCommand = api.ExecCommandParamm +) + +func (s *server) buildURL(method, id string, opts streamOpts) string { + loc := url.URL{ + Scheme: "http", + Host: s.config.Addr, + Path: fmt.Sprintf("/%s/%s", method, id), + } + if s.config.TLSConfig != nil { + loc.Scheme = "https" + } + + query := url.Values{} + if opts.stdin { + query.Add(urlParamStdin, "1") + } + if opts.stdout { + query.Add(urlParamStdout, "1") + } + if opts.stderr { + query.Add(urlParamStderr, "1") + } + if opts.tty { + query.Add(urlParamTTY, "1") + } + for _, c := range opts.command { + query.Add(urlParamCommand, c) + } + loc.RawQuery = query.Encode() + + return loc.String() +} + +func (s *server) serveExec(req *restful.Request, resp *restful.Response) { + containerID := req.PathParameter("containerID") + if containerID == "" { + resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) + return + } + + remotecommand.ServeExec( + resp.ResponseWriter, + req.Request, + s.runtime, + "", // unused: podName + "", // unusued: podUID + containerID, + s.config.StreamIdleTimeout, + s.config.StreamCreationTimeout, + s.config.SupportedProtocols) +} + +func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { + containerID := req.PathParameter("containerID") + if containerID == "" { + resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) + return + } + + remotecommand.ServeAttach( + resp.ResponseWriter, + req.Request, + s.runtime, + "", // unused: podName + "", // unusued: podUID + containerID, + s.config.StreamIdleTimeout, + s.config.StreamCreationTimeout, + s.config.SupportedProtocols) +} + +func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { + podSandboxID := req.PathParameter("podSandboxID") + if podSandboxID == "" { + resp.WriteError(http.StatusBadRequest, errors.New("missing required podSandboxID path parameter")) + return + } + + portforward.ServePortForward( + resp.ResponseWriter, + req.Request, + s.runtime, + podSandboxID, + "", // unused: podUID + s.config.StreamIdleTimeout, + s.config.StreamCreationTimeout) +} + +// criAdapter wraps the Runtime functions to conform to the remotecommand interfaces. +// The adapter binds the container ID to the container name argument, and the pod sandbox ID to the pod name. +type criAdapter struct { + Runtime +} + +var _ remotecommand.Executor = &criAdapter{} +var _ remotecommand.Attacher = &criAdapter{} +var _ portforward.PortForwarder = &criAdapter{} + +func (a *criAdapter) ExecInContainer(podName string, podUID types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error { + return a.Exec(container, cmd, in, out, err, tty, resize) +} + +func (a *criAdapter) AttachContainer(podName string, podUID types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error { + return a.Attach(container, in, out, err, resize) +} + +func (a *criAdapter) PortForward(podName string, podUID types.UID, port uint16, stream io.ReadWriteCloser) error { + return a.Runtime.PortForward(podName, int32(port), stream) +} diff --git a/pkg/kubelet/server/streaming/server_test.go b/pkg/kubelet/server/streaming/server_test.go new file mode 100644 index 00000000000..9922449ea6a --- /dev/null +++ b/pkg/kubelet/server/streaming/server_test.go @@ -0,0 +1,331 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package streaming + +import ( + "crypto/tls" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "k8s.io/client-go/pkg/api" + "k8s.io/kubernetes/pkg/client/restclient" + "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" + runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" + kubeletportforward "k8s.io/kubernetes/pkg/kubelet/server/portforward" + kubeletremotecommand "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" + "k8s.io/kubernetes/pkg/util/term" +) + +const ( + testAddr = "localhost:12345" + testContainerID = "container789" + testPodSandboxID = "pod0987" +) + +func TestGetExec(t *testing.T) { + testcases := []struct { + cmd []string + tty bool + stdin bool + expectedQuery string + }{ + {[]string{"echo", "foo"}, false, false, "?command=echo&command=foo&error=1&output=1"}, + {[]string{"date"}, true, false, "?command=date&output=1&tty=1"}, + {[]string{"date"}, false, true, "?command=date&error=1&input=1&output=1"}, + {[]string{"date"}, true, true, "?command=date&input=1&output=1&tty=1"}, + } + server, err := NewServer(Config{ + Addr: testAddr, + }, nil) + assert.NoError(t, err) + + tlsServer, err := NewServer(Config{ + Addr: testAddr, + TLSConfig: &tls.Config{}, + }, nil) + assert.NoError(t, err) + + containerID := testContainerID + for _, test := range testcases { + request := &runtimeapi.ExecRequest{ + ContainerId: &containerID, + Cmd: test.cmd, + Tty: &test.tty, + Stdin: &test.stdin, + } + // Non-TLS + resp, err := server.GetExec(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "http://" + testAddr + "/exec/" + testContainerID + test.expectedQuery + assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + + // TLS + resp, err = tlsServer.GetExec(request) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL = "https://" + testAddr + "/exec/" + testContainerID + test.expectedQuery + assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + } +} + +func TestGetAttach(t *testing.T) { + testcases := []struct { + tty bool + stdin bool + expectedQuery string + }{ + {false, false, "?error=1&output=1"}, + {true, false, "?output=1&tty=1"}, + {false, true, "?error=1&input=1&output=1"}, + {true, true, "?input=1&output=1&tty=1"}, + } + server, err := NewServer(Config{ + Addr: testAddr, + }, nil) + assert.NoError(t, err) + + tlsServer, err := NewServer(Config{ + Addr: testAddr, + TLSConfig: &tls.Config{}, + }, nil) + assert.NoError(t, err) + + containerID := testContainerID + for _, test := range testcases { + request := &runtimeapi.AttachRequest{ + ContainerId: &containerID, + Stdin: &test.stdin, + } + // Non-TLS + resp, err := server.GetAttach(request, test.tty) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL := "http://" + testAddr + "/attach/" + testContainerID + test.expectedQuery + assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + + // TLS + resp, err = tlsServer.GetAttach(request, test.tty) + assert.NoError(t, err, "testcase=%+v", test) + expectedURL = "https://" + testAddr + "/attach/" + testContainerID + test.expectedQuery + assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) + } +} + +func TestGetPortForward(t *testing.T) { + podSandboxID := testPodSandboxID + request := &runtimeapi.PortForwardRequest{ + PodSandboxId: &podSandboxID, + Port: []int32{1, 2, 3, 4}, + } + + // Non-TLS + server, err := NewServer(Config{ + Addr: testAddr, + }, nil) + assert.NoError(t, err) + resp, err := server.GetPortForward(request) + assert.NoError(t, err) + expectedURL := "http://" + testAddr + "/portforward/" + testPodSandboxID + assert.Equal(t, expectedURL, resp.GetUrl()) + + // TLS + tlsServer, err := NewServer(Config{ + Addr: testAddr, + TLSConfig: &tls.Config{}, + }, nil) + assert.NoError(t, err) + resp, err = tlsServer.GetPortForward(request) + assert.NoError(t, err) + expectedURL = "https://" + testAddr + "/portforward/" + testPodSandboxID + assert.Equal(t, expectedURL, resp.GetUrl()) +} + +func TestServeExec(t *testing.T) { + runRemoteCommandTest(t, "exec") +} + +func TestServeAttach(t *testing.T) { + runRemoteCommandTest(t, "attach") +} + +func TestServePortForward(t *testing.T) { + rt := newFakeRuntime(t) + s, err := NewServer(DefaultConfig, rt) + require.NoError(t, err) + testServer := httptest.NewServer(s) + defer testServer.Close() + + testURL, err := url.Parse(testServer.URL) + require.NoError(t, err) + loc := &url.URL{ + Scheme: testURL.Scheme, + Host: testURL.Host, + } + + loc.Path = fmt.Sprintf("/%s/%s", "portforward", testPodSandboxID) + exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) + require.NoError(t, err) + streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) + require.NoError(t, err) + defer streamConn.Close() + + // Create the streams. + headers := http.Header{} + // Error stream is required, but unused in this test. + headers.Set(api.StreamType, api.StreamTypeError) + headers.Set(api.PortHeader, strconv.Itoa(testPort)) + _, err = streamConn.CreateStream(headers) + require.NoError(t, err) + // Setup the data stream. + headers.Set(api.StreamType, api.StreamTypeData) + headers.Set(api.PortHeader, strconv.Itoa(testPort)) + stream, err := streamConn.CreateStream(headers) + require.NoError(t, err) + + doClientStreams(t, "portforward", stream, stream, nil) +} + +// Run the remote command test. +// commandType is either "exec" or "attach". +func runRemoteCommandTest(t *testing.T, commandType string) { + rt := newFakeRuntime(t) + s, err := NewServer(DefaultConfig, rt) + require.NoError(t, err) + testServer := httptest.NewServer(s) + defer testServer.Close() + + testURL, err := url.Parse(testServer.URL) + require.NoError(t, err) + query := url.Values{} + query.Add(urlParamStdin, "1") + query.Add(urlParamStdout, "1") + query.Add(urlParamStderr, "1") + loc := &url.URL{ + Scheme: testURL.Scheme, + Host: testURL.Host, + RawQuery: query.Encode(), + } + + wg := sync.WaitGroup{} + wg.Add(2) + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + stderrR, stderrW := io.Pipe() + + go func() { + defer wg.Done() + loc.Path = fmt.Sprintf("/%s/%s", commandType, testContainerID) + exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) + require.NoError(t, err) + + opts := remotecommand.StreamOptions{ + SupportedProtocols: kubeletremotecommand.SupportedStreamingProtocols, + Stdin: stdinR, + Stdout: stdoutW, + Stderr: stderrW, + Tty: false, + TerminalSizeQueue: nil, + } + require.NoError(t, exec.Stream(opts)) + }() + + go func() { + defer wg.Done() + doClientStreams(t, commandType, stdinW, stdoutR, stderrR) + }() + + wg.Wait() +} + +const ( + testInput = "abcdefg" + testOutput = "fooBARbaz" + testErr = "ERROR!!!" + testPort = 12345 +) + +func newFakeRuntime(t *testing.T) *fakeRuntime { + return &fakeRuntime{ + t: t, + } +} + +type fakeRuntime struct { + t *testing.T +} + +func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan term.Size) error { + assert.Equal(f.t, testContainerID, containerID) + doServerStreams(f.t, "exec", stdin, stdout, stderr) + return nil +} + +func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, resize <-chan term.Size) error { + assert.Equal(f.t, testContainerID, containerID) + doServerStreams(f.t, "attach", stdin, stdout, stderr) + return nil +} + +func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + assert.Equal(f.t, testPodSandboxID, podSandboxID) + assert.EqualValues(f.t, testPort, port) + doServerStreams(f.t, "portforward", stream, stream, nil) + return nil +} + +// Send & receive expected input/output. Must be the inverse of doClientStreams. +// Function will block until the expected i/o is finished. +func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) { + if stderr != nil { + writeExpected(t, "server stderr", stderr, prefix+testErr) + } + readExpected(t, "server stdin", stdin, prefix+testInput) + writeExpected(t, "server stdout", stdout, prefix+testOutput) +} + +// Send & receive expected input/output. Must be the inverse of doServerStreams. +// Function will block until the expected i/o is finished. +func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) { + if stderr != nil { + readExpected(t, "client stderr", stderr, prefix+testErr) + } + writeExpected(t, "client stdin", stdin, prefix+testInput) + readExpected(t, "client stdout", stdout, prefix+testOutput) +} + +// Read and verify the expected string from the stream. +func readExpected(t *testing.T, streamName string, r io.Reader, expected string) { + result := make([]byte, len(expected)) + _, err := io.ReadAtLeast(r, result, len(expected)) + assert.NoError(t, err, "stream %s", streamName) + assert.Equal(t, expected, string(result), "stream %s", streamName) +} + +// Write and verify success of the data over the stream. +func writeExpected(t *testing.T, streamName string, w io.Writer, data string) { + n, err := io.WriteString(w, data) + assert.NoError(t, err, "stream %s", streamName) + assert.Equal(t, len(data), n, "stream %s", streamName) +}