mirror of
https://github.com/k3s-io/kubernetes.git
synced 2026-01-04 23:17:50 +00:00
Don't include user data in CRI streaming redirect URLs
This commit is contained in:
@@ -18,12 +18,12 @@ package streaming
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -46,18 +46,18 @@ const (
|
||||
)
|
||||
|
||||
func TestGetExec(t *testing.T) {
|
||||
testcases := []struct {
|
||||
cmd []string
|
||||
tty bool
|
||||
stdin bool
|
||||
expectedQuery string
|
||||
}{
|
||||
{[]string{"echo", "foo"}, false, false, "?command=echo&command=foo&error=1&output=1"},
|
||||
{[]string{"date"}, true, false, "?command=date&output=1&tty=1"},
|
||||
{[]string{"date"}, false, true, "?command=date&error=1&input=1&output=1"},
|
||||
{[]string{"date"}, true, true, "?command=date&input=1&output=1&tty=1"},
|
||||
type testcase struct {
|
||||
cmd []string
|
||||
tty bool
|
||||
stdin bool
|
||||
}
|
||||
server, err := NewServer(Config{
|
||||
testcases := []testcase{
|
||||
{[]string{"echo", "foo"}, false, false},
|
||||
{[]string{"date"}, true, false},
|
||||
{[]string{"date"}, false, true},
|
||||
{[]string{"date"}, true, true},
|
||||
}
|
||||
serv, err := NewServer(Config{
|
||||
Addr: testAddr,
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -79,6 +79,14 @@ func TestGetExec(t *testing.T) {
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assertRequestToken := func(test testcase, cache *requestCache, token string) {
|
||||
req, ok := cache.Consume(token)
|
||||
require.True(t, ok, "token %s not found! testcase=%+v", token, test)
|
||||
assert.Equal(t, testContainerID, req.(*runtimeapi.ExecRequest).GetContainerId(), "testcase=%+v", test)
|
||||
assert.Equal(t, test.cmd, req.(*runtimeapi.ExecRequest).GetCmd(), "testcase=%+v", test)
|
||||
assert.Equal(t, test.tty, req.(*runtimeapi.ExecRequest).GetTty(), "testcase=%+v", test)
|
||||
assert.Equal(t, test.stdin, req.(*runtimeapi.ExecRequest).GetStdin(), "testcase=%+v", test)
|
||||
}
|
||||
containerID := testContainerID
|
||||
for _, test := range testcases {
|
||||
request := &runtimeapi.ExecRequest{
|
||||
@@ -87,38 +95,47 @@ func TestGetExec(t *testing.T) {
|
||||
Tty: &test.tty,
|
||||
Stdin: &test.stdin,
|
||||
}
|
||||
// Non-TLS
|
||||
resp, err := server.GetExec(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "http://" + testAddr + "/exec/" + testContainerID + test.expectedQuery
|
||||
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
|
||||
{ // Non-TLS
|
||||
resp, err := serv.GetExec(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "http://" + testAddr + "/exec/"
|
||||
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
assertRequestToken(test, serv.(*server).cache, token)
|
||||
}
|
||||
|
||||
// TLS
|
||||
resp, err = tlsServer.GetExec(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL = "https://" + testAddr + "/exec/" + testContainerID + test.expectedQuery
|
||||
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
|
||||
{ // TLS
|
||||
resp, err := tlsServer.GetExec(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "https://" + testAddr + "/exec/"
|
||||
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
assertRequestToken(test, tlsServer.(*server).cache, token)
|
||||
}
|
||||
|
||||
// Path prefix
|
||||
resp, err = prefixServer.GetExec(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL = "http://" + testAddr + "/" + pathPrefix + "/exec/" + testContainerID + test.expectedQuery
|
||||
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
|
||||
{ // Path prefix
|
||||
resp, err := prefixServer.GetExec(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/"
|
||||
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
assertRequestToken(test, prefixServer.(*server).cache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttach(t *testing.T) {
|
||||
testcases := []struct {
|
||||
tty bool
|
||||
stdin bool
|
||||
expectedQuery string
|
||||
}{
|
||||
{false, false, "?error=1&output=1"},
|
||||
{true, false, "?output=1&tty=1"},
|
||||
{false, true, "?error=1&input=1&output=1"},
|
||||
{true, true, "?input=1&output=1&tty=1"},
|
||||
type testcase struct {
|
||||
tty bool
|
||||
stdin bool
|
||||
}
|
||||
server, err := NewServer(Config{
|
||||
testcases := []testcase{
|
||||
{false, false},
|
||||
{true, false},
|
||||
{false, true},
|
||||
{true, true},
|
||||
}
|
||||
serv, err := NewServer(Config{
|
||||
Addr: testAddr,
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -129,6 +146,13 @@ func TestGetAttach(t *testing.T) {
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assertRequestToken := func(test testcase, cache *requestCache, token string) {
|
||||
req, ok := cache.Consume(token)
|
||||
require.True(t, ok, "token %s not found! testcase=%+v", token, test)
|
||||
assert.Equal(t, testContainerID, req.(*runtimeapi.AttachRequest).GetContainerId(), "testcase=%+v", test)
|
||||
assert.Equal(t, test.tty, req.(*runtimeapi.AttachRequest).GetTty(), "testcase=%+v", test)
|
||||
assert.Equal(t, test.stdin, req.(*runtimeapi.AttachRequest).GetStdin(), "testcase=%+v", test)
|
||||
}
|
||||
containerID := testContainerID
|
||||
for _, test := range testcases {
|
||||
request := &runtimeapi.AttachRequest{
|
||||
@@ -136,17 +160,23 @@ func TestGetAttach(t *testing.T) {
|
||||
Stdin: &test.stdin,
|
||||
Tty: &test.tty,
|
||||
}
|
||||
// Non-TLS
|
||||
resp, err := server.GetAttach(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "http://" + testAddr + "/attach/" + testContainerID + test.expectedQuery
|
||||
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
|
||||
{ // Non-TLS
|
||||
resp, err := serv.GetAttach(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "http://" + testAddr + "/attach/"
|
||||
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
assertRequestToken(test, serv.(*server).cache, token)
|
||||
}
|
||||
|
||||
// TLS
|
||||
resp, err = tlsServer.GetAttach(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL = "https://" + testAddr + "/attach/" + testContainerID + test.expectedQuery
|
||||
assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test)
|
||||
{ // TLS
|
||||
resp, err := tlsServer.GetAttach(request)
|
||||
assert.NoError(t, err, "testcase=%+v", test)
|
||||
expectedURL := "https://" + testAddr + "/attach/"
|
||||
assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test)
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
assertRequestToken(test, tlsServer.(*server).cache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,26 +187,36 @@ func TestGetPortForward(t *testing.T) {
|
||||
Port: []int32{1, 2, 3, 4},
|
||||
}
|
||||
|
||||
// Non-TLS
|
||||
server, err := NewServer(Config{
|
||||
Addr: testAddr,
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
resp, err := server.GetPortForward(request)
|
||||
assert.NoError(t, err)
|
||||
expectedURL := "http://" + testAddr + "/portforward/" + testPodSandboxID
|
||||
assert.Equal(t, expectedURL, resp.GetUrl())
|
||||
{ // Non-TLS
|
||||
serv, err := NewServer(Config{
|
||||
Addr: testAddr,
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
resp, err := serv.GetPortForward(request)
|
||||
assert.NoError(t, err)
|
||||
expectedURL := "http://" + testAddr + "/portforward/"
|
||||
assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL))
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
req, ok := serv.(*server).cache.Consume(token)
|
||||
require.True(t, ok, "token %s not found!", token)
|
||||
assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId())
|
||||
}
|
||||
|
||||
// TLS
|
||||
tlsServer, err := NewServer(Config{
|
||||
Addr: testAddr,
|
||||
TLSConfig: &tls.Config{},
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
resp, err = tlsServer.GetPortForward(request)
|
||||
assert.NoError(t, err)
|
||||
expectedURL = "https://" + testAddr + "/portforward/" + testPodSandboxID
|
||||
assert.Equal(t, expectedURL, resp.GetUrl())
|
||||
{ // TLS
|
||||
tlsServer, err := NewServer(Config{
|
||||
Addr: testAddr,
|
||||
TLSConfig: &tls.Config{},
|
||||
}, nil)
|
||||
assert.NoError(t, err)
|
||||
resp, err := tlsServer.GetPortForward(request)
|
||||
assert.NoError(t, err)
|
||||
expectedURL := "https://" + testAddr + "/portforward/"
|
||||
assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL))
|
||||
token := strings.TrimPrefix(resp.GetUrl(), expectedURL)
|
||||
req, ok := tlsServer.(*server).cache.Consume(token)
|
||||
require.True(t, ok, "token %s not found!", token)
|
||||
assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeExec(t *testing.T) {
|
||||
@@ -188,21 +228,18 @@ func TestServeAttach(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServePortForward(t *testing.T) {
|
||||
rt := newFakeRuntime(t)
|
||||
s, err := NewServer(DefaultConfig, rt)
|
||||
require.NoError(t, err)
|
||||
testServer := httptest.NewServer(s)
|
||||
s, testServer := startTestServer(t)
|
||||
defer testServer.Close()
|
||||
|
||||
testURL, err := url.Parse(testServer.URL)
|
||||
podSandboxID := testPodSandboxID
|
||||
resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{
|
||||
PodSandboxId: &podSandboxID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
reqURL, err := url.Parse(resp.GetUrl())
|
||||
require.NoError(t, err)
|
||||
loc := &url.URL{
|
||||
Scheme: testURL.Scheme,
|
||||
Host: testURL.Host,
|
||||
}
|
||||
|
||||
loc.Path = fmt.Sprintf("/%s/%s", "portforward", testPodSandboxID)
|
||||
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc)
|
||||
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
|
||||
require.NoError(t, err)
|
||||
streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name)
|
||||
require.NoError(t, err)
|
||||
@@ -227,22 +264,30 @@ func TestServePortForward(t *testing.T) {
|
||||
// Run the remote command test.
|
||||
// commandType is either "exec" or "attach".
|
||||
func runRemoteCommandTest(t *testing.T, commandType string) {
|
||||
rt := newFakeRuntime(t)
|
||||
s, err := NewServer(DefaultConfig, rt)
|
||||
require.NoError(t, err)
|
||||
testServer := httptest.NewServer(s)
|
||||
s, testServer := startTestServer(t)
|
||||
defer testServer.Close()
|
||||
|
||||
testURL, err := url.Parse(testServer.URL)
|
||||
require.NoError(t, err)
|
||||
query := url.Values{}
|
||||
query.Add(urlParamStdin, "1")
|
||||
query.Add(urlParamStdout, "1")
|
||||
query.Add(urlParamStderr, "1")
|
||||
loc := &url.URL{
|
||||
Scheme: testURL.Scheme,
|
||||
Host: testURL.Host,
|
||||
RawQuery: query.Encode(),
|
||||
var reqURL *url.URL
|
||||
stdin := true
|
||||
containerID := testContainerID
|
||||
switch commandType {
|
||||
case "exec":
|
||||
resp, err := s.GetExec(&runtimeapi.ExecRequest{
|
||||
ContainerId: &containerID,
|
||||
Cmd: []string{"echo"},
|
||||
Stdin: &stdin,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
reqURL, err = url.Parse(resp.GetUrl())
|
||||
require.NoError(t, err)
|
||||
case "attach":
|
||||
resp, err := s.GetAttach(&runtimeapi.AttachRequest{
|
||||
ContainerId: &containerID,
|
||||
Stdin: &stdin,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
reqURL, err = url.Parse(resp.GetUrl())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
@@ -254,8 +299,7 @@ func runRemoteCommandTest(t *testing.T, commandType string) {
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
loc.Path = fmt.Sprintf("/%s/%s", commandType, testContainerID)
|
||||
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc)
|
||||
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := remotecommand.StreamOptions{
|
||||
@@ -275,6 +319,36 @@ func runRemoteCommandTest(t *testing.T, commandType string) {
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Repeat request with the same URL should be a 404.
|
||||
resp, err := http.Get(reqURL.String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T) (Server, *httptest.Server) {
|
||||
var s Server
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.ServeHTTP(w, r)
|
||||
}))
|
||||
cleanup := true
|
||||
defer func() {
|
||||
if cleanup {
|
||||
testServer.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
testURL, err := url.Parse(testServer.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
rt := newFakeRuntime(t)
|
||||
config := DefaultConfig
|
||||
config.BaseURL = testURL
|
||||
s, err = NewServer(config, rt)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup = false // Caller must close the test server.
|
||||
return s, testServer
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
Reference in New Issue
Block a user