From 2d62c57533f96985b7847dd63f91471167bd6006 Mon Sep 17 00:00:00 2001 From: Mikhail Mazurskiy Date: Wed, 18 May 2022 10:39:35 +1000 Subject: [PATCH] Always dial using a context --- .../apimachinery/pkg/util/proxy/dial.go | 27 ++++++-------- .../server/egressselector/egress_selector.go | 36 +++++++++++++------ .../egressselector/egress_selector_test.go | 2 +- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go b/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go index 18165d8dae2..4ceb2e06eab 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go @@ -51,10 +51,7 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne return d.DialContext(ctx, "tcp", dialAddr) case "https": // Get the tls config from the transport if we recognize it - var tlsConfig *tls.Config - var tlsConn *tls.Conn - var err error - tlsConfig, err = utilnet.TLSClientConfig(transport) + tlsConfig, err := utilnet.TLSClientConfig(transport) if err != nil { klog.V(5).Infof("Unable to unwrap transport %T to get at TLS config: %v", transport, err) } @@ -74,7 +71,7 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne InsecureSkipVerify: true, } } else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - // tls.Handshake() requires ServerName or InsecureSkipVerify + // tls.HandshakeContext() requires ServerName or InsecureSkipVerify // infer the ServerName from the hostname we're connecting to. inferredHost := dialAddr if host, _, err := net.SplitHostPort(dialAddr); err == nil { @@ -86,7 +83,7 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne tlsConfig = tlsConfigCopy } - // Since this method is primary used within a "Connection: Upgrade" call we assume the caller is + // Since this method is primarily used within a "Connection: Upgrade" call we assume the caller is // going to write HTTP/1.1 request to the wire. http2 should not be allowed in the TLSConfig.NextProtos, // so we explicitly set that here. We only do this check if the TLSConfig support http/1.1. if supportsHTTP11(tlsConfig.NextProtos) { @@ -94,23 +91,21 @@ func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (ne tlsConfig.NextProtos = []string{"http/1.1"} } - tlsConn = tls.Client(netConn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { + tlsConn := tls.Client(netConn, tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { netConn.Close() return nil, err } - + return tlsConn, nil } else { - // Dial. This Dial method does not allow to pass a context unfortunately - tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) - if err != nil { - return nil, err + // Dial. + tlsDialer := tls.Dialer{ + Config: tlsConfig, } + return tlsDialer.DialContext(ctx, "tcp", dialAddr) } - - return tlsConn, nil default: - return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme) + return nil, fmt.Errorf("unknown scheme: %s", url.Scheme) } } diff --git a/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go b/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go index 9da0e2a099c..1115c66d701 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go +++ b/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go @@ -157,7 +157,11 @@ func (g *grpcProxier) proxy(ctx context.Context, addr string) (net.Conn, error) type proxyServerConnector interface { // connect establishes connection to the proxy server, and returns a // proxier based on the connection. - connect() (proxier, error) + // + // The provided Context must be non-nil. The context is used for connecting to the proxy only. + // If the context expires before the connection is complete, an error is returned. + // Once successfully connected to the proxy, any expiration of the context will not affect the connection. + connect(context.Context) (proxier, error) } type tcpHTTPConnectConnector struct { @@ -165,8 +169,11 @@ type tcpHTTPConnectConnector struct { tlsConfig *tls.Config } -func (t *tcpHTTPConnectConnector) connect() (proxier, error) { - conn, err := tls.Dial("tcp", t.proxyAddress, t.tlsConfig) +func (t *tcpHTTPConnectConnector) connect(ctx context.Context) (proxier, error) { + d := tls.Dialer{ + Config: t.tlsConfig, + } + conn, err := d.DialContext(ctx, "tcp", t.proxyAddress) if err != nil { return nil, err } @@ -177,8 +184,9 @@ type udsHTTPConnectConnector struct { udsName string } -func (u *udsHTTPConnectConnector) connect() (proxier, error) { - conn, err := net.Dial("unix", u.udsName) +func (u *udsHTTPConnectConnector) connect(ctx context.Context) (proxier, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, "unix", u.udsName) if err != nil { return nil, err } @@ -189,18 +197,24 @@ type udsGRPCConnector struct { udsName string } -func (u *udsGRPCConnector) connect() (proxier, error) { +// connect establishes a connection to a proxy over gRPC. +// TODO At the moment, it does not use the provided context. +func (u *udsGRPCConnector) connect(_ context.Context) (proxier, error) { udsName := u.udsName - dialOption := grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { - c, err := net.Dial("unix", udsName) + dialOption := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + c, err := d.DialContext(ctx, "unix", udsName) if err != nil { klog.Errorf("failed to create connection to uds name %s, error: %v", udsName, err) } return c, err }) - ctx := context.TODO() - tunnel, err := client.CreateSingleUseGrpcTunnel(ctx, udsName, dialOption, grpc.WithInsecure()) + // CreateSingleUseGrpcTunnel() unfortunately couples dial and connection contexts. Because of that, + // we cannot use ctx just for dialing and control the connection lifetime separately. + // See https://github.com/kubernetes-sigs/apiserver-network-proxy/issues/357. + tunnelCtx := context.TODO() + tunnel, err := client.CreateSingleUseGrpcTunnel(tunnelCtx, udsName, dialOption, grpc.WithInsecure()) if err != nil { return nil, err } @@ -226,7 +240,7 @@ func (d *dialerCreator) createDialer() utilnet.DialFunc { trace := utiltrace.New(fmt.Sprintf("Proxy via %s protocol over %s", d.options.protocol, d.options.transport), utiltrace.Field{Key: "address", Value: addr}) defer trace.LogIfLong(500 * time.Millisecond) start := egressmetrics.Metrics.Clock().Now() - proxier, err := d.connector.connect() + proxier, err := d.connector.connect(ctx) if err != nil { egressmetrics.Metrics.ObserveDialFailure(d.options.protocol, d.options.transport, egressmetrics.StageConnect) return nil, err diff --git a/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector_test.go b/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector_test.go index bab1f3ee8b9..c5aaf6b6474 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector_test.go @@ -176,7 +176,7 @@ type fakeProxyServerConnector struct { proxierErr bool } -func (f *fakeProxyServerConnector) connect() (proxier, error) { +func (f *fakeProxyServerConnector) connect(context.Context) (proxier, error) { if f.connectorErr { return nil, fmt.Errorf("fake error") }