Compare commits

...

14 Commits

Author SHA1 Message Date
Kubernetes Publisher
84e38b0e36 Update dependencies to v0.23.13 tag 2022-10-13 20:07:29 +00:00
Kubernetes Publisher
4fe0cac705 Merge pull request #112067 from tkashem/automated-cherry-pick-of-#109114-upstream-release-1.23-1661519732
client-go: make retry in Request thread safe

Kubernetes-commit: bca383612aed64083b6b2c10d2634082877dc6da
2022-10-04 19:53:33 +00:00
Kubernetes Publisher
398dfd1513 Merge pull request #112338 from enj/automated-cherry-pick-of-#112017-upstream-release-1.23
Automated cherry pick of #112017: exec auth: support TLS config caching

Kubernetes-commit: 6c32e02f49079daa992e85a149407e38505e92ec
2022-09-09 00:25:24 -07:00
Monis Khan
9285cfe8df exec auth: support TLS config caching
This change updates the transport.Config .Dial and .TLS.GetCert fields
to use a struct wrapper.  This indirection via a pointer allows the
functions to be compared and thus makes them valid to use as map keys.
This change is then leveraged by the existing global exec auth and TLS
config caches to return the same authenticator and TLS config even when
distinct but identical rest configs were used to create distinct
clientsets.

Signed-off-by: Monis Khan <mok@microsoft.com>

Kubernetes-commit: b69bbf362018ce4f105ff88fe118d4cf18f96706
2022-08-24 16:04:19 +00:00
Kubernetes Publisher
80c66f4109 Merge pull request #111273 from Abirdcfly/automated-cherry-pick-of-#111235-upstream-release-1.23
Automated cherry pick of #111235: fix a possible panic because of taking the address of nil

Kubernetes-commit: e9dbbc3f7471c711cfda65d7a154695700f0b282
2022-08-11 14:48:16 +00:00
Abirdcfly
aca59e4fb1 fix a possible panic because of taking the address of nil
Signed-off-by: Abirdcfly <fp544037857@gmail.com>

Kubernetes-commit: 5f436c0fb35b3a37c2baf3ffefe499b3c9284496
2022-07-19 10:39:08 +08:00
Kubernetes Publisher
a475c28713 Merge pull request #108791 from aojea/cherry-pick-108772
Cherry pick 108772

Kubernetes-commit: 05a09629ab8207a870c9178b0f5694aa92919d92
2022-03-31 09:19:36 +00:00
Kubernetes Publisher
8041ba924b Merge pull request #109159 from wojtek-t/automated-cherry-pick-of-#109137-upstream-release-1.23
Automated cherry pick of #109137 upstream release 1.23

Kubernetes-commit: 4f9c753bf05d1c353fe96eba1e06aeafc58259c8
2022-03-31 05:22:55 +00:00
Michael Bolot
0c2c708257 Addresses the issue which caused #109115
Kubernetes-commit: 814ae980477ab06f8dbe13ae4c4318110e6922f6
2022-03-29 12:35:13 -05:00
Wojciech Tyczyński
9b9e45fc6d Add test for indexer with multiple values
Kubernetes-commit: 6ba5a0bc38306b713723281d4e54d257e2936890
2022-03-30 08:52:10 +02:00
Abu Kashem
7763f75022 client-go: make retry in Request thread safe
Kubernetes-commit: 091f4f00395272e23a777d6bf068d67793bf8931
2022-03-29 13:09:26 -04:00
Kubernetes Publisher
ad6be0fa0b sync: initially remove files BUILD */BUILD BUILD.bazel */BUILD.bazel Gopkg.toml */.gitattributes 2022-03-25 17:27:50 +00:00
Antonio Ojea
a26f2df3da client-go: update generated
Kubernetes-commit: f628706339c120c19d28ccfa7b1a580516be1d1a
2022-03-17 16:03:10 +01:00
Antonio Ojea
b6f49c1554 default kubernetes agent for generated clients
Set default kubernetes agent if empty

Kubernetes-commit: 3de44bd759a34f6fead9fe7254b5204684b99de8
2022-03-17 13:19:43 +01:00
17 changed files with 622 additions and 53 deletions

8
go.mod
View File

@@ -30,8 +30,8 @@ require (
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac
google.golang.org/protobuf v1.27.1
k8s.io/api v0.0.0-20220124172526-d42c342a4737
k8s.io/apimachinery v0.0.0-20220124172104-276a8a7530a3
k8s.io/api v0.23.13
k8s.io/apimachinery v0.23.13
k8s.io/klog/v2 v2.30.0
k8s.io/kube-openapi v0.0.0-20211115234752-e816edb12b65
k8s.io/utils v0.0.0-20211116205334-6203023598ed
@@ -40,6 +40,6 @@ require (
)
replace (
k8s.io/api => k8s.io/api v0.0.0-20220124172526-d42c342a4737
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20220124172104-276a8a7530a3
k8s.io/api => k8s.io/api v0.23.13
k8s.io/apimachinery => k8s.io/apimachinery v0.23.13
)

8
go.sum
View File

@@ -610,10 +610,10 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
k8s.io/api v0.0.0-20220124172526-d42c342a4737 h1:q+j9LyZumYp5R/AOjXgzWYZa33Z4usAz0yInds3I1yE=
k8s.io/api v0.0.0-20220124172526-d42c342a4737/go.mod h1:9Vo9KEQyesRPDP6eG+jDkliZvSar3OLe8f5zfqiRoec=
k8s.io/apimachinery v0.0.0-20220124172104-276a8a7530a3 h1:n2y4Rh6ixZSeru9imTOu4bYNuAObzwMBnUcPVFzaXnk=
k8s.io/apimachinery v0.0.0-20220124172104-276a8a7530a3/go.mod h1:BEuFMMBaIbcOqVIJqNZJXGFTP4W6AycEpb5+m/97hrM=
k8s.io/api v0.23.13 h1:3iknOAH1Ves8mFity5Zf3rkHExJNz+a7vS9uxNUiTpw=
k8s.io/api v0.23.13/go.mod h1:/mjkHiefu63pWKPvk2Yo/4lJl+xHcClv20am3uUoutI=
k8s.io/apimachinery v0.23.13 h1:D/aFU9tzVCt3/uFHmBDWHeuHVTo2pvxmPGXyJktSM0k=
k8s.io/apimachinery v0.23.13/go.mod h1:BEuFMMBaIbcOqVIJqNZJXGFTP4W6AycEpb5+m/97hrM=
k8s.io/gengo v0.0.0-20210813121822-485abfe95c7c/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E=
k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE=
k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y=

View File

@@ -413,6 +413,10 @@ func (c *Clientset) Discovery() discovery.DiscoveryInterface {
func NewForConfig(c *rest.Config) (*Clientset, error) {
configShallowCopy := *c
if configShallowCopy.UserAgent == "" {
configShallowCopy.UserAgent = rest.DefaultKubernetesUserAgent()
}
// share the transport between all clients
httpClient, err := rest.HTTPClientFor(&configShallowCopy)
if err != nil {

View File

@@ -0,0 +1,88 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package kubernetes_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
"k8s.io/client-go/rest"
)
func TestClientUserAgent(t *testing.T) {
tests := []struct {
name string
userAgent string
expect string
}{
{
name: "empty",
expect: rest.DefaultKubernetesUserAgent(),
},
{
name: "custom",
userAgent: "test-agent",
expect: "test-agent",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAgent := r.Header.Get("User-Agent")
if userAgent != tc.expect {
t.Errorf("User Agent expected: %s got: %s", tc.expect, userAgent)
http.Error(w, "Unexpected user agent", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{}"))
}))
ts.Start()
defer ts.Close()
gv := v1.SchemeGroupVersion
config := &rest.Config{
Host: ts.URL,
}
config.GroupVersion = &gv
config.NegotiatedSerializer = scheme.Codecs.WithoutConversion()
config.UserAgent = tc.userAgent
config.ContentType = "application/json"
client, err := kubernetes.NewForConfig(config)
if err != nil {
t.Fatalf("failed to create REST client: %v", err)
}
_, err = client.CoreV1().Pods("").List(context.TODO(), metav1.ListOptions{})
if err != nil {
t.Error(err)
}
_, err = client.CoreV1().Secrets("").List(context.TODO(), metav1.ListOptions{})
if err != nil {
t.Error(err)
}
})
}
}

View File

@@ -1 +0,0 @@
base.go export-subst

View File

@@ -201,14 +201,18 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC
now: time.Now,
environ: os.Environ,
defaultDialer: defaultDialer,
connTracker: connTracker,
connTracker: connTracker,
}
for _, env := range config.Env {
a.env = append(a.env, env.Name+"="+env.Value)
}
// these functions are made comparable and stored in the cache so that repeated clientset
// construction with the same rest.Config results in a single TLS cache and Authenticator
a.getCert = &transport.GetCertHolder{GetCert: a.cert}
a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext}
return c.put(key, a), nil
}
@@ -263,8 +267,6 @@ type Authenticator struct {
now func() time.Time
environ func() []string
// defaultDialer is used for clients which don't specify a custom dialer
defaultDialer *connrotation.Dialer
// connTracker tracks all connections opened that we need to close when rotating a client certificate
connTracker *connrotation.ConnectionTracker
@@ -275,6 +277,12 @@ type Authenticator struct {
mu sync.Mutex
cachedCreds *credentials
exp time.Time
// getCert makes Authenticator.cert comparable to support TLS config caching
getCert *transport.GetCertHolder
// dial is used for clients which do not specify a custom dialer
// it is comparable to support TLS config caching
dial *transport.DialHolder
}
type credentials struct {
@@ -302,18 +310,20 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
if c.TLS.GetCert != nil {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
}
c.TLS.GetCert = a.cert
c.TLS.GetCert = a.getCert.GetCert
c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching
var d *connrotation.Dialer
if c.Dial != nil {
// if c has a custom dialer, we have to wrap it
d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
// TLS config caching is not supported for this config
d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
c.Dial = d.DialContext
c.DialHolder = nil
} else {
d = a.defaultDialer
c.Dial = a.dial.Dial
c.DialHolder = a.dial // comparable for TLS config caching
}
c.Dial = d.DialContext
return nil
}

View File

@@ -0,0 +1,106 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package exec_test // separate package to prevent circular import
import (
"context"
"testing"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilnet "k8s.io/apimachinery/pkg/util/net"
clientset "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)
// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used.
//
// In particular, when:
// - multiple identical rest configs exist as distinct objects, and
// - these rest configs use exec auth, and
// - these rest configs are used to create distinct clientsets, then
//
// the underlying TLS config is shared between those clientsets.
func TestExecTLSCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
t.Cleanup(cancel)
config1 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client1 := clientset.NewForConfigOrDie(config1)
config2 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client2 := clientset.NewForConfigOrDie(config2)
config3 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
Args: []string{"make this exec auth different"},
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client3 := clientset.NewForConfigOrDie(config3)
_, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
_, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
_, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{})
rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport
rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport
rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport
tlsConfig1, err := utilnet.TLSClientConfig(rt1)
if err != nil {
t.Fatal(err)
}
tlsConfig2, err := utilnet.TLSClientConfig(rt2)
if err != nil {
t.Fatal(err)
}
tlsConfig3, err := utilnet.TLSClientConfig(rt3)
if err != nil {
t.Fatal(err)
}
if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil {
t.Fatal("expected non-nil TLS configs")
}
if tlsConfig1 != tlsConfig2 {
t.Fatal("expected the same TLS config for matching exec config via rest config")
}
if tlsConfig1 == tlsConfig3 {
t.Fatal("expected different TLS config for non-matching exec config via rest config")
}
}

View File

@@ -82,6 +82,12 @@ func (r *RequestConstructionError) Error() string {
var noBackoff = &NoBackoff{}
type requestRetryFunc func(maxRetries int) WithRetry
func defaultRequestRetryFn(maxRetries int) WithRetry {
return &withRetry{maxRetries: maxRetries}
}
// Request allows for building up a request to a server in a chained fashion.
// Any errors are stored until the end of your call, so you only have to
// check once.
@@ -93,6 +99,7 @@ type Request struct {
rateLimiter flowcontrol.RateLimiter
backoff BackoffManager
timeout time.Duration
maxRetries int
// generic components accessible via method setters
verb string
@@ -109,9 +116,10 @@ type Request struct {
subresource string
// output
err error
body io.Reader
retry WithRetry
err error
body io.Reader
retryFn requestRetryFunc
}
// NewRequest creates a new request helper object for accessing runtime.Objects on a server.
@@ -142,7 +150,8 @@ func NewRequest(c *RESTClient) *Request {
backoff: backoff,
timeout: timeout,
pathPrefix: pathPrefix,
retry: &withRetry{maxRetries: 10},
maxRetries: 10,
retryFn: defaultRequestRetryFn,
warningHandler: c.warningHandler,
}
@@ -408,7 +417,10 @@ func (r *Request) Timeout(d time.Duration) *Request {
// function is specifically called with a different value.
// A zero maxRetries prevent it from doing retires and return an error immediately.
func (r *Request) MaxRetries(maxRetries int) *Request {
r.retry.SetMaxRetries(maxRetries)
if maxRetries < 0 {
maxRetries = 0
}
r.maxRetries = maxRetries
return r
}
@@ -688,8 +700,10 @@ func (r *Request) Watch(ctx context.Context) (watch.Interface, error) {
}
return false
}
var retryAfter *RetryAfter
url := r.URL().String()
withRetry := r.retryFn(r.maxRetries)
for {
req, err := r.newHTTPRequest(ctx)
if err != nil {
@@ -724,9 +738,9 @@ func (r *Request) Watch(ctx context.Context) (watch.Interface, error) {
defer readAndCloseResponseBody(resp)
var retry bool
retryAfter, retry = r.retry.NextRetry(req, resp, err, isErrRetryableFunc)
retryAfter, retry = withRetry.NextRetry(req, resp, err, isErrRetryableFunc)
if retry {
err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
err := withRetry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
if err == nil {
return false, nil
}
@@ -817,6 +831,7 @@ func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) {
}
var retryAfter *RetryAfter
withRetry := r.retryFn(r.maxRetries)
url := r.URL().String()
for {
req, err := r.newHTTPRequest(ctx)
@@ -862,9 +877,9 @@ func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) {
defer resp.Body.Close()
var retry bool
retryAfter, retry = r.retry.NextRetry(req, resp, err, neverRetryError)
retryAfter, retry = withRetry.NextRetry(req, resp, err, neverRetryError)
if retry {
err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
err := withRetry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body)
if err == nil {
return false, nil
}
@@ -961,6 +976,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp
// Right now we make about ten retry attempts if we get a Retry-After response.
var retryAfter *RetryAfter
withRetry := r.retryFn(r.maxRetries)
for {
req, err := r.newHTTPRequest(ctx)
if err != nil {
@@ -997,7 +1013,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp
}
var retry bool
retryAfter, retry = r.retry.NextRetry(req, resp, err, func(req *http.Request, err error) bool {
retryAfter, retry = withRetry.NextRetry(req, resp, err, func(req *http.Request, err error) bool {
// "Connection reset by peer" or "apiserver is shutting down" are usually a transient errors.
// Thus in case of "GET" operations, we simply retry it.
// We are not automatically retrying "write" operations, as they are not idempotent.
@@ -1011,7 +1027,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp
return false
})
if retry {
err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body)
err := withRetry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body)
if err == nil {
return false
}

View File

@@ -32,6 +32,7 @@ import (
"reflect"
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"
@@ -1194,7 +1195,8 @@ func TestRequestWatch(t *testing.T) {
c.Client = client
}
testCase.Request.backoff = &noSleepBackOff{}
testCase.Request.retry = &withRetry{maxRetries: testCase.maxRetries}
testCase.Request.maxRetries = testCase.maxRetries
testCase.Request.retryFn = defaultRequestRetryFn
watch, err := testCase.Request.Watch(context.Background())
@@ -1407,7 +1409,8 @@ func TestRequestStream(t *testing.T) {
c.Client = client
}
testCase.Request.backoff = &noSleepBackOff{}
testCase.Request.retry = &withRetry{maxRetries: testCase.maxRetries}
testCase.Request.maxRetries = testCase.maxRetries
testCase.Request.retryFn = defaultRequestRetryFn
body, err := testCase.Request.Stream(context.Background())
@@ -1462,7 +1465,7 @@ func TestRequestDo(t *testing.T) {
}
for i, testCase := range testCases {
testCase.Request.backoff = &NoBackoff{}
testCase.Request.retry = &withRetry{}
testCase.Request.retryFn = defaultRequestRetryFn
body, err := testCase.Request.Do(context.Background()).Raw()
hasErr := err != nil
if hasErr != testCase.Err {
@@ -1625,8 +1628,9 @@ func TestConnectionResetByPeerIsRetried(t *testing.T) {
return nil, &net.OpError{Err: syscall.ECONNRESET}
}),
},
backoff: backoff,
retry: &withRetry{maxRetries: 10},
backoff: backoff,
maxRetries: 10,
retryFn: defaultRequestRetryFn,
}
// We expect two retries of "connection reset by peer" and the success.
_, err := req.Do(context.Background()).Raw()
@@ -2699,8 +2703,9 @@ func TestRequestWithRetry(t *testing.T) {
c: &RESTClient{
Client: client,
},
backoff: &noSleepBackOff{},
retry: &withRetry{maxRetries: 1},
backoff: &noSleepBackOff{},
maxRetries: 1,
retryFn: defaultRequestRetryFn,
}
var transformFuncInvoked int
@@ -2890,8 +2895,9 @@ func testRequestWithRetry(t *testing.T, key string, doFunc func(ctx context.Cont
content: defaultContentConfig(),
Client: client,
},
backoff: &noSleepBackOff{},
retry: &withRetry{maxRetries: test.maxRetries},
backoff: &noSleepBackOff{},
maxRetries: test.maxRetries,
retryFn: defaultRequestRetryFn,
}
doFunc(context.Background(), req)
@@ -3093,3 +3099,50 @@ func TestTransportConcurrency(t *testing.T) {
})
}
}
func TestRequestConcurrencyWithRetry(t *testing.T) {
var attempts int32
client := clientForFunc(func(req *http.Request) (*http.Response, error) {
defer func() {
atomic.AddInt32(&attempts, 1)
}()
// always send a retry-after response
return &http.Response{
StatusCode: http.StatusInternalServerError,
Header: http.Header{"Retry-After": []string{"1"}},
}, nil
})
req := &Request{
verb: "POST",
c: &RESTClient{
content: defaultContentConfig(),
Client: client,
},
backoff: &noSleepBackOff{},
maxRetries: 9, // 10 attempts in total, including the first
retryFn: defaultRequestRetryFn,
}
concurrency := 20
wg := sync.WaitGroup{}
wg.Add(concurrency)
startCh := make(chan struct{})
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
<-startCh
req.Do(context.Background())
}()
}
close(startCh)
wg.Wait()
// we expect (concurrency*req.maxRetries+1) attempts to be recorded
expected := concurrency * (req.maxRetries + 1)
if atomic.LoadInt32(&attempts) != int32(expected) {
t.Errorf("Expected attempts: %d, but got: %d", expected, attempts)
}
}

View File

@@ -284,18 +284,15 @@ func (c *threadSafeMap) updateIndices(oldObj interface{}, newObj interface{}, ke
c.indices[name] = index
}
if len(indexValues) == 1 && len(oldIndexValues) == 1 && indexValues[0] == oldIndexValues[0] {
// We optimize for the most common case where indexFunc returns a single value which has not been changed
continue
}
for _, value := range oldIndexValues {
// We optimize for the most common case where index returns a single value.
if len(indexValues) == 1 && value == indexValues[0] {
continue
}
c.deleteKeyFromIndex(key, value, index)
}
for _, value := range indexValues {
// We optimize for the most common case where index returns a single value.
if len(oldIndexValues) == 1 && value == oldIndexValues[0] {
continue
}
c.addKeyToIndex(key, value, index)
}
}

View File

@@ -18,7 +18,11 @@ package cache
import (
"fmt"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
)
func TestThreadSafeStoreDeleteRemovesEmptySetsFromIndex(t *testing.T) {
@@ -92,6 +96,75 @@ func TestThreadSafeStoreAddKeepsNonEmptySetPostDeleteFromIndex(t *testing.T) {
}
}
func TestThreadSafeStoreIndexingFunctionsWithMultipleValues(t *testing.T) {
testIndexer := "testIndexer"
indexers := Indexers{
testIndexer: func(obj interface{}) ([]string, error) {
return strings.Split(obj.(string), ","), nil
},
}
indices := Indices{}
store := NewThreadSafeStore(indexers, indices).(*threadSafeMap)
store.Add("key1", "foo")
store.Add("key2", "bar")
assert := assert.New(t)
compare := func(key string, expected []string) error {
values := store.indices[testIndexer][key].List()
if cmp.Equal(values, expected) {
return nil
}
return fmt.Errorf("unexpected index for key %s, diff=%s", key, cmp.Diff(values, expected))
}
assert.NoError(compare("foo", []string{"key1"}))
assert.NoError(compare("bar", []string{"key2"}))
store.Update("key2", "foo,bar")
assert.NoError(compare("foo", []string{"key1", "key2"}))
assert.NoError(compare("bar", []string{"key2"}))
store.Update("key1", "foo,bar")
assert.NoError(compare("foo", []string{"key1", "key2"}))
assert.NoError(compare("bar", []string{"key1", "key2"}))
store.Add("key3", "foo,bar,baz")
assert.NoError(compare("foo", []string{"key1", "key2", "key3"}))
assert.NoError(compare("bar", []string{"key1", "key2", "key3"}))
assert.NoError(compare("baz", []string{"key3"}))
store.Update("key1", "foo")
assert.NoError(compare("foo", []string{"key1", "key2", "key3"}))
assert.NoError(compare("bar", []string{"key2", "key3"}))
assert.NoError(compare("baz", []string{"key3"}))
store.Update("key2", "bar")
assert.NoError(compare("foo", []string{"key1", "key3"}))
assert.NoError(compare("bar", []string{"key2", "key3"}))
assert.NoError(compare("baz", []string{"key3"}))
store.Delete("key1")
assert.NoError(compare("foo", []string{"key3"}))
assert.NoError(compare("bar", []string{"key2", "key3"}))
assert.NoError(compare("baz", []string{"key3"}))
store.Delete("key3")
assert.NoError(compare("foo", []string{}))
assert.NoError(compare("bar", []string{"key2"}))
assert.NoError(compare("baz", []string{}))
}
func BenchmarkIndexer(b *testing.B) {
testIndexer := "testIndexer"

View File

@@ -51,10 +51,10 @@ func (a *PromptingAuthLoader) LoadAuth(path string) (*clientauth.Info, error) {
// Prompt for user/pass and write a file if none exists.
if _, err := os.Stat(path); os.IsNotExist(err) {
authPtr, err := a.Prompt()
auth := *authPtr
if err != nil {
return nil, err
}
auth := *authPtr
data, err := json.Marshal(auth)
if err != nil {
return &auth, err

View File

@@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"fmt"
"net"
"net/http"
@@ -50,6 +51,9 @@ type tlsCacheKey struct {
serverName string
nextProtos string
disableCompression bool
// these functions are wrapped to allow them to be used as map keys
getCert *GetCertHolder
dial *DialHolder
}
func (t tlsCacheKey) String() string {
@@ -57,7 +61,8 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 {
keyText = "<redacted>"
}
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression)
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
}
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@@ -87,8 +92,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
return http.DefaultTransport, nil
}
dial := config.Dial
if dial == nil {
var dial func(ctx context.Context, network, address string) (net.Conn, error)
if config.Dial != nil {
dial = config.Dial
} else {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
@@ -133,10 +140,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
return tlsCacheKey{}, false, err
}
if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil {
if c.Proxy != nil {
// cannot determine equality for functions
return tlsCacheKey{}, false, nil
}
if c.Dial != nil && c.DialHolder == nil {
// cannot determine equality for dial function that doesn't have non-nil DialHolder set as well
return tlsCacheKey{}, false, nil
}
if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil {
// cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well
return tlsCacheKey{}, false, nil
}
k := tlsCacheKey{
insecure: c.TLS.Insecure,
@@ -144,6 +159,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
serverName: c.TLS.ServerName,
nextProtos: strings.Join(c.TLS.NextProtos, ","),
disableCompression: c.DisableCompression,
getCert: c.TLS.GetCertHolder,
dial: c.DialHolder,
}
if c.TLS.ReloadTLSFiles {

View File

@@ -21,6 +21,7 @@ import (
"crypto/tls"
"net"
"net/http"
"net/url"
"testing"
)
@@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
if keyA != (tlsCacheKey{}) {
t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
}
}
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
getCert := func() (*tls.Certificate, error) { return nil, nil }
getCertHolder := &GetCertHolder{GetCert: getCert}
uniqueConfigurations := map[string]*Config{
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
@@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) {
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
"getCert3": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
GetCertHolder: getCertHolder,
},
},
"getCert1, key 2": {
TLS: TLSConfig{
KeyData: []byte{2},

View File

@@ -68,7 +68,11 @@ type Config struct {
WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections.
// If specified, this transport will be non-cacheable unless DialHolder is also set.
Dial func(ctx context.Context, network, address string) (net.Conn, error)
// DialHolder can be populated to make transport configs cacheable.
// If specified, DialHolder.Dial must be equal to Dial.
DialHolder *DialHolder
// Proxy is the proxy func to be used for all requests made by this
// transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy
@@ -78,6 +82,11 @@ type Config struct {
Proxy func(*http.Request) (*url.URL, error)
}
// DialHolder is used to make the wrapped function comparable so that it can be used as a map key.
type DialHolder struct {
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options
type ImpersonationConfig struct {
// UserName matches user.Info.GetName()
@@ -143,5 +152,15 @@ type TLSConfig struct {
// To use only http/1.1, set to ["http/1.1"].
NextProtos []string
GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// If specified, this transport is non-cacheable unless CertHolder is populated.
GetCert func() (*tls.Certificate, error)
// CertHolder can be populated to make transport configs that set GetCert cacheable.
// If set, CertHolder.GetCert must be equal to GetCert.
GetCertHolder *GetCertHolder
}
// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key.
type GetCertHolder struct {
GetCert func() (*tls.Certificate, error)
}

View File

@@ -24,6 +24,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"reflect"
"sync"
"time"
@@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) {
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
}
if !isValidHolders(config) {
return nil, fmt.Errorf("misconfigured holder for dialer or cert callback")
}
var (
rt http.RoundTripper
err error
@@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) {
return HTTPWrappersForConfig(config, rt)
}
func isValidHolders(config *Config) bool {
if config.TLS.GetCertHolder != nil {
if config.TLS.GetCertHolder.GetCert == nil ||
config.TLS.GetCert == nil ||
reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() {
return false
}
}
if config.DialHolder != nil {
if config.DialHolder.Dial == nil ||
config.Dial == nil ||
reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() {
return false
}
}
return true
}
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(c *Config) (*tls.Config, error) {

View File

@@ -21,6 +21,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"testing"
)
@@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
)
func TestNew(t *testing.T) {
globalGetCert := &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
}
globalDial := &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
}
testCases := map[string]struct {
Config *Config
Err bool
@@ -255,6 +263,144 @@ func TestNew(t *testing.T) {
},
},
},
"nil holders and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: false,
TLSCert: false,
TLSErr: false,
Default: true,
Insecure: false,
DefaultRoots: false,
},
"nil holders and non-nil regular get cert": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"nil holders and non-nil regular dial": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: false,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"non-nil dial holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
},
},
Err: true,
},
"non-nil cert holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil holders+internal and non-nil regular with correct address": {
Config: &Config{
TLS: TLSConfig{
GetCert: globalGetCert.GetCert,
GetCertHolder: globalGetCert,
},
Dial: globalDial.Dial,
DialHolder: globalDial,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
}
for k, testCase := range testCases {
t.Run(k, func(t *testing.T) {