diff --git a/pkg/apis/policy/validation/BUILD b/pkg/apis/policy/validation/BUILD index 0b60688425e..76b6fe1d96f 100644 --- a/pkg/apis/policy/validation/BUILD +++ b/pkg/apis/policy/validation/BUILD @@ -38,6 +38,8 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/intstr:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/validation/field:go_default_library", + "//vendor/github.com/stretchr/testify/assert:go_default_library", + "//vendor/k8s.io/utils/pointer:go_default_library", ], ) diff --git a/pkg/apis/policy/validation/validation.go b/pkg/apis/policy/validation/validation.go index 9b88c21a574..b3c90ade846 100644 --- a/pkg/apis/policy/validation/validation.go +++ b/pkg/apis/policy/validation/validation.go @@ -125,6 +125,7 @@ func ValidatePodSecurityPolicySpec(spec *policy.PodSecurityPolicySpec, fldPath * allErrs = append(allErrs, validatePodSecurityPolicySysctls(fldPath.Child("allowedUnsafeSysctls"), spec.AllowedUnsafeSysctls)...) allErrs = append(allErrs, validatePodSecurityPolicySysctls(fldPath.Child("forbiddenSysctls"), spec.ForbiddenSysctls)...) allErrs = append(allErrs, validatePodSecurityPolicySysctlListsDoNotOverlap(fldPath.Child("allowedUnsafeSysctls"), fldPath.Child("forbiddenSysctls"), spec.AllowedUnsafeSysctls, spec.ForbiddenSysctls)...) + allErrs = append(allErrs, validateRuntimeClassStrategy(fldPath.Child("runtimeClass"), spec.RuntimeClass)...) return allErrs } @@ -476,6 +477,40 @@ func validatePSPCapsAgainstDrops(requiredDrops []core.Capability, capsToCheck [] return allErrs } +// validateRuntimeClassStrategy ensures all the RuntimeClass restrictions are valid. +func validateRuntimeClassStrategy(fldPath *field.Path, rc *policy.RuntimeClassStrategyOptions) field.ErrorList { + if rc == nil { + return nil + } + + var allErrs field.ErrorList + + allowed := map[string]bool{} + for i, name := range rc.AllowedRuntimeClassNames { + if name != policy.AllowAllRuntimeClassNames { + allErrs = append(allErrs, apivalidation.ValidateRuntimeClassName(name, fldPath.Child("allowedRuntimeClassNames").Index(i))...) + } + if allowed[name] { + allErrs = append(allErrs, field.Duplicate(fldPath.Child("allowedRuntimeClassNames").Index(i), name)) + } + allowed[name] = true + } + + if rc.DefaultRuntimeClassName != nil { + allErrs = append(allErrs, apivalidation.ValidateRuntimeClassName(*rc.DefaultRuntimeClassName, fldPath.Child("defaultRuntimeClassName"))...) + if !allowed[*rc.DefaultRuntimeClassName] && !allowed[policy.AllowAllRuntimeClassNames] { + allErrs = append(allErrs, field.Required(fldPath.Child("allowedRuntimeClassNames"), + fmt.Sprintf("default %q must be allowed", *rc.DefaultRuntimeClassName))) + } + } + + if allowed[policy.AllowAllRuntimeClassNames] && len(rc.AllowedRuntimeClassNames) > 1 { + allErrs = append(allErrs, field.Invalid(fldPath.Child("allowedRuntimeClassNames"), rc.AllowedRuntimeClassNames, "if '*' is present, must not specify other RuntimeClass names")) + } + + return allErrs +} + // ValidatePodSecurityPolicyUpdate validates a PSP for updates. func ValidatePodSecurityPolicyUpdate(old *policy.PodSecurityPolicy, new *policy.PodSecurityPolicy) field.ErrorList { allErrs := field.ErrorList{} diff --git a/pkg/apis/policy/validation/validation_test.go b/pkg/apis/policy/validation/validation_test.go index c17d24b5c17..82b52aaaa5c 100644 --- a/pkg/apis/policy/validation/validation_test.go +++ b/pkg/apis/policy/validation/validation_test.go @@ -20,6 +20,7 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/apimachinery/pkg/util/validation/field" @@ -28,6 +29,7 @@ import ( "k8s.io/kubernetes/pkg/security/apparmor" "k8s.io/kubernetes/pkg/security/podsecuritypolicy/seccomp" psputil "k8s.io/kubernetes/pkg/security/podsecuritypolicy/util" + "k8s.io/utils/pointer" ) func TestValidatePodDisruptionBudgetSpec(t *testing.T) { @@ -829,6 +831,7 @@ func TestValidatePSPRunAsUser(t *testing.T) { }) } } + func TestValidatePSPFSGroup(t *testing.T) { var testCases = []struct { name string @@ -853,7 +856,6 @@ func TestValidatePSPFSGroup(t *testing.T) { } }) } - } func TestValidatePSPSupplementalGroup(t *testing.T) { @@ -935,5 +937,88 @@ func TestValidatePSPSELinux(t *testing.T) { } }) } - +} + +func TestValidateRuntimeClassStrategy(t *testing.T) { + var testCases = []struct { + name string + strategy *policy.RuntimeClassStrategyOptions + expectErrors bool + }{{ + name: "nil strategy", + strategy: nil, + }, { + name: "empty strategy", + strategy: &policy.RuntimeClassStrategyOptions{}, + }, { + name: "allow all strategy", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"*"}, + }, + }, { + name: "valid defaulting & allow all", + strategy: &policy.RuntimeClassStrategyOptions{ + DefaultRuntimeClassName: pointer.StringPtr("native"), + AllowedRuntimeClassNames: []string{"*"}, + }, + }, { + name: "valid defaulting & allow explicit", + strategy: &policy.RuntimeClassStrategyOptions{ + DefaultRuntimeClassName: pointer.StringPtr("native"), + AllowedRuntimeClassNames: []string{"foo", "native", "sandboxed"}, + }, + }, { + name: "valid whitelisting", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"foo", "native", "sandboxed"}, + }, + }, { + name: "invalid default name", + strategy: &policy.RuntimeClassStrategyOptions{ + DefaultRuntimeClassName: pointer.StringPtr("foo bar"), + }, + expectErrors: true, + }, { + name: "disallowed default", + strategy: &policy.RuntimeClassStrategyOptions{ + DefaultRuntimeClassName: pointer.StringPtr("foo"), + AllowedRuntimeClassNames: []string{"native", "sandboxed"}, + }, + expectErrors: true, + }, { + name: "nothing allowed default", + strategy: &policy.RuntimeClassStrategyOptions{ + DefaultRuntimeClassName: pointer.StringPtr("foo"), + }, + expectErrors: true, + }, { + name: "invalid whitelist name", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"native", "sandboxed", "foo*"}, + }, + expectErrors: true, + }, { + name: "duplicate whitelist names", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"native", "sandboxed", "native"}, + }, + expectErrors: true, + }, { + name: "allow all redundant whitelist", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"*", "sandboxed", "native"}, + }, + expectErrors: true, + }} + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + errs := validateRuntimeClassStrategy(field.NewPath(""), test.strategy) + if test.expectErrors { + assert.NotEmpty(t, errs) + } else { + assert.Empty(t, errs) + } + }) + } }