diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_baseline.go b/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_baseline.go index 1fa62953ee6..aad61738b50 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_baseline.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_baseline.go @@ -41,11 +41,13 @@ func init() { addCheck(CheckCapabilitiesBaseline) } +const checkCapabilitiesBaselineID CheckID = "capabilities_baseline" + // CheckCapabilitiesBaseline returns a baseline level check // that limits the capabilities that can be added in 1.0+ func CheckCapabilitiesBaseline() Check { return Check{ - ID: "capabilities_baseline", + ID: checkCapabilitiesBaselineID, Level: api.LevelBaseline, Versions: []VersionedCheck{ { diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_restricted.go b/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_restricted.go index fd2e09729a6..48b1ea897b5 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_restricted.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_capabilities_restricted.go @@ -62,8 +62,9 @@ func CheckCapabilitiesRestricted() Check { Level: api.LevelRestricted, Versions: []VersionedCheck{ { - MinimumVersion: api.MajorMinorVersion(1, 22), - CheckPod: capabilitiesRestricted_1_22, + MinimumVersion: api.MajorMinorVersion(1, 22), + CheckPod: capabilitiesRestricted_1_22, + OverrideCheckIDs: []CheckID{checkCapabilitiesBaselineID}, }, }, } diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_hostPathVolumes.go b/staging/src/k8s.io/pod-security-admission/policy/check_hostPathVolumes.go index 600f3734516..3a419ff2495 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_hostPathVolumes.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_hostPathVolumes.go @@ -38,11 +38,13 @@ func init() { addCheck(CheckHostPathVolumes) } +const checkHostPathVolumesID CheckID = "hostPathVolumes" + // CheckHostPathVolumes returns a baseline level check // that requires hostPath=undefined/null in 1.0+ func CheckHostPathVolumes() Check { return Check{ - ID: "hostPathVolumes", + ID: checkHostPathVolumesID, Level: api.LevelBaseline, Versions: []VersionedCheck{ { diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_procMount.go b/staging/src/k8s.io/pod-security-admission/policy/check_procMount.go index a3ed8246162..282dbb32ce9 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_procMount.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_procMount.go @@ -20,7 +20,6 @@ import ( "fmt" corev1 "k8s.io/api/core/v1" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/pod-security-admission/api" @@ -71,7 +70,7 @@ func procMount_1_0(podMetadata *metav1.ObjectMeta, podSpec *corev1.PodSpec) Chec return } // check if the value of the proc mount type is valid. - if *container.SecurityContext.ProcMount != v1.DefaultProcMount { + if *container.SecurityContext.ProcMount != corev1.DefaultProcMount { badContainers = append(badContainers, container.Name) forbiddenProcMountTypes.Insert(string(*container.SecurityContext.ProcMount)) } diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_restrictedVolumes.go b/staging/src/k8s.io/pod-security-admission/policy/check_restrictedVolumes.go index a559a7f9c5e..e171cdd60f1 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_restrictedVolumes.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_restrictedVolumes.go @@ -76,8 +76,9 @@ func CheckRestrictedVolumes() Check { Level: api.LevelRestricted, Versions: []VersionedCheck{ { - MinimumVersion: api.MajorMinorVersion(1, 0), - CheckPod: restrictedVolumes_1_0, + MinimumVersion: api.MajorMinorVersion(1, 0), + CheckPod: restrictedVolumes_1_0, + OverrideCheckIDs: []CheckID{checkHostPathVolumesID}, }, }, } diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_baseline.go b/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_baseline.go index 0409f93e70d..55152b3e6a7 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_baseline.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_baseline.go @@ -49,6 +49,8 @@ spec.initContainers[*].securityContext.seccompProfile.type const ( annotationKeyPod = "seccomp.security.alpha.kubernetes.io/pod" annotationKeyContainerPrefix = "container.seccomp.security.alpha.kubernetes.io/" + + checkSeccompBaselineID CheckID = "seccompProfile_baseline" ) func init() { @@ -57,7 +59,7 @@ func init() { func CheckSeccompBaseline() Check { return Check{ - ID: "seccompProfile_baseline", + ID: checkSeccompBaselineID, Level: api.LevelBaseline, Versions: []VersionedCheck{ { diff --git a/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_restricted.go b/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_restricted.go index 66bec6e05d9..1a8535a0f37 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_restricted.go +++ b/staging/src/k8s.io/pod-security-admission/policy/check_seccompProfile_restricted.go @@ -51,8 +51,9 @@ func CheckSeccompProfileRestricted() Check { Level: api.LevelRestricted, Versions: []VersionedCheck{ { - MinimumVersion: api.MajorMinorVersion(1, 19), - CheckPod: seccompProfileRestricted_1_19, + MinimumVersion: api.MajorMinorVersion(1, 19), + CheckPod: seccompProfileRestricted_1_19, + OverrideCheckIDs: []CheckID{checkSeccompBaselineID}, }, }, } diff --git a/staging/src/k8s.io/pod-security-admission/policy/checks.go b/staging/src/k8s.io/pod-security-admission/policy/checks.go index f46b3b3e4ac..105c4ee0bf0 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/checks.go +++ b/staging/src/k8s.io/pod-security-admission/policy/checks.go @@ -26,7 +26,7 @@ import ( type Check struct { // ID is the unique ID of the check. - ID string + ID CheckID // Level is the policy level this check belongs to. // Must be Baseline or Restricted. // Baseline checks are evaluated for baseline and restricted namespaces. @@ -45,10 +45,15 @@ type VersionedCheck struct { MinimumVersion api.Version // CheckPod determines if the pod is allowed. CheckPod CheckPodFn + // OverrideCheckIDs is an optional list of checks that should be skipped when this check is run. + // Overrides may only be set on restricted checks, and may only override baseline checks. + OverrideCheckIDs []CheckID } type CheckPodFn func(*metav1.ObjectMeta, *corev1.PodSpec) CheckResult +type CheckID string + // CheckResult contains the result of checking a pod and indicates whether the pod is allowed, // and if not, why it was forbidden. // diff --git a/staging/src/k8s.io/pod-security-admission/policy/checks_test.go b/staging/src/k8s.io/pod-security-admission/policy/checks_test.go new file mode 100644 index 00000000000..8f601c77d40 --- /dev/null +++ b/staging/src/k8s.io/pod-security-admission/policy/checks_test.go @@ -0,0 +1,43 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package policy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestValidChecks ensures that all registered checks are valid. +func TestValidChecks(t *testing.T) { + allChecks := append(DefaultChecks(), ExperimentalChecks()...) + + assert.NoError(t, validateChecks(allChecks)) + + // Ensure that all overrides map to existing checks. + allIDs := map[CheckID]bool{} + for _, check := range allChecks { + allIDs[check.ID] = true + } + for _, check := range allChecks { + for _, c := range check.Versions { + for _, override := range c.OverrideCheckIDs { + assert.Contains(t, allIDs, override, "check %s overrides non-existent check %s", check.ID, override) + } + } + } +} diff --git a/staging/src/k8s.io/pod-security-admission/policy/registry.go b/staging/src/k8s.io/pod-security-admission/policy/registry.go index 244bfe3787b..4b91bef8875 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/registry.go +++ b/staging/src/k8s.io/pod-security-admission/policy/registry.go @@ -18,6 +18,7 @@ package policy import ( "fmt" + "sort" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -32,7 +33,7 @@ type Evaluator interface { // checkRegistry provides a default implementation of an Evaluator. type checkRegistry struct { - // The checks are a map of check_ID -> sorted slice of versioned checks, newest first + // The checks are a map policy version to a slice of checks registered for that version. baselineChecks, restrictedChecks map[api.Version][]CheckPodFn // maxVersion is the maximum version that is cached, guaranteed to be at least // the max MinimumVersion of all registered checks. @@ -64,26 +65,29 @@ func (r *checkRegistry) EvaluatePod(lv api.LevelVersion, podMetadata *metav1.Obj if r.maxVersion.Older(lv.Version) { lv.Version = r.maxVersion } - results := []CheckResult{} - for _, check := range r.baselineChecks[lv.Version] { - results = append(results, check(podMetadata, podSpec)) - } + + var checks []CheckPodFn if lv.Level == api.LevelBaseline { - return results + checks = r.baselineChecks[lv.Version] + } else { + // includes non-overridden baseline checks + checks = r.restrictedChecks[lv.Version] } - for _, check := range r.restrictedChecks[lv.Version] { + + var results []CheckResult + for _, check := range checks { results = append(results, check(podMetadata, podSpec)) } return results } func validateChecks(checks []Check) error { - ids := map[string]bool{} + ids := map[CheckID]api.Level{} for _, check := range checks { - if ids[check.ID] { + if _, ok := ids[check.ID]; ok { return fmt.Errorf("multiple checks registered for ID %s", check.ID) } - ids[check.ID] = true + ids[check.ID] = check.Level if check.Level != api.LevelBaseline && check.Level != api.LevelRestricted { return fmt.Errorf("check %s: invalid level %s", check.ID, check.Level) } @@ -107,6 +111,23 @@ func validateChecks(checks []Check) error { maxVersion = c.MinimumVersion } } + // Second pass to validate overrides. + for _, check := range checks { + for _, c := range check.Versions { + if len(c.OverrideCheckIDs) == 0 { + continue + } + + if check.Level != api.LevelRestricted { + return fmt.Errorf("check %s: only restricted checks may set overrides", check.ID) + } + for _, override := range c.OverrideCheckIDs { + if overriddenLevel, ok := ids[override]; ok && overriddenLevel != api.LevelBaseline { + return fmt.Errorf("check %s: overrides %s check %s", check.ID, overriddenLevel, override) + } + } + } + } return nil } @@ -119,28 +140,87 @@ func populate(r *checkRegistry, validChecks []Check) { } } + var ( + restrictedVersionedChecks = map[api.Version]map[CheckID]VersionedCheck{} + baselineVersionedChecks = map[api.Version]map[CheckID]VersionedCheck{} + + baselineIDs, restrictedIDs []CheckID + ) for _, c := range validChecks { if c.Level == api.LevelRestricted { - inflateVersions(c, r.restrictedChecks, r.maxVersion) + restrictedIDs = append(restrictedIDs, c.ID) + inflateVersions(c, restrictedVersionedChecks, r.maxVersion) } else { - inflateVersions(c, r.baselineChecks, r.maxVersion) + baselineIDs = append(baselineIDs, c.ID) + inflateVersions(c, baselineVersionedChecks, r.maxVersion) } } + + // Sort the IDs to maintain consistent error messages. + sort.Slice(restrictedIDs, func(i, j int) bool { return restrictedIDs[i] < restrictedIDs[j] }) + sort.Slice(baselineIDs, func(i, j int) bool { return baselineIDs[i] < baselineIDs[j] }) + orderedIDs := append(baselineIDs, restrictedIDs...) // Baseline checks first, then restricted. + + for v := api.MajorMinorVersion(1, 0); v.Older(nextMinor(r.maxVersion)); v = nextMinor(v) { + // Aggregate all the overridden baseline check ids. + overrides := map[CheckID]bool{} + for _, c := range restrictedVersionedChecks[v] { + for _, override := range c.OverrideCheckIDs { + overrides[override] = true + } + } + // Add the filtered baseline checks to restricted. + for id, c := range baselineVersionedChecks[v] { + if overrides[id] { + continue // Overridden check: skip it. + } + if restrictedVersionedChecks[v] == nil { + restrictedVersionedChecks[v] = map[CheckID]VersionedCheck{} + } + restrictedVersionedChecks[v][id] = c + } + + r.restrictedChecks[v] = mapCheckPodFns(restrictedVersionedChecks[v], orderedIDs) + r.baselineChecks[v] = mapCheckPodFns(baselineVersionedChecks[v], orderedIDs) + } } -func inflateVersions(check Check, versions map[api.Version][]CheckPodFn, maxVersion api.Version) { +func inflateVersions(check Check, versions map[api.Version]map[CheckID]VersionedCheck, maxVersion api.Version) { for i, c := range check.Versions { var nextVersion api.Version if i+1 < len(check.Versions) { nextVersion = check.Versions[i+1].MinimumVersion } else { // Assumes only 1 Major version. - nextVersion = api.MajorMinorVersion(1, maxVersion.Minor()+1) + nextVersion = nextMinor(maxVersion) } // Iterate over all versions from the minimum of the current check, to the minimum of the // next check, or the maxVersion++. - for v := c.MinimumVersion; v.Older(nextVersion); v = api.MajorMinorVersion(1, v.Minor()+1) { - versions[v] = append(versions[v], check.Versions[i].CheckPod) + for v := c.MinimumVersion; v.Older(nextVersion); v = nextMinor(v) { + if versions[v] == nil { + versions[v] = map[CheckID]VersionedCheck{} + } + versions[v][check.ID] = check.Versions[i] } } } + +// mapCheckPodFns converts the versioned check map to an ordered slice of CheckPodFn, +// using the order specified by orderedIDs. All checks must have a corresponding ID in orderedIDs. +func mapCheckPodFns(checks map[CheckID]VersionedCheck, orderedIDs []CheckID) []CheckPodFn { + fns := make([]CheckPodFn, 0, len(checks)) + for _, id := range orderedIDs { + if check, ok := checks[id]; ok { + fns = append(fns, check.CheckPod) + } + } + return fns +} + +// nextMinor increments the minor version +func nextMinor(v api.Version) api.Version { + if v.Latest() { + return v + } + return api.MajorMinorVersion(v.Major(), v.Minor()+1) +} diff --git a/staging/src/k8s.io/pod-security-admission/policy/registry_test.go b/staging/src/k8s.io/pod-security-admission/policy/registry_test.go index cfcb15246d8..ee8471c1aa1 100644 --- a/staging/src/k8s.io/pod-security-admission/policy/registry_test.go +++ b/staging/src/k8s.io/pod-security-admission/policy/registry_test.go @@ -35,16 +35,18 @@ func TestCheckRegistry(t *testing.T) { generateCheck("d", api.LevelBaseline, []string{"v1.11", "v1.15", "v1.20"}), generateCheck("e", api.LevelRestricted, []string{"v1.0"}), generateCheck("f", api.LevelRestricted, []string{"v1.12", "v1.16", "v1.21"}), + withOverrides(generateCheck("g", api.LevelRestricted, []string{"v1.10"}), []CheckID{"a"}), + withOverrides(generateCheck("h", api.LevelRestricted, []string{"v1.0"}), []CheckID{"b"}), } + multiOverride := generateCheck("i", api.LevelRestricted, []string{"v1.10", "v1.21"}) + multiOverride.Versions[0].OverrideCheckIDs = []CheckID{"c"} + multiOverride.Versions[1].OverrideCheckIDs = []CheckID{"d"} + checks = append(checks, multiOverride) reg, err := NewEvaluator(checks) require.NoError(t, err) - levelCases := []struct { - level api.Level - version string - expectedReasons []string - }{ + levelCases := []registryTestCase{ {api.LevelPrivileged, "v1.0", nil}, {api.LevelPrivileged, "latest", nil}, {api.LevelBaseline, "v1.0", []string{"a:v1.0", "c:v1.0"}}, @@ -53,29 +55,112 @@ func TestCheckRegistry(t *testing.T) { {api.LevelBaseline, "v1.10", []string{"a:v1.0", "b:v1.10", "c:v1.10"}}, {api.LevelBaseline, "v1.11", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.11"}}, {api.LevelBaseline, "latest", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20"}}, - {api.LevelRestricted, "v1.0", []string{"a:v1.0", "c:v1.0", "e:v1.0"}}, - {api.LevelRestricted, "v1.4", []string{"a:v1.0", "c:v1.0", "e:v1.0"}}, - {api.LevelRestricted, "v1.5", []string{"a:v1.0", "c:v1.5", "e:v1.0"}}, - {api.LevelRestricted, "v1.10", []string{"a:v1.0", "b:v1.10", "c:v1.10", "e:v1.0"}}, - {api.LevelRestricted, "v1.11", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.11", "e:v1.0"}}, - {api.LevelRestricted, "latest", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20", "e:v1.0", "f:v1.21"}}, - {api.LevelRestricted, "v1.10000", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20", "e:v1.0", "f:v1.21"}}, + {api.LevelRestricted, "v1.0", []string{"a:v1.0", "c:v1.0", "e:v1.0", "h:v1.0"}}, + {api.LevelRestricted, "v1.4", []string{"a:v1.0", "c:v1.0", "e:v1.0", "h:v1.0"}}, + {api.LevelRestricted, "v1.5", []string{"a:v1.0", "c:v1.5", "e:v1.0", "h:v1.0"}}, + {api.LevelRestricted, "v1.10", []string{"e:v1.0", "g:v1.10", "h:v1.0", "i:v1.10"}}, + {api.LevelRestricted, "v1.11", []string{"d:v1.11", "e:v1.0", "g:v1.10", "h:v1.0", "i:v1.10"}}, + {api.LevelRestricted, "latest", []string{"c:v1.10", "e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0", "i:v1.21"}}, + {api.LevelRestricted, "v1.10000", []string{"c:v1.10", "e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0", "i:v1.21"}}, } for _, test := range levelCases { - t.Run(fmt.Sprintf("%s:%s", test.level, test.version), func(t *testing.T) { - results := reg.EvaluatePod(api.LevelVersion{test.level, versionOrPanic(test.version)}, nil, nil) - - // Set extract the ForbiddenReasons from the results. - var actualReasons []string - for _, result := range results { - actualReasons = append(actualReasons, result.ForbiddenReason) - } - assert.ElementsMatch(t, test.expectedReasons, actualReasons) - }) + test.Run(t, reg) } } -func generateCheck(id string, level api.Level, versions []string) Check { +func TestCheckRegistry_NoBaseline(t *testing.T) { + checks := []Check{ + generateCheck("e", api.LevelRestricted, []string{"v1.0"}), + generateCheck("f", api.LevelRestricted, []string{"v1.12", "v1.16", "v1.21"}), + withOverrides(generateCheck("g", api.LevelRestricted, []string{"v1.10"}), []CheckID{"a"}), + withOverrides(generateCheck("h", api.LevelRestricted, []string{"v1.0"}), []CheckID{"b"}), + } + + reg, err := NewEvaluator(checks) + require.NoError(t, err) + + levelCases := []registryTestCase{ + {api.LevelPrivileged, "v1.0", nil}, + {api.LevelPrivileged, "latest", nil}, + {api.LevelBaseline, "v1.0", nil}, + {api.LevelBaseline, "v1.10", nil}, + {api.LevelBaseline, "latest", nil}, + {api.LevelRestricted, "v1.0", []string{"e:v1.0", "h:v1.0"}}, + {api.LevelRestricted, "v1.10", []string{"e:v1.0", "g:v1.10", "h:v1.0"}}, + {api.LevelRestricted, "latest", []string{"e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0"}}, + {api.LevelRestricted, "v1.10000", []string{"e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0"}}, + } + for _, test := range levelCases { + test.Run(t, reg) + } +} + +func TestCheckRegistry_NoRestricted(t *testing.T) { + checks := []Check{ + generateCheck("a", api.LevelBaseline, []string{"v1.0"}), + generateCheck("b", api.LevelBaseline, []string{"v1.10"}), + generateCheck("c", api.LevelBaseline, []string{"v1.0", "v1.5", "v1.10"}), + generateCheck("d", api.LevelBaseline, []string{"v1.11", "v1.15", "v1.20"}), + } + + reg, err := NewEvaluator(checks) + require.NoError(t, err) + + levelCases := []registryTestCase{ + {api.LevelBaseline, "v1.0", []string{"a:v1.0", "c:v1.0"}}, + {api.LevelBaseline, "v1.4", []string{"a:v1.0", "c:v1.0"}}, + {api.LevelBaseline, "v1.5", []string{"a:v1.0", "c:v1.5"}}, + {api.LevelBaseline, "v1.10", []string{"a:v1.0", "b:v1.10", "c:v1.10"}}, + {api.LevelBaseline, "v1.11", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.11"}}, + {api.LevelBaseline, "latest", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20"}}, + } + for _, test := range levelCases { + test.Run(t, reg) + // Restricted results should be identical to baseline. + restrictedTest := test + restrictedTest.level = api.LevelRestricted + restrictedTest.Run(t, reg) + } +} + +func TestCheckRegistry_Empty(t *testing.T) { + reg, err := NewEvaluator(nil) + require.NoError(t, err) + + levelCases := []registryTestCase{ + {api.LevelPrivileged, "latest", nil}, + {api.LevelBaseline, "latest", nil}, + {api.LevelRestricted, "latest", nil}, + } + for _, test := range levelCases { + test.Run(t, reg) + // Restricted results should be identical to baseline. + restrictedTest := test + restrictedTest.level = api.LevelRestricted + restrictedTest.Run(t, reg) + } +} + +type registryTestCase struct { + level api.Level + version string + expectedReasons []string +} + +func (tc *registryTestCase) Run(t *testing.T, registry Evaluator) { + t.Run(fmt.Sprintf("%s:%s", tc.level, tc.version), func(t *testing.T) { + results := registry.EvaluatePod(api.LevelVersion{tc.level, versionOrPanic(tc.version)}, nil, nil) + + // Set extract the ForbiddenReasons from the results. + var actualReasons []string + for _, result := range results { + actualReasons = append(actualReasons, result.ForbiddenReason) + } + assert.Equal(t, tc.expectedReasons, actualReasons) + }) +} + +func generateCheck(id CheckID, level api.Level, versions []string) Check { c := Check{ ID: id, Level: level, @@ -94,6 +179,13 @@ func generateCheck(id string, level api.Level, versions []string) Check { return c } +func withOverrides(c Check, overrides []CheckID) Check { + for i := range c.Versions { + c.Versions[i].OverrideCheckIDs = overrides + } + return c +} + func versionOrPanic(v string) api.Version { ver, err := api.ParseVersion(v) if err != nil { diff --git a/staging/src/k8s.io/pod-security-admission/test/fixtures.go b/staging/src/k8s.io/pod-security-admission/test/fixtures.go index cd11f0d380a..4d00397a0cb 100644 --- a/staging/src/k8s.io/pod-security-admission/test/fixtures.go +++ b/staging/src/k8s.io/pod-security-admission/test/fixtures.go @@ -99,7 +99,7 @@ var fixtureGenerators = map[fixtureKey]fixtureGenerator{} type fixtureKey struct { version api.Version level api.Level - check string + check policy.CheckID } // fixtureGenerator holds generators for valid and invalid fixtures. @@ -184,7 +184,7 @@ func getFixtures(key fixtureKey) (fixtureData, error) { fail: generator.generateFail(validPodForLevel.DeepCopy()), } if len(data.expectErrorSubstring) == 0 { - data.expectErrorSubstring = key.check + data.expectErrorSubstring = string(key.check) } if len(data.fail) == 0 { return fixtureData{}, fmt.Errorf("generateFail for %#v must return at least one pod", key) diff --git a/staging/src/k8s.io/pod-security-admission/test/fixtures_capabilities_baseline.go b/staging/src/k8s.io/pod-security-admission/test/fixtures_capabilities_baseline.go index 0ba6aaf0e6d..1293c2dca70 100644 --- a/staging/src/k8s.io/pod-security-admission/test/fixtures_capabilities_baseline.go +++ b/staging/src/k8s.io/pod-security-admission/test/fixtures_capabilities_baseline.go @@ -47,7 +47,7 @@ func ensureCapabilities(p *corev1.Pod) *corev1.Pod { func init() { fixtureData_1_0 := fixtureGenerator{ - expectErrorSubstring: "non-default capabilities", + expectErrorSubstring: "capabilities", generatePass: func(p *corev1.Pod) []*corev1.Pod { // don't generate fixtures if minimal valid pod drops ALL if p.Spec.Containers[0].SecurityContext != nil && p.Spec.Containers[0].SecurityContext.Capabilities != nil { diff --git a/staging/src/k8s.io/pod-security-admission/test/fixtures_hostPathVolumes.go b/staging/src/k8s.io/pod-security-admission/test/fixtures_hostPathVolumes.go index 10c1a622fa6..7b891fdaff2 100644 --- a/staging/src/k8s.io/pod-security-admission/test/fixtures_hostPathVolumes.go +++ b/staging/src/k8s.io/pod-security-admission/test/fixtures_hostPathVolumes.go @@ -28,7 +28,7 @@ TODO: include field paths in reflect-based unit test func init() { fixtureData_1_0 := fixtureGenerator{ - expectErrorSubstring: "hostPath volumes", + expectErrorSubstring: "hostPath", generatePass: func(p *corev1.Pod) []*corev1.Pod { // minimal valid pod already captures all valid combinations return nil diff --git a/staging/src/k8s.io/pod-security-admission/test/fixtures_test.go b/staging/src/k8s.io/pod-security-admission/test/fixtures_test.go index b6a17b89b3d..959ab0d18f1 100644 --- a/staging/src/k8s.io/pod-security-admission/test/fixtures_test.go +++ b/staging/src/k8s.io/pod-security-admission/test/fixtures_test.go @@ -82,10 +82,10 @@ func TestFixtures(t *testing.T) { } for i, pod := range checkData.pass { - expectedFiles.Insert(testFixtureFile(t, passDir, fmt.Sprintf("%s%d", strings.ToLower(checkID), i), pod)) + expectedFiles.Insert(testFixtureFile(t, passDir, fmt.Sprintf("%s%d", strings.ToLower(string(checkID)), i), pod)) } for i, pod := range checkData.fail { - expectedFiles.Insert(testFixtureFile(t, failDir, fmt.Sprintf("%s%d", strings.ToLower(checkID), i), pod)) + expectedFiles.Insert(testFixtureFile(t, failDir, fmt.Sprintf("%s%d", strings.ToLower(string(checkID)), i), pod)) } } } diff --git a/staging/src/k8s.io/pod-security-admission/test/run.go b/staging/src/k8s.io/pod-security-admission/test/run.go index 030fe808c5b..ea7ea01507b 100644 --- a/staging/src/k8s.io/pod-security-admission/test/run.go +++ b/staging/src/k8s.io/pod-security-admission/test/run.go @@ -73,8 +73,8 @@ func toJSON(pod *corev1.Pod) string { // checksForLevelAndVersion returns the set of check IDs that apply when evaluating the given level and version. // checks are assumed to be well-formed and valid to pass to policy.NewEvaluator(). // level must be api.LevelRestricted or api.LevelBaseline -func checksForLevelAndVersion(checks []policy.Check, level api.Level, version api.Version) ([]string, error) { - retval := []string{} +func checksForLevelAndVersion(checks []policy.Check, level api.Level, version api.Version) ([]policy.CheckID, error) { + retval := []policy.CheckID{} for _, check := range checks { if !version.Older(check.Versions[0].MinimumVersion) && (level == check.Level || level == api.LevelRestricted) { retval = append(retval, check.ID) @@ -318,7 +318,7 @@ func Run(t *testing.T, opts Options) { t.Fatal(err) } - t.Run(ns+"_pass_"+checkID, func(t *testing.T) { + t.Run(ns+"_pass_"+string(checkID), func(t *testing.T) { for i, pod := range checkData.pass { createPod(t, i, pod, true, "") createController(t, i, pod, true, "") @@ -332,7 +332,7 @@ func Run(t *testing.T, opts Options) { disabledRequiredFeatures = append(disabledRequiredFeatures, f) } } - t.Run(ns+"_fail_"+checkID, func(t *testing.T) { + t.Run(ns+"_fail_"+string(checkID), func(t *testing.T) { if len(disabledRequiredFeatures) > 0 { t.Skipf("features required for failure cases are disabled: %v", disabledRequiredFeatures) }