mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-23 03:41:45 +00:00
Merge pull request #108531 from tallclair/redirects
Don't follow redirects with spdy
This commit is contained in:
commit
ea006f5246
@ -863,7 +863,7 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) {
|
||||
|
||||
url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1"
|
||||
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil, true, true)
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil)
|
||||
c := &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
resp, err := c.Do(makeReq(t, "POST", url, "v4.channel.k8s.io"))
|
||||
@ -1019,7 +1019,7 @@ func testExecAttach(t *testing.T, verb string) {
|
||||
upgradeRoundTripper httpstream.UpgradeRoundTripper
|
||||
c *http.Client
|
||||
)
|
||||
upgradeRoundTripper = spdy.NewRoundTripper(nil, true, true)
|
||||
upgradeRoundTripper = spdy.NewRoundTripper(nil)
|
||||
c = &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
resp, err = c.Do(makeReq(t, "POST", url, "v4.channel.k8s.io"))
|
||||
@ -1115,7 +1115,7 @@ func TestServePortForwardIdleTimeout(t *testing.T) {
|
||||
|
||||
url := fw.testHTTPServer.URL + "/portForward/" + podNamespace + "/" + podName
|
||||
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil, true, true)
|
||||
upgradeRoundTripper := spdy.NewRoundTripper(nil)
|
||||
c := &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
req := makeReq(t, "POST", url, "portforward.k8s.io")
|
||||
@ -1214,7 +1214,7 @@ func TestServePortForward(t *testing.T) {
|
||||
c *http.Client
|
||||
)
|
||||
|
||||
upgradeRoundTripper = spdy.NewRoundTripper(nil, true, true)
|
||||
upgradeRoundTripper = spdy.NewRoundTripper(nil)
|
||||
c = &http.Client{Transport: upgradeRoundTripper}
|
||||
|
||||
req := makeReq(t, "POST", url, "portforward.k8s.io")
|
||||
|
@ -67,12 +67,6 @@ type SpdyRoundTripper struct {
|
||||
// Used primarily for mocking the proxy discovery in tests.
|
||||
proxier func(req *http.Request) (*url.URL, error)
|
||||
|
||||
// followRedirects indicates if the round tripper should examine responses for redirects and
|
||||
// follow them.
|
||||
followRedirects bool
|
||||
// requireSameHostRedirects restricts redirect following to only follow redirects to the same host
|
||||
// as the original request.
|
||||
requireSameHostRedirects bool
|
||||
// pingPeriod is a period for sending Ping frames over established
|
||||
// connections.
|
||||
pingPeriod time.Duration
|
||||
@ -84,22 +78,18 @@ var _ utilnet.Dialer = &SpdyRoundTripper{}
|
||||
|
||||
// NewRoundTripper creates a new SpdyRoundTripper that will use the specified
|
||||
// tlsConfig.
|
||||
func NewRoundTripper(tlsConfig *tls.Config, followRedirects, requireSameHostRedirects bool) *SpdyRoundTripper {
|
||||
func NewRoundTripper(tlsConfig *tls.Config) *SpdyRoundTripper {
|
||||
return NewRoundTripperWithConfig(RoundTripperConfig{
|
||||
TLS: tlsConfig,
|
||||
FollowRedirects: followRedirects,
|
||||
RequireSameHostRedirects: requireSameHostRedirects,
|
||||
TLS: tlsConfig,
|
||||
})
|
||||
}
|
||||
|
||||
// NewRoundTripperWithProxy creates a new SpdyRoundTripper that will use the
|
||||
// specified tlsConfig and proxy func.
|
||||
func NewRoundTripperWithProxy(tlsConfig *tls.Config, followRedirects, requireSameHostRedirects bool, proxier func(*http.Request) (*url.URL, error)) *SpdyRoundTripper {
|
||||
func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) *SpdyRoundTripper {
|
||||
return NewRoundTripperWithConfig(RoundTripperConfig{
|
||||
TLS: tlsConfig,
|
||||
FollowRedirects: followRedirects,
|
||||
RequireSameHostRedirects: requireSameHostRedirects,
|
||||
Proxier: proxier,
|
||||
TLS: tlsConfig,
|
||||
Proxier: proxier,
|
||||
})
|
||||
}
|
||||
|
||||
@ -110,11 +100,9 @@ func NewRoundTripperWithConfig(cfg RoundTripperConfig) *SpdyRoundTripper {
|
||||
cfg.Proxier = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
|
||||
}
|
||||
return &SpdyRoundTripper{
|
||||
tlsConfig: cfg.TLS,
|
||||
followRedirects: cfg.FollowRedirects,
|
||||
requireSameHostRedirects: cfg.RequireSameHostRedirects,
|
||||
proxier: cfg.Proxier,
|
||||
pingPeriod: cfg.PingPeriod,
|
||||
tlsConfig: cfg.TLS,
|
||||
proxier: cfg.Proxier,
|
||||
pingPeriod: cfg.PingPeriod,
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,9 +115,6 @@ type RoundTripperConfig struct {
|
||||
// PingPeriod is a period for sending SPDY Pings on the connection.
|
||||
// Optional.
|
||||
PingPeriod time.Duration
|
||||
|
||||
FollowRedirects bool
|
||||
RequireSameHostRedirects bool
|
||||
}
|
||||
|
||||
// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during
|
||||
@ -365,13 +350,9 @@ func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
err error
|
||||
)
|
||||
|
||||
if s.followRedirects {
|
||||
conn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, req.URL, header, req.Body, s, s.requireSameHostRedirects)
|
||||
} else {
|
||||
clone := utilnet.CloneRequest(req)
|
||||
clone.Header = header
|
||||
conn, err = s.Dial(clone)
|
||||
}
|
||||
clone := utilnet.CloneRequest(req)
|
||||
clone.Header = header
|
||||
conn, err = s.Dial(clone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,7 +17,6 @@ limitations under the License.
|
||||
package net
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
@ -446,104 +445,6 @@ type Dialer interface {
|
||||
Dial(req *http.Request) (net.Conn, error)
|
||||
}
|
||||
|
||||
// ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to
|
||||
// originalLocation). It returns the opened net.Conn and the raw response bytes.
|
||||
// If requireSameHostRedirects is true, only redirects to the same host are permitted.
|
||||
func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer, requireSameHostRedirects bool) (net.Conn, []byte, error) {
|
||||
const (
|
||||
maxRedirects = 9 // Fail on the 10th redirect
|
||||
maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers
|
||||
)
|
||||
|
||||
var (
|
||||
location = originalLocation
|
||||
method = originalMethod
|
||||
intermediateConn net.Conn
|
||||
rawResponse = bytes.NewBuffer(make([]byte, 0, 256))
|
||||
body = originalBody
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if intermediateConn != nil {
|
||||
intermediateConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
redirectLoop:
|
||||
for redirects := 0; ; redirects++ {
|
||||
if redirects > maxRedirects {
|
||||
return nil, nil, fmt.Errorf("too many redirects (%d)", redirects)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, location.String(), body)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
req.Header = header
|
||||
|
||||
intermediateConn, err = dialer.Dial(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Peek at the backend response.
|
||||
rawResponse.Reset()
|
||||
respReader := bufio.NewReader(io.TeeReader(
|
||||
io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes.
|
||||
rawResponse)) // Save the raw response.
|
||||
resp, err := http.ReadResponse(respReader, nil)
|
||||
if err != nil {
|
||||
// Unable to read the backend response; let the client handle it.
|
||||
klog.Warningf("Error reading backend response: %v", err)
|
||||
break redirectLoop
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusFound:
|
||||
// Redirect, continue.
|
||||
default:
|
||||
// Don't redirect.
|
||||
break redirectLoop
|
||||
}
|
||||
|
||||
// Redirected requests switch to "GET" according to the HTTP spec:
|
||||
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3
|
||||
method = "GET"
|
||||
// don't send a body when following redirects
|
||||
body = nil
|
||||
|
||||
resp.Body.Close() // not used
|
||||
|
||||
// Prepare to follow the redirect.
|
||||
redirectStr := resp.Header.Get("Location")
|
||||
if redirectStr == "" {
|
||||
return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode)
|
||||
}
|
||||
// We have to parse relative to the current location, NOT originalLocation. For example,
|
||||
// if we request http://foo.com/a and get back "http://bar.com/b", the result should be
|
||||
// http://bar.com/b. If we then make that request and get back a redirect to "/c", the result
|
||||
// should be http://bar.com/c, not http://foo.com/c.
|
||||
location, err = location.Parse(redirectStr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("malformed Location header: %v", err)
|
||||
}
|
||||
|
||||
// Only follow redirects to the same host. Otherwise, propagate the redirect response back.
|
||||
if requireSameHostRedirects && location.Hostname() != originalLocation.Hostname() {
|
||||
return nil, nil, fmt.Errorf("hostname mismatch: expected %s, found %s", originalLocation.Hostname(), location.Hostname())
|
||||
}
|
||||
|
||||
// Reset the connection.
|
||||
intermediateConn.Close()
|
||||
intermediateConn = nil
|
||||
}
|
||||
|
||||
connToReturn := intermediateConn
|
||||
intermediateConn = nil // Don't close the connection when we return it.
|
||||
return connToReturn, rawResponse.Bytes(), nil
|
||||
}
|
||||
|
||||
// CloneRequest creates a shallow copy of the request along with a deep copy of the Headers.
|
||||
func CloneRequest(req *http.Request) *http.Request {
|
||||
r := new(http.Request)
|
||||
|
@ -20,15 +20,11 @@ limitations under the License.
|
||||
package net
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
@ -36,8 +32,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
netutils "k8s.io/utils/net"
|
||||
)
|
||||
|
||||
@ -293,157 +287,6 @@ func TestJoinPreservingTrailingSlash(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithRedirects(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
redirects []string
|
||||
method string // initial request method, empty == GET
|
||||
expectError bool
|
||||
expectedRedirects int
|
||||
newPort bool // special case different port test
|
||||
}{{
|
||||
desc: "relative redirects allowed",
|
||||
redirects: []string{"/ok"},
|
||||
expectedRedirects: 1,
|
||||
}, {
|
||||
desc: "redirects to the same host are allowed",
|
||||
redirects: []string{"http://HOST/ok"}, // HOST replaced with server address in test
|
||||
expectedRedirects: 1,
|
||||
}, {
|
||||
desc: "POST redirects to GET",
|
||||
method: http.MethodPost,
|
||||
redirects: []string{"/ok"},
|
||||
expectedRedirects: 1,
|
||||
}, {
|
||||
desc: "PUT redirects to GET",
|
||||
method: http.MethodPut,
|
||||
redirects: []string{"/ok"},
|
||||
expectedRedirects: 1,
|
||||
}, {
|
||||
desc: "DELETE redirects to GET",
|
||||
method: http.MethodDelete,
|
||||
redirects: []string{"/ok"},
|
||||
expectedRedirects: 1,
|
||||
}, {
|
||||
desc: "9 redirects are allowed",
|
||||
redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9"},
|
||||
expectedRedirects: 9,
|
||||
}, {
|
||||
desc: "10 redirects are forbidden",
|
||||
redirects: []string{"/1", "/2", "/3", "/4", "/5", "/6", "/7", "/8", "/9", "/10"},
|
||||
expectError: true,
|
||||
}, {
|
||||
desc: "redirect to different host are prevented",
|
||||
redirects: []string{"http://example.com/foo"},
|
||||
expectError: true,
|
||||
}, {
|
||||
desc: "multiple redirect to different host forbidden",
|
||||
redirects: []string{"/1", "/2", "/3", "http://example.com/foo"},
|
||||
expectError: true,
|
||||
}, {
|
||||
desc: "redirect to different port is allowed",
|
||||
redirects: []string{"http://HOST/foo"},
|
||||
expectedRedirects: 1,
|
||||
newPort: true,
|
||||
}}
|
||||
|
||||
const resultString = "Test output"
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
redirectCount := 0
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
// Verify redirect request.
|
||||
if redirectCount > 0 {
|
||||
expectedURL, err := url.Parse(test.redirects[redirectCount-1])
|
||||
require.NoError(t, err, "test URL error")
|
||||
assert.Equal(t, req.URL.Path, expectedURL.Path, "unknown redirect path")
|
||||
assert.Equal(t, http.MethodGet, req.Method, "redirects must always be GET")
|
||||
}
|
||||
if redirectCount < len(test.redirects) {
|
||||
http.Redirect(w, req, test.redirects[redirectCount], http.StatusFound)
|
||||
redirectCount++
|
||||
} else if redirectCount == len(test.redirects) {
|
||||
w.Write([]byte(resultString))
|
||||
} else {
|
||||
t.Errorf("unexpected number of redirects %d to %s", redirectCount, req.URL.String())
|
||||
}
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
u, err := url.Parse(s.URL)
|
||||
require.NoError(t, err, "Error parsing server URL")
|
||||
host := u.Host
|
||||
|
||||
// Special case new-port test with a secondary server.
|
||||
if test.newPort {
|
||||
s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte(resultString))
|
||||
}))
|
||||
defer s2.Close()
|
||||
u2, err := url.Parse(s2.URL)
|
||||
require.NoError(t, err, "Error parsing secondary server URL")
|
||||
|
||||
// Sanity check: secondary server uses same hostname, different port.
|
||||
require.Equal(t, u.Hostname(), u2.Hostname(), "sanity check: same hostname")
|
||||
require.NotEqual(t, u.Port(), u2.Port(), "sanity check: different port")
|
||||
|
||||
// Redirect to the secondary server.
|
||||
host = u2.Host
|
||||
|
||||
}
|
||||
|
||||
// Update redirect URLs with actual host.
|
||||
for i := range test.redirects {
|
||||
test.redirects[i] = strings.Replace(test.redirects[i], "HOST", host, 1)
|
||||
}
|
||||
|
||||
method := test.method
|
||||
if method == "" {
|
||||
method = http.MethodGet
|
||||
}
|
||||
|
||||
netdialer := &net.Dialer{
|
||||
Timeout: wait.ForeverTestTimeout,
|
||||
KeepAlive: wait.ForeverTestTimeout,
|
||||
}
|
||||
dialer := DialerFunc(func(req *http.Request) (net.Conn, error) {
|
||||
conn, err := netdialer.Dial("tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
return conn, err
|
||||
}
|
||||
if err = req.Write(conn); err != nil {
|
||||
require.NoError(t, conn.Close())
|
||||
return nil, fmt.Errorf("error sending request: %v", err)
|
||||
}
|
||||
return conn, err
|
||||
})
|
||||
conn, rawResponse, err := ConnectWithRedirects(method, u, http.Header{} /*body*/, nil, dialer, true)
|
||||
if test.expectError {
|
||||
require.Error(t, err, "expected request error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "unexpected request error")
|
||||
assert.NoError(t, conn.Close(), "error closing connection")
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(rawResponse)), nil)
|
||||
require.NoError(t, err, "unexpected request error")
|
||||
|
||||
result, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
if test.expectedRedirects < len(test.redirects) {
|
||||
// Expect the last redirect to be returned.
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode, "Final response is not a redirect")
|
||||
assert.Equal(t, test.redirects[len(test.redirects)-1], resp.Header.Get("Location"))
|
||||
assert.NotEqual(t, resultString, string(result), "wrong content")
|
||||
} else {
|
||||
assert.Equal(t, resultString, string(result), "stream content does not match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowsHTTP2(t *testing.T) {
|
||||
testcases := []struct {
|
||||
Name string
|
||||
|
@ -44,11 +44,9 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er
|
||||
proxy = config.Proxy
|
||||
}
|
||||
upgradeRoundTripper := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{
|
||||
TLS: tlsConfig,
|
||||
FollowRedirects: true,
|
||||
RequireSameHostRedirects: false,
|
||||
Proxier: proxy,
|
||||
PingPeriod: time.Second * 5,
|
||||
TLS: tlsConfig,
|
||||
Proxier: proxy,
|
||||
PingPeriod: time.Second * 5,
|
||||
})
|
||||
wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user