Use Dial with context

Kubernetes-commit: 5e8e570dbda6ed89af9bc2e0a05e3d94bfdfcb61
This commit is contained in:
Mikhail Mazurskiy 2018-05-19 08:14:37 +10:00 committed by Kubernetes Publisher
parent 4bb327ea2f
commit 4a75b93cb4
7 changed files with 28 additions and 50 deletions

View File

@ -44,12 +44,8 @@ const (
defaultRetries = 2 defaultRetries = 2
// protobuf mime type // protobuf mime type
mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf" mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf"
)
var (
// defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient. // defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient.
// Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist. // Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist.
// It's a variable to be able to change it in tests.
defaultTimeout = 32 * time.Second defaultTimeout = 32 * time.Second
) )

View File

@ -23,12 +23,11 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"strings"
"testing" "testing"
"time"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/googleapis/gnostic/OpenAPIv2" "github.com/googleapis/gnostic/OpenAPIv2"
"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
@ -131,31 +130,11 @@ func TestGetServerGroupsWithBrokenServer(t *testing.T) {
} }
} }
} }
func TestGetServerGroupsWithTimeout(t *testing.T) {
done := make(chan bool) func TestTimeoutIsSet(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { cfg := &restclient.Config{}
// first we need to write headers, otherwise http client will complain about setDiscoveryDefaults(cfg)
// exceeding timeout awaiting headers, only after we can block the call assert.Equal(t, defaultTimeout, cfg.Timeout)
w.Header().Set("Connection", "keep-alive")
if wf, ok := w.(http.Flusher); ok {
wf.Flush()
}
<-done
}))
defer server.Close()
defer close(done)
client := NewDiscoveryClientForConfigOrDie(&restclient.Config{Host: server.URL, Timeout: 2 * time.Second})
_, err := client.ServerGroups()
// the error we're getting here is wrapped in errors.errorString which makes
// it impossible to unwrap and check it's attributes, so instead we're checking
// the textual output which is presenting http.httpError with timeout set to true
if err == nil {
t.Fatal("missing error")
}
if !strings.Contains(err.Error(), "timeout:true") &&
!strings.Contains(err.Error(), "context.deadlineExceededError") {
t.Fatalf("unexpected error: %v", err)
}
} }
func TestGetServerResourcesWithV1Server(t *testing.T) { func TestGetServerResourcesWithV1Server(t *testing.T) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package rest package rest
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -110,7 +111,7 @@ type Config struct {
Timeout time.Duration Timeout time.Duration
// Dial specifies the dial function for creating unencrypted TCP connections. // Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error) Dial func(ctx context.Context, network, address string) (net.Conn, error)
// Version forces a specific version to be used (if registered) // Version forces a specific version to be used (if registered)
// Do we need this? // Do we need this?

View File

@ -17,6 +17,8 @@ limitations under the License.
package rest package rest
import ( import (
"context"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -25,8 +27,6 @@ import (
"strings" "strings"
"testing" "testing"
fuzz "github.com/google/gofuzz"
"k8s.io/api/core/v1" "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
@ -35,8 +35,7 @@ import (
clientcmdapi "k8s.io/client-go/tools/clientcmd/api" clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/util/flowcontrol" "k8s.io/client-go/util/flowcontrol"
"errors" fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -208,7 +207,7 @@ func (n *fakeNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder,
return &fakeCodec{} return &fakeCodec{}
} }
var fakeDialFunc = func(network, addr string) (net.Conn, error) { var fakeDialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, fakeDialerError return nil, fakeDialerError
} }
var fakeDialerError = errors.New("fakedialer") var fakeDialerError = errors.New("fakedialer")
@ -253,7 +252,7 @@ func TestAnonymousConfig(t *testing.T) {
r.Config = map[string]string{} r.Config = map[string]string{}
}, },
// Dial does not require fuzzer // Dial does not require fuzzer
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {}, func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {},
) )
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
original := &Config{} original := &Config{}
@ -284,10 +283,10 @@ func TestAnonymousConfig(t *testing.T) {
expected.WrapTransport = nil expected.WrapTransport = nil
} }
if actual.Dial != nil { if actual.Dial != nil {
_, actualError := actual.Dial("", "") _, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := actual.Dial("", "") _, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) { if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field") t.Fatalf("CopyConfig dropped the Dial field")
} }
} else { } else {
actual.Dial = nil actual.Dial = nil
@ -329,7 +328,7 @@ func TestCopyConfig(t *testing.T) {
func(r *AuthProviderConfigPersister, f fuzz.Continue) { func(r *AuthProviderConfigPersister, f fuzz.Continue) {
*r = fakeAuthProviderConfigPersister{} *r = fakeAuthProviderConfigPersister{}
}, },
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) { func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {
*r = fakeDialFunc *r = fakeDialFunc
}, },
) )
@ -351,8 +350,8 @@ func TestCopyConfig(t *testing.T) {
expected.WrapTransport = nil expected.WrapTransport = nil
} }
if actual.Dial != nil { if actual.Dial != nil {
_, actualError := actual.Dial("", "") _, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := actual.Dial("", "") _, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) { if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field") t.Fatalf("CopyConfig dropped the Dial field")
} }
@ -361,7 +360,7 @@ func TestCopyConfig(t *testing.T) {
expected.Dial = nil expected.Dial = nil
if actual.AuthConfigPersister != nil { if actual.AuthConfigPersister != nil {
actualError := actual.AuthConfigPersister.Persist(nil) actualError := actual.AuthConfigPersister.Persist(nil)
expectedError := actual.AuthConfigPersister.Persist(nil) expectedError := expected.AuthConfigPersister.Persist(nil)
if !reflect.DeepEqual(expectedError, actualError) { if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field") t.Fatalf("CopyConfig dropped the Dial field")
} }

View File

@ -85,7 +85,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dial = (&net.Dialer{ dial = (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
}).Dial }).DialContext
} }
// Cache a single transport for these options // Cache a single transport for these options
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
@ -93,7 +93,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: idleConnsPerHost, MaxIdleConnsPerHost: idleConnsPerHost,
Dial: dial, DialContext: dial,
}) })
return c.transports[key], nil return c.transports[key], nil
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport package transport
import ( import (
"context"
"net" "net"
"net/http" "net/http"
"testing" "testing"
@ -52,10 +53,11 @@ func TestTLSConfigKey(t *testing.T) {
} }
// Make sure config fields that affect the tls config affect the cache key // Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
uniqueConfigurations := map[string]*Config{ uniqueConfigurations := map[string]*Config{
"no tls": {}, "no tls": {},
"dialer": {Dial: net.Dial}, "dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(network, address string) (net.Conn, error) { return nil, nil }}, "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"insecure": {TLS: TLSConfig{Insecure: true}}, "insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport package transport
import ( import (
"context"
"net" "net"
"net/http" "net/http"
) )
@ -53,7 +54,7 @@ type Config struct {
WrapTransport func(rt http.RoundTripper) http.RoundTripper WrapTransport func(rt http.RoundTripper) http.RoundTripper
// Dial specifies the dial function for creating unencrypted TCP connections. // Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error) Dial func(ctx context.Context, network, address string) (net.Conn, error)
} }
// ImpersonationConfig has all the available impersonation options // ImpersonationConfig has all the available impersonation options