From baaa38c0992e7220c2d0e509cf4297552f371337 Mon Sep 17 00:00:00 2001 From: Richa Banker Date: Tue, 18 Mar 2025 22:11:28 -0700 Subject: [PATCH] Remove mutation of authn options by binding flag setters to a tracking boolean in options --- pkg/kubeapiserver/options/authentication.go | 54 ++-- .../options/authentication_test.go | 129 ++++---- .../component-base/cli/flag/tracker_flag.go | 82 +++++ .../cli/flag/tracker_flag_test.go | 285 ++++++++++++++++++ 4 files changed, 452 insertions(+), 98 deletions(-) create mode 100644 staging/src/k8s.io/component-base/cli/flag/tracker_flag.go create mode 100644 staging/src/k8s.io/component-base/cli/flag/tracker_flag_test.go diff --git a/pkg/kubeapiserver/options/authentication.go b/pkg/kubeapiserver/options/authentication.go index 6b5c9f5718a..72bfdd8bd11 100644 --- a/pkg/kubeapiserver/options/authentication.go +++ b/pkg/kubeapiserver/options/authentication.go @@ -100,8 +100,9 @@ type BuiltInAuthenticationOptions struct { // AnonymousAuthenticationOptions contains anonymous authentication options for API Server type AnonymousAuthenticationOptions struct { - Allow bool - areFlagsSet func() bool + Allow bool + // FlagsSet tracks whether any of the configuration options were set via a command-line flag. + FlagsSet bool } // BootstrapTokenAuthenticationOptions contains bootstrap token authentication options for API Server @@ -121,8 +122,8 @@ type OIDCAuthenticationOptions struct { SigningAlgs []string RequiredClaims map[string]string - // areFlagsConfigured is a function that returns true if any of the oidc-* flags are configured. - areFlagsConfigured func() bool + // FlagsSet tracks whether any of the configuration options were set via a command-line flag. + FlagsSet bool } // ServiceAccountAuthenticationOptions contains service account authentication options for API Server @@ -183,8 +184,7 @@ func (o *BuiltInAuthenticationOptions) WithAll() *BuiltInAuthenticationOptions { // WithAnonymous set default value for anonymous authentication func (o *BuiltInAuthenticationOptions) WithAnonymous() *BuiltInAuthenticationOptions { o.Anonymous = &AnonymousAuthenticationOptions{ - Allow: true, - areFlagsSet: func() bool { return false }, + Allow: true, } return o } @@ -204,9 +204,8 @@ func (o *BuiltInAuthenticationOptions) WithClientCert() *BuiltInAuthenticationOp // WithOIDC set default value for OIDC authentication func (o *BuiltInAuthenticationOptions) WithOIDC() *BuiltInAuthenticationOptions { o.OIDC = &OIDCAuthenticationOptions{ - areFlagsConfigured: func() bool { return false }, - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, } return o } @@ -337,10 +336,7 @@ func (o *BuiltInAuthenticationOptions) AddFlags(fs *pflag.FlagSet) { "Enables anonymous requests to the secure port of the API server. "+ "Requests that are not rejected by another authentication method are treated as anonymous requests. "+ "Anonymous requests have a username of system:anonymous, and a group name of system:unauthenticated.") - - o.Anonymous.areFlagsSet = func() bool { - return fs.Changed("anonymous-auth") - } + trackProvidedFlag(fs, "anonymous-auth", &o.Anonymous.FlagsSet) } if o.BootstrapToken != nil { @@ -357,54 +353,51 @@ func (o *BuiltInAuthenticationOptions) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&o.OIDC.IssuerURL, oidcIssuerURLFlag, o.OIDC.IssuerURL, ""+ "The URL of the OpenID issuer, only HTTPS scheme will be accepted. "+ "If set, it will be used to verify the OIDC JSON Web Token (JWT).") + trackProvidedFlag(fs, oidcIssuerURLFlag, &o.OIDC.FlagsSet) fs.StringVar(&o.OIDC.ClientID, oidcClientIDFlag, o.OIDC.ClientID, ""+ "The client ID for the OpenID Connect client, must be set if oidc-issuer-url is set.") + trackProvidedFlag(fs, oidcClientIDFlag, &o.OIDC.FlagsSet) fs.StringVar(&o.OIDC.CAFile, oidcCAFileFlag, o.OIDC.CAFile, ""+ "If set, the OpenID server's certificate will be verified by one of the authorities "+ "in the oidc-ca-file, otherwise the host's root CA set will be used.") + trackProvidedFlag(fs, oidcCAFileFlag, &o.OIDC.FlagsSet) fs.StringVar(&o.OIDC.UsernameClaim, oidcUsernameClaimFlag, o.OIDC.UsernameClaim, ""+ "The OpenID claim to use as the user name. Note that claims other than the default ('sub') "+ "is not guaranteed to be unique and immutable. This flag is experimental, please see "+ "the authentication documentation for further details.") + trackProvidedFlag(fs, oidcUsernameClaimFlag, &o.OIDC.FlagsSet) fs.StringVar(&o.OIDC.UsernamePrefix, oidcUsernamePrefixFlag, o.OIDC.UsernamePrefix, ""+ "If provided, all usernames will be prefixed with this value. If not provided, "+ "username claims other than 'email' are prefixed by the issuer URL to avoid "+ "clashes. To skip any prefixing, provide the value '-'.") + trackProvidedFlag(fs, oidcUsernamePrefixFlag, &o.OIDC.FlagsSet) fs.StringVar(&o.OIDC.GroupsClaim, oidcGroupsClaimFlag, o.OIDC.GroupsClaim, ""+ "If provided, the name of a custom OpenID Connect claim for specifying user groups. "+ "The claim value is expected to be a string or array of strings. This flag is experimental, "+ "please see the authentication documentation for further details.") + trackProvidedFlag(fs, oidcGroupsClaimFlag, &o.OIDC.FlagsSet) fs.StringVar(&o.OIDC.GroupsPrefix, oidcGroupsPrefixFlag, o.OIDC.GroupsPrefix, ""+ "If provided, all groups will be prefixed with this value to prevent conflicts with "+ "other authentication strategies.") + trackProvidedFlag(fs, oidcGroupsPrefixFlag, &o.OIDC.FlagsSet) fs.StringSliceVar(&o.OIDC.SigningAlgs, oidcSigningAlgsFlag, o.OIDC.SigningAlgs, ""+ "Comma-separated list of allowed JOSE asymmetric signing algorithms. JWTs with a "+ "supported 'alg' header values are: RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384, PS512. "+ "Values are defined by RFC 7518 https://tools.ietf.org/html/rfc7518#section-3.1.") + trackProvidedFlag(fs, oidcSigningAlgsFlag, &o.OIDC.FlagsSet) fs.Var(cliflag.NewMapStringStringNoSplit(&o.OIDC.RequiredClaims), oidcRequiredClaimFlag, ""+ "A key=value pair that describes a required claim in the ID Token. "+ "If set, the claim is verified to be present in the ID Token with a matching value. "+ "Repeat this flag to specify multiple claims.") - - o.OIDC.areFlagsConfigured = func() bool { - return fs.Changed(oidcIssuerURLFlag) || - fs.Changed(oidcClientIDFlag) || - fs.Changed(oidcCAFileFlag) || - fs.Changed(oidcUsernameClaimFlag) || - fs.Changed(oidcUsernamePrefixFlag) || - fs.Changed(oidcGroupsClaimFlag) || - fs.Changed(oidcGroupsPrefixFlag) || - fs.Changed(oidcSigningAlgsFlag) || - fs.Changed(oidcRequiredClaimFlag) - } + trackProvidedFlag(fs, oidcRequiredClaimFlag, &o.OIDC.FlagsSet) } if o.RequestHeader != nil { @@ -572,7 +565,7 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat // Set up anonymous authenticator from config file or flags if o.Anonymous != nil { switch { - case ret.AuthenticationConfig.Anonymous != nil && o.Anonymous.areFlagsSet(): + case ret.AuthenticationConfig.Anonymous != nil && o.Anonymous.FlagsSet: // Flags and config file are mutually exclusive return kubeauthenticator.Config{}, field.Forbidden(field.NewPath("anonymous"), "--anonynous-auth flag cannot be set when anonymous field is configured in authentication configuration file") case ret.AuthenticationConfig.Anonymous != nil: @@ -823,12 +816,17 @@ func (o *BuiltInAuthenticationOptions) ApplyAuthorization(authorization *BuiltIn } } +func trackProvidedFlag(fs *pflag.FlagSet, flagName string, provided *bool) { + f := fs.Lookup(flagName) + f.Value = cliflag.NewTracker(f.Value, provided) +} + func (o *BuiltInAuthenticationOptions) validateOIDCOptions() []error { var allErrors []error // Existing validation when jwt authenticator is configured with oidc-* flags if len(o.AuthenticationConfigFile) == 0 { - if o.OIDC != nil && o.OIDC.areFlagsConfigured() && (len(o.OIDC.IssuerURL) == 0 || len(o.OIDC.ClientID) == 0) { + if o.OIDC != nil && o.OIDC.FlagsSet && (len(o.OIDC.IssuerURL) == 0 || len(o.OIDC.ClientID) == 0) { allErrors = append(allErrors, fmt.Errorf("oidc-issuer-url and oidc-client-id must be specified together when any oidc-* flags are set")) } @@ -843,7 +841,7 @@ func (o *BuiltInAuthenticationOptions) validateOIDCOptions() []error { } // Authentication config file and oidc-* flags are mutually exclusive - if o.OIDC != nil && o.OIDC.areFlagsConfigured() { + if o.OIDC != nil && o.OIDC.FlagsSet { allErrors = append(allErrors, fmt.Errorf("authentication-config file and oidc-* flags are mutually exclusive")) } diff --git a/pkg/kubeapiserver/options/authentication_test.go b/pkg/kubeapiserver/options/authentication_test.go index 0b377214ea1..0cf77b12f63 100644 --- a/pkg/kubeapiserver/options/authentication_test.go +++ b/pkg/kubeapiserver/options/authentication_test.go @@ -71,11 +71,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when OIDC and ServiceAccounts are valid", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{"http://foo.bar.com"}, @@ -85,10 +85,10 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when OIDC is invalid", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{"http://foo.bar.com"}, @@ -99,11 +99,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccounts doesn't have key file", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{"http://foo.bar.com"}, @@ -113,11 +113,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccounts doesn't have issuer", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{}, @@ -127,11 +127,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccounts has empty string as issuer", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{""}, @@ -141,11 +141,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccounts has duplicate issuers", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{"http://foo.bar.com", "http://foo.bar.com"}, @@ -155,11 +155,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccount has bad issuer", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ Issuers: []string{"http://[::1]:namedport"}, @@ -169,11 +169,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccounts has invalid JWKSURI", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ KeyFiles: []string{"cert", "key"}, @@ -185,11 +185,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when ServiceAccounts has invalid JWKSURI (not https scheme)", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ KeyFiles: []string{"cert", "key"}, @@ -201,11 +201,11 @@ func TestAuthenticationValidate(t *testing.T) { { name: "test when WebHook has invalid retry attempts", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, testSA: &ServiceAccountAuthenticationOptions{ KeyFiles: []string{"cert", "key"}, @@ -234,11 +234,11 @@ func TestAuthenticationValidate(t *testing.T) { name: "test when authentication config file and oidc-* flags are set", testAuthenticationConfigFile: "configfile", testOIDC: &OIDCAuthenticationOptions{ - UsernameClaim: "sub", - SigningAlgs: []string{"RS256"}, - IssuerURL: "https://testIssuerURL", - ClientID: "testClientID", - areFlagsConfigured: func() bool { return true }, + UsernameClaim: "sub", + SigningAlgs: []string{"RS256"}, + IssuerURL: "https://testIssuerURL", + ClientID: "testClientID", + FlagsSet: true, }, expectErr: "authentication-config file and oidc-* flags are mutually exclusive", }, @@ -247,8 +247,8 @@ func TestAuthenticationValidate(t *testing.T) { disabledFeatures: []featuregate.Feature{features.AnonymousAuthConfigurableEndpoints}, testAuthenticationConfigFile: "configfile", testAnonymous: &AnonymousAuthenticationOptions{ - Allow: true, - areFlagsSet: func() bool { return true }, + Allow: true, + FlagsSet: true, }, }, } @@ -413,7 +413,8 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) { expected := &BuiltInAuthenticationOptions{ APIAudiences: []string{"foo"}, Anonymous: &AnonymousAuthenticationOptions{ - Allow: true, + Allow: true, + FlagsSet: true, }, BootstrapToken: &BootstrapTokenAuthenticationOptions{ Enable: true, @@ -428,6 +429,7 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) { UsernameClaim: "sub", UsernamePrefix: "-", SigningAlgs: []string{"RS256"}, + FlagsSet: true, }, RequestHeader: &apiserveroptions.RequestHeaderAuthenticationOptions{ ClientCAFile: "testdata/root.pem", @@ -470,19 +472,6 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) { t.Fatal(err) } - if !opts.OIDC.areFlagsConfigured() { - t.Fatal("OIDC flags should be configured") - } - // nil these out because you cannot compare functions - opts.OIDC.areFlagsConfigured = nil - - if !opts.Anonymous.areFlagsSet() { - t.Fatalf("Anonymous flags should be configured") - } - - // nil these out because you cannot compare functions - opts.Anonymous.areFlagsSet = nil - if !reflect.DeepEqual(opts, expected) { t.Error(cmp.Diff(opts, expected, cmp.AllowUnexported(OIDCAuthenticationOptions{}, AnonymousAuthenticationOptions{}))) } diff --git a/staging/src/k8s.io/component-base/cli/flag/tracker_flag.go b/staging/src/k8s.io/component-base/cli/flag/tracker_flag.go new file mode 100644 index 00000000000..a7f6efed3e9 --- /dev/null +++ b/staging/src/k8s.io/component-base/cli/flag/tracker_flag.go @@ -0,0 +1,82 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package flag + +import ( + "github.com/spf13/pflag" +) + +// TrackerValue wraps a non-boolean value and stores true in the provided boolean when it is set. +type TrackerValue struct { + value pflag.Value + provided *bool +} + +// BoolTrackerValue wraps a boolean value and stores true in the provided boolean when it is set. +type BoolTrackerValue struct { + boolValue + provided *bool +} + +type boolValue interface { + pflag.Value + IsBoolFlag() bool +} + +var _ pflag.Value = &TrackerValue{} +var _ boolValue = &BoolTrackerValue{} + +// NewTracker returns a Value wrapping the given value which stores true in the provided boolean when it is set. +func NewTracker(value pflag.Value, provided *bool) pflag.Value { + if value == nil { + panic("value must not be nil") + } + + if provided == nil { + panic("provided boolean must not be nil") + } + + if boolValue, ok := value.(boolValue); ok { + return &BoolTrackerValue{boolValue: boolValue, provided: provided} + } + return &TrackerValue{value: value, provided: provided} +} + +func (f *TrackerValue) String() string { + return f.value.String() +} + +func (f *TrackerValue) Set(value string) error { + err := f.value.Set(value) + if err == nil { + *f.provided = true + } + return err +} + +func (f *TrackerValue) Type() string { + return f.value.Type() +} + +func (f *BoolTrackerValue) Set(value string) error { + err := f.boolValue.Set(value) + if err == nil { + *f.provided = true + } + + return err +} diff --git a/staging/src/k8s.io/component-base/cli/flag/tracker_flag_test.go b/staging/src/k8s.io/component-base/cli/flag/tracker_flag_test.go new file mode 100644 index 00000000000..6f0a1ade607 --- /dev/null +++ b/staging/src/k8s.io/component-base/cli/flag/tracker_flag_test.go @@ -0,0 +1,285 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package flag + +import ( + "fmt" + "testing" + + "github.com/spf13/pflag" +) + +func TestNewTracker(t *testing.T) { + tests := []struct { + name string + value pflag.Value + provided *bool + wantType string + }{ + { + name: "non-bool-tracker", + value: &nonBoolFlagMockValue{val: "initial", typ: "string"}, + provided: new(bool), + wantType: "string", + }, + { + name: "bool-tracker", + value: &boolFlagMockValue{val: "false", typ: "bool", isBool: true}, + provided: new(bool), + wantType: "bool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewTracker(tt.value, tt.provided) + + if tracker.Type() != tt.wantType { + t.Errorf("Want type %s, got %s", tt.wantType, tracker.Type()) + } + + if trackerValue, ok := tracker.(*TrackerValue); ok { + if trackerValue.provided != tt.provided { + t.Errorf("Provided pointer not stored correctly in TrackerValue") + } + } else if boolTrackerValue, ok := tracker.(*BoolTrackerValue); ok { + if boolTrackerValue.provided != tt.provided { + t.Errorf("Provided pointer not stored correctly in BoolTrackerValue") + } + } + }) + } +} + +func TestNewTrackerPanics(t *testing.T) { + tests := []struct { + name string + value pflag.Value + provided *bool + panicMsg string + }{ + { + name: "nil-value", + value: nil, + provided: new(bool), + panicMsg: "value must not be nil", + }, + { + name: "nil-provided", + value: &boolFlagMockValue{val: "test"}, + provided: nil, + panicMsg: "provided boolean must not be nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic, but did not panic") + } else if r != tt.panicMsg { + t.Errorf("expected panic message %q, got %q", tt.panicMsg, r) + } + }() + NewTracker(tt.value, tt.provided) + }) + } +} + +func TestTrackerValue_String(t *testing.T) { + testCases := []struct { + name string + mockValue pflag.Value + want string + }{ + { + name: "bool-flag", + mockValue: &boolFlagMockValue{val: "bool-test"}, + want: "bool-test", + }, + { + name: "non-bool-flag", + mockValue: &nonBoolFlagMockValue{val: "non-bool-test"}, + want: "non-bool-test", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tracker := NewTracker(tc.mockValue, new(bool)) + result := tracker.String() + if result != tc.want { + t.Errorf("Want %q, but got %q", tc.want, result) + } + }) + } +} + +func TestTrackerValue_Set(t *testing.T) { + testCases := []struct { + name string + mockValue pflag.Value + provided *bool + mockErr error + wantProvided bool + wantErr bool + }{ + { + name: "success-bool-flag", + mockValue: &boolFlagMockValue{val: "bool-test"}, + provided: new(bool), + wantProvided: true, + wantErr: false, + }, + { + name: "success-non-bool-flag", + mockValue: &nonBoolFlagMockValue{val: "bool-test"}, + provided: new(bool), + wantProvided: true, + wantErr: false, + }, + { + name: "error-bool-flag", + mockValue: &boolFlagMockValue{val: "bool-test", err: fmt.Errorf("set error")}, + provided: new(bool), + wantProvided: false, + wantErr: true, + }, + { + name: "error-non-bool-flag", + mockValue: &nonBoolFlagMockValue{val: "bool-test", err: fmt.Errorf("set error")}, + provided: new(bool), + wantProvided: false, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tracker := NewTracker(tc.mockValue, tc.provided) + err := tracker.Set("new value") + + if (err != nil) != tc.wantErr { + t.Errorf("Want error: %v, got: %v", tc.wantErr, err != nil) + } + + if *tc.provided != tc.wantProvided { + t.Errorf("Want provided to be %v, got: %v", tc.wantProvided, *tc.provided) + } + }) + } +} + +func TestTrackerValue_MultipleSetCalls(t *testing.T) { + provided := false + mock := &boolFlagMockValue{val: "initial"} + tracker := NewTracker(mock, &provided) + + err := tracker.Set("new value") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if mock.val != "new value" { + t.Errorf("Expected mock value to be 'new value', got '%s'", mock.val) + } + if !provided { + t.Error("Expected 'provided' to be true, got false") + } + + provided = false // reset + mock.err = fmt.Errorf("set error") + err = tracker.Set("failed set") + + if err == nil { + t.Errorf("Expected an error, got nil") + } + if provided { + t.Error("Expected 'provided' to be false after error, got true") + } +} + +func TestTrackerValue_Type(t *testing.T) { + testCases := []struct { + name string + mockValue pflag.Value + want string + }{ + { + name: "success-bool-flag", + mockValue: &boolFlagMockValue{typ: "mockBoolType"}, + want: "mockBoolType", + }, + { + name: "success-non-bool-flag", + mockValue: &nonBoolFlagMockValue{typ: "mockNonBoolType"}, + want: "mockNonBoolType", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tracker := NewTracker(tc.mockValue, new(bool)) + result := tracker.Type() + if result != tc.want { + t.Errorf("Want %q, but got %q", tc.want, result) + } + }) + } +} + +type boolFlagMockValue struct { + val string + typ string + isBool bool + err error +} + +func (m *boolFlagMockValue) String() string { + return m.val +} + +func (m *boolFlagMockValue) Set(value string) error { + m.val = value + return m.err +} + +func (m *boolFlagMockValue) Type() string { + return m.typ +} + +func (m *boolFlagMockValue) IsBoolFlag() bool { + return m.isBool +} + +type nonBoolFlagMockValue struct { + val string + typ string + err error +} + +func (m *nonBoolFlagMockValue) String() string { + return m.val +} + +func (m *nonBoolFlagMockValue) Set(value string) error { + m.val = value + return m.err +} + +func (m *nonBoolFlagMockValue) Type() string { + return m.typ +}