Make wrapping a client transport more pleasant

Properly wrapping a transport can be tricky. Make the normal case
(adding a non-nil transport wrapper to a config) easier with a helper.
Also enforce a rough ordering, which in the future we can use to
simplify the WrapTransport mechanism down into an array of functions
we execute in order and avoid wrapping altogether.
This commit is contained in:
Clayton Coleman 2018-12-27 11:47:50 -05:00
parent 09890b6c48
commit 1f590e697e
No known key found for this signature in database
GPG Key ID: 3D16906B4F1C5CB3
8 changed files with 158 additions and 26 deletions

View File

@ -25,7 +25,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/googleapis/gnostic/OpenAPIv2" openapi_v2 "github.com/googleapis/gnostic/OpenAPIv2"
"k8s.io/klog" "k8s.io/klog"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -266,13 +266,10 @@ func NewCachedDiscoveryClientForConfig(config *restclient.Config, discoveryCache
if len(httpCacheDir) > 0 { if len(httpCacheDir) > 0 {
// update the given restconfig with a custom roundtripper that // update the given restconfig with a custom roundtripper that
// understands how to handle cache responses. // understands how to handle cache responses.
wt := config.WrapTransport config = restclient.CopyConfig(config)
config.WrapTransport = func(rt http.RoundTripper) http.RoundTripper { config.Wrap(func(rt http.RoundTripper) http.RoundTripper {
if wt != nil {
rt = wt(rt)
}
return newCacheRoundTripper(httpCacheDir, rt) return newCacheRoundTripper(httpCacheDir, rt)
} })
} }
discoveryClient, err := NewDiscoveryClientForConfig(config) discoveryClient, err := NewDiscoveryClientForConfig(config)

View File

@ -32,7 +32,7 @@ import (
"time" "time"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
"k8s.io/apimachinery/pkg/apis/meta/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apimachinery/pkg/runtime/serializer"
@ -172,13 +172,9 @@ type credentials struct {
// UpdateTransportConfig updates the transport.Config to use credentials // UpdateTransportConfig updates the transport.Config to use credentials
// returned by the plugin. // returned by the plugin.
func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error { func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
wt := c.WrapTransport c.Wrap(func(rt http.RoundTripper) http.RoundTripper {
c.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
if wt != nil {
rt = wt(rt)
}
return &roundTripper{a, rt} return &roundTripper{a, rt}
} })
if c.TLS.GetCert != nil { if c.TLS.GetCert != nil {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set") return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")

View File

@ -34,6 +34,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/pkg/version" "k8s.io/client-go/pkg/version"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api" clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/transport"
certutil "k8s.io/client-go/util/cert" certutil "k8s.io/client-go/util/cert"
"k8s.io/client-go/util/flowcontrol" "k8s.io/client-go/util/flowcontrol"
"k8s.io/klog" "k8s.io/klog"
@ -95,13 +96,16 @@ type Config struct {
// Transport may be used for custom HTTP behavior. This attribute may not // Transport may be used for custom HTTP behavior. This attribute may not
// be specified with the TLS client certificate options. Use WrapTransport // be specified with the TLS client certificate options. Use WrapTransport
// for most client level operations. // to provide additional per-server middleware behavior.
Transport http.RoundTripper Transport http.RoundTripper
// WrapTransport will be invoked for custom HTTP behavior after the underlying // WrapTransport will be invoked for custom HTTP behavior after the underlying
// transport is initialized (either the transport created from TLSClientConfig, // transport is initialized (either the transport created from TLSClientConfig,
// Transport, or http.DefaultTransport). The config may layer other RoundTrippers // Transport, or http.DefaultTransport). The config may layer other RoundTrippers
// on top of the returned RoundTripper. // on top of the returned RoundTripper.
WrapTransport func(rt http.RoundTripper) http.RoundTripper //
// A future release will change this field to an array. Use config.Wrap()
// instead of setting this value directly.
WrapTransport transport.WrapperFunc
// QPS indicates the maximum QPS to the master from this client. // QPS indicates the maximum QPS to the master from this client.
// If it's zero, the created RESTClient will use DefaultQPS: 5 // If it's zero, the created RESTClient will use DefaultQPS: 5

View File

@ -27,12 +27,13 @@ import (
"strings" "strings"
"testing" "testing"
"k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/diff" "k8s.io/apimachinery/pkg/util/diff"
"k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/kubernetes/scheme"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api" clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/transport"
"k8s.io/client-go/util/flowcontrol" "k8s.io/client-go/util/flowcontrol"
fuzz "github.com/google/gofuzz" fuzz "github.com/google/gofuzz"
@ -236,6 +237,9 @@ func TestAnonymousConfig(t *testing.T) {
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) { func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
*fn = fakeWrapperFunc *fn = fakeWrapperFunc
}, },
func(fn *transport.WrapperFunc, f fuzz.Continue) {
*fn = fakeWrapperFunc
},
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) { func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
serializer := &fakeNegotiatedSerializer{} serializer := &fakeNegotiatedSerializer{}
f.Fuzz(serializer) f.Fuzz(serializer)
@ -316,6 +320,9 @@ func TestCopyConfig(t *testing.T) {
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) { func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
*fn = fakeWrapperFunc *fn = fakeWrapperFunc
}, },
func(fn *transport.WrapperFunc, f fuzz.Continue) {
*fn = fakeWrapperFunc
},
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) { func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
serializer := &fakeNegotiatedSerializer{} serializer := &fakeNegotiatedSerializer{}
f.Fuzz(serializer) f.Fuzz(serializer)

View File

@ -103,14 +103,15 @@ func (c *Config) TransportConfig() (*transport.Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
wt := conf.WrapTransport conf.Wrap(provider.WrapTransport)
if wt != nil {
conf.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
return provider.WrapTransport(wt(rt))
}
} else {
conf.WrapTransport = provider.WrapTransport
}
} }
return conf, nil return conf, nil
} }
// Wrap adds a transport middleware function that will give the caller
// an opportunity to wrap the underlying http.RoundTripper prior to the
// first API call being made. The provided function is invoked after any
// existing transport wrappers are invoked.
func (c *Config) Wrap(fn transport.WrapperFunc) {
c.WrapTransport = transport.Wrappers(c.WrapTransport, fn)
}

View File

@ -57,7 +57,10 @@ type Config struct {
// from TLSClientConfig, Transport, or http.DefaultTransport). The // from TLSClientConfig, Transport, or http.DefaultTransport). The
// config may layer other RoundTrippers on top of the returned // config may layer other RoundTrippers on top of the returned
// RoundTripper. // RoundTripper.
WrapTransport func(rt http.RoundTripper) http.RoundTripper //
// A future release will change this field to an array. Use config.Wrap()
// instead of setting this value directly.
WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections. // Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(ctx context.Context, network, address string) (net.Conn, error) Dial func(ctx context.Context, network, address string) (net.Conn, error)
@ -98,6 +101,14 @@ func (c *Config) HasCertCallback() bool {
return c.TLS.GetCert != nil return c.TLS.GetCert != nil
} }
// Wrap adds a transport middleware function that will give the caller
// an opportunity to wrap the underlying http.RoundTripper prior to the
// first API call being made. The provided function is invoked after any
// existing transport wrappers are invoked.
func (c *Config) Wrap(fn WrapperFunc) {
c.WrapTransport = Wrappers(c.WrapTransport, fn)
}
// TLSConfig holds the information needed to set up a TLS transport. // TLSConfig holds the information needed to set up a TLS transport.
type TLSConfig struct { type TLSConfig struct {
CAFile string // Path of the PEM-encoded server trusted root certificates. CAFile string // Path of the PEM-encoded server trusted root certificates.

View File

@ -167,3 +167,32 @@ func rootCertPool(caData []byte) *x509.CertPool {
certPool.AppendCertsFromPEM(caData) certPool.AppendCertsFromPEM(caData)
return certPool return certPool
} }
// WrapperFunc wraps an http.RoundTripper when a new transport
// is created for a client, allowing per connection behavior
// to be injected.
type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
// Wrappers accepts any number of wrappers and returns a wrapper
// function that is the equivalent of calling each of them in order. Nil
// values are ignored, which makes this function convenient for incrementally
// wrapping a function.
func Wrappers(fns ...WrapperFunc) WrapperFunc {
if len(fns) == 0 {
return nil
}
// optimize the common case of wrapping a possibly nil transport wrapper
// with an additional wrapper
if len(fns) == 2 && fns[0] == nil {
return fns[1]
}
return func(rt http.RoundTripper) http.RoundTripper {
base := rt
for _, fn := range fns {
if fn != nil {
base = fn(base)
}
}
return base
}
}

View File

@ -310,3 +310,90 @@ func TestNew(t *testing.T) {
}) })
} }
} }
type fakeRoundTripper struct {
Req *http.Request
Resp *http.Response
Err error
}
func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.Req = req
return rt.Resp, rt.Err
}
type chainRoundTripper struct {
rt http.RoundTripper
value string
}
func testChain(value string) WrapperFunc {
return func(rt http.RoundTripper) http.RoundTripper {
return &chainRoundTripper{rt: rt, value: value}
}
}
func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rt.rt.RoundTrip(req)
if resp != nil {
if resp.Header == nil {
resp.Header = make(http.Header)
}
resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
}
return resp, err
}
func TestWrappers(t *testing.T) {
resp1 := &http.Response{}
wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
return &fakeRoundTripper{Resp: resp1}
}
resp2 := &http.Response{}
wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
return &fakeRoundTripper{Resp: resp2}
}
tests := []struct {
name string
fns []WrapperFunc
wantNil bool
want func(*http.Response) bool
}{
{fns: []WrapperFunc{}, wantNil: true},
{fns: []WrapperFunc{nil, nil}, wantNil: true},
{fns: []WrapperFunc{nil}, wantNil: false},
{fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
{fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
{fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
{fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
{fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
{fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Wrappers(tt.fns...)
if got == nil != tt.wantNil {
t.Errorf("Wrappers() = %v", got)
return
}
if got == nil {
return
}
rt := &fakeRoundTripper{Resp: &http.Response{}}
nested := got(rt)
req := &http.Request{}
resp, _ := nested.RoundTrip(req)
if tt.want != nil && !tt.want(resp) {
t.Errorf("unexpected response: %#v", resp)
}
})
}
}