Merge pull request #110079 from ash2k/dial-with-context

Always dial using a context
This commit is contained in:
Kubernetes Prow Robot 2022-05-24 08:52:17 -07:00 committed by GitHub
commit 114cdea709
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 28 deletions

View File

@ -51,10 +51,7 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne
return d.DialContext(ctx, "tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
var tlsConn *tls.Conn
var err error
tlsConfig, err = utilnet.TLSClientConfig(transport)
tlsConfig, err := utilnet.TLSClientConfig(transport)
if err != nil {
klog.V(5).Infof("Unable to unwrap transport %T to get at TLS config: %v", transport, err)
}
@ -74,7 +71,7 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne
InsecureSkipVerify: true,
}
} else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
// tls.Handshake() requires ServerName or InsecureSkipVerify
// tls.HandshakeContext() requires ServerName or InsecureSkipVerify
// infer the ServerName from the hostname we're connecting to.
inferredHost := dialAddr
if host, _, err := net.SplitHostPort(dialAddr); err == nil {
@ -86,7 +83,7 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne
tlsConfig = tlsConfigCopy
}
// Since this method is primary used within a "Connection: Upgrade" call we assume the caller is
// Since this method is primarily used within a "Connection: Upgrade" call we assume the caller is
// going to write HTTP/1.1 request to the wire. http2 should not be allowed in the TLSConfig.NextProtos,
// so we explicitly set that here. We only do this check if the TLSConfig support http/1.1.
if supportsHTTP11(tlsConfig.NextProtos) {
@ -94,23 +91,21 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne
tlsConfig.NextProtos = []string{"http/1.1"}
}
tlsConn = tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
tlsConn := tls.Client(netConn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
netConn.Close()
return nil, err
}
return tlsConn, nil
} else {
// Dial. This Dial method does not allow to pass a context unfortunately
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
// Dial.
tlsDialer := tls.Dialer{
Config: tlsConfig,
}
return tlsDialer.DialContext(ctx, "tcp", dialAddr)
}
return tlsConn, nil
default:
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
return nil, fmt.Errorf("unknown scheme: %s", url.Scheme)
}
}

View File

@ -157,7 +157,11 @@ func (g *grpcProxier) proxy(ctx context.Context, addr string) (net.Conn, error)
type proxyServerConnector interface {
// connect establishes connection to the proxy server, and returns a
// proxier based on the connection.
connect() (proxier, error)
//
// The provided Context must be non-nil. The context is used for connecting to the proxy only.
// If the context expires before the connection is complete, an error is returned.
// Once successfully connected to the proxy, any expiration of the context will not affect the connection.
connect(context.Context) (proxier, error)
}
type tcpHTTPConnectConnector struct {
@ -165,8 +169,11 @@ type tcpHTTPConnectConnector struct {
tlsConfig *tls.Config
}
func (t *tcpHTTPConnectConnector) connect() (proxier, error) {
conn, err := tls.Dial("tcp", t.proxyAddress, t.tlsConfig)
func (t *tcpHTTPConnectConnector) connect(ctx context.Context) (proxier, error) {
d := tls.Dialer{
Config: t.tlsConfig,
}
conn, err := d.DialContext(ctx, "tcp", t.proxyAddress)
if err != nil {
return nil, err
}
@ -177,8 +184,9 @@ type udsHTTPConnectConnector struct {
udsName string
}
func (u *udsHTTPConnectConnector) connect() (proxier, error) {
conn, err := net.Dial("unix", u.udsName)
func (u *udsHTTPConnectConnector) connect(ctx context.Context) (proxier, error) {
var d net.Dialer
conn, err := d.DialContext(ctx, "unix", u.udsName)
if err != nil {
return nil, err
}
@ -189,18 +197,24 @@ type udsGRPCConnector struct {
udsName string
}
func (u *udsGRPCConnector) connect() (proxier, error) {
// connect establishes a connection to a proxy over gRPC.
// TODO At the moment, it does not use the provided context.
func (u *udsGRPCConnector) connect(_ context.Context) (proxier, error) {
udsName := u.udsName
dialOption := grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
c, err := net.Dial("unix", udsName)
dialOption := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
c, err := d.DialContext(ctx, "unix", udsName)
if err != nil {
klog.Errorf("failed to create connection to uds name %s, error: %v", udsName, err)
}
return c, err
})
ctx := context.TODO()
tunnel, err := client.CreateSingleUseGrpcTunnel(ctx, udsName, dialOption, grpc.WithInsecure())
// CreateSingleUseGrpcTunnel() unfortunately couples dial and connection contexts. Because of that,
// we cannot use ctx just for dialing and control the connection lifetime separately.
// See https://github.com/kubernetes-sigs/apiserver-network-proxy/issues/357.
tunnelCtx := context.TODO()
tunnel, err := client.CreateSingleUseGrpcTunnel(tunnelCtx, udsName, dialOption, grpc.WithInsecure())
if err != nil {
return nil, err
}
@ -226,7 +240,7 @@ func (d *dialerCreator) createDialer() utilnet.DialFunc {
trace := utiltrace.New(fmt.Sprintf("Proxy via %s protocol over %s", d.options.protocol, d.options.transport), utiltrace.Field{Key: "address", Value: addr})
defer trace.LogIfLong(500 * time.Millisecond)
start := egressmetrics.Metrics.Clock().Now()
proxier, err := d.connector.connect()
proxier, err := d.connector.connect(ctx)
if err != nil {
egressmetrics.Metrics.ObserveDialFailure(d.options.protocol, d.options.transport, egressmetrics.StageConnect)
return nil, err

View File

@ -176,7 +176,7 @@ type fakeProxyServerConnector struct {
proxierErr bool
}
func (f *fakeProxyServerConnector) connect() (proxier, error) {
func (f *fakeProxyServerConnector) connect(context.Context) (proxier, error) {
if f.connectorErr {
return nil, fmt.Errorf("fake error")
}