Merge pull request #130916 from richabanker/oidc-flags-v3

Remove mutation of authn options by binding flag setters to a tracking bool in options
This commit is contained in:
Kubernetes Prow Robot 2025-03-19 19:04:37 -07:00 committed by GitHub
commit ce87977639
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 452 additions and 98 deletions

View File

@ -100,8 +100,9 @@ type BuiltInAuthenticationOptions struct {
// AnonymousAuthenticationOptions contains anonymous authentication options for API Server
type AnonymousAuthenticationOptions struct {
Allow bool
areFlagsSet func() bool
Allow bool
// FlagsSet tracks whether any of the configuration options were set via a command-line flag.
FlagsSet bool
}
// BootstrapTokenAuthenticationOptions contains bootstrap token authentication options for API Server
@ -121,8 +122,8 @@ type OIDCAuthenticationOptions struct {
SigningAlgs []string
RequiredClaims map[string]string
// areFlagsConfigured is a function that returns true if any of the oidc-* flags are configured.
areFlagsConfigured func() bool
// FlagsSet tracks whether any of the configuration options were set via a command-line flag.
FlagsSet bool
}
// ServiceAccountAuthenticationOptions contains service account authentication options for API Server
@ -183,8 +184,7 @@ func (o *BuiltInAuthenticationOptions) WithAll() *BuiltInAuthenticationOptions {
// WithAnonymous set default value for anonymous authentication
func (o *BuiltInAuthenticationOptions) WithAnonymous() *BuiltInAuthenticationOptions {
o.Anonymous = &AnonymousAuthenticationOptions{
Allow: true,
areFlagsSet: func() bool { return false },
Allow: true,
}
return o
}
@ -204,9 +204,8 @@ func (o *BuiltInAuthenticationOptions) WithClientCert() *BuiltInAuthenticationOp
// WithOIDC set default value for OIDC authentication
func (o *BuiltInAuthenticationOptions) WithOIDC() *BuiltInAuthenticationOptions {
o.OIDC = &OIDCAuthenticationOptions{
areFlagsConfigured: func() bool { return false },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
}
return o
}
@ -337,10 +336,7 @@ func (o *BuiltInAuthenticationOptions) AddFlags(fs *pflag.FlagSet) {
"Enables anonymous requests to the secure port of the API server. "+
"Requests that are not rejected by another authentication method are treated as anonymous requests. "+
"Anonymous requests have a username of system:anonymous, and a group name of system:unauthenticated.")
o.Anonymous.areFlagsSet = func() bool {
return fs.Changed("anonymous-auth")
}
trackProvidedFlag(fs, "anonymous-auth", &o.Anonymous.FlagsSet)
}
if o.BootstrapToken != nil {
@ -357,54 +353,51 @@ func (o *BuiltInAuthenticationOptions) AddFlags(fs *pflag.FlagSet) {
fs.StringVar(&o.OIDC.IssuerURL, oidcIssuerURLFlag, o.OIDC.IssuerURL, ""+
"The URL of the OpenID issuer, only HTTPS scheme will be accepted. "+
"If set, it will be used to verify the OIDC JSON Web Token (JWT).")
trackProvidedFlag(fs, oidcIssuerURLFlag, &o.OIDC.FlagsSet)
fs.StringVar(&o.OIDC.ClientID, oidcClientIDFlag, o.OIDC.ClientID, ""+
"The client ID for the OpenID Connect client, must be set if oidc-issuer-url is set.")
trackProvidedFlag(fs, oidcClientIDFlag, &o.OIDC.FlagsSet)
fs.StringVar(&o.OIDC.CAFile, oidcCAFileFlag, o.OIDC.CAFile, ""+
"If set, the OpenID server's certificate will be verified by one of the authorities "+
"in the oidc-ca-file, otherwise the host's root CA set will be used.")
trackProvidedFlag(fs, oidcCAFileFlag, &o.OIDC.FlagsSet)
fs.StringVar(&o.OIDC.UsernameClaim, oidcUsernameClaimFlag, o.OIDC.UsernameClaim, ""+
"The OpenID claim to use as the user name. Note that claims other than the default ('sub') "+
"is not guaranteed to be unique and immutable. This flag is experimental, please see "+
"the authentication documentation for further details.")
trackProvidedFlag(fs, oidcUsernameClaimFlag, &o.OIDC.FlagsSet)
fs.StringVar(&o.OIDC.UsernamePrefix, oidcUsernamePrefixFlag, o.OIDC.UsernamePrefix, ""+
"If provided, all usernames will be prefixed with this value. If not provided, "+
"username claims other than 'email' are prefixed by the issuer URL to avoid "+
"clashes. To skip any prefixing, provide the value '-'.")
trackProvidedFlag(fs, oidcUsernamePrefixFlag, &o.OIDC.FlagsSet)
fs.StringVar(&o.OIDC.GroupsClaim, oidcGroupsClaimFlag, o.OIDC.GroupsClaim, ""+
"If provided, the name of a custom OpenID Connect claim for specifying user groups. "+
"The claim value is expected to be a string or array of strings. This flag is experimental, "+
"please see the authentication documentation for further details.")
trackProvidedFlag(fs, oidcGroupsClaimFlag, &o.OIDC.FlagsSet)
fs.StringVar(&o.OIDC.GroupsPrefix, oidcGroupsPrefixFlag, o.OIDC.GroupsPrefix, ""+
"If provided, all groups will be prefixed with this value to prevent conflicts with "+
"other authentication strategies.")
trackProvidedFlag(fs, oidcGroupsPrefixFlag, &o.OIDC.FlagsSet)
fs.StringSliceVar(&o.OIDC.SigningAlgs, oidcSigningAlgsFlag, o.OIDC.SigningAlgs, ""+
"Comma-separated list of allowed JOSE asymmetric signing algorithms. JWTs with a "+
"supported 'alg' header values are: RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384, PS512. "+
"Values are defined by RFC 7518 https://tools.ietf.org/html/rfc7518#section-3.1.")
trackProvidedFlag(fs, oidcSigningAlgsFlag, &o.OIDC.FlagsSet)
fs.Var(cliflag.NewMapStringStringNoSplit(&o.OIDC.RequiredClaims), oidcRequiredClaimFlag, ""+
"A key=value pair that describes a required claim in the ID Token. "+
"If set, the claim is verified to be present in the ID Token with a matching value. "+
"Repeat this flag to specify multiple claims.")
o.OIDC.areFlagsConfigured = func() bool {
return fs.Changed(oidcIssuerURLFlag) ||
fs.Changed(oidcClientIDFlag) ||
fs.Changed(oidcCAFileFlag) ||
fs.Changed(oidcUsernameClaimFlag) ||
fs.Changed(oidcUsernamePrefixFlag) ||
fs.Changed(oidcGroupsClaimFlag) ||
fs.Changed(oidcGroupsPrefixFlag) ||
fs.Changed(oidcSigningAlgsFlag) ||
fs.Changed(oidcRequiredClaimFlag)
}
trackProvidedFlag(fs, oidcRequiredClaimFlag, &o.OIDC.FlagsSet)
}
if o.RequestHeader != nil {
@ -572,7 +565,7 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat
// Set up anonymous authenticator from config file or flags
if o.Anonymous != nil {
switch {
case ret.AuthenticationConfig.Anonymous != nil && o.Anonymous.areFlagsSet():
case ret.AuthenticationConfig.Anonymous != nil && o.Anonymous.FlagsSet:
// Flags and config file are mutually exclusive
return kubeauthenticator.Config{}, field.Forbidden(field.NewPath("anonymous"), "--anonynous-auth flag cannot be set when anonymous field is configured in authentication configuration file")
case ret.AuthenticationConfig.Anonymous != nil:
@ -823,12 +816,17 @@ func (o *BuiltInAuthenticationOptions) ApplyAuthorization(authorization *BuiltIn
}
}
func trackProvidedFlag(fs *pflag.FlagSet, flagName string, provided *bool) {
f := fs.Lookup(flagName)
f.Value = cliflag.NewTracker(f.Value, provided)
}
func (o *BuiltInAuthenticationOptions) validateOIDCOptions() []error {
var allErrors []error
// Existing validation when jwt authenticator is configured with oidc-* flags
if len(o.AuthenticationConfigFile) == 0 {
if o.OIDC != nil && o.OIDC.areFlagsConfigured() && (len(o.OIDC.IssuerURL) == 0 || len(o.OIDC.ClientID) == 0) {
if o.OIDC != nil && o.OIDC.FlagsSet && (len(o.OIDC.IssuerURL) == 0 || len(o.OIDC.ClientID) == 0) {
allErrors = append(allErrors, fmt.Errorf("oidc-issuer-url and oidc-client-id must be specified together when any oidc-* flags are set"))
}
@ -843,7 +841,7 @@ func (o *BuiltInAuthenticationOptions) validateOIDCOptions() []error {
}
// Authentication config file and oidc-* flags are mutually exclusive
if o.OIDC != nil && o.OIDC.areFlagsConfigured() {
if o.OIDC != nil && o.OIDC.FlagsSet {
allErrors = append(allErrors, fmt.Errorf("authentication-config file and oidc-* flags are mutually exclusive"))
}

View File

@ -71,11 +71,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when OIDC and ServiceAccounts are valid",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{"http://foo.bar.com"},
@ -85,10 +85,10 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when OIDC is invalid",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{"http://foo.bar.com"},
@ -99,11 +99,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccounts doesn't have key file",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{"http://foo.bar.com"},
@ -113,11 +113,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccounts doesn't have issuer",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{},
@ -127,11 +127,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccounts has empty string as issuer",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{""},
@ -141,11 +141,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccounts has duplicate issuers",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{"http://foo.bar.com", "http://foo.bar.com"},
@ -155,11 +155,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccount has bad issuer",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
Issuers: []string{"http://[::1]:namedport"},
@ -169,11 +169,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccounts has invalid JWKSURI",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
KeyFiles: []string{"cert", "key"},
@ -185,11 +185,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when ServiceAccounts has invalid JWKSURI (not https scheme)",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
KeyFiles: []string{"cert", "key"},
@ -201,11 +201,11 @@ func TestAuthenticationValidate(t *testing.T) {
{
name: "test when WebHook has invalid retry attempts",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
testSA: &ServiceAccountAuthenticationOptions{
KeyFiles: []string{"cert", "key"},
@ -234,11 +234,11 @@ func TestAuthenticationValidate(t *testing.T) {
name: "test when authentication config file and oidc-* flags are set",
testAuthenticationConfigFile: "configfile",
testOIDC: &OIDCAuthenticationOptions{
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
areFlagsConfigured: func() bool { return true },
UsernameClaim: "sub",
SigningAlgs: []string{"RS256"},
IssuerURL: "https://testIssuerURL",
ClientID: "testClientID",
FlagsSet: true,
},
expectErr: "authentication-config file and oidc-* flags are mutually exclusive",
},
@ -247,8 +247,8 @@ func TestAuthenticationValidate(t *testing.T) {
disabledFeatures: []featuregate.Feature{features.AnonymousAuthConfigurableEndpoints},
testAuthenticationConfigFile: "configfile",
testAnonymous: &AnonymousAuthenticationOptions{
Allow: true,
areFlagsSet: func() bool { return true },
Allow: true,
FlagsSet: true,
},
},
}
@ -413,7 +413,8 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) {
expected := &BuiltInAuthenticationOptions{
APIAudiences: []string{"foo"},
Anonymous: &AnonymousAuthenticationOptions{
Allow: true,
Allow: true,
FlagsSet: true,
},
BootstrapToken: &BootstrapTokenAuthenticationOptions{
Enable: true,
@ -428,6 +429,7 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) {
UsernameClaim: "sub",
UsernamePrefix: "-",
SigningAlgs: []string{"RS256"},
FlagsSet: true,
},
RequestHeader: &apiserveroptions.RequestHeaderAuthenticationOptions{
ClientCAFile: "testdata/root.pem",
@ -470,19 +472,6 @@ func TestBuiltInAuthenticationOptionsAddFlags(t *testing.T) {
t.Fatal(err)
}
if !opts.OIDC.areFlagsConfigured() {
t.Fatal("OIDC flags should be configured")
}
// nil these out because you cannot compare functions
opts.OIDC.areFlagsConfigured = nil
if !opts.Anonymous.areFlagsSet() {
t.Fatalf("Anonymous flags should be configured")
}
// nil these out because you cannot compare functions
opts.Anonymous.areFlagsSet = nil
if !reflect.DeepEqual(opts, expected) {
t.Error(cmp.Diff(opts, expected, cmp.AllowUnexported(OIDCAuthenticationOptions{}, AnonymousAuthenticationOptions{})))
}

View File

@ -0,0 +1,82 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
"github.com/spf13/pflag"
)
// TrackerValue wraps a non-boolean value and stores true in the provided boolean when it is set.
type TrackerValue struct {
value pflag.Value
provided *bool
}
// BoolTrackerValue wraps a boolean value and stores true in the provided boolean when it is set.
type BoolTrackerValue struct {
boolValue
provided *bool
}
type boolValue interface {
pflag.Value
IsBoolFlag() bool
}
var _ pflag.Value = &TrackerValue{}
var _ boolValue = &BoolTrackerValue{}
// NewTracker returns a Value wrapping the given value which stores true in the provided boolean when it is set.
func NewTracker(value pflag.Value, provided *bool) pflag.Value {
if value == nil {
panic("value must not be nil")
}
if provided == nil {
panic("provided boolean must not be nil")
}
if boolValue, ok := value.(boolValue); ok {
return &BoolTrackerValue{boolValue: boolValue, provided: provided}
}
return &TrackerValue{value: value, provided: provided}
}
func (f *TrackerValue) String() string {
return f.value.String()
}
func (f *TrackerValue) Set(value string) error {
err := f.value.Set(value)
if err == nil {
*f.provided = true
}
return err
}
func (f *TrackerValue) Type() string {
return f.value.Type()
}
func (f *BoolTrackerValue) Set(value string) error {
err := f.boolValue.Set(value)
if err == nil {
*f.provided = true
}
return err
}

View File

@ -0,0 +1,285 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
"fmt"
"testing"
"github.com/spf13/pflag"
)
func TestNewTracker(t *testing.T) {
tests := []struct {
name string
value pflag.Value
provided *bool
wantType string
}{
{
name: "non-bool-tracker",
value: &nonBoolFlagMockValue{val: "initial", typ: "string"},
provided: new(bool),
wantType: "string",
},
{
name: "bool-tracker",
value: &boolFlagMockValue{val: "false", typ: "bool", isBool: true},
provided: new(bool),
wantType: "bool",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewTracker(tt.value, tt.provided)
if tracker.Type() != tt.wantType {
t.Errorf("Want type %s, got %s", tt.wantType, tracker.Type())
}
if trackerValue, ok := tracker.(*TrackerValue); ok {
if trackerValue.provided != tt.provided {
t.Errorf("Provided pointer not stored correctly in TrackerValue")
}
} else if boolTrackerValue, ok := tracker.(*BoolTrackerValue); ok {
if boolTrackerValue.provided != tt.provided {
t.Errorf("Provided pointer not stored correctly in BoolTrackerValue")
}
}
})
}
}
func TestNewTrackerPanics(t *testing.T) {
tests := []struct {
name string
value pflag.Value
provided *bool
panicMsg string
}{
{
name: "nil-value",
value: nil,
provided: new(bool),
panicMsg: "value must not be nil",
},
{
name: "nil-provided",
value: &boolFlagMockValue{val: "test"},
provided: nil,
panicMsg: "provided boolean must not be nil",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic, but did not panic")
} else if r != tt.panicMsg {
t.Errorf("expected panic message %q, got %q", tt.panicMsg, r)
}
}()
NewTracker(tt.value, tt.provided)
})
}
}
func TestTrackerValue_String(t *testing.T) {
testCases := []struct {
name string
mockValue pflag.Value
want string
}{
{
name: "bool-flag",
mockValue: &boolFlagMockValue{val: "bool-test"},
want: "bool-test",
},
{
name: "non-bool-flag",
mockValue: &nonBoolFlagMockValue{val: "non-bool-test"},
want: "non-bool-test",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tracker := NewTracker(tc.mockValue, new(bool))
result := tracker.String()
if result != tc.want {
t.Errorf("Want %q, but got %q", tc.want, result)
}
})
}
}
func TestTrackerValue_Set(t *testing.T) {
testCases := []struct {
name string
mockValue pflag.Value
provided *bool
mockErr error
wantProvided bool
wantErr bool
}{
{
name: "success-bool-flag",
mockValue: &boolFlagMockValue{val: "bool-test"},
provided: new(bool),
wantProvided: true,
wantErr: false,
},
{
name: "success-non-bool-flag",
mockValue: &nonBoolFlagMockValue{val: "bool-test"},
provided: new(bool),
wantProvided: true,
wantErr: false,
},
{
name: "error-bool-flag",
mockValue: &boolFlagMockValue{val: "bool-test", err: fmt.Errorf("set error")},
provided: new(bool),
wantProvided: false,
wantErr: true,
},
{
name: "error-non-bool-flag",
mockValue: &nonBoolFlagMockValue{val: "bool-test", err: fmt.Errorf("set error")},
provided: new(bool),
wantProvided: false,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tracker := NewTracker(tc.mockValue, tc.provided)
err := tracker.Set("new value")
if (err != nil) != tc.wantErr {
t.Errorf("Want error: %v, got: %v", tc.wantErr, err != nil)
}
if *tc.provided != tc.wantProvided {
t.Errorf("Want provided to be %v, got: %v", tc.wantProvided, *tc.provided)
}
})
}
}
func TestTrackerValue_MultipleSetCalls(t *testing.T) {
provided := false
mock := &boolFlagMockValue{val: "initial"}
tracker := NewTracker(mock, &provided)
err := tracker.Set("new value")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if mock.val != "new value" {
t.Errorf("Expected mock value to be 'new value', got '%s'", mock.val)
}
if !provided {
t.Error("Expected 'provided' to be true, got false")
}
provided = false // reset
mock.err = fmt.Errorf("set error")
err = tracker.Set("failed set")
if err == nil {
t.Errorf("Expected an error, got nil")
}
if provided {
t.Error("Expected 'provided' to be false after error, got true")
}
}
func TestTrackerValue_Type(t *testing.T) {
testCases := []struct {
name string
mockValue pflag.Value
want string
}{
{
name: "success-bool-flag",
mockValue: &boolFlagMockValue{typ: "mockBoolType"},
want: "mockBoolType",
},
{
name: "success-non-bool-flag",
mockValue: &nonBoolFlagMockValue{typ: "mockNonBoolType"},
want: "mockNonBoolType",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tracker := NewTracker(tc.mockValue, new(bool))
result := tracker.Type()
if result != tc.want {
t.Errorf("Want %q, but got %q", tc.want, result)
}
})
}
}
type boolFlagMockValue struct {
val string
typ string
isBool bool
err error
}
func (m *boolFlagMockValue) String() string {
return m.val
}
func (m *boolFlagMockValue) Set(value string) error {
m.val = value
return m.err
}
func (m *boolFlagMockValue) Type() string {
return m.typ
}
func (m *boolFlagMockValue) IsBoolFlag() bool {
return m.isBool
}
type nonBoolFlagMockValue struct {
val string
typ string
err error
}
func (m *nonBoolFlagMockValue) String() string {
return m.val
}
func (m *nonBoolFlagMockValue) Set(value string) error {
m.val = value
return m.err
}
func (m *nonBoolFlagMockValue) Type() string {
return m.typ
}