Compare commits

...

10 Commits

Author SHA1 Message Date
Kubernetes Publisher
f481aa0f41 Merge pull request #112394 from tkashem/automated-cherry-pick-of-#112067-upstream-release-1.22
Automated cherry pick of #112067: client-go: make retry in Request thread safe

Kubernetes-commit: 490d07d58cec093a3c91699b6659f53dfc2fc97c
2022-10-05 07:57:03 +00:00
Abu Kashem
1ca54d7d5b client-go: make retry in Request thread safe
Kubernetes-commit: 05f1bea8f7864df91cbc5efe8730bc13298cffcd
2022-03-29 13:09:26 -04:00
Kubernetes Publisher
02baeca50e Merge pull request #112339 from enj/automated-cherry-pick-of-#112017-upstream-release-1.22
Automated cherry pick of #112017: exec auth: support TLS config caching

Kubernetes-commit: c400c8bb1229d739ec6119a8b79a1f0852ab5905
2022-09-09 00:25:24 -07:00
Monis Khan
2d6ac8e675 exec auth: support TLS config caching
This change updates the transport.Config .Dial and .TLS.GetCert fields
to use a struct wrapper.  This indirection via a pointer allows the
functions to be compared and thus makes them valid to use as map keys.
This change is then leveraged by the existing global exec auth and TLS
config caches to return the same authenticator and TLS config even when
distinct but identical rest configs were used to create distinct
clientsets.

Signed-off-by: Monis Khan <mok@microsoft.com>

Kubernetes-commit: b1b8578ae937b77e64a1ce75c23f0dacb6d45064
2022-08-24 16:04:19 +00:00
Monis Khan
6f28fe1a58 client-go exec: make sure round tripper can be unwrapped
Signed-off-by: Monis Khan <mok@vmware.com>

Kubernetes-commit: 02b0b8d4c8daa2030672de6d6dbc86b009f5db32
2021-10-29 17:59:52 -04:00
Kubernetes Publisher
48190f8df6 Merge pull request #111272 from Abirdcfly/automated-cherry-pick-of-#111235-upstream-release-1.22
Automated cherry pick of #111235: fix a possible panic because of taking the address of nil

Kubernetes-commit: cdfdbd03d6b159e8a765b7872363dce068dd785c
2022-08-11 14:41:06 +00:00
Abirdcfly
595613289c fix a possible panic because of taking the address of nil
Signed-off-by: Abirdcfly <fp544037857@gmail.com>

Kubernetes-commit: b4431bcac3530dab757d345c2a177fb57015e7e8
2022-07-19 10:39:08 +08:00
Kubernetes Publisher
b626b5b2e9 sync: initially remove files BUILD */BUILD BUILD.bazel */BUILD.bazel Gopkg.toml */.gitattributes 2022-03-25 17:21:24 +00:00
Kubernetes Publisher
be23bb76ca Merge pull request #107637 from gjkim42/cherry-pick-of-#106473-upstream-release-1.22
Update k/utils to v0.0.0-20211116205334-6203023598ed

Kubernetes-commit: b2604798aaa65dcecf94a67ede54c2cf0c3c3e63
2022-01-24 17:29:54 +00:00
Gunju Kim
561c55c5a6 Update k/utils to v0.0.0-20211116205334-6203023598ed
Kubernetes-commit: 723c1b334ecf48c43fb9ddf6abf4b45406f1110b
2022-01-19 20:47:53 +09:00
13 changed files with 462 additions and 48 deletions

10
go.mod
View File

@@ -30,15 +30,15 @@ require (
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac
google.golang.org/protobuf v1.26.0
k8s.io/api v0.0.0-20220115132223-0b924f4f7b63
k8s.io/apimachinery v0.0.0-20220115131919-1ea05acba7c0
k8s.io/api v0.0.0-20220222193716-39d32415443c
k8s.io/apimachinery v0.0.0-20220919093433-dcdd25a23be9
k8s.io/klog/v2 v2.9.0
k8s.io/utils v0.0.0-20210819203725-bdf08cb9a70a
k8s.io/utils v0.0.0-20211116205334-6203023598ed
sigs.k8s.io/structured-merge-diff/v4 v4.2.1
sigs.k8s.io/yaml v1.2.0
)
replace (
k8s.io/api => k8s.io/api v0.0.0-20220115132223-0b924f4f7b63
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20220115131919-1ea05acba7c0
k8s.io/api => k8s.io/api v0.0.0-20220222193716-39d32415443c
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20220919093433-dcdd25a23be9
)

12
go.sum
View File

@@ -444,18 +444,18 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
k8s.io/api v0.0.0-20220115132223-0b924f4f7b63 h1:ARkTMQaUQg/XAxZGClgPO8jLBGe+0vyHOzkv1klocn4=
k8s.io/api v0.0.0-20220115132223-0b924f4f7b63/go.mod h1:VZ3awwdNsOtdAUAYPvruiFtrH/OD4gT5qmZHaQmXgxY=
k8s.io/apimachinery v0.0.0-20220115131919-1ea05acba7c0 h1:SoF7jCfsIaGo2y+PiJGjFnRMeGh4zO4aGEYc0IVGAu4=
k8s.io/apimachinery v0.0.0-20220115131919-1ea05acba7c0/go.mod h1:ZvVLP5iLhwVFg2Yx9Gh5W0um0DUauExbRhe+2Z8I1EU=
k8s.io/api v0.0.0-20220222193716-39d32415443c h1:7uKzkkSldmiACEmFGRb0+RdhsZShf8z4pC+523YGlRE=
k8s.io/api v0.0.0-20220222193716-39d32415443c/go.mod h1:VZ3awwdNsOtdAUAYPvruiFtrH/OD4gT5qmZHaQmXgxY=
k8s.io/apimachinery v0.0.0-20220919093433-dcdd25a23be9 h1:Ys+9Q+7lj+x4whF0Ckupf3oj0SA9pmK8K2vZsZKVo/4=
k8s.io/apimachinery v0.0.0-20220919093433-dcdd25a23be9/go.mod h1:ZvVLP5iLhwVFg2Yx9Gh5W0um0DUauExbRhe+2Z8I1EU=
k8s.io/gengo v0.0.0-20200413195148-3a45101e95ac/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0=
k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE=
k8s.io/klog/v2 v2.9.0 h1:D7HV+n1V57XeZ0m6tdRkfknthUaM06VFbWldOFh8kzM=
k8s.io/klog/v2 v2.9.0/go.mod h1:hy9LJ/NvuK+iVyP4Ehqva4HxZG/oXyIS3n3Jmire4Ec=
k8s.io/kube-openapi v0.0.0-20211109043538-20434351676c h1:jvamsI1tn9V0S8jicyX82qaFC0H/NKxv2e5mbqsgR80=
k8s.io/kube-openapi v0.0.0-20211109043538-20434351676c/go.mod h1:vHXdDvt9+2spS2Rx9ql3I8tycm3H9FDfdUoIuKCefvw=
k8s.io/utils v0.0.0-20210819203725-bdf08cb9a70a h1:8dYfu/Fc9Gz2rNJKB9IQRGgQOh2clmRzNIPPY1xLY5g=
k8s.io/utils v0.0.0-20210819203725-bdf08cb9a70a/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA=
k8s.io/utils v0.0.0-20211116205334-6203023598ed h1:ck1fRPWPJWsMd8ZRFsWc6mh/zHp5fZ/shhbrgPUxDAE=
k8s.io/utils v0.0.0-20211116205334-6203023598ed/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=

View File

@@ -1 +0,0 @@
base.go export-subst

View File

@@ -39,6 +39,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/clock"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/client-go/pkg/apis/clientauthentication"
"k8s.io/client-go/pkg/apis/clientauthentication/install"
clientauthenticationv1 "k8s.io/client-go/pkg/apis/clientauthentication/v1"
@@ -200,14 +201,18 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC
now: time.Now,
environ: os.Environ,
defaultDialer: defaultDialer,
connTracker: connTracker,
connTracker: connTracker,
}
for _, env := range config.Env {
a.env = append(a.env, env.Name+"="+env.Value)
}
// these functions are made comparable and stored in the cache so that repeated clientset
// construction with the same rest.Config results in a single TLS cache and Authenticator
a.getCert = &transport.GetCertHolder{GetCert: a.cert}
a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext}
return c.put(key, a), nil
}
@@ -262,8 +267,6 @@ type Authenticator struct {
now func() time.Time
environ func() []string
// defaultDialer is used for clients which don't specify a custom dialer
defaultDialer *connrotation.Dialer
// connTracker tracks all connections opened that we need to close when rotating a client certificate
connTracker *connrotation.ConnectionTracker
@@ -274,6 +277,12 @@ type Authenticator struct {
mu sync.Mutex
cachedCreds *credentials
exp time.Time
// getCert makes Authenticator.cert comparable to support TLS config caching
getCert *transport.GetCertHolder
// dial is used for clients which do not specify a custom dialer
// it is comparable to support TLS config caching
dial *transport.DialHolder
}
type credentials struct {
@@ -301,26 +310,34 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
if c.TLS.GetCert != nil {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
}
c.TLS.GetCert = a.cert
c.TLS.GetCert = a.getCert.GetCert
c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching
var d *connrotation.Dialer
if c.Dial != nil {
// if c has a custom dialer, we have to wrap it
d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
// TLS config caching is not supported for this config
d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
c.Dial = d.DialContext
c.DialHolder = nil
} else {
d = a.defaultDialer
c.Dial = a.dial.Dial
c.DialHolder = a.dial // comparable for TLS config caching
}
c.Dial = d.DialContext
return nil
}
var _ utilnet.RoundTripperWrapper = &roundTripper{}
type roundTripper struct {
a *Authenticator
base http.RoundTripper
}
func (r *roundTripper) WrappedRoundTripper() http.RoundTripper {
return r.base
}
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// If a user has already set credentials, use that. This makes commands like
// "kubectl get --token (token) pods" work.

View File

@@ -0,0 +1,106 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package exec_test // separate package to prevent circular import
import (
"context"
"testing"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilnet "k8s.io/apimachinery/pkg/util/net"
clientset "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)
// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used.
//
// In particular, when:
// - multiple identical rest configs exist as distinct objects, and
// - these rest configs use exec auth, and
// - these rest configs are used to create distinct clientsets, then
//
// the underlying TLS config is shared between those clientsets.
func TestExecTLSCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
t.Cleanup(cancel)
config1 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client1 := clientset.NewForConfigOrDie(config1)
config2 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client2 := clientset.NewForConfigOrDie(config2)
config3 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
Args: []string{"make this exec auth different"},
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client3 := clientset.NewForConfigOrDie(config3)
_, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
_, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
_, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{})
rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport
rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport
rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport
tlsConfig1, err := utilnet.TLSClientConfig(rt1)
if err != nil {
t.Fatal(err)
}
tlsConfig2, err := utilnet.TLSClientConfig(rt2)
if err != nil {
t.Fatal(err)
}
tlsConfig3, err := utilnet.TLSClientConfig(rt3)
if err != nil {
t.Fatal(err)
}
if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil {
t.Fatal("expected non-nil TLS configs")
}
if tlsConfig1 != tlsConfig2 {
t.Fatal("expected the same TLS config for matching exec config via rest config")
}
if tlsConfig1 == tlsConfig3 {
t.Fatal("expected different TLS config for non-matching exec config via rest config")
}
}

View File

@@ -82,6 +82,12 @@ func (r *RequestConstructionError) Error() string {
var noBackoff = &NoBackoff{}
type requestRetryFunc func(maxRetries int) WithRetry
func defaultRequestRetryFn(maxRetries int) WithRetry {
return &withRetry{maxRetries: maxRetries}
}
// Request allows for building up a request to a server in a chained fashion.
// Any errors are stored until the end of your call, so you only have to
// check once.
@@ -93,6 +99,7 @@ type Request struct {
rateLimiter flowcontrol.RateLimiter
backoff BackoffManager
timeout time.Duration
maxRetries int
// generic components accessible via method setters
verb string
@@ -109,9 +116,10 @@ type Request struct {
subresource string
// output
err error
body io.Reader
retry WithRetry
err error
body io.Reader
retryFn requestRetryFunc
}
// NewRequest creates a new request helper object for accessing runtime.Objects on a server.
@@ -142,7 +150,8 @@ func NewRequest(c *RESTClient) *Request {
backoff: backoff,
timeout: timeout,
pathPrefix: pathPrefix,
retry: &withRetry{maxRetries: 10},
maxRetries: 10,
retryFn: defaultRequestRetryFn,
warningHandler: c.warningHandler,
}
@@ -408,7 +417,10 @@ func (r *Request) Timeout(d time.Duration) *Request {
// function is specifically called with a different value.
// A zero maxRetries prevent it from doing retires and return an error immediately.
func (r *Request) MaxRetries(maxRetries int) *Request {
r.retry.SetMaxRetries(maxRetries)
if maxRetries < 0 {
maxRetries = 0
}
r.maxRetries = maxRetries
return r
}
@@ -688,8 +700,10 @@ func (r *Request) Watch(ctx context.Context) (watch.Interface, error) {
}
return false
}
var retryAfter *RetryAfter
url := r.URL().String()
withRetry := r.retryFn(r.maxRetries)
for {
req, err := r.newHTTPRequest(ctx)
if err != nil {
@@ -724,9 +738,9 @@ func (r *Request) Watch(ctx context.Context) (watch.Interface, error) {
defer readAndCloseResponseBody(resp)
var retry bool
retryAfter, retry = r.retry.NextRetry(req, resp, err, isErrRetryableFunc)
retryAfter, retry = withRetry.NextRetry(req, resp, err, isErrRetryableFunc)
if retry {
err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
err := withRetry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
if err == nil {
return false, nil
}
@@ -817,6 +831,7 @@ func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) {
}
var retryAfter *RetryAfter
withRetry := r.retryFn(r.maxRetries)
url := r.URL().String()
for {
req, err := r.newHTTPRequest(ctx)
@@ -862,9 +877,9 @@ func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) {
defer resp.Body.Close()
var retry bool
retryAfter, retry = r.retry.NextRetry(req, resp, err, neverRetryError)
retryAfter, retry = withRetry.NextRetry(req, resp, err, neverRetryError)
if retry {
err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
err := withRetry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
if err == nil {
return false, nil
}
@@ -961,6 +976,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp
// Right now we make about ten retry attempts if we get a Retry-After response.
var retryAfter *RetryAfter
withRetry := r.retryFn(r.maxRetries)
for {
req, err := r.newHTTPRequest(ctx)
if err != nil {
@@ -997,7 +1013,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp
}
var retry bool
retryAfter, retry = r.retry.NextRetry(req, resp, err, func(req *http.Request, err error) bool {
retryAfter, retry = withRetry.NextRetry(req, resp, err, func(req *http.Request, err error) bool {
// "Connection reset by peer" or "apiserver is shutting down" are usually a transient errors.
// Thus in case of "GET" operations, we simply retry it.
// We are not automatically retrying "write" operations, as they are not idempotent.
@@ -1011,7 +1027,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp
return false
})
if retry {
err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body)
err := withRetry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body)
if err == nil {
return false
}

View File

@@ -32,6 +32,7 @@ import (
"reflect"
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"
@@ -1193,7 +1194,8 @@ func TestRequestWatch(t *testing.T) {
c.Client = client
}
testCase.Request.backoff = &noSleepBackOff{}
testCase.Request.retry = &withRetry{maxRetries: testCase.maxRetries}
testCase.Request.maxRetries = testCase.maxRetries
testCase.Request.retryFn = defaultRequestRetryFn
watch, err := testCase.Request.Watch(context.Background())
@@ -1406,7 +1408,8 @@ func TestRequestStream(t *testing.T) {
c.Client = client
}
testCase.Request.backoff = &noSleepBackOff{}
testCase.Request.retry = &withRetry{maxRetries: testCase.maxRetries}
testCase.Request.maxRetries = testCase.maxRetries
testCase.Request.retryFn = defaultRequestRetryFn
body, err := testCase.Request.Stream(context.Background())
@@ -1495,7 +1498,7 @@ func TestRequestDo(t *testing.T) {
}
for i, testCase := range testCases {
testCase.Request.backoff = &NoBackoff{}
testCase.Request.retry = &withRetry{}
testCase.Request.retryFn = defaultRequestRetryFn
body, err := testCase.Request.Do(context.Background()).Raw()
hasErr := err != nil
if hasErr != testCase.Err {
@@ -1658,8 +1661,9 @@ func TestConnectionResetByPeerIsRetried(t *testing.T) {
return nil, &net.OpError{Err: syscall.ECONNRESET}
}),
},
backoff: backoff,
retry: &withRetry{maxRetries: 10},
backoff: backoff,
maxRetries: 10,
retryFn: defaultRequestRetryFn,
}
// We expect two retries of "connection reset by peer" and the success.
_, err := req.Do(context.Background()).Raw()
@@ -2730,8 +2734,9 @@ func TestRequestWithRetry(t *testing.T) {
c: &RESTClient{
Client: client,
},
backoff: &noSleepBackOff{},
retry: &withRetry{maxRetries: 1},
backoff: &noSleepBackOff{},
maxRetries: 1,
retryFn: defaultRequestRetryFn,
}
var transformFuncInvoked int
@@ -2921,8 +2926,9 @@ func testRequestWithRetry(t *testing.T, key string, doFunc func(ctx context.Cont
content: defaultContentConfig(),
Client: client,
},
backoff: &noSleepBackOff{},
retry: &withRetry{maxRetries: test.maxRetries},
backoff: &noSleepBackOff{},
maxRetries: test.maxRetries,
retryFn: defaultRequestRetryFn,
}
doFunc(context.Background(), req)
@@ -2944,3 +2950,50 @@ func testRequestWithRetry(t *testing.T, key string, doFunc func(ctx context.Cont
})
}
}
func TestRequestConcurrencyWithRetry(t *testing.T) {
var attempts int32
client := clientForFunc(func(req *http.Request) (*http.Response, error) {
defer func() {
atomic.AddInt32(&attempts, 1)
}()
// always send a retry-after response
return &http.Response{
StatusCode: http.StatusInternalServerError,
Header: http.Header{"Retry-After": []string{"1"}},
}, nil
})
req := &Request{
verb: "POST",
c: &RESTClient{
content: defaultContentConfig(),
Client: client,
},
backoff: &noSleepBackOff{},
maxRetries: 9, // 10 attempts in total, including the first
retryFn: defaultRequestRetryFn,
}
concurrency := 20
wg := sync.WaitGroup{}
wg.Add(concurrency)
startCh := make(chan struct{})
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
<-startCh
req.Do(context.Background())
}()
}
close(startCh)
wg.Wait()
// we expect (concurrency*req.maxRetries+1) attempts to be recorded
expected := concurrency * (req.maxRetries + 1)
if atomic.LoadInt32(&attempts) != int32(expected) {
t.Errorf("Expected attempts: %d, but got: %d", expected, attempts)
}
}

View File

@@ -51,10 +51,10 @@ func (a *PromptingAuthLoader) LoadAuth(path string) (*clientauth.Info, error) {
// Prompt for user/pass and write a file if none exists.
if _, err := os.Stat(path); os.IsNotExist(err) {
authPtr, err := a.Prompt()
auth := *authPtr
if err != nil {
return nil, err
}
auth := *authPtr
data, err := json.Marshal(auth)
if err != nil {
return &auth, err

View File

@@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"fmt"
"net"
"net/http"
@@ -50,6 +51,9 @@ type tlsCacheKey struct {
serverName string
nextProtos string
disableCompression bool
// these functions are wrapped to allow them to be used as map keys
getCert *GetCertHolder
dial *DialHolder
}
func (t tlsCacheKey) String() string {
@@ -57,7 +61,8 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 {
keyText = "<redacted>"
}
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression)
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
}
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@@ -87,8 +92,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
return http.DefaultTransport, nil
}
dial := config.Dial
if dial == nil {
var dial func(ctx context.Context, network, address string) (net.Conn, error)
if config.Dial != nil {
dial = config.Dial
} else {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
@@ -133,10 +140,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
return tlsCacheKey{}, false, err
}
if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil {
if c.Proxy != nil {
// cannot determine equality for functions
return tlsCacheKey{}, false, nil
}
if c.Dial != nil && c.DialHolder == nil {
// cannot determine equality for dial function that doesn't have non-nil DialHolder set as well
return tlsCacheKey{}, false, nil
}
if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil {
// cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well
return tlsCacheKey{}, false, nil
}
k := tlsCacheKey{
insecure: c.TLS.Insecure,
@@ -144,6 +159,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
serverName: c.TLS.ServerName,
nextProtos: strings.Join(c.TLS.NextProtos, ","),
disableCompression: c.DisableCompression,
getCert: c.TLS.GetCertHolder,
dial: c.DialHolder,
}
if c.TLS.ReloadTLSFiles {

View File

@@ -21,6 +21,7 @@ import (
"crypto/tls"
"net"
"net/http"
"net/url"
"testing"
)
@@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
if keyA != (tlsCacheKey{}) {
t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
}
}
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
getCert := func() (*tls.Certificate, error) { return nil, nil }
getCertHolder := &GetCertHolder{GetCert: getCert}
uniqueConfigurations := map[string]*Config{
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
@@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) {
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
"getCert3": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
GetCertHolder: getCertHolder,
},
},
"getCert1, key 2": {
TLS: TLSConfig{
KeyData: []byte{2},

View File

@@ -68,7 +68,11 @@ type Config struct {
WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections.
// If specified, this transport will be non-cacheable unless DialHolder is also set.
Dial func(ctx context.Context, network, address string) (net.Conn, error)
// DialHolder can be populated to make transport configs cacheable.
// If specified, DialHolder.Dial must be equal to Dial.
DialHolder *DialHolder
// Proxy is the proxy func to be used for all requests made by this
// transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy
@@ -78,6 +82,11 @@ type Config struct {
Proxy func(*http.Request) (*url.URL, error)
}
// DialHolder is used to make the wrapped function comparable so that it can be used as a map key.
type DialHolder struct {
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options
type ImpersonationConfig struct {
// UserName matches user.Info.GetName()
@@ -141,5 +150,15 @@ type TLSConfig struct {
// To use only http/1.1, set to ["http/1.1"].
NextProtos []string
GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// If specified, this transport is non-cacheable unless CertHolder is populated.
GetCert func() (*tls.Certificate, error)
// CertHolder can be populated to make transport configs that set GetCert cacheable.
// If set, CertHolder.GetCert must be equal to GetCert.
GetCertHolder *GetCertHolder
}
// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key.
type GetCertHolder struct {
GetCert func() (*tls.Certificate, error)
}

View File

@@ -24,6 +24,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"reflect"
"sync"
"time"
@@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) {
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
}
if !isValidHolders(config) {
return nil, fmt.Errorf("misconfigured holder for dialer or cert callback")
}
var (
rt http.RoundTripper
err error
@@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) {
return HTTPWrappersForConfig(config, rt)
}
func isValidHolders(config *Config) bool {
if config.TLS.GetCertHolder != nil {
if config.TLS.GetCertHolder.GetCert == nil ||
config.TLS.GetCert == nil ||
reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() {
return false
}
}
if config.DialHolder != nil {
if config.DialHolder.Dial == nil ||
config.Dial == nil ||
reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() {
return false
}
}
return true
}
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(c *Config) (*tls.Config, error) {

View File

@@ -21,6 +21,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"testing"
)
@@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
)
func TestNew(t *testing.T) {
globalGetCert := &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
}
globalDial := &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
}
testCases := map[string]struct {
Config *Config
Err bool
@@ -255,6 +263,144 @@ func TestNew(t *testing.T) {
},
},
},
"nil holders and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: false,
TLSCert: false,
TLSErr: false,
Default: true,
Insecure: false,
DefaultRoots: false,
},
"nil holders and non-nil regular get cert": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"nil holders and non-nil regular dial": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: false,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"non-nil dial holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
},
},
Err: true,
},
"non-nil cert holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil holders+internal and non-nil regular with correct address": {
Config: &Config{
TLS: TLSConfig{
GetCert: globalGetCert.GetCert,
GetCertHolder: globalGetCert,
},
Dial: globalDial.Dial,
DialHolder: globalDial,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
}
for k, testCase := range testCases {
t.Run(k, func(t *testing.T) {