diff --git a/discovery/discovery_client.go b/discovery/discovery_client.go index 08350861..8beecb2f 100644 --- a/discovery/discovery_client.go +++ b/discovery/discovery_client.go @@ -22,6 +22,7 @@ import ( "net/url" "sort" "strings" + "time" "github.com/golang/protobuf/proto" "github.com/googleapis/gnostic/OpenAPIv2" @@ -43,6 +44,13 @@ const ( mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf" ) +var ( + // defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient. + // Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist. + // It's a variable to be able to change it in tests. + defaultTimeout = 32 * time.Second +) + // DiscoveryInterface holds the methods that discover server-supported API groups, // versions and resources. type DiscoveryInterface interface { @@ -373,6 +381,9 @@ func withRetries(maxRetries int, f func() ([]*metav1.APIResourceList, error)) ([ func setDiscoveryDefaults(config *restclient.Config) error { config.APIPath = "" config.GroupVersion = nil + if config.Timeout == 0 { + config.Timeout = defaultTimeout + } codec := runtime.NoopEncoder{Decoder: scheme.Codecs.UniversalDecoder()} config.NegotiatedSerializer = serializer.NegotiatedSerializerWrapper(runtime.SerializerInfo{Serializer: codec}) if len(config.UserAgent) == 0 { diff --git a/discovery/discovery_client_test.go b/discovery/discovery_client_test.go index 6f4de6b4..f2f5ae89 100644 --- a/discovery/discovery_client_test.go +++ b/discovery/discovery_client_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package discovery_test +package discovery import ( "encoding/json" @@ -23,7 +23,9 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "testing" + "time" "github.com/gogo/protobuf/proto" "github.com/googleapis/gnostic/OpenAPIv2" @@ -32,7 +34,6 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/version" - . "k8s.io/client-go/discovery" restclient "k8s.io/client-go/rest" ) @@ -129,6 +130,21 @@ func TestGetServerGroupsWithBrokenServer(t *testing.T) { } } } +func TestGetServerGroupsWithTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + tmp := defaultTimeout + defaultTimeout = 1 * time.Second + client := NewDiscoveryClientForConfigOrDie(&restclient.Config{Host: server.URL}) + _, err := client.ServerGroups() + if err == nil || strings.Contains(err.Error(), "deadline") { + t.Fatalf("unexpected error: %v", err) + } + defaultTimeout = tmp +} func TestGetServerResourcesWithV1Server(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { diff --git a/rest/request.go b/rest/request.go index 6ca9e019..9df0b444 100644 --- a/rest/request.go +++ b/rest/request.go @@ -353,8 +353,8 @@ func (r *Request) SetHeader(key string, values ...string) *Request { return r } -// Timeout makes the request use the given duration as a timeout. Sets the "timeout" -// parameter. +// Timeout makes the request use the given duration as an overall timeout for the +// request. Additionally, if set passes the value as "timeout" parameter in URL. func (r *Request) Timeout(d time.Duration) *Request { if r.err != nil { return r @@ -640,7 +640,6 @@ func (r *Request) request(fn func(*http.Request, *http.Response)) error { } // Right now we make about ten retry attempts if we get a Retry-After response. - // TODO: Change to a timeout based approach. maxRetries := 10 retries := 0 for { @@ -649,6 +648,14 @@ func (r *Request) request(fn func(*http.Request, *http.Response)) error { if err != nil { return err } + if r.timeout > 0 { + if r.ctx == nil { + r.ctx = context.Background() + } + var cancelFn context.CancelFunc + r.ctx, cancelFn = context.WithTimeout(r.ctx, r.timeout) + defer cancelFn() + } if r.ctx != nil { req = req.WithContext(r.ctx) }