Pass context for TLS dial

This commit is contained in:
Mikhail Mazurskiy 2022-04-29 11:13:15 +10:00
parent 6d74474972
commit 1dda915659
2 changed files with 27 additions and 21 deletions

View File

@ -273,29 +273,20 @@ func (s *SpdyRoundTripper) tlsConn(ctx context.Context, rwc net.Conn, targetHost
// dialWithoutProxy dials the host specified by url, using TLS if appropriate.
func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer := s.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
if url.Scheme == "http" {
if s.Dialer == nil {
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
} else {
return s.Dialer.DialContext(ctx, "tcp", dialAddr)
}
return dialer.DialContext(ctx, "tcp", dialAddr)
}
// TODO validate the TLSClientConfig is set up?
var conn *tls.Conn
var err error
if s.Dialer == nil {
conn, err = tls.Dial("tcp", dialAddr, s.tlsConfig)
} else {
conn, err = tls.DialWithDialer(s.Dialer, "tcp", dialAddr, s.tlsConfig)
tlsDialer := tls.Dialer{
NetDialer: dialer,
Config: s.tlsConfig,
}
if err != nil {
return nil, err
}
return conn, nil
return tlsDialer.DialContext(ctx, "tcp", dialAddr)
}
// proxyAuth returns, for a given proxy URL, the value to be used for the Proxy-Authorization header
@ -325,9 +316,7 @@ func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
resp, err := http.ReadResponse(responseReader, nil)
if err != nil {
if conn != nil {
conn.Close()
}
conn.Close()
return nil, err
}

View File

@ -31,6 +31,8 @@ import (
"github.com/armon/go-socks5"
"github.com/elazarl/goproxy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream"
)
@ -682,6 +684,21 @@ func TestRoundTripSocks5AndNewConnection(t *testing.T) {
}
}
func TestRoundTripPassesContextToDialer(t *testing.T) {
urls := []string{"http://127.0.0.1:1233/", "https://127.0.0.1:1233/"}
for _, u := range urls {
t.Run(u, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
require.NoError(t, err)
spdyTransport := NewRoundTripper(&tls.Config{})
_, err = spdyTransport.Dial(req)
assert.EqualError(t, err, "dial tcp 127.0.0.1:1233: operation was canceled")
})
}
}
// exampleCert was generated from crypto/tls/generate_cert.go with the following command:
// go run generate_cert.go --rsa-bits 2048 --host example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var exampleCert = []byte(`-----BEGIN CERTIFICATE-----