diff --git a/rest/config.go b/rest/config.go index e179e012..4909dc53 100644 --- a/rest/config.go +++ b/rest/config.go @@ -305,6 +305,8 @@ type ContentConfig struct { // object. Note that a RESTClient may require fields that are optional when initializing a Client. // A RESTClient created by this method is generic - it expects to operate on an API that follows // the Kubernetes conventions, but may not be the Kubernetes API. +// RESTClientFor is equivalent to calling RESTClientForConfigAndClient(config, httpClient), +// where httpClient was generated with HTTPClientFor(config). func RESTClientFor(config *Config) (*RESTClient, error) { if config.GroupVersion == nil { return nil, fmt.Errorf("GroupVersion is required when initializing a RESTClient") @@ -313,24 +315,40 @@ func RESTClientFor(config *Config) (*RESTClient, error) { return nil, fmt.Errorf("NegotiatedSerializer is required when initializing a RESTClient") } + // Validate config.Host before constructing the transport/client so we can fail fast. + // ServerURL will be obtained later in RESTClientForConfigAndClient() + _, _, err := defaultServerUrlFor(config) + if err != nil { + return nil, err + } + + httpClient, err := HTTPClientFor(config) + if err != nil { + return nil, err + } + + return RESTClientForConfigAndClient(config, httpClient) +} + +// RESTClientForConfigAndClient returns a RESTClient that satisfies the requested attributes on a +// client Config object. +// Unlike RESTClientFor, RESTClientForConfigAndClient allows to pass an http.Client that is shared +// between all the API Groups and Versions. +// Note that the http client takes precedence over the transport values configured. +// The http client defaults to the `http.DefaultClient` if nil. +func RESTClientForConfigAndClient(config *Config, httpClient *http.Client) (*RESTClient, error) { + if config.GroupVersion == nil { + return nil, fmt.Errorf("GroupVersion is required when initializing a RESTClient") + } + if config.NegotiatedSerializer == nil { + return nil, fmt.Errorf("NegotiatedSerializer is required when initializing a RESTClient") + } + baseURL, versionedAPIPath, err := defaultServerUrlFor(config) if err != nil { return nil, err } - transport, err := TransportFor(config) - if err != nil { - return nil, err - } - - var httpClient *http.Client - if transport != http.DefaultTransport { - httpClient = &http.Client{Transport: transport} - if config.Timeout > 0 { - httpClient.Timeout = config.Timeout - } - } - rateLimiter := config.RateLimiter if rateLimiter == nil { qps := config.QPS @@ -371,24 +389,33 @@ func UnversionedRESTClientFor(config *Config) (*RESTClient, error) { return nil, fmt.Errorf("NegotiatedSerializer is required when initializing a RESTClient") } + // Validate config.Host before constructing the transport/client so we can fail fast. + // ServerURL will be obtained later in UnversionedRESTClientForConfigAndClient() + _, _, err := defaultServerUrlFor(config) + if err != nil { + return nil, err + } + + httpClient, err := HTTPClientFor(config) + if err != nil { + return nil, err + } + + return UnversionedRESTClientForConfigAndClient(config, httpClient) +} + +// UnversionedRESTClientForConfigAndClient is the same as RESTClientForConfigAndClient, +// except that it allows the config.Version to be empty. +func UnversionedRESTClientForConfigAndClient(config *Config, httpClient *http.Client) (*RESTClient, error) { + if config.NegotiatedSerializer == nil { + return nil, fmt.Errorf("NegotiatedSerializer is required when initializing a RESTClient") + } + baseURL, versionedAPIPath, err := defaultServerUrlFor(config) if err != nil { return nil, err } - transport, err := TransportFor(config) - if err != nil { - return nil, err - } - - var httpClient *http.Client - if transport != http.DefaultTransport { - httpClient = &http.Client{Transport: transport} - if config.Timeout > 0 { - httpClient.Timeout = config.Timeout - } - } - rateLimiter := config.RateLimiter if rateLimiter == nil { qps := config.QPS diff --git a/rest/connection_test.go b/rest/connection_test.go index e58aff19..70fd2aa1 100644 --- a/rest/connection_test.go +++ b/rest/connection_test.go @@ -26,6 +26,7 @@ import ( "net/url" "os" "strconv" + "strings" "sync/atomic" "testing" "time" @@ -162,3 +163,33 @@ func TestReconnectBrokenTCP(t *testing.T) { t.Fatalf("expected %d dials, got %d", 2, dials) } } + +func TestRestClientTimeout(t *testing.T) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + fmt.Fprintf(w, "Hello, %s", r.Proto) + })) + ts.Start() + defer ts.Close() + + config := &Config{ + Host: ts.URL, + Timeout: 1 * time.Second, + // These fields are required to create a REST client. + ContentConfig: ContentConfig{ + GroupVersion: &schema.GroupVersion{}, + NegotiatedSerializer: &serializer.CodecFactory{}, + }, + } + client, err := RESTClientFor(config) + if err != nil { + t.Fatalf("failed to create REST client: %v", err) + } + _, err = client.Get().AbsPath("/").DoRaw(context.TODO()) + if err == nil { + t.Fatalf("timeout error expected") + } + if !strings.Contains(err.Error(), "deadline exceeded") { + t.Fatalf("timeout error expected, received %v", err) + } +} diff --git a/rest/transport.go b/rest/transport.go index 57d9215c..7c38c6d9 100644 --- a/rest/transport.go +++ b/rest/transport.go @@ -26,6 +26,27 @@ import ( "k8s.io/client-go/transport" ) +// HTTPClientFor returns an http.Client that will provide the authentication +// or transport level security defined by the provided Config. Will return the +// default http.DefaultClient if no special case behavior is needed. +func HTTPClientFor(config *Config) (*http.Client, error) { + transport, err := TransportFor(config) + if err != nil { + return nil, err + } + var httpClient *http.Client + if transport != http.DefaultTransport || config.Timeout > 0 { + httpClient = &http.Client{ + Transport: transport, + Timeout: config.Timeout, + } + } else { + httpClient = http.DefaultClient + } + + return httpClient, nil +} + // TLSConfigFor returns a tls.Config that will provide the transport level security defined // by the provided Config. Will return nil if no transport level security is requested. func TLSConfigFor(config *Config) (*tls.Config, error) {