diff --git a/discovery/cached_discovery.go b/discovery/cached_discovery.go index df69d6a1..61a758c0 100644 --- a/discovery/cached_discovery.go +++ b/discovery/cached_discovery.go @@ -25,7 +25,7 @@ import ( "sync" "time" - "github.com/googleapis/gnostic/OpenAPIv2" + openapi_v2 "github.com/googleapis/gnostic/OpenAPIv2" "k8s.io/klog" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -266,13 +266,10 @@ func NewCachedDiscoveryClientForConfig(config *restclient.Config, discoveryCache if len(httpCacheDir) > 0 { // update the given restconfig with a custom roundtripper that // understands how to handle cache responses. - wt := config.WrapTransport - config.WrapTransport = func(rt http.RoundTripper) http.RoundTripper { - if wt != nil { - rt = wt(rt) - } + config = restclient.CopyConfig(config) + config.Wrap(func(rt http.RoundTripper) http.RoundTripper { return newCacheRoundTripper(httpCacheDir, rt) - } + }) } discoveryClient, err := NewDiscoveryClientForConfig(config) diff --git a/plugin/pkg/client/auth/exec/exec.go b/plugin/pkg/client/auth/exec/exec.go index 4d725265..be4814bc 100644 --- a/plugin/pkg/client/auth/exec/exec.go +++ b/plugin/pkg/client/auth/exec/exec.go @@ -32,7 +32,7 @@ import ( "time" "golang.org/x/crypto/ssh/terminal" - "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/serializer" @@ -172,13 +172,9 @@ type credentials struct { // UpdateTransportConfig updates the transport.Config to use credentials // returned by the plugin. func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error { - wt := c.WrapTransport - c.WrapTransport = func(rt http.RoundTripper) http.RoundTripper { - if wt != nil { - rt = wt(rt) - } + c.Wrap(func(rt http.RoundTripper) http.RoundTripper { return &roundTripper{a, rt} - } + }) if c.TLS.GetCert != nil { return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set") diff --git a/rest/config.go b/rest/config.go index 072e7392..271693c2 100644 --- a/rest/config.go +++ b/rest/config.go @@ -34,6 +34,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/pkg/version" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" + "k8s.io/client-go/transport" certutil "k8s.io/client-go/util/cert" "k8s.io/client-go/util/flowcontrol" "k8s.io/klog" @@ -95,13 +96,16 @@ type Config struct { // Transport may be used for custom HTTP behavior. This attribute may not // be specified with the TLS client certificate options. Use WrapTransport - // for most client level operations. + // to provide additional per-server middleware behavior. Transport http.RoundTripper // WrapTransport will be invoked for custom HTTP behavior after the underlying // transport is initialized (either the transport created from TLSClientConfig, // Transport, or http.DefaultTransport). The config may layer other RoundTrippers // on top of the returned RoundTripper. - WrapTransport func(rt http.RoundTripper) http.RoundTripper + // + // A future release will change this field to an array. Use config.Wrap() + // instead of setting this value directly. + WrapTransport transport.WrapperFunc // QPS indicates the maximum QPS to the master from this client. // If it's zero, the created RESTClient will use DefaultQPS: 5 diff --git a/rest/config_test.go b/rest/config_test.go index 22c18d77..8f5cce67 100644 --- a/rest/config_test.go +++ b/rest/config_test.go @@ -27,12 +27,13 @@ import ( "strings" "testing" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/diff" "k8s.io/client-go/kubernetes/scheme" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" + "k8s.io/client-go/transport" "k8s.io/client-go/util/flowcontrol" fuzz "github.com/google/gofuzz" @@ -236,6 +237,9 @@ func TestAnonymousConfig(t *testing.T) { func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) { *fn = fakeWrapperFunc }, + func(fn *transport.WrapperFunc, f fuzz.Continue) { + *fn = fakeWrapperFunc + }, func(r *runtime.NegotiatedSerializer, f fuzz.Continue) { serializer := &fakeNegotiatedSerializer{} f.Fuzz(serializer) @@ -316,6 +320,9 @@ func TestCopyConfig(t *testing.T) { func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) { *fn = fakeWrapperFunc }, + func(fn *transport.WrapperFunc, f fuzz.Continue) { + *fn = fakeWrapperFunc + }, func(r *runtime.NegotiatedSerializer, f fuzz.Continue) { serializer := &fakeNegotiatedSerializer{} f.Fuzz(serializer) diff --git a/rest/transport.go b/rest/transport.go index 25c1801b..bd5749dc 100644 --- a/rest/transport.go +++ b/rest/transport.go @@ -103,14 +103,15 @@ func (c *Config) TransportConfig() (*transport.Config, error) { if err != nil { return nil, err } - wt := conf.WrapTransport - if wt != nil { - conf.WrapTransport = func(rt http.RoundTripper) http.RoundTripper { - return provider.WrapTransport(wt(rt)) - } - } else { - conf.WrapTransport = provider.WrapTransport - } + conf.Wrap(provider.WrapTransport) } return conf, nil } + +// Wrap adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper prior to the +// first API call being made. The provided function is invoked after any +// existing transport wrappers are invoked. +func (c *Config) Wrap(fn transport.WrapperFunc) { + c.WrapTransport = transport.Wrappers(c.WrapTransport, fn) +} diff --git a/transport/config.go b/transport/config.go index acb126d8..5de0a2cb 100644 --- a/transport/config.go +++ b/transport/config.go @@ -57,7 +57,10 @@ type Config struct { // from TLSClientConfig, Transport, or http.DefaultTransport). The // config may layer other RoundTrippers on top of the returned // RoundTripper. - WrapTransport func(rt http.RoundTripper) http.RoundTripper + // + // A future release will change this field to an array. Use config.Wrap() + // instead of setting this value directly. + WrapTransport WrapperFunc // Dial specifies the dial function for creating unencrypted TCP connections. Dial func(ctx context.Context, network, address string) (net.Conn, error) @@ -98,6 +101,14 @@ func (c *Config) HasCertCallback() bool { return c.TLS.GetCert != nil } +// Wrap adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper prior to the +// first API call being made. The provided function is invoked after any +// existing transport wrappers are invoked. +func (c *Config) Wrap(fn WrapperFunc) { + c.WrapTransport = Wrappers(c.WrapTransport, fn) +} + // TLSConfig holds the information needed to set up a TLS transport. type TLSConfig struct { CAFile string // Path of the PEM-encoded server trusted root certificates. diff --git a/transport/transport.go b/transport/transport.go index c19739fd..f62f8003 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -167,3 +167,32 @@ func rootCertPool(caData []byte) *x509.CertPool { certPool.AppendCertsFromPEM(caData) return certPool } + +// WrapperFunc wraps an http.RoundTripper when a new transport +// is created for a client, allowing per connection behavior +// to be injected. +type WrapperFunc func(rt http.RoundTripper) http.RoundTripper + +// Wrappers accepts any number of wrappers and returns a wrapper +// function that is the equivalent of calling each of them in order. Nil +// values are ignored, which makes this function convenient for incrementally +// wrapping a function. +func Wrappers(fns ...WrapperFunc) WrapperFunc { + if len(fns) == 0 { + return nil + } + // optimize the common case of wrapping a possibly nil transport wrapper + // with an additional wrapper + if len(fns) == 2 && fns[0] == nil { + return fns[1] + } + return func(rt http.RoundTripper) http.RoundTripper { + base := rt + for _, fn := range fns { + if fn != nil { + base = fn(base) + } + } + return base + } +} diff --git a/transport/transport_test.go b/transport/transport_test.go index eead38aa..66850121 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -310,3 +310,90 @@ func TestNew(t *testing.T) { }) } } + +type fakeRoundTripper struct { + Req *http.Request + Resp *http.Response + Err error +} + +func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.Req = req + return rt.Resp, rt.Err +} + +type chainRoundTripper struct { + rt http.RoundTripper + value string +} + +func testChain(value string) WrapperFunc { + return func(rt http.RoundTripper) http.RoundTripper { + return &chainRoundTripper{rt: rt, value: value} + } +} + +func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := rt.rt.RoundTrip(req) + if resp != nil { + if resp.Header == nil { + resp.Header = make(http.Header) + } + resp.Header.Set("Value", resp.Header.Get("Value")+rt.value) + } + return resp, err +} + +func TestWrappers(t *testing.T) { + resp1 := &http.Response{} + wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper { + return &fakeRoundTripper{Resp: resp1} + } + resp2 := &http.Response{} + wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper { + return &fakeRoundTripper{Resp: resp2} + } + + tests := []struct { + name string + fns []WrapperFunc + wantNil bool + want func(*http.Response) bool + }{ + {fns: []WrapperFunc{}, wantNil: true}, + {fns: []WrapperFunc{nil, nil}, wantNil: true}, + {fns: []WrapperFunc{nil}, wantNil: false}, + + {fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }}, + {fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }}, + {fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }}, + {fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }}, + {fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }}, + {fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }}, + + {fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }}, + {fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }}, + {fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }}, + {fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Wrappers(tt.fns...) + if got == nil != tt.wantNil { + t.Errorf("Wrappers() = %v", got) + return + } + if got == nil { + return + } + + rt := &fakeRoundTripper{Resp: &http.Response{}} + nested := got(rt) + req := &http.Request{} + resp, _ := nested.RoundTrip(req) + if tt.want != nil && !tt.want(resp) { + t.Errorf("unexpected response: %#v", resp) + } + }) + } +}