diff --git a/pkg/security/podsecuritypolicy/BUILD b/pkg/security/podsecuritypolicy/BUILD index 02668dc0486..b4031c320d5 100644 --- a/pkg/security/podsecuritypolicy/BUILD +++ b/pkg/security/podsecuritypolicy/BUILD @@ -51,6 +51,7 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/util/diff:go_default_library", "//vendor/github.com/stretchr/testify/assert:go_default_library", "//vendor/github.com/stretchr/testify/require:go_default_library", + "//vendor/k8s.io/utils/pointer:go_default_library", ], ) diff --git a/pkg/security/podsecuritypolicy/provider.go b/pkg/security/podsecuritypolicy/provider.go index 9c120c23f80..b67014dd70e 100644 --- a/pkg/security/podsecuritypolicy/provider.go +++ b/pkg/security/podsecuritypolicy/provider.go @@ -104,6 +104,10 @@ func (s *simpleProvider) MutatePod(pod *api.Pod) error { pod.Spec.SecurityContext = sc.PodSecurityContext() + if s.psp.Spec.RuntimeClass != nil && pod.Spec.RuntimeClassName == nil { + pod.Spec.RuntimeClassName = s.psp.Spec.RuntimeClass.DefaultRuntimeClassName + } + for i := range pod.Spec.InitContainers { if err := s.mutateContainer(pod, &pod.Spec.InitContainers[i]); err != nil { return err @@ -295,6 +299,10 @@ func (s *simpleProvider) ValidatePod(pod *api.Pod) field.ErrorList { } } + if s.psp.Spec.RuntimeClass != nil { + allErrs = append(allErrs, validateRuntimeClassName(pod.Spec.RuntimeClassName, s.psp.Spec.RuntimeClass.AllowedRuntimeClassNames)...) + } + fldPath := field.NewPath("spec", "initContainers") for i := range pod.Spec.InitContainers { allErrs = append(allErrs, s.validateContainer(pod, &pod.Spec.InitContainers[i], fldPath.Index(i))...) @@ -413,3 +421,20 @@ func hostPortRangesToString(ranges []policy.HostPortRange) string { } return formattedString } + +// validates that the actual RuntimeClassName is contained in the list of valid names. +func validateRuntimeClassName(actual *string, validNames []string) field.ErrorList { + if actual == nil { + return nil // An unset RuntimeClassName is always allowed. + } + + for _, valid := range validNames { + if valid == policy.AllowAllRuntimeClassNames { + return nil + } + if *actual == valid { + return nil + } + } + return field.ErrorList{field.Invalid(field.NewPath("spec", "runtimeClassName"), *actual, "")} +} diff --git a/pkg/security/podsecuritypolicy/provider_test.go b/pkg/security/podsecuritypolicy/provider_test.go index 8f7046b74e3..cfac1cba442 100644 --- a/pkg/security/podsecuritypolicy/provider_test.go +++ b/pkg/security/podsecuritypolicy/provider_test.go @@ -34,6 +34,7 @@ import ( "k8s.io/kubernetes/pkg/security/apparmor" "k8s.io/kubernetes/pkg/security/podsecuritypolicy/seccomp" psputil "k8s.io/kubernetes/pkg/security/podsecuritypolicy/util" + "k8s.io/utils/pointer" ) const defaultContainerName = "test-c" @@ -1134,10 +1135,14 @@ func TestGenerateContainerSecurityContextReadOnlyRootFS(t *testing.T) { } func defaultPSP() *policy.PodSecurityPolicy { + return defaultNamedPSP("psp-sa") +} + +func defaultNamedPSP(name string) *policy.PodSecurityPolicy { allowPrivilegeEscalation := true return &policy.PodSecurityPolicy{ ObjectMeta: metav1.ObjectMeta{ - Name: "psp-sa", + Name: name, Annotations: map[string]string{}, }, Spec: policy.PodSecurityPolicySpec{ @@ -1259,7 +1264,7 @@ func TestValidateAllowedVolumes(t *testing.T) { } func TestAllowPrivilegeEscalation(t *testing.T) { - ptr := func(b bool) *bool { return &b } + ptr := pointer.BoolPtr tests := []struct { pspAPE bool // PSP AllowPrivilegeEscalation pspDAPE *bool // PSP DefaultAllowPrivilegeEscalation @@ -1319,3 +1324,129 @@ func TestAllowPrivilegeEscalation(t *testing.T) { }) } } + +func TestDefaultRuntimeClassName(t *testing.T) { + const ( + defaultedName = "foo" + presetName = "tim" + ) + + noRCS := defaultNamedPSP("nil-strategy") + emptyRCS := defaultNamedPSP("empty-strategy") + emptyRCS.Spec.RuntimeClass = &policy.RuntimeClassStrategyOptions{} + noDefaultRCS := defaultNamedPSP("no-default") + noDefaultRCS.Spec.RuntimeClass = &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"foo", "bar"}, + } + defaultRCS := defaultNamedPSP("defaulting") + defaultRCS.Spec.RuntimeClass = &policy.RuntimeClassStrategyOptions{ + DefaultRuntimeClassName: pointer.StringPtr(defaultedName), + } + + noRCPod := defaultPod() + noRCPod.Name = "no-runtimeclass" + rcPod := defaultPod() + rcPod.Name = "preset-runtimeclass" + rcPod.Spec.RuntimeClassName = pointer.StringPtr(presetName) + + type testcase struct { + psp *policy.PodSecurityPolicy + pod *api.Pod + expectedRuntimeClassName *string + } + tests := []testcase{{ + psp: defaultRCS, + pod: noRCPod, + expectedRuntimeClassName: pointer.StringPtr(defaultedName), + }} + // Non-defaulting no-preset cases + for _, psp := range []*policy.PodSecurityPolicy{noRCS, emptyRCS, noDefaultRCS} { + tests = append(tests, testcase{psp, noRCPod, nil}) + } + // Non-defaulting preset cases + for _, psp := range []*policy.PodSecurityPolicy{noRCS, emptyRCS, noDefaultRCS, defaultRCS} { + tests = append(tests, testcase{psp, rcPod, pointer.StringPtr(presetName)}) + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s-psp %s-pod", test.psp.Name, test.pod.Name), func(t *testing.T) { + provider, err := NewSimpleProvider(test.psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "error creating provider") + + actualPod := test.pod.DeepCopy() + require.NoError(t, provider.MutatePod(actualPod)) + + expectedPod := test.pod.DeepCopy() + expectedPod.Spec.RuntimeClassName = test.expectedRuntimeClassName + assert.Equal(t, expectedPod, actualPod) + }) + } +} + +func TestAllowedRuntimeClassNames(t *testing.T) { + const ( + goodName = "good" + ) + + noRCPod := defaultPod() + noRCPod.Name = "no-runtimeclass" + rcPod := defaultPod() + rcPod.Name = "good-runtimeclass" + rcPod.Spec.RuntimeClassName = pointer.StringPtr(goodName) + otherPod := defaultPod() + otherPod.Name = "bad-runtimeclass" + otherPod.Spec.RuntimeClassName = pointer.StringPtr("bad") + allPods := []*api.Pod{noRCPod, rcPod, otherPod} + + type testcase struct { + name string + strategy *policy.RuntimeClassStrategyOptions + validPods []*api.Pod + invalidPods []*api.Pod + } + tests := []testcase{{ + name: "nil-strategy", + validPods: allPods, + }, { + name: "empty-strategy", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{}, + }, + validPods: []*api.Pod{noRCPod}, + invalidPods: []*api.Pod{rcPod, otherPod}, + }, { + name: "allow-all-strategy", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{"*"}, + DefaultRuntimeClassName: pointer.StringPtr("foo"), + }, + validPods: allPods, + }, { + name: "named-allowed", + strategy: &policy.RuntimeClassStrategyOptions{ + AllowedRuntimeClassNames: []string{goodName}, + }, + validPods: []*api.Pod{rcPod}, + invalidPods: []*api.Pod{otherPod}, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + psp := defaultNamedPSP(test.name) + psp.Spec.RuntimeClass = test.strategy + provider, err := NewSimpleProvider(psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "error creating provider") + + for _, pod := range test.validPods { + copy := pod.DeepCopy() + assert.NoError(t, provider.ValidatePod(copy).ToAggregate(), "expected valid pod %s", pod.Name) + assert.Equal(t, pod, copy, "validate should not mutate!") + } + for _, pod := range test.invalidPods { + copy := pod.DeepCopy() + assert.Error(t, provider.ValidatePod(copy).ToAggregate(), "expected invalid pod %s", pod.Name) + assert.Equal(t, pod, copy, "validate should not mutate!") + } + }) + } +}