mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-18 16:21:13 +00:00
Make wrapping a client transport more pleasant
Properly wrapping a transport can be tricky. Make the normal case (adding a non-nil transport wrapper to a config) easier with a helper. Also enforce a rough ordering, which in the future we can use to simplify the WrapTransport mechanism down into an array of functions we execute in order and avoid wrapping altogether.
This commit is contained in:
parent
09890b6c48
commit
1f590e697e
@ -25,7 +25,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/googleapis/gnostic/OpenAPIv2"
|
openapi_v2 "github.com/googleapis/gnostic/OpenAPIv2"
|
||||||
"k8s.io/klog"
|
"k8s.io/klog"
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
@ -266,13 +266,10 @@ func NewCachedDiscoveryClientForConfig(config *restclient.Config, discoveryCache
|
|||||||
if len(httpCacheDir) > 0 {
|
if len(httpCacheDir) > 0 {
|
||||||
// update the given restconfig with a custom roundtripper that
|
// update the given restconfig with a custom roundtripper that
|
||||||
// understands how to handle cache responses.
|
// understands how to handle cache responses.
|
||||||
wt := config.WrapTransport
|
config = restclient.CopyConfig(config)
|
||||||
config.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
|
config.Wrap(func(rt http.RoundTripper) http.RoundTripper {
|
||||||
if wt != nil {
|
|
||||||
rt = wt(rt)
|
|
||||||
}
|
|
||||||
return newCacheRoundTripper(httpCacheDir, rt)
|
return newCacheRoundTripper(httpCacheDir, rt)
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
discoveryClient, err := NewDiscoveryClientForConfig(config)
|
discoveryClient, err := NewDiscoveryClientForConfig(config)
|
||||||
|
@ -32,7 +32,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1"
|
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/runtime"
|
"k8s.io/apimachinery/pkg/runtime"
|
||||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||||
"k8s.io/apimachinery/pkg/runtime/serializer"
|
"k8s.io/apimachinery/pkg/runtime/serializer"
|
||||||
@ -172,13 +172,9 @@ type credentials struct {
|
|||||||
// UpdateTransportConfig updates the transport.Config to use credentials
|
// UpdateTransportConfig updates the transport.Config to use credentials
|
||||||
// returned by the plugin.
|
// returned by the plugin.
|
||||||
func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
|
func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
|
||||||
wt := c.WrapTransport
|
c.Wrap(func(rt http.RoundTripper) http.RoundTripper {
|
||||||
c.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
|
|
||||||
if wt != nil {
|
|
||||||
rt = wt(rt)
|
|
||||||
}
|
|
||||||
return &roundTripper{a, rt}
|
return &roundTripper{a, rt}
|
||||||
}
|
})
|
||||||
|
|
||||||
if c.TLS.GetCert != nil {
|
if c.TLS.GetCert != nil {
|
||||||
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
|
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
|
||||||
|
@ -34,6 +34,7 @@ import (
|
|||||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||||
"k8s.io/client-go/pkg/version"
|
"k8s.io/client-go/pkg/version"
|
||||||
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
|
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
|
||||||
|
"k8s.io/client-go/transport"
|
||||||
certutil "k8s.io/client-go/util/cert"
|
certutil "k8s.io/client-go/util/cert"
|
||||||
"k8s.io/client-go/util/flowcontrol"
|
"k8s.io/client-go/util/flowcontrol"
|
||||||
"k8s.io/klog"
|
"k8s.io/klog"
|
||||||
@ -95,13 +96,16 @@ type Config struct {
|
|||||||
|
|
||||||
// Transport may be used for custom HTTP behavior. This attribute may not
|
// Transport may be used for custom HTTP behavior. This attribute may not
|
||||||
// be specified with the TLS client certificate options. Use WrapTransport
|
// be specified with the TLS client certificate options. Use WrapTransport
|
||||||
// for most client level operations.
|
// to provide additional per-server middleware behavior.
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
// WrapTransport will be invoked for custom HTTP behavior after the underlying
|
// WrapTransport will be invoked for custom HTTP behavior after the underlying
|
||||||
// transport is initialized (either the transport created from TLSClientConfig,
|
// transport is initialized (either the transport created from TLSClientConfig,
|
||||||
// Transport, or http.DefaultTransport). The config may layer other RoundTrippers
|
// Transport, or http.DefaultTransport). The config may layer other RoundTrippers
|
||||||
// on top of the returned RoundTripper.
|
// on top of the returned RoundTripper.
|
||||||
WrapTransport func(rt http.RoundTripper) http.RoundTripper
|
//
|
||||||
|
// A future release will change this field to an array. Use config.Wrap()
|
||||||
|
// instead of setting this value directly.
|
||||||
|
WrapTransport transport.WrapperFunc
|
||||||
|
|
||||||
// QPS indicates the maximum QPS to the master from this client.
|
// QPS indicates the maximum QPS to the master from this client.
|
||||||
// If it's zero, the created RESTClient will use DefaultQPS: 5
|
// If it's zero, the created RESTClient will use DefaultQPS: 5
|
||||||
|
@ -27,12 +27,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"k8s.io/api/core/v1"
|
v1 "k8s.io/api/core/v1"
|
||||||
"k8s.io/apimachinery/pkg/runtime"
|
"k8s.io/apimachinery/pkg/runtime"
|
||||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||||
"k8s.io/apimachinery/pkg/util/diff"
|
"k8s.io/apimachinery/pkg/util/diff"
|
||||||
"k8s.io/client-go/kubernetes/scheme"
|
"k8s.io/client-go/kubernetes/scheme"
|
||||||
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
|
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
|
||||||
|
"k8s.io/client-go/transport"
|
||||||
"k8s.io/client-go/util/flowcontrol"
|
"k8s.io/client-go/util/flowcontrol"
|
||||||
|
|
||||||
fuzz "github.com/google/gofuzz"
|
fuzz "github.com/google/gofuzz"
|
||||||
@ -236,6 +237,9 @@ func TestAnonymousConfig(t *testing.T) {
|
|||||||
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
|
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
|
||||||
*fn = fakeWrapperFunc
|
*fn = fakeWrapperFunc
|
||||||
},
|
},
|
||||||
|
func(fn *transport.WrapperFunc, f fuzz.Continue) {
|
||||||
|
*fn = fakeWrapperFunc
|
||||||
|
},
|
||||||
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
|
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
|
||||||
serializer := &fakeNegotiatedSerializer{}
|
serializer := &fakeNegotiatedSerializer{}
|
||||||
f.Fuzz(serializer)
|
f.Fuzz(serializer)
|
||||||
@ -316,6 +320,9 @@ func TestCopyConfig(t *testing.T) {
|
|||||||
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
|
func(fn *func(http.RoundTripper) http.RoundTripper, f fuzz.Continue) {
|
||||||
*fn = fakeWrapperFunc
|
*fn = fakeWrapperFunc
|
||||||
},
|
},
|
||||||
|
func(fn *transport.WrapperFunc, f fuzz.Continue) {
|
||||||
|
*fn = fakeWrapperFunc
|
||||||
|
},
|
||||||
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
|
func(r *runtime.NegotiatedSerializer, f fuzz.Continue) {
|
||||||
serializer := &fakeNegotiatedSerializer{}
|
serializer := &fakeNegotiatedSerializer{}
|
||||||
f.Fuzz(serializer)
|
f.Fuzz(serializer)
|
||||||
|
@ -103,14 +103,15 @@ func (c *Config) TransportConfig() (*transport.Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
wt := conf.WrapTransport
|
conf.Wrap(provider.WrapTransport)
|
||||||
if wt != nil {
|
|
||||||
conf.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
|
|
||||||
return provider.WrapTransport(wt(rt))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
conf.WrapTransport = provider.WrapTransport
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap adds a transport middleware function that will give the caller
|
||||||
|
// an opportunity to wrap the underlying http.RoundTripper prior to the
|
||||||
|
// first API call being made. The provided function is invoked after any
|
||||||
|
// existing transport wrappers are invoked.
|
||||||
|
func (c *Config) Wrap(fn transport.WrapperFunc) {
|
||||||
|
c.WrapTransport = transport.Wrappers(c.WrapTransport, fn)
|
||||||
|
}
|
||||||
|
@ -57,7 +57,10 @@ type Config struct {
|
|||||||
// from TLSClientConfig, Transport, or http.DefaultTransport). The
|
// from TLSClientConfig, Transport, or http.DefaultTransport). The
|
||||||
// config may layer other RoundTrippers on top of the returned
|
// config may layer other RoundTrippers on top of the returned
|
||||||
// RoundTripper.
|
// RoundTripper.
|
||||||
WrapTransport func(rt http.RoundTripper) http.RoundTripper
|
//
|
||||||
|
// A future release will change this field to an array. Use config.Wrap()
|
||||||
|
// instead of setting this value directly.
|
||||||
|
WrapTransport WrapperFunc
|
||||||
|
|
||||||
// Dial specifies the dial function for creating unencrypted TCP connections.
|
// Dial specifies the dial function for creating unencrypted TCP connections.
|
||||||
Dial func(ctx context.Context, network, address string) (net.Conn, error)
|
Dial func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
@ -98,6 +101,14 @@ func (c *Config) HasCertCallback() bool {
|
|||||||
return c.TLS.GetCert != nil
|
return c.TLS.GetCert != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap adds a transport middleware function that will give the caller
|
||||||
|
// an opportunity to wrap the underlying http.RoundTripper prior to the
|
||||||
|
// first API call being made. The provided function is invoked after any
|
||||||
|
// existing transport wrappers are invoked.
|
||||||
|
func (c *Config) Wrap(fn WrapperFunc) {
|
||||||
|
c.WrapTransport = Wrappers(c.WrapTransport, fn)
|
||||||
|
}
|
||||||
|
|
||||||
// TLSConfig holds the information needed to set up a TLS transport.
|
// TLSConfig holds the information needed to set up a TLS transport.
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
CAFile string // Path of the PEM-encoded server trusted root certificates.
|
CAFile string // Path of the PEM-encoded server trusted root certificates.
|
||||||
|
@ -167,3 +167,32 @@ func rootCertPool(caData []byte) *x509.CertPool {
|
|||||||
certPool.AppendCertsFromPEM(caData)
|
certPool.AppendCertsFromPEM(caData)
|
||||||
return certPool
|
return certPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WrapperFunc wraps an http.RoundTripper when a new transport
|
||||||
|
// is created for a client, allowing per connection behavior
|
||||||
|
// to be injected.
|
||||||
|
type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
|
||||||
|
|
||||||
|
// Wrappers accepts any number of wrappers and returns a wrapper
|
||||||
|
// function that is the equivalent of calling each of them in order. Nil
|
||||||
|
// values are ignored, which makes this function convenient for incrementally
|
||||||
|
// wrapping a function.
|
||||||
|
func Wrappers(fns ...WrapperFunc) WrapperFunc {
|
||||||
|
if len(fns) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// optimize the common case of wrapping a possibly nil transport wrapper
|
||||||
|
// with an additional wrapper
|
||||||
|
if len(fns) == 2 && fns[0] == nil {
|
||||||
|
return fns[1]
|
||||||
|
}
|
||||||
|
return func(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
base := rt
|
||||||
|
for _, fn := range fns {
|
||||||
|
if fn != nil {
|
||||||
|
base = fn(base)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -310,3 +310,90 @@ func TestNew(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeRoundTripper struct {
|
||||||
|
Req *http.Request
|
||||||
|
Resp *http.Response
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
rt.Req = req
|
||||||
|
return rt.Resp, rt.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
type chainRoundTripper struct {
|
||||||
|
rt http.RoundTripper
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChain(value string) WrapperFunc {
|
||||||
|
return func(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
return &chainRoundTripper{rt: rt, value: value}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := rt.rt.RoundTrip(req)
|
||||||
|
if resp != nil {
|
||||||
|
if resp.Header == nil {
|
||||||
|
resp.Header = make(http.Header)
|
||||||
|
}
|
||||||
|
resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrappers(t *testing.T) {
|
||||||
|
resp1 := &http.Response{}
|
||||||
|
wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
return &fakeRoundTripper{Resp: resp1}
|
||||||
|
}
|
||||||
|
resp2 := &http.Response{}
|
||||||
|
wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
return &fakeRoundTripper{Resp: resp2}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fns []WrapperFunc
|
||||||
|
wantNil bool
|
||||||
|
want func(*http.Response) bool
|
||||||
|
}{
|
||||||
|
{fns: []WrapperFunc{}, wantNil: true},
|
||||||
|
{fns: []WrapperFunc{nil, nil}, wantNil: true},
|
||||||
|
{fns: []WrapperFunc{nil}, wantNil: false},
|
||||||
|
|
||||||
|
{fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
|
||||||
|
{fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
|
||||||
|
{fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
|
||||||
|
{fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
|
||||||
|
{fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
|
||||||
|
{fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
|
||||||
|
|
||||||
|
{fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
|
||||||
|
{fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
|
||||||
|
{fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
|
||||||
|
{fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := Wrappers(tt.fns...)
|
||||||
|
if got == nil != tt.wantNil {
|
||||||
|
t.Errorf("Wrappers() = %v", got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := &fakeRoundTripper{Resp: &http.Response{}}
|
||||||
|
nested := got(rt)
|
||||||
|
req := &http.Request{}
|
||||||
|
resp, _ := nested.RoundTrip(req)
|
||||||
|
if tt.want != nil && !tt.want(resp) {
|
||||||
|
t.Errorf("unexpected response: %#v", resp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user