mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-23 11:50:44 +00:00
Merge pull request #112017 from enj/enj/i/exec_tls_cache
exec auth: support TLS config caching
This commit is contained in:
commit
082da2f04e
@ -199,7 +199,6 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC
|
|||||||
now: time.Now,
|
now: time.Now,
|
||||||
environ: os.Environ,
|
environ: os.Environ,
|
||||||
|
|
||||||
defaultDialer: defaultDialer,
|
|
||||||
connTracker: connTracker,
|
connTracker: connTracker,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,6 +206,11 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC
|
|||||||
a.env = append(a.env, env.Name+"="+env.Value)
|
a.env = append(a.env, env.Name+"="+env.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// these functions are made comparable and stored in the cache so that repeated clientset
|
||||||
|
// construction with the same rest.Config results in a single TLS cache and Authenticator
|
||||||
|
a.getCert = &transport.GetCertHolder{GetCert: a.cert}
|
||||||
|
a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext}
|
||||||
|
|
||||||
return c.put(key, a), nil
|
return c.put(key, a), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,8 +265,6 @@ type Authenticator struct {
|
|||||||
now func() time.Time
|
now func() time.Time
|
||||||
environ func() []string
|
environ func() []string
|
||||||
|
|
||||||
// defaultDialer is used for clients which don't specify a custom dialer
|
|
||||||
defaultDialer *connrotation.Dialer
|
|
||||||
// connTracker tracks all connections opened that we need to close when rotating a client certificate
|
// connTracker tracks all connections opened that we need to close when rotating a client certificate
|
||||||
connTracker *connrotation.ConnectionTracker
|
connTracker *connrotation.ConnectionTracker
|
||||||
|
|
||||||
@ -273,6 +275,12 @@ type Authenticator struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cachedCreds *credentials
|
cachedCreds *credentials
|
||||||
exp time.Time
|
exp time.Time
|
||||||
|
|
||||||
|
// getCert makes Authenticator.cert comparable to support TLS config caching
|
||||||
|
getCert *transport.GetCertHolder
|
||||||
|
// dial is used for clients which do not specify a custom dialer
|
||||||
|
// it is comparable to support TLS config caching
|
||||||
|
dial *transport.DialHolder
|
||||||
}
|
}
|
||||||
|
|
||||||
type credentials struct {
|
type credentials struct {
|
||||||
@ -300,17 +308,19 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
|
|||||||
if c.HasCertCallback() {
|
if c.HasCertCallback() {
|
||||||
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
|
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
|
||||||
}
|
}
|
||||||
c.TLS.GetCert = a.cert
|
c.TLS.GetCert = a.getCert.GetCert
|
||||||
|
c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching
|
||||||
|
|
||||||
var d *connrotation.Dialer
|
|
||||||
if c.Dial != nil {
|
if c.Dial != nil {
|
||||||
// if c has a custom dialer, we have to wrap it
|
// if c has a custom dialer, we have to wrap it
|
||||||
d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
|
// TLS config caching is not supported for this config
|
||||||
} else {
|
d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
|
||||||
d = a.defaultDialer
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Dial = d.DialContext
|
c.Dial = d.DialContext
|
||||||
|
c.DialHolder = nil
|
||||||
|
} else {
|
||||||
|
c.Dial = a.dial.Dial
|
||||||
|
c.DialHolder = a.dial // comparable for TLS config caching
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,106 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2022 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package exec_test // separate package to prevent circular import
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||||
|
clientset "k8s.io/client-go/kubernetes"
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used.
|
||||||
|
//
|
||||||
|
// In particular, when:
|
||||||
|
// - multiple identical rest configs exist as distinct objects, and
|
||||||
|
// - these rest configs use exec auth, and
|
||||||
|
// - these rest configs are used to create distinct clientsets, then
|
||||||
|
//
|
||||||
|
// the underlying TLS config is shared between those clientsets.
|
||||||
|
func TestExecTLSCache(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
config1 := &rest.Config{
|
||||||
|
Host: "https://localhost",
|
||||||
|
ExecProvider: &clientcmdapi.ExecConfig{
|
||||||
|
Command: "./testdata/test-plugin.sh",
|
||||||
|
APIVersion: "client.authentication.k8s.io/v1",
|
||||||
|
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client1 := clientset.NewForConfigOrDie(config1)
|
||||||
|
|
||||||
|
config2 := &rest.Config{
|
||||||
|
Host: "https://localhost",
|
||||||
|
ExecProvider: &clientcmdapi.ExecConfig{
|
||||||
|
Command: "./testdata/test-plugin.sh",
|
||||||
|
APIVersion: "client.authentication.k8s.io/v1",
|
||||||
|
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client2 := clientset.NewForConfigOrDie(config2)
|
||||||
|
|
||||||
|
config3 := &rest.Config{
|
||||||
|
Host: "https://localhost",
|
||||||
|
ExecProvider: &clientcmdapi.ExecConfig{
|
||||||
|
Command: "./testdata/test-plugin.sh",
|
||||||
|
Args: []string{"make this exec auth different"},
|
||||||
|
APIVersion: "client.authentication.k8s.io/v1",
|
||||||
|
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client3 := clientset.NewForConfigOrDie(config3)
|
||||||
|
|
||||||
|
_, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
|
||||||
|
_, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
|
||||||
|
_, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{})
|
||||||
|
|
||||||
|
rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport
|
||||||
|
rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport
|
||||||
|
rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport
|
||||||
|
|
||||||
|
tlsConfig1, err := utilnet.TLSClientConfig(rt1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tlsConfig2, err := utilnet.TLSClientConfig(rt2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tlsConfig3, err := utilnet.TLSClientConfig(rt3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil {
|
||||||
|
t.Fatal("expected non-nil TLS configs")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig1 != tlsConfig2 {
|
||||||
|
t.Fatal("expected the same TLS config for matching exec config via rest config")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig1 == tlsConfig3 {
|
||||||
|
t.Fatal("expected different TLS config for non-matching exec config via rest config")
|
||||||
|
}
|
||||||
|
}
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -55,6 +56,9 @@ type tlsCacheKey struct {
|
|||||||
serverName string
|
serverName string
|
||||||
nextProtos string
|
nextProtos string
|
||||||
disableCompression bool
|
disableCompression bool
|
||||||
|
// these functions are wrapped to allow them to be used as map keys
|
||||||
|
getCert *GetCertHolder
|
||||||
|
dial *DialHolder
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t tlsCacheKey) String() string {
|
func (t tlsCacheKey) String() string {
|
||||||
@ -62,7 +66,8 @@ func (t tlsCacheKey) String() string {
|
|||||||
if len(t.keyData) > 0 {
|
if len(t.keyData) > 0 {
|
||||||
keyText = "<redacted>"
|
keyText = "<redacted>"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression)
|
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
|
||||||
|
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
|
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
|
||||||
@ -92,8 +97,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
|
|||||||
return http.DefaultTransport, nil
|
return http.DefaultTransport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dial := config.Dial
|
var dial func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
if dial == nil {
|
if config.Dial != nil {
|
||||||
|
dial = config.Dial
|
||||||
|
} else {
|
||||||
dial = (&net.Dialer{
|
dial = (&net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
KeepAlive: 30 * time.Second,
|
KeepAlive: 30 * time.Second,
|
||||||
@ -138,10 +145,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
|
|||||||
return tlsCacheKey{}, false, err
|
return tlsCacheKey{}, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil {
|
if c.Proxy != nil {
|
||||||
// cannot determine equality for functions
|
// cannot determine equality for functions
|
||||||
return tlsCacheKey{}, false, nil
|
return tlsCacheKey{}, false, nil
|
||||||
}
|
}
|
||||||
|
if c.Dial != nil && c.DialHolder == nil {
|
||||||
|
// cannot determine equality for dial function that doesn't have non-nil DialHolder set as well
|
||||||
|
return tlsCacheKey{}, false, nil
|
||||||
|
}
|
||||||
|
if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil {
|
||||||
|
// cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well
|
||||||
|
return tlsCacheKey{}, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
k := tlsCacheKey{
|
k := tlsCacheKey{
|
||||||
insecure: c.TLS.Insecure,
|
insecure: c.TLS.Insecure,
|
||||||
@ -149,6 +164,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
|
|||||||
serverName: c.TLS.ServerName,
|
serverName: c.TLS.ServerName,
|
||||||
nextProtos: strings.Join(c.TLS.NextProtos, ","),
|
nextProtos: strings.Join(c.TLS.NextProtos, ","),
|
||||||
disableCompression: c.DisableCompression,
|
disableCompression: c.DisableCompression,
|
||||||
|
getCert: c.TLS.GetCertHolder,
|
||||||
|
dial: c.DialHolder,
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.TLS.ReloadTLSFiles {
|
if c.TLS.ReloadTLSFiles {
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) {
|
|||||||
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
|
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if keyA != (tlsCacheKey{}) {
|
||||||
|
t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure config fields that affect the tls config affect the cache key
|
// Make sure config fields that affect the tls config affect the cache key
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
getCert := func() (*tls.Certificate, error) { return nil, nil }
|
getCert := func() (*tls.Certificate, error) { return nil, nil }
|
||||||
|
getCertHolder := &GetCertHolder{GetCert: getCert}
|
||||||
uniqueConfigurations := map[string]*Config{
|
uniqueConfigurations := map[string]*Config{
|
||||||
|
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
|
||||||
"no tls": {},
|
"no tls": {},
|
||||||
"dialer": {Dial: dialer.DialContext},
|
"dialer": {Dial: dialer.DialContext},
|
||||||
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
|
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
|
||||||
|
"dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}},
|
||||||
|
"dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
|
||||||
"insecure": {TLS: TLSConfig{Insecure: true}},
|
"insecure": {TLS: TLSConfig{Insecure: true}},
|
||||||
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
|
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
|
||||||
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
|
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
|
||||||
@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) {
|
|||||||
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"getCert3": {
|
||||||
|
TLS: TLSConfig{
|
||||||
|
KeyData: []byte{1},
|
||||||
|
GetCert: getCert,
|
||||||
|
GetCertHolder: getCertHolder,
|
||||||
|
},
|
||||||
|
},
|
||||||
"getCert1, key 2": {
|
"getCert1, key 2": {
|
||||||
TLS: TLSConfig{
|
TLS: TLSConfig{
|
||||||
KeyData: []byte{2},
|
KeyData: []byte{2},
|
||||||
|
@ -68,7 +68,11 @@ type Config struct {
|
|||||||
WrapTransport WrapperFunc
|
WrapTransport WrapperFunc
|
||||||
|
|
||||||
// Dial specifies the dial function for creating unencrypted TCP connections.
|
// Dial specifies the dial function for creating unencrypted TCP connections.
|
||||||
|
// If specified, this transport will be non-cacheable unless DialHolder is also set.
|
||||||
Dial func(ctx context.Context, network, address string) (net.Conn, error)
|
Dial func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
// DialHolder can be populated to make transport configs cacheable.
|
||||||
|
// If specified, DialHolder.Dial must be equal to Dial.
|
||||||
|
DialHolder *DialHolder
|
||||||
|
|
||||||
// Proxy is the proxy func to be used for all requests made by this
|
// Proxy is the proxy func to be used for all requests made by this
|
||||||
// transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy
|
// transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy
|
||||||
@ -78,6 +82,11 @@ type Config struct {
|
|||||||
Proxy func(*http.Request) (*url.URL, error)
|
Proxy func(*http.Request) (*url.URL, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialHolder is used to make the wrapped function comparable so that it can be used as a map key.
|
||||||
|
type DialHolder struct {
|
||||||
|
Dial func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
// ImpersonationConfig has all the available impersonation options
|
// ImpersonationConfig has all the available impersonation options
|
||||||
type ImpersonationConfig struct {
|
type ImpersonationConfig struct {
|
||||||
// UserName matches user.Info.GetName()
|
// UserName matches user.Info.GetName()
|
||||||
@ -143,5 +152,15 @@ type TLSConfig struct {
|
|||||||
// To use only http/1.1, set to ["http/1.1"].
|
// To use only http/1.1, set to ["http/1.1"].
|
||||||
NextProtos []string
|
NextProtos []string
|
||||||
|
|
||||||
GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
|
// Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
|
||||||
|
// If specified, this transport is non-cacheable unless CertHolder is populated.
|
||||||
|
GetCert func() (*tls.Certificate, error)
|
||||||
|
// CertHolder can be populated to make transport configs that set GetCert cacheable.
|
||||||
|
// If set, CertHolder.GetCert must be equal to GetCert.
|
||||||
|
GetCertHolder *GetCertHolder
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key.
|
||||||
|
type GetCertHolder struct {
|
||||||
|
GetCert func() (*tls.Certificate, error)
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) {
|
|||||||
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
|
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !isValidHolders(config) {
|
||||||
|
return nil, fmt.Errorf("misconfigured holder for dialer or cert callback")
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
rt http.RoundTripper
|
rt http.RoundTripper
|
||||||
err error
|
err error
|
||||||
@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) {
|
|||||||
return HTTPWrappersForConfig(config, rt)
|
return HTTPWrappersForConfig(config, rt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isValidHolders(config *Config) bool {
|
||||||
|
if config.TLS.GetCertHolder != nil {
|
||||||
|
if config.TLS.GetCertHolder.GetCert == nil ||
|
||||||
|
config.TLS.GetCert == nil ||
|
||||||
|
reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.DialHolder != nil {
|
||||||
|
if config.DialHolder.Dial == nil ||
|
||||||
|
config.Dial == nil ||
|
||||||
|
reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
|
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
|
||||||
// by the provided Config. Will return nil if no transport level security is requested.
|
// by the provided Config. Will return nil if no transport level security is requested.
|
||||||
func TLSConfigFor(c *Config) (*tls.Config, error) {
|
func TLSConfigFor(c *Config) (*tls.Config, error) {
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
|
globalGetCert := &GetCertHolder{
|
||||||
|
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
||||||
|
}
|
||||||
|
globalDial := &DialHolder{
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
|
||||||
|
}
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
Err bool
|
Err bool
|
||||||
@ -255,6 +263,144 @@ func TestNew(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"nil holders and nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: nil,
|
||||||
|
GetCertHolder: nil,
|
||||||
|
},
|
||||||
|
Dial: nil,
|
||||||
|
DialHolder: nil,
|
||||||
|
},
|
||||||
|
Err: false,
|
||||||
|
TLS: false,
|
||||||
|
TLSCert: false,
|
||||||
|
TLSErr: false,
|
||||||
|
Default: true,
|
||||||
|
Insecure: false,
|
||||||
|
DefaultRoots: false,
|
||||||
|
},
|
||||||
|
"nil holders and non-nil regular get cert": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
||||||
|
GetCertHolder: nil,
|
||||||
|
},
|
||||||
|
Dial: nil,
|
||||||
|
DialHolder: nil,
|
||||||
|
},
|
||||||
|
Err: false,
|
||||||
|
TLS: true,
|
||||||
|
TLSCert: true,
|
||||||
|
TLSErr: false,
|
||||||
|
Default: false,
|
||||||
|
Insecure: false,
|
||||||
|
DefaultRoots: true,
|
||||||
|
},
|
||||||
|
"nil holders and non-nil regular dial": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: nil,
|
||||||
|
GetCertHolder: nil,
|
||||||
|
},
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
|
||||||
|
DialHolder: nil,
|
||||||
|
},
|
||||||
|
Err: false,
|
||||||
|
TLS: true,
|
||||||
|
TLSCert: false,
|
||||||
|
TLSErr: false,
|
||||||
|
Default: false,
|
||||||
|
Insecure: false,
|
||||||
|
DefaultRoots: true,
|
||||||
|
},
|
||||||
|
"non-nil dial holder and nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: nil,
|
||||||
|
GetCertHolder: nil,
|
||||||
|
},
|
||||||
|
Dial: nil,
|
||||||
|
DialHolder: &DialHolder{},
|
||||||
|
},
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
"non-nil cert holder and nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: nil,
|
||||||
|
GetCertHolder: &GetCertHolder{},
|
||||||
|
},
|
||||||
|
Dial: nil,
|
||||||
|
DialHolder: nil,
|
||||||
|
},
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
"non-nil dial holder and non-nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: nil,
|
||||||
|
GetCertHolder: nil,
|
||||||
|
},
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
|
||||||
|
DialHolder: &DialHolder{},
|
||||||
|
},
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
"non-nil cert holder and non-nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
||||||
|
GetCertHolder: &GetCertHolder{},
|
||||||
|
},
|
||||||
|
Dial: nil,
|
||||||
|
DialHolder: nil,
|
||||||
|
},
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
"non-nil dial holder+internal and non-nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: nil,
|
||||||
|
GetCertHolder: nil,
|
||||||
|
},
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
|
||||||
|
DialHolder: &DialHolder{
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
"non-nil cert holder+internal and non-nil regular": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
||||||
|
GetCertHolder: &GetCertHolder{
|
||||||
|
GetCert: func() (*tls.Certificate, error) { return nil, nil },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Dial: nil,
|
||||||
|
DialHolder: nil,
|
||||||
|
},
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
"non-nil holders+internal and non-nil regular with correct address": {
|
||||||
|
Config: &Config{
|
||||||
|
TLS: TLSConfig{
|
||||||
|
GetCert: globalGetCert.GetCert,
|
||||||
|
GetCertHolder: globalGetCert,
|
||||||
|
},
|
||||||
|
Dial: globalDial.Dial,
|
||||||
|
DialHolder: globalDial,
|
||||||
|
},
|
||||||
|
Err: false,
|
||||||
|
TLS: true,
|
||||||
|
TLSCert: true,
|
||||||
|
TLSErr: false,
|
||||||
|
Default: false,
|
||||||
|
Insecure: false,
|
||||||
|
DefaultRoots: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for k, testCase := range testCases {
|
for k, testCase := range testCases {
|
||||||
t.Run(k, func(t *testing.T) {
|
t.Run(k, func(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user