mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-23 11:50:44 +00:00
Remove mutation of authn options by binding flag setters to a tracking boolean in options
This commit is contained in:
parent
67bdb110b4
commit
baaa38c099
@ -100,8 +100,9 @@ type BuiltInAuthenticationOptions struct {
|
||||
|
||||
// AnonymousAuthenticationOptions contains anonymous authentication options for API Server
|
||||
type AnonymousAuthenticationOptions struct {
|
||||
Allow bool
|
||||
areFlagsSet func() bool
|
||||
Allow bool
|
||||
// FlagsSet tracks whether any of the configuration options were set via a command-line flag.
|
||||
FlagsSet bool
|
||||
}
|
||||
|
||||
// BootstrapTokenAuthenticationOptions contains bootstrap token authentication options for API Server
|
||||
@ -121,8 +122,8 @@ type OIDCAuthenticationOptions struct {
|
||||
SigningAlgs []string
|
||||
RequiredClaims map[string]string
|
||||
|
||||
// areFlagsConfigured is a function that returns true if any of the oidc-* flags are configured.
|
||||
areFlagsConfigured func() bool
|
||||
// FlagsSet tracks whether any of the configuration options were set via a command-line flag.
|
||||
FlagsSet bool
|
||||
}
|
||||
|
||||
// ServiceAccountAuthenticationOptions contains service account authentication options for API Server
|
||||
@ -183,8 +184,7 @@ func (o *BuiltInAuthenticationOptions) WithAll() *BuiltInAuthenticationOptions {
|
||||
// WithAnonymous set default value for anonymous authentication
|
||||
func (o *BuiltInAuthenticationOptions) WithAnonymous() *BuiltInAuthenticationOptions {
|
||||
o.Anonymous = &AnonymousAuthenticationOptions{
|
||||
Allow: true,
|
||||
areFlagsSet: func() bool { return false },
|
||||
Allow: true,
|
||||
}
|
||||
return o
|
||||
}
|
||||
@ -204,9 +204,8 @@ func (o *BuiltInAuthenticationOptions) WithClientCert() *BuiltInAuthenticationOp
|
||||
// WithOIDC set default value for OIDC authentication
|
||||
func (o *BuiltInAuthenticationOptions) WithOIDC() *BuiltInAuthenticationOptions {
|
||||
o.OIDC = &OIDCAuthenticationOptions{
|
||||
areFlagsConfigured: func() bool { return false },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
}
|
||||
return o
|
||||
}
|
||||
@ -337,10 +336,7 @@ func (o *BuiltInAuthenticationOptions) AddFlags(fs *pflag.FlagSet) {
|
||||
"Enables anonymous requests to the secure port of the API server. "+
|
||||
"Requests that are not rejected by another authentication method are treated as anonymous requests. "+
|
||||
"Anonymous requests have a username of system:anonymous, and a group name of system:unauthenticated.")
|
||||
|
||||
o.Anonymous.areFlagsSet = func() bool {
|
||||
return fs.Changed("anonymous-auth")
|
||||
}
|
||||
trackProvidedFlag(fs, "anonymous-auth", &o.Anonymous.FlagsSet)
|
||||
}
|
||||
|
||||
if o.BootstrapToken != nil {
|
||||
@ -357,54 +353,51 @@ func (o *BuiltInAuthenticationOptions) AddFlags(fs *pflag.FlagSet) {
|
||||
fs.StringVar(&o.OIDC.IssuerURL, oidcIssuerURLFlag, o.OIDC.IssuerURL, ""+
|
||||
"The URL of the OpenID issuer, only HTTPS scheme will be accepted. "+
|
||||
"If set, it will be used to verify the OIDC JSON Web Token (JWT).")
|
||||
trackProvidedFlag(fs, oidcIssuerURLFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringVar(&o.OIDC.ClientID, oidcClientIDFlag, o.OIDC.ClientID, ""+
|
||||
"The client ID for the OpenID Connect client, must be set if oidc-issuer-url is set.")
|
||||
trackProvidedFlag(fs, oidcClientIDFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringVar(&o.OIDC.CAFile, oidcCAFileFlag, o.OIDC.CAFile, ""+
|
||||
"If set, the OpenID server's certificate will be verified by one of the authorities "+
|
||||
"in the oidc-ca-file, otherwise the host's root CA set will be used.")
|
||||
trackProvidedFlag(fs, oidcCAFileFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringVar(&o.OIDC.UsernameClaim, oidcUsernameClaimFlag, o.OIDC.UsernameClaim, ""+
|
||||
"The OpenID claim to use as the user name. Note that claims other than the default ('sub') "+
|
||||
"is not guaranteed to be unique and immutable. This flag is experimental, please see "+
|
||||
"the authentication documentation for further details.")
|
||||
trackProvidedFlag(fs, oidcUsernameClaimFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringVar(&o.OIDC.UsernamePrefix, oidcUsernamePrefixFlag, o.OIDC.UsernamePrefix, ""+
|
||||
"If provided, all usernames will be prefixed with this value. If not provided, "+
|
||||
"username claims other than 'email' are prefixed by the issuer URL to avoid "+
|
||||
"clashes. To skip any prefixing, provide the value '-'.")
|
||||
trackProvidedFlag(fs, oidcUsernamePrefixFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringVar(&o.OIDC.GroupsClaim, oidcGroupsClaimFlag, o.OIDC.GroupsClaim, ""+
|
||||
"If provided, the name of a custom OpenID Connect claim for specifying user groups. "+
|
||||
"The claim value is expected to be a string or array of strings. This flag is experimental, "+
|
||||
"please see the authentication documentation for further details.")
|
||||
trackProvidedFlag(fs, oidcGroupsClaimFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringVar(&o.OIDC.GroupsPrefix, oidcGroupsPrefixFlag, o.OIDC.GroupsPrefix, ""+
|
||||
"If provided, all groups will be prefixed with this value to prevent conflicts with "+
|
||||
"other authentication strategies.")
|
||||
trackProvidedFlag(fs, oidcGroupsPrefixFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.StringSliceVar(&o.OIDC.SigningAlgs, oidcSigningAlgsFlag, o.OIDC.SigningAlgs, ""+
|
||||
"Comma-separated list of allowed JOSE asymmetric signing algorithms. JWTs with a "+
|
||||
"supported 'alg' header values are: RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384, PS512. "+
|
||||
"Values are defined by RFC 7518 https://tools.ietf.org/html/rfc7518#section-3.1.")
|
||||
trackProvidedFlag(fs, oidcSigningAlgsFlag, &o.OIDC.FlagsSet)
|
||||
|
||||
fs.Var(cliflag.NewMapStringStringNoSplit(&o.OIDC.RequiredClaims), oidcRequiredClaimFlag, ""+
|
||||
"A key=value pair that describes a required claim in the ID Token. "+
|
||||
"If set, the claim is verified to be present in the ID Token with a matching value. "+
|
||||
"Repeat this flag to specify multiple claims.")
|
||||
|
||||
o.OIDC.areFlagsConfigured = func() bool {
|
||||
return fs.Changed(oidcIssuerURLFlag) ||
|
||||
fs.Changed(oidcClientIDFlag) ||
|
||||
fs.Changed(oidcCAFileFlag) ||
|
||||
fs.Changed(oidcUsernameClaimFlag) ||
|
||||
fs.Changed(oidcUsernamePrefixFlag) ||
|
||||
fs.Changed(oidcGroupsClaimFlag) ||
|
||||
fs.Changed(oidcGroupsPrefixFlag) ||
|
||||
fs.Changed(oidcSigningAlgsFlag) ||
|
||||
fs.Changed(oidcRequiredClaimFlag)
|
||||
}
|
||||
trackProvidedFlag(fs, oidcRequiredClaimFlag, &o.OIDC.FlagsSet)
|
||||
}
|
||||
|
||||
if o.RequestHeader != nil {
|
||||
@ -572,7 +565,7 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat
|
||||
// Set up anonymous authenticator from config file or flags
|
||||
if o.Anonymous != nil {
|
||||
switch {
|
||||
case ret.AuthenticationConfig.Anonymous != nil && o.Anonymous.areFlagsSet():
|
||||
case ret.AuthenticationConfig.Anonymous != nil && o.Anonymous.FlagsSet:
|
||||
// Flags and config file are mutually exclusive
|
||||
return kubeauthenticator.Config{}, field.Forbidden(field.NewPath("anonymous"), "--anonynous-auth flag cannot be set when anonymous field is configured in authentication configuration file")
|
||||
case ret.AuthenticationConfig.Anonymous != nil:
|
||||
@ -823,12 +816,17 @@ func (o *BuiltInAuthenticationOptions) ApplyAuthorization(authorization *BuiltIn
|
||||
}
|
||||
}
|
||||
|
||||
func trackProvidedFlag(fs *pflag.FlagSet, flagName string, provided *bool) {
|
||||
f := fs.Lookup(flagName)
|
||||
f.Value = cliflag.NewTracker(f.Value, provided)
|
||||
}
|
||||
|
||||
func (o *BuiltInAuthenticationOptions) validateOIDCOptions() []error {
|
||||
var allErrors []error
|
||||
|
||||
// Existing validation when jwt authenticator is configured with oidc-* flags
|
||||
if len(o.AuthenticationConfigFile) == 0 {
|
||||
if o.OIDC != nil && o.OIDC.areFlagsConfigured() && (len(o.OIDC.IssuerURL) == 0 || len(o.OIDC.ClientID) == 0) {
|
||||
if o.OIDC != nil && o.OIDC.FlagsSet && (len(o.OIDC.IssuerURL) == 0 || len(o.OIDC.ClientID) == 0) {
|
||||
allErrors = append(allErrors, fmt.Errorf("oidc-issuer-url and oidc-client-id must be specified together when any oidc-* flags are set"))
|
||||
}
|
||||
|
||||
@ -843,7 +841,7 @@ func (o *BuiltInAuthenticationOptions) validateOIDCOptions() []error {
|
||||
}
|
||||
|
||||
// Authentication config file and oidc-* flags are mutually exclusive
|
||||
if o.OIDC != nil && o.OIDC.areFlagsConfigured() {
|
||||
if o.OIDC != nil && o.OIDC.FlagsSet {
|
||||
allErrors = append(allErrors, fmt.Errorf("authentication-config file and oidc-* flags are mutually exclusive"))
|
||||
}
|
||||
|
||||
|
@ -71,11 +71,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when OIDC and ServiceAccounts are valid",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{"http://foo.bar.com"},
|
||||
@ -85,10 +85,10 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when OIDC is invalid",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{"http://foo.bar.com"},
|
||||
@ -99,11 +99,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccounts doesn't have key file",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{"http://foo.bar.com"},
|
||||
@ -113,11 +113,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccounts doesn't have issuer",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{},
|
||||
@ -127,11 +127,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccounts has empty string as issuer",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{""},
|
||||
@ -141,11 +141,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccounts has duplicate issuers",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{"http://foo.bar.com", "http://foo.bar.com"},
|
||||
@ -155,11 +155,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccount has bad issuer",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
Issuers: []string{"http://[::1]:namedport"},
|
||||
@ -169,11 +169,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccounts has invalid JWKSURI",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
KeyFiles: []string{"cert", "key"},
|
||||
@ -185,11 +185,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when ServiceAccounts has invalid JWKSURI (not https scheme)",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
KeyFiles: []string{"cert", "key"},
|
||||
@ -201,11 +201,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
{
|
||||
name: "test when WebHook has invalid retry attempts",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
testSA: &ServiceAccountAuthenticationOptions{
|
||||
KeyFiles: []string{"cert", "key"},
|
||||
@ -234,11 +234,11 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
name: "test when authentication config file and oidc-* flags are set",
|
||||
testAuthenticationConfigFile: "configfile",
|
||||
testOIDC: &OIDCAuthenticationOptions{
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
areFlagsConfigured: func() bool { return true },
|
||||
UsernameClaim: "sub",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
IssuerURL: "https://testIssuerURL",
|
||||
ClientID: "testClientID",
|
||||
FlagsSet: true,
|
||||
},
|
||||
expectErr: "authentication-config file and oidc-* flags are mutually exclusive",
|
||||
},
|
||||
@ -247,8 +247,8 @@ func TestAuthenticationValidate(t *testing.T) {
|
||||
disabledFeatures: []featuregate.Feature{features.AnonymousAuthConfigurableEndpoints},
|
||||
testAuthenticationConfigFile: "configfile",
|
||||
testAnonymous: &AnonymousAuthenticationOptions{
|
||||
Allow: true,
|
||||
areFlagsSet: func() bool { return true },
|
||||
Allow: true,
|
||||
FlagsSet: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -413,7 +413,8 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) {
|
||||
expected := &BuiltInAuthenticationOptions{
|
||||
APIAudiences: []string{"foo"},
|
||||
Anonymous: &AnonymousAuthenticationOptions{
|
||||
Allow: true,
|
||||
Allow: true,
|
||||
FlagsSet: true,
|
||||
},
|
||||
BootstrapToken: &BootstrapTokenAuthenticationOptions{
|
||||
Enable: true,
|
||||
@ -428,6 +429,7 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) {
|
||||
UsernameClaim: "sub",
|
||||
UsernamePrefix: "-",
|
||||
SigningAlgs: []string{"RS256"},
|
||||
FlagsSet: true,
|
||||
},
|
||||
RequestHeader: &apiserveroptions.RequestHeaderAuthenticationOptions{
|
||||
ClientCAFile: "testdata/root.pem",
|
||||
@ -470,19 +472,6 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !opts.OIDC.areFlagsConfigured() {
|
||||
t.Fatal("OIDC flags should be configured")
|
||||
}
|
||||
// nil these out because you cannot compare functions
|
||||
opts.OIDC.areFlagsConfigured = nil
|
||||
|
||||
if !opts.Anonymous.areFlagsSet() {
|
||||
t.Fatalf("Anonymous flags should be configured")
|
||||
}
|
||||
|
||||
// nil these out because you cannot compare functions
|
||||
opts.Anonymous.areFlagsSet = nil
|
||||
|
||||
if !reflect.DeepEqual(opts, expected) {
|
||||
t.Error(cmp.Diff(opts, expected, cmp.AllowUnexported(OIDCAuthenticationOptions{}, AnonymousAuthenticationOptions{})))
|
||||
}
|
||||
|
82
staging/src/k8s.io/component-base/cli/flag/tracker_flag.go
Normal file
82
staging/src/k8s.io/component-base/cli/flag/tracker_flag.go
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package flag
|
||||
|
||||
import (
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// TrackerValue wraps a non-boolean value and stores true in the provided boolean when it is set.
|
||||
type TrackerValue struct {
|
||||
value pflag.Value
|
||||
provided *bool
|
||||
}
|
||||
|
||||
// BoolTrackerValue wraps a boolean value and stores true in the provided boolean when it is set.
|
||||
type BoolTrackerValue struct {
|
||||
boolValue
|
||||
provided *bool
|
||||
}
|
||||
|
||||
type boolValue interface {
|
||||
pflag.Value
|
||||
IsBoolFlag() bool
|
||||
}
|
||||
|
||||
var _ pflag.Value = &TrackerValue{}
|
||||
var _ boolValue = &BoolTrackerValue{}
|
||||
|
||||
// NewTracker returns a Value wrapping the given value which stores true in the provided boolean when it is set.
|
||||
func NewTracker(value pflag.Value, provided *bool) pflag.Value {
|
||||
if value == nil {
|
||||
panic("value must not be nil")
|
||||
}
|
||||
|
||||
if provided == nil {
|
||||
panic("provided boolean must not be nil")
|
||||
}
|
||||
|
||||
if boolValue, ok := value.(boolValue); ok {
|
||||
return &BoolTrackerValue{boolValue: boolValue, provided: provided}
|
||||
}
|
||||
return &TrackerValue{value: value, provided: provided}
|
||||
}
|
||||
|
||||
func (f *TrackerValue) String() string {
|
||||
return f.value.String()
|
||||
}
|
||||
|
||||
func (f *TrackerValue) Set(value string) error {
|
||||
err := f.value.Set(value)
|
||||
if err == nil {
|
||||
*f.provided = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *TrackerValue) Type() string {
|
||||
return f.value.Type()
|
||||
}
|
||||
|
||||
func (f *BoolTrackerValue) Set(value string) error {
|
||||
err := f.boolValue.Set(value)
|
||||
if err == nil {
|
||||
*f.provided = true
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
285
staging/src/k8s.io/component-base/cli/flag/tracker_flag_test.go
Normal file
285
staging/src/k8s.io/component-base/cli/flag/tracker_flag_test.go
Normal file
@ -0,0 +1,285 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package flag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
func TestNewTracker(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value pflag.Value
|
||||
provided *bool
|
||||
wantType string
|
||||
}{
|
||||
{
|
||||
name: "non-bool-tracker",
|
||||
value: &nonBoolFlagMockValue{val: "initial", typ: "string"},
|
||||
provided: new(bool),
|
||||
wantType: "string",
|
||||
},
|
||||
{
|
||||
name: "bool-tracker",
|
||||
value: &boolFlagMockValue{val: "false", typ: "bool", isBool: true},
|
||||
provided: new(bool),
|
||||
wantType: "bool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewTracker(tt.value, tt.provided)
|
||||
|
||||
if tracker.Type() != tt.wantType {
|
||||
t.Errorf("Want type %s, got %s", tt.wantType, tracker.Type())
|
||||
}
|
||||
|
||||
if trackerValue, ok := tracker.(*TrackerValue); ok {
|
||||
if trackerValue.provided != tt.provided {
|
||||
t.Errorf("Provided pointer not stored correctly in TrackerValue")
|
||||
}
|
||||
} else if boolTrackerValue, ok := tracker.(*BoolTrackerValue); ok {
|
||||
if boolTrackerValue.provided != tt.provided {
|
||||
t.Errorf("Provided pointer not stored correctly in BoolTrackerValue")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTrackerPanics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value pflag.Value
|
||||
provided *bool
|
||||
panicMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil-value",
|
||||
value: nil,
|
||||
provided: new(bool),
|
||||
panicMsg: "value must not be nil",
|
||||
},
|
||||
{
|
||||
name: "nil-provided",
|
||||
value: &boolFlagMockValue{val: "test"},
|
||||
provided: nil,
|
||||
panicMsg: "provided boolean must not be nil",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("expected panic, but did not panic")
|
||||
} else if r != tt.panicMsg {
|
||||
t.Errorf("expected panic message %q, got %q", tt.panicMsg, r)
|
||||
}
|
||||
}()
|
||||
NewTracker(tt.value, tt.provided)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackerValue_String(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockValue pflag.Value
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "bool-flag",
|
||||
mockValue: &boolFlagMockValue{val: "bool-test"},
|
||||
want: "bool-test",
|
||||
},
|
||||
{
|
||||
name: "non-bool-flag",
|
||||
mockValue: &nonBoolFlagMockValue{val: "non-bool-test"},
|
||||
want: "non-bool-test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tracker := NewTracker(tc.mockValue, new(bool))
|
||||
result := tracker.String()
|
||||
if result != tc.want {
|
||||
t.Errorf("Want %q, but got %q", tc.want, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackerValue_Set(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockValue pflag.Value
|
||||
provided *bool
|
||||
mockErr error
|
||||
wantProvided bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success-bool-flag",
|
||||
mockValue: &boolFlagMockValue{val: "bool-test"},
|
||||
provided: new(bool),
|
||||
wantProvided: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "success-non-bool-flag",
|
||||
mockValue: &nonBoolFlagMockValue{val: "bool-test"},
|
||||
provided: new(bool),
|
||||
wantProvided: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error-bool-flag",
|
||||
mockValue: &boolFlagMockValue{val: "bool-test", err: fmt.Errorf("set error")},
|
||||
provided: new(bool),
|
||||
wantProvided: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error-non-bool-flag",
|
||||
mockValue: &nonBoolFlagMockValue{val: "bool-test", err: fmt.Errorf("set error")},
|
||||
provided: new(bool),
|
||||
wantProvided: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tracker := NewTracker(tc.mockValue, tc.provided)
|
||||
err := tracker.Set("new value")
|
||||
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Want error: %v, got: %v", tc.wantErr, err != nil)
|
||||
}
|
||||
|
||||
if *tc.provided != tc.wantProvided {
|
||||
t.Errorf("Want provided to be %v, got: %v", tc.wantProvided, *tc.provided)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackerValue_MultipleSetCalls(t *testing.T) {
|
||||
provided := false
|
||||
mock := &boolFlagMockValue{val: "initial"}
|
||||
tracker := NewTracker(mock, &provided)
|
||||
|
||||
err := tracker.Set("new value")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if mock.val != "new value" {
|
||||
t.Errorf("Expected mock value to be 'new value', got '%s'", mock.val)
|
||||
}
|
||||
if !provided {
|
||||
t.Error("Expected 'provided' to be true, got false")
|
||||
}
|
||||
|
||||
provided = false // reset
|
||||
mock.err = fmt.Errorf("set error")
|
||||
err = tracker.Set("failed set")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if provided {
|
||||
t.Error("Expected 'provided' to be false after error, got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackerValue_Type(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockValue pflag.Value
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "success-bool-flag",
|
||||
mockValue: &boolFlagMockValue{typ: "mockBoolType"},
|
||||
want: "mockBoolType",
|
||||
},
|
||||
{
|
||||
name: "success-non-bool-flag",
|
||||
mockValue: &nonBoolFlagMockValue{typ: "mockNonBoolType"},
|
||||
want: "mockNonBoolType",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tracker := NewTracker(tc.mockValue, new(bool))
|
||||
result := tracker.Type()
|
||||
if result != tc.want {
|
||||
t.Errorf("Want %q, but got %q", tc.want, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type boolFlagMockValue struct {
|
||||
val string
|
||||
typ string
|
||||
isBool bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *boolFlagMockValue) String() string {
|
||||
return m.val
|
||||
}
|
||||
|
||||
func (m *boolFlagMockValue) Set(value string) error {
|
||||
m.val = value
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *boolFlagMockValue) Type() string {
|
||||
return m.typ
|
||||
}
|
||||
|
||||
func (m *boolFlagMockValue) IsBoolFlag() bool {
|
||||
return m.isBool
|
||||
}
|
||||
|
||||
type nonBoolFlagMockValue struct {
|
||||
val string
|
||||
typ string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *nonBoolFlagMockValue) String() string {
|
||||
return m.val
|
||||
}
|
||||
|
||||
func (m *nonBoolFlagMockValue) Set(value string) error {
|
||||
m.val = value
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *nonBoolFlagMockValue) Type() string {
|
||||
return m.typ
|
||||
}
|
Loading…
Reference in New Issue
Block a user