From 4a75b93cb4ed8c16e3d555ba3c707029e76b6879 Mon Sep 17 00:00:00 2001 From: Mikhail Mazurskiy Date: Sat, 19 May 2018 08:14:37 +1000 Subject: [PATCH] Use Dial with context Kubernetes-commit: 5e8e570dbda6ed89af9bc2e0a05e3d94bfdfcb61 --- discovery/discovery_client.go | 4 ---- discovery/discovery_client_test.go | 33 ++++++------------------------ rest/config.go | 3 ++- rest/config_test.go | 25 +++++++++++----------- transport/cache.go | 4 ++-- transport/cache_test.go | 6 ++++-- transport/config.go | 3 ++- 7 files changed, 28 insertions(+), 50 deletions(-) diff --git a/discovery/discovery_client.go b/discovery/discovery_client.go index cef4d401..a9660297 100644 --- a/discovery/discovery_client.go +++ b/discovery/discovery_client.go @@ -44,12 +44,8 @@ const ( defaultRetries = 2 // protobuf mime type mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf" -) - -var ( // defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient. // Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist. - // It's a variable to be able to change it in tests. defaultTimeout = 32 * time.Second ) diff --git a/discovery/discovery_client_test.go b/discovery/discovery_client_test.go index 9148bfb4..10e49432 100644 --- a/discovery/discovery_client_test.go +++ b/discovery/discovery_client_test.go @@ -23,12 +23,11 @@ import ( "net/http" "net/http/httptest" "reflect" - "strings" "testing" - "time" "github.com/gogo/protobuf/proto" "github.com/googleapis/gnostic/OpenAPIv2" + "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" @@ -131,31 +130,11 @@ func TestGetServerGroupsWithBrokenServer(t *testing.T) { } } } -func TestGetServerGroupsWithTimeout(t *testing.T) { - done := make(chan bool) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // first we need to write headers, otherwise http client will complain about - // exceeding timeout awaiting headers, only after we can block the call - w.Header().Set("Connection", "keep-alive") - if wf, ok := w.(http.Flusher); ok { - wf.Flush() - } - <-done - })) - defer server.Close() - defer close(done) - client := NewDiscoveryClientForConfigOrDie(&restclient.Config{Host: server.URL, Timeout: 2 * time.Second}) - _, err := client.ServerGroups() - // the error we're getting here is wrapped in errors.errorString which makes - // it impossible to unwrap and check it's attributes, so instead we're checking - // the textual output which is presenting http.httpError with timeout set to true - if err == nil { - t.Fatal("missing error") - } - if !strings.Contains(err.Error(), "timeout:true") && - !strings.Contains(err.Error(), "context.deadlineExceededError") { - t.Fatalf("unexpected error: %v", err) - } + +func TestTimeoutIsSet(t *testing.T) { + cfg := &restclient.Config{} + setDiscoveryDefaults(cfg) + assert.Equal(t, defaultTimeout, cfg.Timeout) } func TestGetServerResourcesWithV1Server(t *testing.T) { diff --git a/rest/config.go b/rest/config.go index af2cbb99..7934a019 100644 --- a/rest/config.go +++ b/rest/config.go @@ -17,6 +17,7 @@ limitations under the License. package rest import ( + "context" "fmt" "io/ioutil" "net" @@ -110,7 +111,7 @@ type Config struct { Timeout time.Duration // Dial specifies the dial function for creating unencrypted TCP connections. - Dial func(network, addr string) (net.Conn, error) + Dial func(ctx context.Context, network, address string) (net.Conn, error) // Version forces a specific version to be used (if registered) // Do we need this? diff --git a/rest/config_test.go b/rest/config_test.go index a9495d79..34786428 100644 --- a/rest/config_test.go +++ b/rest/config_test.go @@ -17,6 +17,8 @@ limitations under the License. package rest import ( + "context" + "errors" "io" "net" "net/http" @@ -25,8 +27,6 @@ import ( "strings" "testing" - fuzz "github.com/google/gofuzz" - "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -35,8 +35,7 @@ import ( clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "k8s.io/client-go/util/flowcontrol" - "errors" - + fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/assert" ) @@ -208,7 +207,7 @@ func (n *fakeNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder, return &fakeCodec{} } -var fakeDialFunc = func(network, addr string) (net.Conn, error) { +var fakeDialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, fakeDialerError } var fakeDialerError = errors.New("fakedialer") @@ -253,7 +252,7 @@ func TestAnonymousConfig(t *testing.T) { r.Config = map[string]string{} }, // Dial does not require fuzzer - func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {}, + func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {}, ) for i := 0; i < 20; i++ { original := &Config{} @@ -284,10 +283,10 @@ func TestAnonymousConfig(t *testing.T) { expected.WrapTransport = nil } if actual.Dial != nil { - _, actualError := actual.Dial("", "") - _, expectedError := actual.Dial("", "") + _, actualError := actual.Dial(context.Background(), "", "") + _, expectedError := expected.Dial(context.Background(), "", "") if !reflect.DeepEqual(expectedError, actualError) { - t.Fatalf("CopyConfig dropped the Dial field") + t.Fatalf("CopyConfig dropped the Dial field") } } else { actual.Dial = nil @@ -329,7 +328,7 @@ func TestCopyConfig(t *testing.T) { func(r *AuthProviderConfigPersister, f fuzz.Continue) { *r = fakeAuthProviderConfigPersister{} }, - func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) { + func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) { *r = fakeDialFunc }, ) @@ -351,8 +350,8 @@ func TestCopyConfig(t *testing.T) { expected.WrapTransport = nil } if actual.Dial != nil { - _, actualError := actual.Dial("", "") - _, expectedError := actual.Dial("", "") + _, actualError := actual.Dial(context.Background(), "", "") + _, expectedError := expected.Dial(context.Background(), "", "") if !reflect.DeepEqual(expectedError, actualError) { t.Fatalf("CopyConfig dropped the Dial field") } @@ -361,7 +360,7 @@ func TestCopyConfig(t *testing.T) { expected.Dial = nil if actual.AuthConfigPersister != nil { actualError := actual.AuthConfigPersister.Persist(nil) - expectedError := actual.AuthConfigPersister.Persist(nil) + expectedError := expected.AuthConfigPersister.Persist(nil) if !reflect.DeepEqual(expectedError, actualError) { t.Fatalf("CopyConfig dropped the Dial field") } diff --git a/transport/cache.go b/transport/cache.go index 83291c57..540af849 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -85,7 +85,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { dial = (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - }).Dial + }).DialContext } // Cache a single transport for these options c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ @@ -93,7 +93,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, MaxIdleConnsPerHost: idleConnsPerHost, - Dial: dial, + DialContext: dial, }) return c.transports[key], nil } diff --git a/transport/cache_test.go b/transport/cache_test.go index d3d14099..61f3affc 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "net" "net/http" "testing" @@ -52,10 +53,11 @@ func TestTLSConfigKey(t *testing.T) { } // Make sure config fields that affect the tls config affect the cache key + dialer := net.Dialer{} uniqueConfigurations := map[string]*Config{ "no tls": {}, - "dialer": {Dial: net.Dial}, - "dialer2": {Dial: func(network, address string) (net.Conn, error) { return nil, nil }}, + "dialer": {Dial: dialer.DialContext}, + "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}, "insecure": {TLS: TLSConfig{Insecure: true}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, diff --git a/transport/config.go b/transport/config.go index af347daf..90f705d2 100644 --- a/transport/config.go +++ b/transport/config.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "net" "net/http" ) @@ -53,7 +54,7 @@ type Config struct { WrapTransport func(rt http.RoundTripper) http.RoundTripper // Dial specifies the dial function for creating unencrypted TCP connections. - Dial func(network, addr string) (net.Conn, error) + Dial func(ctx context.Context, network, address string) (net.Conn, error) } // ImpersonationConfig has all the available impersonation options