Use Dial with context

Kubernetes-commit: 5e8e570dbda6ed89af9bc2e0a05e3d94bfdfcb61
This commit is contained in:
Mikhail Mazurskiy 2018-05-19 08:14:37 +10:00 committed by Kubernetes Publisher
parent 4bb327ea2f
commit 4a75b93cb4
7 changed files with 28 additions and 50 deletions

View File

@ -44,12 +44,8 @@ const (
defaultRetries = 2
// protobuf mime type
mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf"
)
var (
// defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient.
// Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist.
// It's a variable to be able to change it in tests.
defaultTimeout = 32 * time.Second
)

View File

@ -23,12 +23,11 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gogo/protobuf/proto"
"github.com/googleapis/gnostic/OpenAPIv2"
"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
@ -131,31 +130,11 @@ func TestGetServerGroupsWithBrokenServer(t *testing.T) {
}
}
}
func TestGetServerGroupsWithTimeout(t *testing.T) {
done := make(chan bool)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// first we need to write headers, otherwise http client will complain about
// exceeding timeout awaiting headers, only after we can block the call
w.Header().Set("Connection", "keep-alive")
if wf, ok := w.(http.Flusher); ok {
wf.Flush()
}
<-done
}))
defer server.Close()
defer close(done)
client := NewDiscoveryClientForConfigOrDie(&restclient.Config{Host: server.URL, Timeout: 2 * time.Second})
_, err := client.ServerGroups()
// the error we're getting here is wrapped in errors.errorString which makes
// it impossible to unwrap and check it's attributes, so instead we're checking
// the textual output which is presenting http.httpError with timeout set to true
if err == nil {
t.Fatal("missing error")
}
if !strings.Contains(err.Error(), "timeout:true") &&
!strings.Contains(err.Error(), "context.deadlineExceededError") {
t.Fatalf("unexpected error: %v", err)
}
func TestTimeoutIsSet(t *testing.T) {
cfg := &restclient.Config{}
setDiscoveryDefaults(cfg)
assert.Equal(t, defaultTimeout, cfg.Timeout)
}
func TestGetServerResourcesWithV1Server(t *testing.T) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package rest
import (
"context"
"fmt"
"io/ioutil"
"net"
@ -110,7 +111,7 @@ type Config struct {
Timeout time.Duration
// Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, address string) (net.Conn, error)
// Version forces a specific version to be used (if registered)
// Do we need this?

View File

@ -17,6 +17,8 @@ limitations under the License.
package rest
import (
"context"
"errors"
"io"
"net"
"net/http"
@ -25,8 +27,6 @@ import (
"strings"
"testing"
fuzz "github.com/google/gofuzz"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
@ -35,8 +35,7 @@ import (
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/util/flowcontrol"
"errors"
fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/assert"
)
@ -208,7 +207,7 @@ func (n *fakeNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder,
return &fakeCodec{}
}
var fakeDialFunc = func(network, addr string) (net.Conn, error) {
var fakeDialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, fakeDialerError
}
var fakeDialerError = errors.New("fakedialer")
@ -253,7 +252,7 @@ func TestAnonymousConfig(t *testing.T) {
r.Config = map[string]string{}
},
// Dial does not require fuzzer
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {},
func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {},
)
for i := 0; i < 20; i++ {
original := &Config{}
@ -284,10 +283,10 @@ func TestAnonymousConfig(t *testing.T) {
expected.WrapTransport = nil
}
if actual.Dial != nil {
_, actualError := actual.Dial("", "")
_, expectedError := actual.Dial("", "")
_, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field")
t.Fatalf("CopyConfig dropped the Dial field")
}
} else {
actual.Dial = nil
@ -329,7 +328,7 @@ func TestCopyConfig(t *testing.T) {
func(r *AuthProviderConfigPersister, f fuzz.Continue) {
*r = fakeAuthProviderConfigPersister{}
},
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {
func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {
*r = fakeDialFunc
},
)
@ -351,8 +350,8 @@ func TestCopyConfig(t *testing.T) {
expected.WrapTransport = nil
}
if actual.Dial != nil {
_, actualError := actual.Dial("", "")
_, expectedError := actual.Dial("", "")
_, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field")
}
@ -361,7 +360,7 @@ func TestCopyConfig(t *testing.T) {
expected.Dial = nil
if actual.AuthConfigPersister != nil {
actualError := actual.AuthConfigPersister.Persist(nil)
expectedError := actual.AuthConfigPersister.Persist(nil)
expectedError := expected.AuthConfigPersister.Persist(nil)
if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field")
}

View File

@ -85,7 +85,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial
}).DialContext
}
// Cache a single transport for these options
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
@ -93,7 +93,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: idleConnsPerHost,
Dial: dial,
DialContext: dial,
})
return c.transports[key], nil
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"net"
"net/http"
"testing"
@ -52,10 +53,11 @@ func TestTLSConfigKey(t *testing.T) {
}
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
uniqueConfigurations := map[string]*Config{
"no tls": {},
"dialer": {Dial: net.Dial},
"dialer2": {Dial: func(network, address string) (net.Conn, error) { return nil, nil }},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"net"
"net/http"
)
@ -53,7 +54,7 @@ type Config struct {
WrapTransport func(rt http.RoundTripper) http.RoundTripper
// Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options