Merge pull request #107698 from tallclair/psa-overrides

[PodSecurity] Deduplicate errors between baseline & restricted checks
This commit is contained in:
Kubernetes Prow Robot 2022-02-09 14:09:49 -08:00 committed by GitHub
commit 6ab748eeec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 289 additions and 61 deletions

View File

@ -41,11 +41,13 @@ func init() {
addCheck(CheckCapabilitiesBaseline)
}
const checkCapabilitiesBaselineID CheckID = "capabilities_baseline"
// CheckCapabilitiesBaseline returns a baseline level check
// that limits the capabilities that can be added in 1.0+
func CheckCapabilitiesBaseline() Check {
return Check{
ID: "capabilities_baseline",
ID: checkCapabilitiesBaselineID,
Level: api.LevelBaseline,
Versions: []VersionedCheck{
{

View File

@ -64,6 +64,7 @@ func CheckCapabilitiesRestricted() Check {
{
MinimumVersion: api.MajorMinorVersion(1, 22),
CheckPod: capabilitiesRestricted_1_22,
OverrideCheckIDs: []CheckID{checkCapabilitiesBaselineID},
},
},
}

View File

@ -38,11 +38,13 @@ func init() {
addCheck(CheckHostPathVolumes)
}
const checkHostPathVolumesID CheckID = "hostPathVolumes"
// CheckHostPathVolumes returns a baseline level check
// that requires hostPath=undefined/null in 1.0+
func CheckHostPathVolumes() Check {
return Check{
ID: "hostPathVolumes",
ID: checkHostPathVolumesID,
Level: api.LevelBaseline,
Versions: []VersionedCheck{
{

View File

@ -20,7 +20,6 @@ import (
"fmt"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/pod-security-admission/api"
@ -71,7 +70,7 @@ func procMount_1_0(podMetadata *metav1.ObjectMeta, podSpec *corev1.PodSpec) Chec
return
}
// check if the value of the proc mount type is valid.
if *container.SecurityContext.ProcMount != v1.DefaultProcMount {
if *container.SecurityContext.ProcMount != corev1.DefaultProcMount {
badContainers = append(badContainers, container.Name)
forbiddenProcMountTypes.Insert(string(*container.SecurityContext.ProcMount))
}

View File

@ -78,6 +78,7 @@ func CheckRestrictedVolumes() Check {
{
MinimumVersion: api.MajorMinorVersion(1, 0),
CheckPod: restrictedVolumes_1_0,
OverrideCheckIDs: []CheckID{checkHostPathVolumesID},
},
},
}

View File

@ -49,6 +49,8 @@ spec.initContainers[*].securityContext.seccompProfile.type
const (
annotationKeyPod = "seccomp.security.alpha.kubernetes.io/pod"
annotationKeyContainerPrefix = "container.seccomp.security.alpha.kubernetes.io/"
checkSeccompBaselineID CheckID = "seccompProfile_baseline"
)
func init() {
@ -57,7 +59,7 @@ func init() {
func CheckSeccompBaseline() Check {
return Check{
ID: "seccompProfile_baseline",
ID: checkSeccompBaselineID,
Level: api.LevelBaseline,
Versions: []VersionedCheck{
{

View File

@ -53,6 +53,7 @@ func CheckSeccompProfileRestricted() Check {
{
MinimumVersion: api.MajorMinorVersion(1, 19),
CheckPod: seccompProfileRestricted_1_19,
OverrideCheckIDs: []CheckID{checkSeccompBaselineID},
},
},
}

View File

@ -26,7 +26,7 @@ import (
type Check struct {
// ID is the unique ID of the check.
ID string
ID CheckID
// Level is the policy level this check belongs to.
// Must be Baseline or Restricted.
// Baseline checks are evaluated for baseline and restricted namespaces.
@ -45,10 +45,15 @@ type VersionedCheck struct {
MinimumVersion api.Version
// CheckPod determines if the pod is allowed.
CheckPod CheckPodFn
// OverrideCheckIDs is an optional list of checks that should be skipped when this check is run.
// Overrides may only be set on restricted checks, and may only override baseline checks.
OverrideCheckIDs []CheckID
}
type CheckPodFn func(*metav1.ObjectMeta, *corev1.PodSpec) CheckResult
type CheckID string
// CheckResult contains the result of checking a pod and indicates whether the pod is allowed,
// and if not, why it was forbidden.
//

View File

@ -0,0 +1,43 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package policy
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestValidChecks ensures that all registered checks are valid.
func TestValidChecks(t *testing.T) {
allChecks := append(DefaultChecks(), ExperimentalChecks()...)
assert.NoError(t, validateChecks(allChecks))
// Ensure that all overrides map to existing checks.
allIDs := map[CheckID]bool{}
for _, check := range allChecks {
allIDs[check.ID] = true
}
for _, check := range allChecks {
for _, c := range check.Versions {
for _, override := range c.OverrideCheckIDs {
assert.Contains(t, allIDs, override, "check %s overrides non-existent check %s", check.ID, override)
}
}
}
}

View File

@ -18,6 +18,7 @@ package policy
import (
"fmt"
"sort"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -32,7 +33,7 @@ type Evaluator interface {
// checkRegistry provides a default implementation of an Evaluator.
type checkRegistry struct {
// The checks are a map of check_ID -> sorted slice of versioned checks, newest first
// The checks are a map policy version to a slice of checks registered for that version.
baselineChecks, restrictedChecks map[api.Version][]CheckPodFn
// maxVersion is the maximum version that is cached, guaranteed to be at least
// the max MinimumVersion of all registered checks.
@ -64,26 +65,29 @@ func (r *checkRegistry) EvaluatePod(lv api.LevelVersion, podMetadata *metav1.Obj
if r.maxVersion.Older(lv.Version) {
lv.Version = r.maxVersion
}
results := []CheckResult{}
for _, check := range r.baselineChecks[lv.Version] {
results = append(results, check(podMetadata, podSpec))
}
var checks []CheckPodFn
if lv.Level == api.LevelBaseline {
return results
checks = r.baselineChecks[lv.Version]
} else {
// includes non-overridden baseline checks
checks = r.restrictedChecks[lv.Version]
}
for _, check := range r.restrictedChecks[lv.Version] {
var results []CheckResult
for _, check := range checks {
results = append(results, check(podMetadata, podSpec))
}
return results
}
func validateChecks(checks []Check) error {
ids := map[string]bool{}
ids := map[CheckID]api.Level{}
for _, check := range checks {
if ids[check.ID] {
if _, ok := ids[check.ID]; ok {
return fmt.Errorf("multiple checks registered for ID %s", check.ID)
}
ids[check.ID] = true
ids[check.ID] = check.Level
if check.Level != api.LevelBaseline && check.Level != api.LevelRestricted {
return fmt.Errorf("check %s: invalid level %s", check.ID, check.Level)
}
@ -107,6 +111,23 @@ func validateChecks(checks []Check) error {
maxVersion = c.MinimumVersion
}
}
// Second pass to validate overrides.
for _, check := range checks {
for _, c := range check.Versions {
if len(c.OverrideCheckIDs) == 0 {
continue
}
if check.Level != api.LevelRestricted {
return fmt.Errorf("check %s: only restricted checks may set overrides", check.ID)
}
for _, override := range c.OverrideCheckIDs {
if overriddenLevel, ok := ids[override]; ok && overriddenLevel != api.LevelBaseline {
return fmt.Errorf("check %s: overrides %s check %s", check.ID, overriddenLevel, override)
}
}
}
}
return nil
}
@ -119,28 +140,87 @@ func populate(r *checkRegistry, validChecks []Check) {
}
}
var (
restrictedVersionedChecks = map[api.Version]map[CheckID]VersionedCheck{}
baselineVersionedChecks = map[api.Version]map[CheckID]VersionedCheck{}
baselineIDs, restrictedIDs []CheckID
)
for _, c := range validChecks {
if c.Level == api.LevelRestricted {
inflateVersions(c, r.restrictedChecks, r.maxVersion)
restrictedIDs = append(restrictedIDs, c.ID)
inflateVersions(c, restrictedVersionedChecks, r.maxVersion)
} else {
inflateVersions(c, r.baselineChecks, r.maxVersion)
}
baselineIDs = append(baselineIDs, c.ID)
inflateVersions(c, baselineVersionedChecks, r.maxVersion)
}
}
func inflateVersions(check Check, versions map[api.Version][]CheckPodFn, maxVersion api.Version) {
// Sort the IDs to maintain consistent error messages.
sort.Slice(restrictedIDs, func(i, j int) bool { return restrictedIDs[i] < restrictedIDs[j] })
sort.Slice(baselineIDs, func(i, j int) bool { return baselineIDs[i] < baselineIDs[j] })
orderedIDs := append(baselineIDs, restrictedIDs...) // Baseline checks first, then restricted.
for v := api.MajorMinorVersion(1, 0); v.Older(nextMinor(r.maxVersion)); v = nextMinor(v) {
// Aggregate all the overridden baseline check ids.
overrides := map[CheckID]bool{}
for _, c := range restrictedVersionedChecks[v] {
for _, override := range c.OverrideCheckIDs {
overrides[override] = true
}
}
// Add the filtered baseline checks to restricted.
for id, c := range baselineVersionedChecks[v] {
if overrides[id] {
continue // Overridden check: skip it.
}
if restrictedVersionedChecks[v] == nil {
restrictedVersionedChecks[v] = map[CheckID]VersionedCheck{}
}
restrictedVersionedChecks[v][id] = c
}
r.restrictedChecks[v] = mapCheckPodFns(restrictedVersionedChecks[v], orderedIDs)
r.baselineChecks[v] = mapCheckPodFns(baselineVersionedChecks[v], orderedIDs)
}
}
func inflateVersions(check Check, versions map[api.Version]map[CheckID]VersionedCheck, maxVersion api.Version) {
for i, c := range check.Versions {
var nextVersion api.Version
if i+1 < len(check.Versions) {
nextVersion = check.Versions[i+1].MinimumVersion
} else {
// Assumes only 1 Major version.
nextVersion = api.MajorMinorVersion(1, maxVersion.Minor()+1)
nextVersion = nextMinor(maxVersion)
}
// Iterate over all versions from the minimum of the current check, to the minimum of the
// next check, or the maxVersion++.
for v := c.MinimumVersion; v.Older(nextVersion); v = api.MajorMinorVersion(1, v.Minor()+1) {
versions[v] = append(versions[v], check.Versions[i].CheckPod)
for v := c.MinimumVersion; v.Older(nextVersion); v = nextMinor(v) {
if versions[v] == nil {
versions[v] = map[CheckID]VersionedCheck{}
}
versions[v][check.ID] = check.Versions[i]
}
}
}
// mapCheckPodFns converts the versioned check map to an ordered slice of CheckPodFn,
// using the order specified by orderedIDs. All checks must have a corresponding ID in orderedIDs.
func mapCheckPodFns(checks map[CheckID]VersionedCheck, orderedIDs []CheckID) []CheckPodFn {
fns := make([]CheckPodFn, 0, len(checks))
for _, id := range orderedIDs {
if check, ok := checks[id]; ok {
fns = append(fns, check.CheckPod)
}
}
return fns
}
// nextMinor increments the minor version
func nextMinor(v api.Version) api.Version {
if v.Latest() {
return v
}
return api.MajorMinorVersion(v.Major(), v.Minor()+1)
}

View File

@ -35,16 +35,18 @@ func TestCheckRegistry(t *testing.T) {
generateCheck("d", api.LevelBaseline, []string{"v1.11", "v1.15", "v1.20"}),
generateCheck("e", api.LevelRestricted, []string{"v1.0"}),
generateCheck("f", api.LevelRestricted, []string{"v1.12", "v1.16", "v1.21"}),
withOverrides(generateCheck("g", api.LevelRestricted, []string{"v1.10"}), []CheckID{"a"}),
withOverrides(generateCheck("h", api.LevelRestricted, []string{"v1.0"}), []CheckID{"b"}),
}
multiOverride := generateCheck("i", api.LevelRestricted, []string{"v1.10", "v1.21"})
multiOverride.Versions[0].OverrideCheckIDs = []CheckID{"c"}
multiOverride.Versions[1].OverrideCheckIDs = []CheckID{"d"}
checks = append(checks, multiOverride)
reg, err := NewEvaluator(checks)
require.NoError(t, err)
levelCases := []struct {
level api.Level
version string
expectedReasons []string
}{
levelCases := []registryTestCase{
{api.LevelPrivileged, "v1.0", nil},
{api.LevelPrivileged, "latest", nil},
{api.LevelBaseline, "v1.0", []string{"a:v1.0", "c:v1.0"}},
@ -53,29 +55,112 @@ func TestCheckRegistry(t *testing.T) {
{api.LevelBaseline, "v1.10", []string{"a:v1.0", "b:v1.10", "c:v1.10"}},
{api.LevelBaseline, "v1.11", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.11"}},
{api.LevelBaseline, "latest", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20"}},
{api.LevelRestricted, "v1.0", []string{"a:v1.0", "c:v1.0", "e:v1.0"}},
{api.LevelRestricted, "v1.4", []string{"a:v1.0", "c:v1.0", "e:v1.0"}},
{api.LevelRestricted, "v1.5", []string{"a:v1.0", "c:v1.5", "e:v1.0"}},
{api.LevelRestricted, "v1.10", []string{"a:v1.0", "b:v1.10", "c:v1.10", "e:v1.0"}},
{api.LevelRestricted, "v1.11", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.11", "e:v1.0"}},
{api.LevelRestricted, "latest", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20", "e:v1.0", "f:v1.21"}},
{api.LevelRestricted, "v1.10000", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20", "e:v1.0", "f:v1.21"}},
{api.LevelRestricted, "v1.0", []string{"a:v1.0", "c:v1.0", "e:v1.0", "h:v1.0"}},
{api.LevelRestricted, "v1.4", []string{"a:v1.0", "c:v1.0", "e:v1.0", "h:v1.0"}},
{api.LevelRestricted, "v1.5", []string{"a:v1.0", "c:v1.5", "e:v1.0", "h:v1.0"}},
{api.LevelRestricted, "v1.10", []string{"e:v1.0", "g:v1.10", "h:v1.0", "i:v1.10"}},
{api.LevelRestricted, "v1.11", []string{"d:v1.11", "e:v1.0", "g:v1.10", "h:v1.0", "i:v1.10"}},
{api.LevelRestricted, "latest", []string{"c:v1.10", "e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0", "i:v1.21"}},
{api.LevelRestricted, "v1.10000", []string{"c:v1.10", "e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0", "i:v1.21"}},
}
for _, test := range levelCases {
t.Run(fmt.Sprintf("%s:%s", test.level, test.version), func(t *testing.T) {
results := reg.EvaluatePod(api.LevelVersion{test.level, versionOrPanic(test.version)}, nil, nil)
test.Run(t, reg)
}
}
func TestCheckRegistry_NoBaseline(t *testing.T) {
checks := []Check{
generateCheck("e", api.LevelRestricted, []string{"v1.0"}),
generateCheck("f", api.LevelRestricted, []string{"v1.12", "v1.16", "v1.21"}),
withOverrides(generateCheck("g", api.LevelRestricted, []string{"v1.10"}), []CheckID{"a"}),
withOverrides(generateCheck("h", api.LevelRestricted, []string{"v1.0"}), []CheckID{"b"}),
}
reg, err := NewEvaluator(checks)
require.NoError(t, err)
levelCases := []registryTestCase{
{api.LevelPrivileged, "v1.0", nil},
{api.LevelPrivileged, "latest", nil},
{api.LevelBaseline, "v1.0", nil},
{api.LevelBaseline, "v1.10", nil},
{api.LevelBaseline, "latest", nil},
{api.LevelRestricted, "v1.0", []string{"e:v1.0", "h:v1.0"}},
{api.LevelRestricted, "v1.10", []string{"e:v1.0", "g:v1.10", "h:v1.0"}},
{api.LevelRestricted, "latest", []string{"e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0"}},
{api.LevelRestricted, "v1.10000", []string{"e:v1.0", "f:v1.21", "g:v1.10", "h:v1.0"}},
}
for _, test := range levelCases {
test.Run(t, reg)
}
}
func TestCheckRegistry_NoRestricted(t *testing.T) {
checks := []Check{
generateCheck("a", api.LevelBaseline, []string{"v1.0"}),
generateCheck("b", api.LevelBaseline, []string{"v1.10"}),
generateCheck("c", api.LevelBaseline, []string{"v1.0", "v1.5", "v1.10"}),
generateCheck("d", api.LevelBaseline, []string{"v1.11", "v1.15", "v1.20"}),
}
reg, err := NewEvaluator(checks)
require.NoError(t, err)
levelCases := []registryTestCase{
{api.LevelBaseline, "v1.0", []string{"a:v1.0", "c:v1.0"}},
{api.LevelBaseline, "v1.4", []string{"a:v1.0", "c:v1.0"}},
{api.LevelBaseline, "v1.5", []string{"a:v1.0", "c:v1.5"}},
{api.LevelBaseline, "v1.10", []string{"a:v1.0", "b:v1.10", "c:v1.10"}},
{api.LevelBaseline, "v1.11", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.11"}},
{api.LevelBaseline, "latest", []string{"a:v1.0", "b:v1.10", "c:v1.10", "d:v1.20"}},
}
for _, test := range levelCases {
test.Run(t, reg)
// Restricted results should be identical to baseline.
restrictedTest := test
restrictedTest.level = api.LevelRestricted
restrictedTest.Run(t, reg)
}
}
func TestCheckRegistry_Empty(t *testing.T) {
reg, err := NewEvaluator(nil)
require.NoError(t, err)
levelCases := []registryTestCase{
{api.LevelPrivileged, "latest", nil},
{api.LevelBaseline, "latest", nil},
{api.LevelRestricted, "latest", nil},
}
for _, test := range levelCases {
test.Run(t, reg)
// Restricted results should be identical to baseline.
restrictedTest := test
restrictedTest.level = api.LevelRestricted
restrictedTest.Run(t, reg)
}
}
type registryTestCase struct {
level api.Level
version string
expectedReasons []string
}
func (tc *registryTestCase) Run(t *testing.T, registry Evaluator) {
t.Run(fmt.Sprintf("%s:%s", tc.level, tc.version), func(t *testing.T) {
results := registry.EvaluatePod(api.LevelVersion{tc.level, versionOrPanic(tc.version)}, nil, nil)
// Set extract the ForbiddenReasons from the results.
var actualReasons []string
for _, result := range results {
actualReasons = append(actualReasons, result.ForbiddenReason)
}
assert.ElementsMatch(t, test.expectedReasons, actualReasons)
assert.Equal(t, tc.expectedReasons, actualReasons)
})
}
}
func generateCheck(id string, level api.Level, versions []string) Check {
func generateCheck(id CheckID, level api.Level, versions []string) Check {
c := Check{
ID: id,
Level: level,
@ -94,6 +179,13 @@ func generateCheck(id string, level api.Level, versions []string) Check {
return c
}
func withOverrides(c Check, overrides []CheckID) Check {
for i := range c.Versions {
c.Versions[i].OverrideCheckIDs = overrides
}
return c
}
func versionOrPanic(v string) api.Version {
ver, err := api.ParseVersion(v)
if err != nil {

View File

@ -99,7 +99,7 @@ var fixtureGenerators = map[fixtureKey]fixtureGenerator{}
type fixtureKey struct {
version api.Version
level api.Level
check string
check policy.CheckID
}
// fixtureGenerator holds generators for valid and invalid fixtures.
@ -184,7 +184,7 @@ func getFixtures(key fixtureKey) (fixtureData, error) {
fail: generator.generateFail(validPodForLevel.DeepCopy()),
}
if len(data.expectErrorSubstring) == 0 {
data.expectErrorSubstring = key.check
data.expectErrorSubstring = string(key.check)
}
if len(data.fail) == 0 {
return fixtureData{}, fmt.Errorf("generateFail for %#v must return at least one pod", key)

View File

@ -47,7 +47,7 @@ func ensureCapabilities(p *corev1.Pod) *corev1.Pod {
func init() {
fixtureData_1_0 := fixtureGenerator{
expectErrorSubstring: "non-default capabilities",
expectErrorSubstring: "capabilities",
generatePass: func(p *corev1.Pod) []*corev1.Pod {
// don't generate fixtures if minimal valid pod drops ALL
if p.Spec.Containers[0].SecurityContext != nil && p.Spec.Containers[0].SecurityContext.Capabilities != nil {

View File

@ -28,7 +28,7 @@ TODO: include field paths in reflect-based unit test
func init() {
fixtureData_1_0 := fixtureGenerator{
expectErrorSubstring: "hostPath volumes",
expectErrorSubstring: "hostPath",
generatePass: func(p *corev1.Pod) []*corev1.Pod {
// minimal valid pod already captures all valid combinations
return nil

View File

@ -82,10 +82,10 @@ func TestFixtures(t *testing.T) {
}
for i, pod := range checkData.pass {
expectedFiles.Insert(testFixtureFile(t, passDir, fmt.Sprintf("%s%d", strings.ToLower(checkID), i), pod))
expectedFiles.Insert(testFixtureFile(t, passDir, fmt.Sprintf("%s%d", strings.ToLower(string(checkID)), i), pod))
}
for i, pod := range checkData.fail {
expectedFiles.Insert(testFixtureFile(t, failDir, fmt.Sprintf("%s%d", strings.ToLower(checkID), i), pod))
expectedFiles.Insert(testFixtureFile(t, failDir, fmt.Sprintf("%s%d", strings.ToLower(string(checkID)), i), pod))
}
}
}

View File

@ -73,8 +73,8 @@ func toJSON(pod *corev1.Pod) string {
// checksForLevelAndVersion returns the set of check IDs that apply when evaluating the given level and version.
// checks are assumed to be well-formed and valid to pass to policy.NewEvaluator().
// level must be api.LevelRestricted or api.LevelBaseline
func checksForLevelAndVersion(checks []policy.Check, level api.Level, version api.Version) ([]string, error) {
retval := []string{}
func checksForLevelAndVersion(checks []policy.Check, level api.Level, version api.Version) ([]policy.CheckID, error) {
retval := []policy.CheckID{}
for _, check := range checks {
if !version.Older(check.Versions[0].MinimumVersion) && (level == check.Level || level == api.LevelRestricted) {
retval = append(retval, check.ID)
@ -318,7 +318,7 @@ func Run(t *testing.T, opts Options) {
t.Fatal(err)
}
t.Run(ns+"_pass_"+checkID, func(t *testing.T) {
t.Run(ns+"_pass_"+string(checkID), func(t *testing.T) {
for i, pod := range checkData.pass {
createPod(t, i, pod, true, "")
createController(t, i, pod, true, "")
@ -332,7 +332,7 @@ func Run(t *testing.T, opts Options) {
disabledRequiredFeatures = append(disabledRequiredFeatures, f)
}
}
t.Run(ns+"_fail_"+checkID, func(t *testing.T) {
t.Run(ns+"_fail_"+string(checkID), func(t *testing.T) {
if len(disabledRequiredFeatures) > 0 {
t.Skipf("features required for failure cases are disabled: %v", disabledRequiredFeatures)
}