Disable HTTP2 while proxying a "Connection: upgrade" request

When proxying connection upgrade requests, like websockets, we dial
the target and then manually write the http.Request to the wire,
bypassing the http.Client.  In this scenario we are by default using
HTTP/1.1 from both the client and to the target server we are proxying.
Because of this we must disable HTTP2 in the TLS handshake so that the
server does not think we are writing a HTTP2 request. We do this by
setting the TLSConfig.NextProtos field to "http/1.1".

Signed-off-by: Darren Shepherd <darren@rancher.com>
This commit is contained in:
Darren Shepherd 2020-03-03 15:53:15 -07:00
parent 861c918a44
commit eb9cf777dc
4 changed files with 73 additions and 10 deletions

View File

@ -30,7 +30,12 @@ import (
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
// dialURL will dial the specified URL using the underlying dialer held by the passed
// RoundTripper. The primary use of this method is to support proxying upgradable connections.
// For this reason this method will prefer to negotiate http/1.1 if the URL scheme is https.
// If you wish to ensure ALPN negotiates http2 then set NextProto=[]string{"http2"} in the
// TLSConfig of the http.Transport
func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer, err := utilnet.DialerFor(transport)
@ -81,6 +86,15 @@ func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne
tlsConfigCopy.ServerName = inferredHost
tlsConfig = tlsConfigCopy
}
// Since this method is primary used within a "Connection: Upgrade" call we assume the caller is
// going to write HTTP/1.1 request to the wire. http2 should not be allowed in the TLSConfig.NextProtos,
// so we explicitly set that here. We only do this check if the TLSConfig support http/1.1.
if supportsHTTP11(tlsConfig.NextProtos) {
tlsConfig = tlsConfig.Clone()
tlsConfig.NextProtos = []string{"http/1.1"}
}
tlsConn = tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
netConn.Close()
@ -115,3 +129,15 @@ func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
}
}
func supportsHTTP11(nextProtos []string) bool {
if len(nextProtos) == 0 {
return true
}
for _, proto := range nextProtos {
if proto == "http/1.1" {
return true
}
}
return false
}

View File

@ -49,6 +49,7 @@ func TestDialURL(t *testing.T) {
TLSConfig *tls.Config
Dial utilnet.DialFunc
ExpectError string
ExpectProto string
}{
"insecure": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
@ -90,13 +91,28 @@ func TestDialURL(t *testing.T) {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: d.DialContext,
},
"ensure we use http2 if specified": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com", NextProtos: []string{"http2"}},
Dial: d.DialContext,
ExpectProto: "http2",
},
"ensure we use http/1.1 if unspecified": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: d.DialContext,
ExpectProto: "http/1.1",
},
"ensure we use http/1.1 if available": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com", NextProtos: []string{"http2", "http/1.1"}},
Dial: d.DialContext,
ExpectProto: "http/1.1",
},
}
for k, tc := range testcases {
func() {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}))
defer ts.Close()
ts.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
ts.TLS = &tls.Config{Certificates: []tls.Certificate{cert}, NextProtos: []string{"http2", "http/1.1"}}
ts.StartTLS()
// Make a copy of the config
@ -127,7 +143,7 @@ func TestDialURL(t *testing.T) {
u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := DialURL(context.Background(), u, transport)
conn, err := dialURL(context.Background(), u, transport)
// Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {
@ -143,6 +159,14 @@ func TestDialURL(t *testing.T) {
}
return
}
tlsConn := conn.(*tls.Conn)
if tc.ExpectProto != "" {
if tlsConn.ConnectionState().NegotiatedProtocol != tc.ExpectProto {
t.Errorf("%s: expected proto %s, got %s", k, tc.ExpectProto, tlsConn.ConnectionState().NegotiatedProtocol)
}
}
conn.Close()
if tc.ExpectError != "" {
t.Errorf("%s: expected error %q, got none", k, tc.ExpectError)

View File

@ -384,10 +384,6 @@ func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Reques
return true
}
func (h *UpgradeAwareHandler) Dial(req *http.Request) (net.Conn, error) {
return dial(req, h.Transport)
}
func (h *UpgradeAwareHandler) DialForUpgrade(req *http.Request) (net.Conn, error) {
if h.UpgradeTransport == nil {
return dial(req, h.Transport)
@ -414,7 +410,7 @@ func getResponse(r io.Reader) (*http.Response, []byte, error) {
// dial dials the backend at req.URL and writes req to it.
func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
conn, err := DialURL(req.Context(), req.URL, transport)
conn, err := dialURL(req.Context(), req.URL, transport)
if err != nil {
return nil, fmt.Errorf("error dialing backend: %v", err)
}
@ -427,8 +423,6 @@ func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
return conn, err
}
var _ utilnet.Dialer = &UpgradeAwareHandler{}
func (h *UpgradeAwareHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper {
scheme := url.Scheme
host := url.Host

View File

@ -355,6 +355,25 @@ func TestProxyUpgrade(t *testing.T) {
ServerFunc: httptest.NewServer,
ProxyTransport: nil,
},
"both client and server support http2, but force to http/1.1 for upgrade": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(exampleCert, exampleKey)
if err != nil {
t.Errorf("https (invalid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"http2", "http/1.1"},
}
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{TLSClientConfig: &tls.Config{
NextProtos: []string{"http2", "http/1.1"},
InsecureSkipVerify: true,
}}),
},
"https (invalid hostname + InsecureSkipVerify)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(exampleCert, exampleKey)