mirror of
https://github.com/rancher/dynamiclistener.git
synced 2026-06-30 22:35:20 +00:00
Add support for wildcard SANs (#272)
* factory: relax cnRegexp to accept RFC 6125 single-label wildcards * factory: escape '*' in getAnnotationKey to satisfy K8s annotation key rules * factory: NeedsUpdate honors existing wildcard SANs (RFC 6125 match) * factory: tests for cert-lifecycle paths with wildcard SANs * listener: filter wildcards from runtime sources (TLS SNI, TCP, HTTP) --------- Co-authored-by: Eshaan Lumba <lumbaeshaan@microsoft.com>
This commit is contained in:
@@ -29,7 +29,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
cnRegexp = regexp.MustCompile("^([A-Za-z0-9:][-A-Za-z0-9_.:]*)?[A-Za-z0-9:]$")
|
||||
cnRegexp = regexp.MustCompile(`^(\*\.)?([A-Za-z0-9:][-A-Za-z0-9_.:]*)?[A-Za-z0-9:]$`)
|
||||
)
|
||||
|
||||
type TLS struct {
|
||||
@@ -268,22 +268,49 @@ func IsStatic(secret *v1.Secret) bool {
|
||||
|
||||
// NeedsUpdate returns true if any of the CNs are not currently present on the
|
||||
// secret's Certificate, as recorded in the cnPrefix annotations. It will return
|
||||
// false if all requested CNs are already present, or if maxSANs is non-zero and has
|
||||
// false if all requested CNs are already present (either explicitly, or covered
|
||||
// by an existing wildcard SAN per RFC 6125), or if maxSANs is non-zero and has
|
||||
// been exceeded.
|
||||
func NeedsUpdate(maxSANs int, secret *v1.Secret, cn ...string) bool {
|
||||
if secret == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
existingCNs := cns(secret)
|
||||
for _, cn := range cn {
|
||||
if secret.Annotations[getAnnotationKey(cn)] == "" {
|
||||
if maxSANs > 0 && len(cns(secret)) >= maxSANs {
|
||||
return false
|
||||
}
|
||||
if secret.Annotations[getAnnotationKey(cn)] != "" {
|
||||
continue
|
||||
}
|
||||
if isCoveredByWildcard(cn, existingCNs) {
|
||||
continue
|
||||
}
|
||||
if maxSANs > 0 && len(existingCNs) >= maxSANs {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isCoveredByWildcard reports whether cn is matched by any "*.parent" entry in
|
||||
// existing, per RFC 6125 single-label leftmost-wildcard semantics.
|
||||
//
|
||||
// "*.example.com" covers "foo.example.com" but NOT "a.b.example.com" and NOT "example.com".
|
||||
//
|
||||
// A wildcard cn is never considered covered by another wildcard.
|
||||
func isCoveredByWildcard(cn string, existing []string) bool {
|
||||
if strings.HasPrefix(cn, "*.") {
|
||||
return false
|
||||
}
|
||||
dot := strings.IndexByte(cn, '.')
|
||||
if dot < 1 {
|
||||
return false
|
||||
}
|
||||
parent := "*" + cn[dot:]
|
||||
for _, e := range existing {
|
||||
if e == parent {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -348,11 +375,12 @@ func NewPrivateKey() (crypto.Signer, error) {
|
||||
func getAnnotationKey(cn string) string {
|
||||
cn = cnPrefix + cn
|
||||
cnLen := len(cn)
|
||||
if cnLen < 64 && !strings.ContainsRune(cn, ':') {
|
||||
if cnLen < 64 && !strings.ContainsRune(cn, ':') && !strings.ContainsRune(cn, '*') {
|
||||
return cn
|
||||
}
|
||||
digest := sha256.Sum256([]byte(cn))
|
||||
cn = strings.ReplaceAll(cn, ":", "_")
|
||||
cn = strings.ReplaceAll(cn, "*", "_")
|
||||
if cnLen > 56 {
|
||||
cnLen = 56
|
||||
}
|
||||
|
||||
322
factory/gen_test.go
Normal file
322
factory/gen_test.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rancher/dynamiclistener/cert"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
)
|
||||
|
||||
var hashSuffixRe = regexp.MustCompile(`-[0-9a-f]{6}$`)
|
||||
|
||||
func TestCnRegexp_Wildcards(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
want bool
|
||||
}{
|
||||
// Existing valid CNs still validate.
|
||||
{"plain hostname", "kubernetes", true},
|
||||
{"two-label FQDN", "foo.example.com", true},
|
||||
{"multi-label FQDN", "a.b.c.example.com", true},
|
||||
{"IPv4", "127.0.0.1", true},
|
||||
{"IPv6", "2001:db8::1", true},
|
||||
|
||||
// New: RFC 6125 single-label leading wildcard.
|
||||
{"leading wildcard, two-label parent", "*.example.com", true},
|
||||
{"leading wildcard, multi-label parent", "*.foo.bar.example.com", true},
|
||||
{"leading wildcard, single-char label after", "*.a", true},
|
||||
|
||||
// Still rejected: invalid wildcard forms.
|
||||
{"bare wildcard", "*", false},
|
||||
{"multi-label wildcard", "*.*.example.com", false},
|
||||
{"embedded wildcard", "foo*.example.com", false},
|
||||
{"prefix wildcard", "*foo.example.com", false},
|
||||
{"double leading wildcard", "**.example.com", false},
|
||||
{"trailing dot FQDN", "*.example.com.", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := cnRegexp.MatchString(tt.cn)
|
||||
if got != tt.want {
|
||||
t.Errorf("cnRegexp.MatchString(%q) = %v, want %v", tt.cn, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnnotationKey_EscapesWildcard(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
}{
|
||||
{"two-label wildcard", "*.example.com"},
|
||||
{"multi-label wildcard", "*.foo.bar.example.com"},
|
||||
{"long wildcard hostname", "*.this-is-a-very-long-subdomain-that-makes-the-whole-thing-exceed-sixty-three.example.com"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := getAnnotationKey(tt.cn)
|
||||
|
||||
if strings.ContainsRune(key, '*') {
|
||||
t.Errorf("getAnnotationKey(%q) = %q, contains '*' (invalid in K8s annotation keys)", tt.cn, key)
|
||||
}
|
||||
|
||||
nameLen := len(strings.TrimPrefix(key, cnPrefix))
|
||||
if nameLen >= 64 {
|
||||
t.Errorf("getAnnotationKey(%q) name part is %d chars, must be < 64", tt.cn, nameLen)
|
||||
}
|
||||
|
||||
if got := getAnnotationKey(tt.cn); got != key {
|
||||
t.Errorf("getAnnotationKey(%q) is not deterministic: %q vs %q", tt.cn, key, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnnotationKey_IPv6AndWildcardCoexist(t *testing.T) {
|
||||
t.Run("IPv6 still escaped", func(t *testing.T) {
|
||||
ipv6 := getAnnotationKey("2001:db8::1")
|
||||
if strings.ContainsRune(ipv6, ':') {
|
||||
t.Errorf("getAnnotationKey(IPv6) = %q, contains ':'", ipv6)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wildcard and colon coexist", func(t *testing.T) {
|
||||
mixed := getAnnotationKey("*.foo:bar.example.com")
|
||||
if strings.ContainsAny(mixed, "*:") {
|
||||
t.Errorf("getAnnotationKey(mixed) = %q, contains '*' or ':'", mixed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAnnotationKey_LongWildcardHostname(t *testing.T) {
|
||||
cn := "*.really.long.subdomain.example.com.foo.bar.baz.thing.thing.thing"
|
||||
key := getAnnotationKey(cn)
|
||||
|
||||
nameLen := len(strings.TrimPrefix(key, cnPrefix))
|
||||
if nameLen >= 64 {
|
||||
t.Errorf("name part is %d chars, must be < 64", nameLen)
|
||||
}
|
||||
if !hashSuffixRe.MatchString(key) {
|
||||
t.Errorf("expected hash suffix '-XXXXXX' (6 hex chars) at end of %q", key)
|
||||
}
|
||||
if strings.ContainsRune(key, '*') {
|
||||
t.Errorf("key %q contains '*'", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCoveredByWildcard(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
existing []string
|
||||
want bool
|
||||
}{
|
||||
{"wildcard covers single-label match", "foo.example.com", []string{"*.example.com"}, true},
|
||||
{"wildcard covers when other entries also present", "a.example.com", []string{"foo.example.com", "*.example.com"}, true},
|
||||
{"wildcard does not cover multi-label", "a.b.example.com", []string{"*.example.com"}, false},
|
||||
{"wildcard does not cover apex", "example.com", []string{"*.example.com"}, false},
|
||||
{"no wildcard in existing", "foo.example.com", []string{"foo.example.com", "bar.example.com"}, false},
|
||||
{"wildcard cn never covered (exact wildcard)", "*.example.com", []string{"*.example.com"}, false},
|
||||
{"wildcard cn never covered (more specific)", "*.foo.example.com", []string{"*.example.com"}, false},
|
||||
{"wrong parent", "foo.evil.com", []string{"*.example.com"}, false},
|
||||
{"cn with no dot cannot be covered", "localhost", []string{"*.localhost"}, false},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isCoveredByWildcard(tt.cn, tt.existing)
|
||||
if got != tt.want {
|
||||
t.Errorf("isCoveredByWildcard(%q, %v) = %v, want %v", tt.cn, tt.existing, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeSecretWithCNs(cns ...string) *v1.Secret {
|
||||
s := &v1.Secret{}
|
||||
s.Annotations = map[string]string{}
|
||||
for _, cn := range cns {
|
||||
s.Annotations[getAnnotationKey(cn)] = cn
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestNeedsUpdate_WildcardCovers(t *testing.T) {
|
||||
secret := makeSecretWithCNs("*.example.com")
|
||||
|
||||
t.Run("subdomain covered by wildcard", func(t *testing.T) {
|
||||
if NeedsUpdate(0, secret, "foo.example.com") {
|
||||
t.Error("NeedsUpdate should be false: foo.example.com is covered by *.example.com")
|
||||
}
|
||||
})
|
||||
t.Run("multi-label not covered", func(t *testing.T) {
|
||||
if !NeedsUpdate(0, secret, "a.b.example.com") {
|
||||
t.Error("NeedsUpdate should be true: a.b.example.com is multi-label, not covered")
|
||||
}
|
||||
})
|
||||
t.Run("apex not covered", func(t *testing.T) {
|
||||
if !NeedsUpdate(0, secret, "example.com") {
|
||||
t.Error("NeedsUpdate should be true: example.com is the apex, not covered by *.example.com")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNeedsUpdate_WildcardDoesNotCoverWildcard(t *testing.T) {
|
||||
secret := makeSecretWithCNs("*.example.com")
|
||||
|
||||
t.Run("exact wildcard match", func(t *testing.T) {
|
||||
if NeedsUpdate(0, secret, "*.example.com") {
|
||||
t.Error("NeedsUpdate should be false: exact wildcard match")
|
||||
}
|
||||
})
|
||||
t.Run("more specific wildcard not covered", func(t *testing.T) {
|
||||
if !NeedsUpdate(0, secret, "*.foo.example.com") {
|
||||
t.Error("NeedsUpdate should be true: *.foo.example.com is a different wildcard")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNeedsUpdate_WildcardCountsAsOneSAN(t *testing.T) {
|
||||
t.Run("room for one more SAN", func(t *testing.T) {
|
||||
secret := makeSecretWithCNs("a", "b", "c", "d", "e", "f", "g", "h", "i")
|
||||
if !NeedsUpdate(10, secret, "*.new.com") {
|
||||
t.Error("NeedsUpdate should be true: room for one more SAN")
|
||||
}
|
||||
})
|
||||
t.Run("MaxSANs reached", func(t *testing.T) {
|
||||
secret := makeSecretWithCNs("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
|
||||
if NeedsUpdate(10, secret, "*.new.com") {
|
||||
t.Error("NeedsUpdate should be false: MaxSANs reached")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newTestTLS(t *testing.T) *TLS {
|
||||
t.Helper()
|
||||
caKey, err := NewPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("NewPrivateKey: %v", err)
|
||||
}
|
||||
caCert, err := NewSelfSignedCACert(caKey, "test-ca", "test-org")
|
||||
if err != nil {
|
||||
t.Fatalf("NewSelfSignedCACert: %v", err)
|
||||
}
|
||||
return &TLS{
|
||||
CACert: []*x509.Certificate{caCert},
|
||||
CAKey: caKey,
|
||||
CN: "test-cn",
|
||||
Organization: []string{"test-org"},
|
||||
}
|
||||
}
|
||||
|
||||
func assertCertHasDNSName(t *testing.T, secret *v1.Secret, name string) {
|
||||
t.Helper()
|
||||
certs, err := cert.ParseCertsPEM(secret.Data[v1.TLSCertKey])
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertsPEM: %v", err)
|
||||
}
|
||||
if len(certs) == 0 {
|
||||
t.Fatal("no certs in secret")
|
||||
}
|
||||
for _, n := range certs[0].DNSNames {
|
||||
if n == name {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("cert DNSNames %v does not contain %q", certs[0].DNSNames, name)
|
||||
}
|
||||
|
||||
func assertCertDoesNotHaveDNSName(t *testing.T, secret *v1.Secret, name string) {
|
||||
t.Helper()
|
||||
certs, err := cert.ParseCertsPEM(secret.Data[v1.TLSCertKey])
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertsPEM: %v", err)
|
||||
}
|
||||
if len(certs) == 0 {
|
||||
t.Fatal("no certs in secret")
|
||||
}
|
||||
for _, n := range certs[0].DNSNames {
|
||||
if n == name {
|
||||
t.Errorf("cert DNSNames %v unexpectedly contains %q", certs[0].DNSNames, name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCert_WildcardSAN(t *testing.T) {
|
||||
tlsFactory := newTestTLS(t)
|
||||
secret, _, err := tlsFactory.AddCN(nil, "*.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("AddCN: %v", err)
|
||||
}
|
||||
assertCertHasDNSName(t, secret, "*.example.com")
|
||||
}
|
||||
|
||||
func TestRenew_PreservesWildcard(t *testing.T) {
|
||||
tlsFactory := newTestTLS(t)
|
||||
secret, _, err := tlsFactory.AddCN(nil, "*.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("AddCN: %v", err)
|
||||
}
|
||||
renewed, err := tlsFactory.Renew(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Renew: %v", err)
|
||||
}
|
||||
assertCertHasDNSName(t, renewed, "*.example.com")
|
||||
}
|
||||
|
||||
func TestRegenerate_PreservesWildcard(t *testing.T) {
|
||||
tlsFactory := newTestTLS(t)
|
||||
secret, _, err := tlsFactory.AddCN(nil, "*.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("AddCN: %v", err)
|
||||
}
|
||||
regen, err := tlsFactory.Regenerate(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Regenerate: %v", err)
|
||||
}
|
||||
assertCertHasDNSName(t, regen, "*.example.com")
|
||||
}
|
||||
|
||||
func TestMerge_WildcardCovering(t *testing.T) {
|
||||
tlsFactory := newTestTLS(t)
|
||||
|
||||
target, _, err := tlsFactory.AddCN(nil, "*.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("AddCN target: %v", err)
|
||||
}
|
||||
additional, _, err := tlsFactory.AddCN(nil, "foo.example.com", "bar.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("AddCN additional: %v", err)
|
||||
}
|
||||
|
||||
merged, _, err := tlsFactory.Merge(target, additional)
|
||||
if err != nil {
|
||||
t.Fatalf("Merge: %v", err)
|
||||
}
|
||||
|
||||
// Wildcard must be present.
|
||||
assertCertHasDNSName(t, merged, "*.example.com")
|
||||
// And the specific names from `additional` must NOT have been added —
|
||||
// the wildcard already covers them per RFC 6125. This is what distinguishes
|
||||
// the cover-short-circuit path from a regenerate-fallthrough path.
|
||||
assertCertDoesNotHaveDNSName(t, merged, "foo.example.com")
|
||||
assertCertDoesNotHaveDNSName(t, merged, "bar.example.com")
|
||||
}
|
||||
|
||||
func TestAddCN_WildcardAndSpecificCoexist(t *testing.T) {
|
||||
// Realistic shape: admin configures both a wildcard SAN and one or more
|
||||
// specific SANs in a single call. Both must end up in the cert.
|
||||
tlsFactory := newTestTLS(t)
|
||||
secret, _, err := tlsFactory.AddCN(nil, "*.example.com", "other.org")
|
||||
if err != nil {
|
||||
t.Fatalf("AddCN: %v", err)
|
||||
}
|
||||
assertCertHasDNSName(t, secret, "*.example.com")
|
||||
assertCertHasDNSName(t, secret, "other.org")
|
||||
}
|
||||
30
listener.go
30
listener.go
@@ -17,6 +17,13 @@ import (
|
||||
v1 "k8s.io/api/core/v1"
|
||||
)
|
||||
|
||||
// isWildcardSAN reports whether cn is an RFC 6125 wildcard pattern.
|
||||
// Wildcards are accepted only from Config.SANs (admin), never from runtime
|
||||
// sources (TLS SNI, TCP LocalAddr, HTTP Host header).
|
||||
func isWildcardSAN(cn string) bool {
|
||||
return strings.HasPrefix(cn, "*.")
|
||||
}
|
||||
|
||||
type TLSStorage interface {
|
||||
Get() (*v1.Secret, error)
|
||||
Update(secret *v1.Secret) error
|
||||
@@ -271,6 +278,9 @@ func (l *listener) checkExpiration(days int) error {
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
l.init.Do(func() {
|
||||
if len(l.sans) > 0 {
|
||||
// Trusted path: Config.SANs is admin-controlled (--tls-san), so wildcards
|
||||
// are permitted here. Runtime-discovered SANs (SNI, TCP, HTTP) MUST go
|
||||
// through isWildcardSAN gates - see the call sites below.
|
||||
if err := l.updateCert(l.sans...); err != nil {
|
||||
logrus.Errorf("dynamiclistener %s: failed to update cert with configured SANs: %v", l.Addr(), err)
|
||||
return
|
||||
@@ -303,8 +313,10 @@ func (l *listener) Accept() (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
if err := l.updateCert(host); err != nil {
|
||||
logrus.Errorf("dynamiclistener %s: failed to update cert with connection local address: %v", l.Addr(), err)
|
||||
if !isWildcardSAN(host) {
|
||||
if err := l.updateCert(host); err != nil {
|
||||
logrus.Errorf("dynamiclistener %s: failed to update cert with connection local address: %v", l.Addr(), err)
|
||||
}
|
||||
}
|
||||
|
||||
if l.conns != nil {
|
||||
@@ -350,7 +362,9 @@ func (c *closeWrapper) Close() error {
|
||||
func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
newConn := hello.Conn
|
||||
if hello.ServerName != "" {
|
||||
if err := l.updateCert(hello.ServerName); err != nil {
|
||||
if isWildcardSAN(hello.ServerName) {
|
||||
logrus.Debugf("dynamiclistener %s: ignoring wildcard SAN from TLS SNI: %s", l.Addr(), hello.ServerName)
|
||||
} else if err := l.updateCert(hello.ServerName); err != nil {
|
||||
logrus.Errorf("dynamiclistener %s: failed to update cert with TLS ServerName: %v", l.Addr(), err)
|
||||
return nil, err
|
||||
}
|
||||
@@ -483,8 +497,14 @@ func (l *listener) cacheHandler() http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
if err := l.updateCert(h); err != nil {
|
||||
logrus.Errorf("dynamiclistener %s: failed to update cert with HTTP request Host header: %v", l.Addr(), err)
|
||||
// Defense-in-depth: the surrounding `if len(ip) > 0` block already
|
||||
// excludes non-IP hosts (which is where wildcards would appear). This gate
|
||||
// guards against future relaxation of that filter - wildcards must never
|
||||
// enter the cert from runtime-discovered HTTP Host headers.
|
||||
if !isWildcardSAN(h) {
|
||||
if err := l.updateCert(h); err != nil {
|
||||
logrus.Errorf("dynamiclistener %s: failed to update cert with HTTP request Host header: %v", l.Addr(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
163
listener_test.go
163
listener_test.go
@@ -2,6 +2,8 @@ package dynamiclistener
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -223,3 +225,164 @@ func (fakeConn) RemoteAddr() net.Addr { return nil }
|
||||
func (fakeConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestIsWildcardSAN(t *testing.T) {
|
||||
cases := []struct {
|
||||
cn string
|
||||
want bool
|
||||
}{
|
||||
{"*.example.com", true},
|
||||
{"*.foo.bar.com", true},
|
||||
{"foo.example.com", false},
|
||||
{"*", false},
|
||||
{"foo*", false},
|
||||
{"", false},
|
||||
{"*foo.example.com", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := isWildcardSAN(c.cn)
|
||||
if got != c.want {
|
||||
t.Errorf("isWildcardSAN(%q) = %v, want %v", c.cn, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trackingStorage counts Update calls and remembers the last secret.
|
||||
// (MockTLSStorage.Update panics, so we replace rather than wrap.)
|
||||
type trackingStorage struct {
|
||||
secret *v1.Secret
|
||||
updateCalls int
|
||||
}
|
||||
|
||||
func (s *trackingStorage) Get() (*v1.Secret, error) {
|
||||
if s.secret == nil {
|
||||
return &v1.Secret{}, nil
|
||||
}
|
||||
return s.secret, nil
|
||||
}
|
||||
|
||||
func (s *trackingStorage) Update(secret *v1.Secret) error {
|
||||
s.updateCalls++
|
||||
s.secret = secret
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestTLSFactory(t *testing.T) *factory.TLS {
|
||||
t.Helper()
|
||||
caCert, caKey, err := factory.GenCA()
|
||||
if err != nil {
|
||||
t.Fatalf("factory.GenCA: %v", err)
|
||||
}
|
||||
return &factory.TLS{
|
||||
CACert: []*x509.Certificate{caCert},
|
||||
CAKey: caKey,
|
||||
CN: "test",
|
||||
Organization: []string{"test"},
|
||||
}
|
||||
}
|
||||
|
||||
// fakeListener is a no-op net.Listener used to satisfy listener.Addr() in unit tests.
|
||||
type fakeListener struct{}
|
||||
|
||||
func (fakeListener) Accept() (net.Conn, error) {
|
||||
return nil, errors.New("fakeListener: Accept not supported")
|
||||
}
|
||||
func (fakeListener) Close() error { return nil }
|
||||
func (fakeListener) Addr() net.Addr { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} }
|
||||
|
||||
func newTestListener(t *testing.T, configSANs []string) (*listener, *trackingStorage) {
|
||||
t.Helper()
|
||||
storage := &trackingStorage{}
|
||||
l := &listener{
|
||||
Listener: fakeListener{},
|
||||
factory: newTestTLSFactory(t),
|
||||
storage: storage,
|
||||
sans: configSANs,
|
||||
certReady: make(chan struct{}),
|
||||
}
|
||||
return l, storage
|
||||
}
|
||||
|
||||
func storedHasCN(storage *trackingStorage, cn string) bool {
|
||||
if storage.secret == nil {
|
||||
return false
|
||||
}
|
||||
for _, v := range storage.secret.Annotations {
|
||||
if v == cn {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestListener_RejectsWildcardFromSNI(t *testing.T) {
|
||||
l, storage := newTestListener(t, []string{"foo.example.com"})
|
||||
|
||||
hello := &tls.ClientHelloInfo{ServerName: "*.evil.com"}
|
||||
_, _ = l.getCertificate(hello)
|
||||
|
||||
if storage.updateCalls != 0 {
|
||||
t.Errorf("storage.Update called %d times, expected 0 (wildcard SNI should be rejected)", storage.updateCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListener_AcceptsWildcardFromConfigSANs(t *testing.T) {
|
||||
l, storage := newTestListener(t, []string{"*.example.com"})
|
||||
|
||||
if err := l.updateCert(l.sans...); err != nil {
|
||||
t.Fatalf("updateCert from trusted source failed: %v", err)
|
||||
}
|
||||
if storage.updateCalls == 0 {
|
||||
t.Error("expected storage.Update to be called for admin-supplied wildcard SAN")
|
||||
}
|
||||
if !storedHasCN(storage, "*.example.com") {
|
||||
t.Errorf("expected wildcard *.example.com in stored secret annotations; got %v", storage.secret.Annotations)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListener_AdminWildcardSuppressesRuntimeSubdomainRegen(t *testing.T) {
|
||||
l, storage := newTestListener(t, []string{"*.example.com"})
|
||||
|
||||
if err := l.updateCert(l.sans...); err != nil {
|
||||
t.Fatalf("initial updateCert: %v", err)
|
||||
}
|
||||
updatesAfterInit := storage.updateCalls
|
||||
|
||||
hello := &tls.ClientHelloInfo{ServerName: "foo.example.com"}
|
||||
_, _ = l.getCertificate(hello)
|
||||
|
||||
if storage.updateCalls != updatesAfterInit {
|
||||
t.Errorf("storage.Update called again (%d -> %d), expected NO regen for covered subdomain",
|
||||
updatesAfterInit, storage.updateCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListener_FilterCallbackCanReturnWildcards(t *testing.T) {
|
||||
// Documents the trust model: Config.FilterCN is integrator-controlled, and runs
|
||||
// inside updateCert AFTER the listener-layer gate. So if FilterCN returns a wildcard
|
||||
// for an input the gate already permitted, the wildcard IS accepted into the cert.
|
||||
// This is intentional, not a bug - FilterCN is part of the integrator's
|
||||
// admin-controlled trust boundary.
|
||||
tlsFactory := newTestTLSFactory(t)
|
||||
tlsFactory.FilterCN = func(cn ...string) []string {
|
||||
return []string{"*.example.com"}
|
||||
}
|
||||
storage := &trackingStorage{}
|
||||
l := &listener{
|
||||
Listener: fakeListener{},
|
||||
factory: tlsFactory,
|
||||
storage: storage,
|
||||
sans: nil,
|
||||
certReady: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Drive a non-wildcard SNI through the gate at getCertificate. The gate permits it
|
||||
// (not a wildcard), so updateCert is invoked; inside updateCert, Filter then
|
||||
// replaces the input with the wildcard, which reaches storage.
|
||||
hello := &tls.ClientHelloInfo{ServerName: "non-wildcard.input.com"}
|
||||
_, _ = l.getCertificate(hello)
|
||||
|
||||
if !storedHasCN(storage, "*.example.com") {
|
||||
t.Errorf("expected wildcard from FilterCN to reach the cert; annotations: %v", storage.secret)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user