Use Dial with context

This commit is contained in:
Mikhail Mazurskiy 2018-05-19 08:14:37 +10:00
parent 77a08ee2d7
commit 5e8e570dbd
No known key found for this signature in database
GPG Key ID: 93551ECC96E2F568
25 changed files with 111 additions and 110 deletions

View File

@ -261,7 +261,7 @@ func CreateNodeDialer(s completedServerRunOptions) (tunneler.Tunneler, *http.Tra
// Proxying to pods and services is IP-based... don't expect to be able to verify the hostname // Proxying to pods and services is IP-based... don't expect to be able to verify the hostname
proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true} proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true}
proxyTransport := utilnet.SetTransportDefaults(&http.Transport{ proxyTransport := utilnet.SetTransportDefaults(&http.Transport{
Dial: proxyDialerFn, DialContext: proxyDialerFn,
TLSClientConfig: proxyTLSClientConfig, TLSClientConfig: proxyTLSClientConfig,
}) })
return nodeTunneler, proxyTransport, nil return nodeTunneler, proxyTransport, nil
@ -522,8 +522,8 @@ func BuildGenericConfig(
if err != nil { if err != nil {
return nil, err return nil, err
} }
if proxyTransport != nil && proxyTransport.Dial != nil { if proxyTransport != nil && proxyTransport.DialContext != nil {
ret.Dial = proxyTransport.Dial ret.Dial = proxyTransport.DialContext
} }
return ret, err return ret, err
}, },

View File

@ -74,7 +74,7 @@ func MakeTransport(config *KubeletClientConfig) (http.RoundTripper, error) {
rt := http.DefaultTransport rt := http.DefaultTransport
if config.Dial != nil || tlsConfig != nil { if config.Dial != nil || tlsConfig != nil {
rt = utilnet.SetOldTransportDefaults(&http.Transport{ rt = utilnet.SetOldTransportDefaults(&http.Transport{
Dial: config.Dial, DialContext: config.Dial,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
}) })
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package master package master
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
@ -108,7 +109,7 @@ func setUp(t *testing.T) (*etcdtesting.EtcdTestServer, Config, informers.SharedI
config.GenericConfig.LoopbackClientConfig = &restclient.Config{APIPath: "/api", ContentConfig: restclient.ContentConfig{NegotiatedSerializer: legacyscheme.Codecs}} config.GenericConfig.LoopbackClientConfig = &restclient.Config{APIPath: "/api", ContentConfig: restclient.ContentConfig{NegotiatedSerializer: legacyscheme.Codecs}}
config.ExtraConfig.KubeletClientConfig = kubeletclient.KubeletClientConfig{Port: 10250} config.ExtraConfig.KubeletClientConfig = kubeletclient.KubeletClientConfig{Port: 10250}
config.ExtraConfig.ProxyTransport = utilnet.SetTransportDefaults(&http.Transport{ config.ExtraConfig.ProxyTransport = utilnet.SetTransportDefaults(&http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return nil, nil }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil },
TLSClientConfig: &tls.Config{}, TLSClientConfig: &tls.Config{},
}) })

View File

@ -43,7 +43,7 @@ type AddressFunc func() (addresses []string, err error)
type Tunneler interface { type Tunneler interface {
Run(AddressFunc) Run(AddressFunc)
Stop() Stop()
Dial(net, addr string) (net.Conn, error) Dial(ctx context.Context, net, addr string) (net.Conn, error)
SecondsSinceSync() int64 SecondsSinceSync() int64
SecondsSinceSSHKeySync() int64 SecondsSinceSSHKeySync() int64
} }
@ -149,8 +149,8 @@ func (c *SSHTunneler) Stop() {
} }
} }
func (c *SSHTunneler) Dial(net, addr string) (net.Conn, error) { func (c *SSHTunneler) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
return c.tunnels.Dial(net, addr) return c.tunnels.Dial(ctx, net, addr)
} }
func (c *SSHTunneler) SecondsSinceSync() int64 { func (c *SSHTunneler) SecondsSinceSync() int64 {

View File

@ -17,6 +17,7 @@ limitations under the License.
package tunneler package tunneler
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -111,11 +112,11 @@ type FakeTunneler struct {
SecondsSinceSSHKeySyncValue int64 SecondsSinceSSHKeySyncValue int64
} }
func (t *FakeTunneler) Run(AddressFunc) {} func (t *FakeTunneler) Run(AddressFunc) {}
func (t *FakeTunneler) Stop() {} func (t *FakeTunneler) Stop() {}
func (t *FakeTunneler) Dial(net, addr string) (net.Conn, error) { return nil, nil } func (t *FakeTunneler) Dial(ctx context.Context, net, addr string) (net.Conn, error) { return nil, nil }
func (t *FakeTunneler) SecondsSinceSync() int64 { return t.SecondsSinceSyncValue } func (t *FakeTunneler) SecondsSinceSync() int64 { return t.SecondsSinceSyncValue }
func (t *FakeTunneler) SecondsSinceSSHKeySync() int64 { return t.SecondsSinceSSHKeySyncValue } func (t *FakeTunneler) SecondsSinceSSHKeySync() int64 { return t.SecondsSinceSSHKeySyncValue }
// TestIsTunnelSyncHealthy verifies that the 600 second lag test // TestIsTunnelSyncHealthy verifies that the 600 second lag test
// is honored. // is honored.

View File

@ -18,6 +18,7 @@ package ssh
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
@ -121,10 +122,11 @@ func (s *SSHTunnel) Open() error {
return err return err
} }
func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) { func (s *SSHTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if s.client == nil { if s.client == nil {
return nil, errors.New("tunnel is not opened.") return nil, errors.New("tunnel is not opened.")
} }
// This Dial method does not allow to pass a context unfortunately
return s.client.Dial(network, address) return s.client.Dial(network, address)
} }
@ -294,7 +296,7 @@ func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
type tunnel interface { type tunnel interface {
Open() error Open() error
Close() error Close() error
Dial(network, address string) (net.Conn, error) Dial(ctx context.Context, network, address string) (net.Conn, error)
} }
type sshTunnelEntry struct { type sshTunnelEntry struct {
@ -361,7 +363,7 @@ func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration
func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error { func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
// GET the healthcheck path using the provided tunnel's dial function. // GET the healthcheck path using the provided tunnel's dial function.
transport := utilnet.SetTransportDefaults(&http.Transport{ transport := utilnet.SetTransportDefaults(&http.Transport{
Dial: e.Tunnel.Dial, DialContext: e.Tunnel.Dial,
// TODO(cjcullen): Plumb real TLS options through. // TODO(cjcullen): Plumb real TLS options through.
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
// We don't reuse the clients, so disable the keep-alive to properly // We don't reuse the clients, so disable the keep-alive to properly
@ -394,7 +396,7 @@ func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
go l.createAndAddTunnel(e.Address) go l.createAndAddTunnel(e.Address)
} }
func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) { func (l *SSHTunnelList) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
start := time.Now() start := time.Now()
id := mathrand.Int63() // So you can match begins/ends in the log. id := mathrand.Int63() // So you can match begins/ends in the log.
glog.Infof("[%x: %v] Dialing...", id, addr) glog.Infof("[%x: %v] Dialing...", id, addr)
@ -405,7 +407,7 @@ func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tunnel.Dial(net, addr) return tunnel.Dial(ctx, net, addr)
} }
func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) { func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package ssh package ssh
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -145,7 +146,7 @@ func TestSSHTunnel(t *testing.T) {
t.FailNow() t.FailNow()
} }
_, err = tunnel.Dial("tcp", "127.0.0.1:8080") _, err = tunnel.Dial(context.Background(), "tcp", "127.0.0.1:8080")
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
@ -176,7 +177,7 @@ func (*fakeTunnel) Close() error {
return nil return nil
} }
func (*fakeTunnel) Dial(network, address string) (net.Conn, error) { func (*fakeTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil return nil, nil
} }

View File

@ -19,6 +19,7 @@ package spdy
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -118,7 +119,7 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
} }
if proxyURL == nil { if proxyURL == nil {
return s.dialWithoutProxy(req.URL) return s.dialWithoutProxy(req.Context(), req.URL)
} }
// ensure we use a canonical host with proxyReq // ensure we use a canonical host with proxyReq
@ -136,7 +137,7 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
proxyReq.Header.Set("Proxy-Authorization", pa) proxyReq.Header.Set("Proxy-Authorization", pa)
} }
proxyDialConn, err := s.dialWithoutProxy(proxyURL) proxyDialConn, err := s.dialWithoutProxy(req.Context(), proxyURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -187,14 +188,15 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
} }
// dialWithoutProxy dials the host specified by url, using TLS if appropriate. // dialWithoutProxy dials the host specified by url, using TLS if appropriate.
func (s *SpdyRoundTripper) dialWithoutProxy(url *url.URL) (net.Conn, error) { func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url) dialAddr := netutil.CanonicalAddr(url)
if url.Scheme == "http" { if url.Scheme == "http" {
if s.Dialer == nil { if s.Dialer == nil {
return net.Dial("tcp", dialAddr) var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
} else { } else {
return s.Dialer.Dial("tcp", dialAddr) return s.Dialer.DialContext(ctx, "tcp", dialAddr)
} }
} }

View File

@ -19,6 +19,7 @@ package net
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -90,8 +91,8 @@ func SetOldTransportDefaults(t *http.Transport) *http.Transport {
// ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY // ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment) t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
} }
if t.Dial == nil { if t.DialContext == nil {
t.Dial = defaultTransport.Dial t.DialContext = defaultTransport.DialContext
} }
if t.TLSHandshakeTimeout == 0 { if t.TLSHandshakeTimeout == 0 {
t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
@ -119,7 +120,7 @@ type RoundTripperWrapper interface {
WrappedRoundTripper() http.RoundTripper WrappedRoundTripper() http.RoundTripper
} }
type DialFunc func(net, addr string) (net.Conn, error) type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
func DialerFor(transport http.RoundTripper) (DialFunc, error) { func DialerFor(transport http.RoundTripper) (DialFunc, error) {
if transport == nil { if transport == nil {
@ -128,7 +129,7 @@ func DialerFor(transport http.RoundTripper) (DialFunc, error) {
switch transport := transport.(type) { switch transport := transport.(type) {
case *http.Transport: case *http.Transport:
return transport.Dial, nil return transport.DialContext, nil
case RoundTripperWrapper: case RoundTripperWrapper:
return DialerFor(transport.WrappedRoundTripper()) return DialerFor(transport.WrappedRoundTripper())
default: default:

View File

@ -17,6 +17,7 @@ limitations under the License.
package proxy package proxy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@ -29,7 +30,7 @@ import (
"k8s.io/apimachinery/third_party/forked/golang/netutil" "k8s.io/apimachinery/third_party/forked/golang/netutil"
) )
func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url) dialAddr := netutil.CanonicalAddr(url)
dialer, err := utilnet.DialerFor(transport) dialer, err := utilnet.DialerFor(transport)
@ -40,9 +41,10 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
switch url.Scheme { switch url.Scheme {
case "http": case "http":
if dialer != nil { if dialer != nil {
return dialer("tcp", dialAddr) return dialer(ctx, "tcp", dialAddr)
} }
return net.Dial("tcp", dialAddr) var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
case "https": case "https":
// Get the tls config from the transport if we recognize it // Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config var tlsConfig *tls.Config
@ -56,7 +58,7 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
if dialer != nil { if dialer != nil {
// We have a dialer; use it to open the connection, then // We have a dialer; use it to open the connection, then
// create a tls client using the connection. // create a tls client using the connection.
netConn, err := dialer("tcp", dialAddr) netConn, err := dialer(ctx, "tcp", dialAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -86,7 +88,7 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
} }
} else { } else {
// Dial // Dial. This Dial method does not allow to pass a context unfortunately
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -17,6 +17,7 @@ limitations under the License.
package proxy package proxy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
@ -42,6 +43,7 @@ func TestDialURL(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var d net.Dialer
testcases := map[string]struct { testcases := map[string]struct {
TLSConfig *tls.Config TLSConfig *tls.Config
@ -68,25 +70,25 @@ func TestDialURL(t *testing.T) {
"insecure, custom dial": { "insecure, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: true}, TLSConfig: &tls.Config{InsecureSkipVerify: true},
Dial: net.Dial, Dial: d.DialContext,
}, },
"secure, no roots, custom dial": { "secure, no roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false}, TLSConfig: &tls.Config{InsecureSkipVerify: false},
Dial: net.Dial, Dial: d.DialContext,
ExpectError: "unknown authority", ExpectError: "unknown authority",
}, },
"secure with roots, custom dial": { "secure with roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots}, TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
Dial: net.Dial, Dial: d.DialContext,
}, },
"secure with mismatched server, custom dial": { "secure with mismatched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"}, TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
Dial: net.Dial, Dial: d.DialContext,
ExpectError: "not bogus.com", ExpectError: "not bogus.com",
}, },
"secure with matched server, custom dial": { "secure with matched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"}, TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: net.Dial, Dial: d.DialContext,
}, },
} }
@ -102,7 +104,7 @@ func TestDialURL(t *testing.T) {
// Clone() mutates the receiver (!), so also call it on the copy // Clone() mutates the receiver (!), so also call it on the copy
tlsConfigCopy.Clone() tlsConfigCopy.Clone()
transport := &http.Transport{ transport := &http.Transport{
Dial: tc.Dial, DialContext: tc.Dial,
TLSClientConfig: tlsConfigCopy, TLSClientConfig: tlsConfigCopy,
} }
@ -125,7 +127,7 @@ func TestDialURL(t *testing.T) {
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host) _, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p) u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := DialURL(u, transport) conn, err := DialURL(context.Background(), u, transport)
// Make sure dialing doesn't mutate the transport's TLSConfig // Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) { if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {

View File

@ -347,7 +347,7 @@ func (h *UpgradeAwareHandler) DialForUpgrade(req *http.Request) (net.Conn, error
// dial dials the backend at req.URL and writes req to it. // dial dials the backend at req.URL and writes req to it.
func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) { func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
conn, err := DialURL(req.URL, transport) conn, err := DialURL(req.Context(), req.URL, transport)
if err != nil { if err != nil {
return nil, fmt.Errorf("error dialing backend: %v", err) return nil, fmt.Errorf("error dialing backend: %v", err)
} }

View File

@ -19,6 +19,7 @@ package proxy
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
@ -341,6 +342,7 @@ func TestProxyUpgrade(t *testing.T) {
if !localhostPool.AppendCertsFromPEM(localhostCert) { if !localhostPool.AppendCertsFromPEM(localhostCert) {
t.Errorf("error setting up localhostCert pool") t.Errorf("error setting up localhostCert pool")
} }
var d net.Dialer
testcases := map[string]struct { testcases := map[string]struct {
ServerFunc func(http.Handler) *httptest.Server ServerFunc func(http.Handler) *httptest.Server
@ -395,7 +397,7 @@ func TestProxyUpgrade(t *testing.T) {
ts.StartTLS() ts.StartTLS()
return ts return ts
}, },
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}), ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
}, },
"https (valid hostname + RootCAs + custom dialer + bearer token)": { "https (valid hostname + RootCAs + custom dialer + bearer token)": {
ServerFunc: func(h http.Handler) *httptest.Server { ServerFunc: func(h http.Handler) *httptest.Server {
@ -410,9 +412,9 @@ func TestProxyUpgrade(t *testing.T) {
ts.StartTLS() ts.StartTLS()
return ts return ts
}, },
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}), ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
UpgradeTransport: NewUpgradeRequestRoundTripper( UpgradeTransport: NewUpgradeRequestRoundTripper(
utilnet.SetOldTransportDefaults(&http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}), utilnet.SetOldTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
RoundTripperFunc(func(req *http.Request) (*http.Response, error) { RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = utilnet.CloneRequest(req) req = utilnet.CloneRequest(req)
req.Header.Set("Authorization", "Bearer 1234") req.Header.Set("Authorization", "Bearer 1234")
@ -496,9 +498,15 @@ func TestProxyUpgradeErrorResponse(t *testing.T) {
expectedErr = errors.New("EXPECTED") expectedErr = errors.New("EXPECTED")
) )
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
transport := http.DefaultTransport.(*http.Transport) transport := &http.Transport{
transport.Dial = func(network, addr string) (net.Conn, error) { Proxy: http.ProxyFromEnvironment,
return &fakeConn{err: expectedErr}, nil DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return &fakeConn{err: expectedErr}, nil
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
} }
responder = &fakeResponder{t: t, w: w} responder = &fakeResponder{t: t, w: w}
proxyHandler := NewUpgradeAwareHandler( proxyHandler := NewUpgradeAwareHandler(

View File

@ -17,6 +17,7 @@ limitations under the License.
package config package config
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -147,9 +148,10 @@ func (cm *ClientManager) HookClient(h *v1beta1.Webhook) (*rest.RESTClient, error
delegateDialer := cfg.Dial delegateDialer := cfg.Dial
if delegateDialer == nil { if delegateDialer == nil {
delegateDialer = net.Dial var d net.Dialer
delegateDialer = d.DialContext
} }
cfg.Dial = func(network, addr string) (net.Conn, error) { cfg.Dial = func(ctx context.Context, network, addr string) (net.Conn, error) {
if addr == host { if addr == host {
u, err := cm.serviceResolver.ResolveEndpoint(svc.Namespace, svc.Name) u, err := cm.serviceResolver.ResolveEndpoint(svc.Namespace, svc.Name)
if err != nil { if err != nil {
@ -157,7 +159,7 @@ func (cm *ClientManager) HookClient(h *v1beta1.Webhook) (*rest.RESTClient, error
} }
addr = u.Host addr = u.Host
} }
return delegateDialer(network, addr) return delegateDialer(ctx, network, addr)
} }
return complete(cfg) return complete(cfg)

View File

@ -69,10 +69,10 @@ func newTransportForETCD2(certFile, keyFile, caFile string) (*http.Transport, er
// TODO: Determine if transport needs optimization // TODO: Determine if transport needs optimization
tr := utilnet.SetTransportDefaults(&http.Transport{ tr := utilnet.SetTransportDefaults(&http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
}).Dial, }).DialContext,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
MaxIdleConnsPerHost: 500, MaxIdleConnsPerHost: 500,
TLSClientConfig: cfg, TLSClientConfig: cfg,

View File

@ -44,12 +44,8 @@ const (
defaultRetries = 2 defaultRetries = 2
// protobuf mime type // protobuf mime type
mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf" mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf"
)
var (
// defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient. // defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient.
// Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist. // Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist.
// It's a variable to be able to change it in tests.
defaultTimeout = 32 * time.Second defaultTimeout = 32 * time.Second
) )

View File

@ -23,12 +23,11 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"strings"
"testing" "testing"
"time"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/googleapis/gnostic/OpenAPIv2" "github.com/googleapis/gnostic/OpenAPIv2"
"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
@ -131,31 +130,11 @@ func TestGetServerGroupsWithBrokenServer(t *testing.T) {
} }
} }
} }
func TestGetServerGroupsWithTimeout(t *testing.T) {
done := make(chan bool) func TestTimeoutIsSet(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { cfg := &restclient.Config{}
// first we need to write headers, otherwise http client will complain about setDiscoveryDefaults(cfg)
// exceeding timeout awaiting headers, only after we can block the call assert.Equal(t, defaultTimeout, cfg.Timeout)
w.Header().Set("Connection", "keep-alive")
if wf, ok := w.(http.Flusher); ok {
wf.Flush()
}
<-done
}))
defer server.Close()
defer close(done)
client := NewDiscoveryClientForConfigOrDie(&restclient.Config{Host: server.URL, Timeout: 2 * time.Second})
_, err := client.ServerGroups()
// the error we're getting here is wrapped in errors.errorString which makes
// it impossible to unwrap and check it's attributes, so instead we're checking
// the textual output which is presenting http.httpError with timeout set to true
if err == nil {
t.Fatal("missing error")
}
if !strings.Contains(err.Error(), "timeout:true") &&
!strings.Contains(err.Error(), "context.deadlineExceededError") {
t.Fatalf("unexpected error: %v", err)
}
} }
func TestGetServerResourcesWithV1Server(t *testing.T) { func TestGetServerResourcesWithV1Server(t *testing.T) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package rest package rest
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -110,7 +111,7 @@ type Config struct {
Timeout time.Duration Timeout time.Duration
// Dial specifies the dial function for creating unencrypted TCP connections. // Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error) Dial func(ctx context.Context, network, address string) (net.Conn, error)
// Version forces a specific version to be used (if registered) // Version forces a specific version to be used (if registered)
// Do we need this? // Do we need this?

View File

@ -17,6 +17,8 @@ limitations under the License.
package rest package rest
import ( import (
"context"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -25,8 +27,6 @@ import (
"strings" "strings"
"testing" "testing"
fuzz "github.com/google/gofuzz"
"k8s.io/api/core/v1" "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
@ -35,8 +35,7 @@ import (
clientcmdapi "k8s.io/client-go/tools/clientcmd/api" clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/util/flowcontrol" "k8s.io/client-go/util/flowcontrol"
"errors" fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -208,7 +207,7 @@ func (n *fakeNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder,
return &fakeCodec{} return &fakeCodec{}
} }
var fakeDialFunc = func(network, addr string) (net.Conn, error) { var fakeDialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, fakeDialerError return nil, fakeDialerError
} }
var fakeDialerError = errors.New("fakedialer") var fakeDialerError = errors.New("fakedialer")
@ -253,7 +252,7 @@ func TestAnonymousConfig(t *testing.T) {
r.Config = map[string]string{} r.Config = map[string]string{}
}, },
// Dial does not require fuzzer // Dial does not require fuzzer
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {}, func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {},
) )
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
original := &Config{} original := &Config{}
@ -284,10 +283,10 @@ func TestAnonymousConfig(t *testing.T) {
expected.WrapTransport = nil expected.WrapTransport = nil
} }
if actual.Dial != nil { if actual.Dial != nil {
_, actualError := actual.Dial("", "") _, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := actual.Dial("", "") _, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) { if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field") t.Fatalf("CopyConfig dropped the Dial field")
} }
} else { } else {
actual.Dial = nil actual.Dial = nil
@ -329,7 +328,7 @@ func TestCopyConfig(t *testing.T) {
func(r *AuthProviderConfigPersister, f fuzz.Continue) { func(r *AuthProviderConfigPersister, f fuzz.Continue) {
*r = fakeAuthProviderConfigPersister{} *r = fakeAuthProviderConfigPersister{}
}, },
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) { func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {
*r = fakeDialFunc *r = fakeDialFunc
}, },
) )
@ -351,8 +350,8 @@ func TestCopyConfig(t *testing.T) {
expected.WrapTransport = nil expected.WrapTransport = nil
} }
if actual.Dial != nil { if actual.Dial != nil {
_, actualError := actual.Dial("", "") _, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := actual.Dial("", "") _, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) { if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field") t.Fatalf("CopyConfig dropped the Dial field")
} }
@ -361,7 +360,7 @@ func TestCopyConfig(t *testing.T) {
expected.Dial = nil expected.Dial = nil
if actual.AuthConfigPersister != nil { if actual.AuthConfigPersister != nil {
actualError := actual.AuthConfigPersister.Persist(nil) actualError := actual.AuthConfigPersister.Persist(nil)
expectedError := actual.AuthConfigPersister.Persist(nil) expectedError := expected.AuthConfigPersister.Persist(nil)
if !reflect.DeepEqual(expectedError, actualError) { if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field") t.Fatalf("CopyConfig dropped the Dial field")
} }

View File

@ -85,7 +85,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dial = (&net.Dialer{ dial = (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
}).Dial }).DialContext
} }
// Cache a single transport for these options // Cache a single transport for these options
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
@ -93,7 +93,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: idleConnsPerHost, MaxIdleConnsPerHost: idleConnsPerHost,
Dial: dial, DialContext: dial,
}) })
return c.transports[key], nil return c.transports[key], nil
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport package transport
import ( import (
"context"
"net" "net"
"net/http" "net/http"
"testing" "testing"
@ -52,10 +53,11 @@ func TestTLSConfigKey(t *testing.T) {
} }
// Make sure config fields that affect the tls config affect the cache key // Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
uniqueConfigurations := map[string]*Config{ uniqueConfigurations := map[string]*Config{
"no tls": {}, "no tls": {},
"dialer": {Dial: net.Dial}, "dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(network, address string) (net.Conn, error) { return nil, nil }}, "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"insecure": {TLS: TLSConfig{Insecure: true}}, "insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport package transport
import ( import (
"context"
"net" "net"
"net/http" "net/http"
) )
@ -53,7 +54,7 @@ type Config struct {
WrapTransport func(rt http.RoundTripper) http.RoundTripper WrapTransport func(rt http.RoundTripper) http.RoundTripper
// Dial specifies the dial function for creating unencrypted TCP connections. // Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error) Dial func(ctx context.Context, network, address string) (net.Conn, error)
} }
// ImpersonationConfig has all the available impersonation options // ImpersonationConfig has all the available impersonation options

View File

@ -209,10 +209,10 @@ func (r *proxyHandler) updateAPIService(apiService *apiregistrationapi.APIServic
serviceAvailable: apiregistrationapi.IsAPIServiceConditionTrue(apiService, apiregistrationapi.Available), serviceAvailable: apiregistrationapi.IsAPIServiceConditionTrue(apiService, apiregistrationapi.Available),
} }
newInfo.proxyRoundTripper, newInfo.transportBuildingError = restclient.TransportFor(newInfo.restConfig) newInfo.proxyRoundTripper, newInfo.transportBuildingError = restclient.TransportFor(newInfo.restConfig)
if newInfo.transportBuildingError == nil && r.proxyTransport != nil && r.proxyTransport.Dial != nil { if newInfo.transportBuildingError == nil && r.proxyTransport != nil && r.proxyTransport.DialContext != nil {
switch transport := newInfo.proxyRoundTripper.(type) { switch transport := newInfo.proxyRoundTripper.(type) {
case *http.Transport: case *http.Transport:
transport.Dial = r.proxyTransport.Dial transport.DialContext = r.proxyTransport.DialContext
default: default:
newInfo.transportBuildingError = fmt.Errorf("unable to set dialer for %s/%s as rest transport is of type %T", apiService.Spec.Service.Namespace, apiService.Spec.Service.Name, newInfo.proxyRoundTripper) newInfo.transportBuildingError = fmt.Errorf("unable to set dialer for %s/%s as rest transport is of type %T", apiService.Spec.Service.Namespace, apiService.Spec.Service.Name, newInfo.proxyRoundTripper)
glog.Warning(newInfo.transportBuildingError.Error()) glog.Warning(newInfo.transportBuildingError.Error())

View File

@ -1868,11 +1868,12 @@ func startProxyServer() (int, *exec.Cmd, error) {
} }
func curlUnix(url string, path string) (string, error) { func curlUnix(url string, path string) (string, error) {
dial := func(proto, addr string) (net.Conn, error) { dial := func(ctx context.Context, proto, addr string) (net.Conn, error) {
return net.Dial("unix", path) var d net.Dialer
return d.DialContext(ctx, "unix", path)
} }
transport := utilnet.SetTransportDefaults(&http.Transport{ transport := utilnet.SetTransportDefaults(&http.Transport{
Dial: dial, DialContext: dial,
}) })
return curlTransport(url, transport) return curlTransport(url, transport)
} }

View File

@ -373,10 +373,10 @@ func createClients(numberOfClients int) ([]clientset.Interface, []internalclient
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: 100, MaxIdleConnsPerHost: 100,
Dial: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
}).Dial, }).DialContext,
}) })
// Overwrite TLS-related fields from config to avoid collision with // Overwrite TLS-related fields from config to avoid collision with
// Transport field. // Transport field.