Make wrapping a client transport more pleasant

Properly wrapping a transport can be tricky. Make the normal case
(adding a non-nil transport wrapper to a config) easier with a helper.
Also enforce a rough ordering, which in the future we can use to
simplify the WrapTransport mechanism down into an array of functions
we execute in order and avoid wrapping altogether.

Kubernetes-commit: 1f590e697ef64812620c787720b4b5942027e4a1
This commit is contained in:
Clayton Coleman 2018-12-27 11:47:50 -05:00 committed by Kubernetes Publisher
parent 30b06a83d6
commit 615e8e2492
8 changed files with 158 additions and 26 deletions

View File

@ -25,7 +25,7 @@ import (
"sync"
"time"
"github.com/googleapis/gnostic/OpenAPIv2"
openapi_v2 "github.com/googleapis/gnostic/OpenAPIv2"
"k8s.io/klog"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -266,13 +266,10 @@ func NewCachedDiscoveryClientForConfig(config *restclient.Config, discoveryCache
if len(httpCacheDir) > 0 {
// update the given restconfig with a custom roundtripper that
// understands how to handle cache responses.
wt := config.WrapTransport
config.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
if wt != nil {
rt = wt(rt)
}
config = restclient.CopyConfig(config)
config.Wrap(func(rt http.RoundTripper) http.RoundTripper {
return newCacheRoundTripper(httpCacheDir, rt)
}
})
}
discoveryClient, err := NewDiscoveryClientForConfig(config)

View File

@ -32,7 +32,7 @@ import (
"time"
"golang.org/x/crypto/ssh/terminal"
"k8s.io/apimachinery/pkg/apis/meta/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
@ -172,13 +172,9 @@ type credentials struct {
// UpdateTransportConfig updates the transport.Config to use credentials
// returned by the plugin.
func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
wt := c.WrapTransport
c.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
if wt != nil {
rt = wt(rt)
}
c.Wrap(func(rt http.RoundTripper) http.RoundTripper {
return &roundTripper{a, rt}
}
})
if c.TLS.GetCert != nil {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")

View File

@ -34,6 +34,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/pkg/version"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/transport"
certutil "k8s.io/client-go/util/cert"
"k8s.io/client-go/util/flowcontrol"
"k8s.io/klog"
@ -95,13 +96,16 @@ type Config struct {
// Transport may be used for custom HTTP behavior. This attribute may not
// be specified with the TLS client certificate options. Use WrapTransport
// for most client level operations.
// to provide additional per-server middleware behavior.
Transport http.RoundTripper
// WrapTransport will be invoked for custom HTTP behavior after the underlying
// transport is initialized (either the transport created from TLSClientConfig,
// Transport, or http.DefaultTransport). The config may layer other RoundTrippers
// on top of the returned RoundTripper.
WrapTransport func(rt http.RoundTripper) http.RoundTripper
//
// A future release will change this field to an array. Use config.Wrap()
// instead of setting this value directly.
WrapTransport transport.WrapperFunc
// QPS indicates the maximum QPS to the master from this client.
// If it's zero, the created RESTClient will use DefaultQPS: 5

View File

@ -27,12 +27,13 @@ import (
"strings"
"testing"
"k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/diff"
"k8s.io/client-go/kubernetes/scheme"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/transport"
"k8s.io/client-go/util/flowcontrol"
fuzz "github.com/google/gofuzz"
@ -236,6 +237,9 @@ func TestAnonymousConfig(t *testing.T) {
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
*fn = fakeWrapperFunc
},
func(fn *transport.WrapperFunc, f fuzz.Continue) {
*fn = fakeWrapperFunc
},
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
serializer := &fakeNegotiatedSerializer{}
f.Fuzz(serializer)
@ -316,6 +320,9 @@ func TestCopyConfig(t *testing.T) {
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
*fn = fakeWrapperFunc
},
func(fn *transport.WrapperFunc, f fuzz.Continue) {
*fn = fakeWrapperFunc
},
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
serializer := &fakeNegotiatedSerializer{}
f.Fuzz(serializer)

View File

@ -103,14 +103,15 @@ func (c *Config) TransportConfig() (*transport.Config, error) {
if err != nil {
return nil, err
}
wt := conf.WrapTransport
if wt != nil {
conf.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
return provider.WrapTransport(wt(rt))
}
} else {
conf.WrapTransport = provider.WrapTransport
}
conf.Wrap(provider.WrapTransport)
}
return conf, nil
}
// Wrap adds a transport middleware function that will give the caller
// an opportunity to wrap the underlying http.RoundTripper prior to the
// first API call being made. The provided function is invoked after any
// existing transport wrappers are invoked.
func (c *Config) Wrap(fn transport.WrapperFunc) {
c.WrapTransport = transport.Wrappers(c.WrapTransport, fn)
}

View File

@ -57,7 +57,10 @@ type Config struct {
// from TLSClientConfig, Transport, or http.DefaultTransport). The
// config may layer other RoundTrippers on top of the returned
// RoundTripper.
WrapTransport func(rt http.RoundTripper) http.RoundTripper
//
// A future release will change this field to an array. Use config.Wrap()
// instead of setting this value directly.
WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(ctx context.Context, network, address string) (net.Conn, error)
@ -98,6 +101,14 @@ func (c *Config) HasCertCallback() bool {
return c.TLS.GetCert != nil
}
// Wrap adds a transport middleware function that will give the caller
// an opportunity to wrap the underlying http.RoundTripper prior to the
// first API call being made. The provided function is invoked after any
// existing transport wrappers are invoked.
func (c *Config) Wrap(fn WrapperFunc) {
c.WrapTransport = Wrappers(c.WrapTransport, fn)
}
// TLSConfig holds the information needed to set up a TLS transport.
type TLSConfig struct {
CAFile string // Path of the PEM-encoded server trusted root certificates.

View File

@ -167,3 +167,32 @@ func rootCertPool(caData []byte) *x509.CertPool {
certPool.AppendCertsFromPEM(caData)
return certPool
}
// WrapperFunc wraps an http.RoundTripper when a new transport
// is created for a client, allowing per connection behavior
// to be injected.
type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
// Wrappers accepts any number of wrappers and returns a wrapper
// function that is the equivalent of calling each of them in order. Nil
// values are ignored, which makes this function convenient for incrementally
// wrapping a function.
func Wrappers(fns ...WrapperFunc) WrapperFunc {
if len(fns) == 0 {
return nil
}
// optimize the common case of wrapping a possibly nil transport wrapper
// with an additional wrapper
if len(fns) == 2 && fns[0] == nil {
return fns[1]
}
return func(rt http.RoundTripper) http.RoundTripper {
base := rt
for _, fn := range fns {
if fn != nil {
base = fn(base)
}
}
return base
}
}

View File

@ -310,3 +310,90 @@ func TestNew(t *testing.T) {
})
}
}
type fakeRoundTripper struct {
Req *http.Request
Resp *http.Response
Err error
}
func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.Req = req
return rt.Resp, rt.Err
}
type chainRoundTripper struct {
rt http.RoundTripper
value string
}
func testChain(value string) WrapperFunc {
return func(rt http.RoundTripper) http.RoundTripper {
return &chainRoundTripper{rt: rt, value: value}
}
}
func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rt.rt.RoundTrip(req)
if resp != nil {
if resp.Header == nil {
resp.Header = make(http.Header)
}
resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
}
return resp, err
}
func TestWrappers(t *testing.T) {
resp1 := &http.Response{}
wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
return &fakeRoundTripper{Resp: resp1}
}
resp2 := &http.Response{}
wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
return &fakeRoundTripper{Resp: resp2}
}
tests := []struct {
name string
fns []WrapperFunc
wantNil bool
want func(*http.Response) bool
}{
{fns: []WrapperFunc{}, wantNil: true},
{fns: []WrapperFunc{nil, nil}, wantNil: true},
{fns: []WrapperFunc{nil}, wantNil: false},
{fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
{fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
{fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
{fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
{fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
{fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
{fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Wrappers(tt.fns...)
if got == nil != tt.wantNil {
t.Errorf("Wrappers() = %v", got)
return
}
if got == nil {
return
}
rt := &fakeRoundTripper{Resp: &http.Response{}}
nested := got(rt)
req := &http.Request{}
resp, _ := nested.RoundTrip(req)
if tt.want != nil && !tt.want(resp) {
t.Errorf("unexpected response: %#v", resp)
}
})
}
}