diff --git a/pkg/security/podsecuritypolicy/BUILD b/pkg/security/podsecuritypolicy/BUILD index abfd148dc41..02668dc0486 100644 --- a/pkg/security/podsecuritypolicy/BUILD +++ b/pkg/security/podsecuritypolicy/BUILD @@ -49,8 +49,6 @@ go_test( "//staging/src/k8s.io/api/policy/v1beta1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/diff:go_default_library", - "//staging/src/k8s.io/apimachinery/pkg/util/validation/field:go_default_library", - "//vendor/github.com/davecgh/go-spew/spew:go_default_library", "//vendor/github.com/stretchr/testify/assert:go_default_library", "//vendor/github.com/stretchr/testify/require:go_default_library", ], diff --git a/pkg/security/podsecuritypolicy/provider.go b/pkg/security/podsecuritypolicy/provider.go index d347bcbf53d..9c120c23f80 100644 --- a/pkg/security/podsecuritypolicy/provider.go +++ b/pkg/security/podsecuritypolicy/provider.go @@ -59,10 +59,10 @@ func NewSimpleProvider(psp *policy.PodSecurityPolicy, namespace string, strategy }, nil } -// DefaultPodSecurityContext sets the default values of the required but not filled fields. -// It modifies the SecurityContext and annotations of the provided pod. Validation should be -// used after the context is defaulted to ensure it complies with the required restrictions. -func (s *simpleProvider) DefaultPodSecurityContext(pod *api.Pod) error { +// MutatePod sets the default values of the required but not filled fields. +// Validation should be used after the context is defaulted to ensure it +// complies with the required restrictions. +func (s *simpleProvider) MutatePod(pod *api.Pod) error { sc := securitycontext.NewPodSecurityContextMutator(pod.Spec.SecurityContext) if sc.SupplementalGroups() == nil { @@ -104,13 +104,25 @@ func (s *simpleProvider) DefaultPodSecurityContext(pod *api.Pod) error { pod.Spec.SecurityContext = sc.PodSecurityContext() + for i := range pod.Spec.InitContainers { + if err := s.mutateContainer(pod, &pod.Spec.InitContainers[i]); err != nil { + return err + } + } + + for i := range pod.Spec.Containers { + if err := s.mutateContainer(pod, &pod.Spec.Containers[i]); err != nil { + return err + } + } + return nil } -// DefaultContainerSecurityContext sets the default values of the required but not filled fields. +// mutateContainer sets the default values of the required but not filled fields. // It modifies the SecurityContext of the container and annotations of the pod. Validation should // be used after the context is defaulted to ensure it complies with the required restrictions. -func (s *simpleProvider) DefaultContainerSecurityContext(pod *api.Pod, container *api.Container) error { +func (s *simpleProvider) mutateContainer(pod *api.Pod, container *api.Container) error { sc := securitycontext.NewEffectiveContainerSecurityContextMutator( securitycontext.NewPodSecurityContextAccessor(pod.Spec.SecurityContext), securitycontext.NewContainerSecurityContextMutator(container.SecurityContext), @@ -282,11 +294,22 @@ func (s *simpleProvider) ValidatePod(pod *api.Pod) field.ErrorList { } } } + + fldPath := field.NewPath("spec", "initContainers") + for i := range pod.Spec.InitContainers { + allErrs = append(allErrs, s.validateContainer(pod, &pod.Spec.InitContainers[i], fldPath.Index(i))...) + } + + fldPath = field.NewPath("spec", "containers") + for i := range pod.Spec.Containers { + allErrs = append(allErrs, s.validateContainer(pod, &pod.Spec.Containers[i], fldPath.Index(i))...) + } + return allErrs } // Ensure a container's SecurityContext is in compliance with the given constraints -func (s *simpleProvider) ValidateContainer(pod *api.Pod, container *api.Container, containerPath *field.Path) field.ErrorList { +func (s *simpleProvider) validateContainer(pod *api.Pod, container *api.Container, containerPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} podSC := securitycontext.NewPodSecurityContextAccessor(pod.Spec.SecurityContext) diff --git a/pkg/security/podsecuritypolicy/provider_test.go b/pkg/security/podsecuritypolicy/provider_test.go index ad6b31ffcfc..8f7046b74e3 100644 --- a/pkg/security/podsecuritypolicy/provider_test.go +++ b/pkg/security/podsecuritypolicy/provider_test.go @@ -20,10 +20,8 @@ import ( "fmt" "reflect" "strconv" - "strings" "testing" - "github.com/davecgh/go-spew/spew" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,7 +29,6 @@ import ( policy "k8s.io/api/policy/v1beta1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/diff" - "k8s.io/apimachinery/pkg/util/validation/field" api "k8s.io/kubernetes/pkg/apis/core" k8s_api_v1 "k8s.io/kubernetes/pkg/apis/core/v1" "k8s.io/kubernetes/pkg/security/apparmor" @@ -41,7 +38,7 @@ import ( const defaultContainerName = "test-c" -func TestDefaultPodSecurityContextNonmutating(t *testing.T) { +func TestMutatePodNonmutating(t *testing.T) { // Create a pod with a security context that needs filling in createPod := func() *api.Pod { return &api.Pod{ @@ -86,26 +83,22 @@ func TestDefaultPodSecurityContextNonmutating(t *testing.T) { psp := createPSP() provider, err := NewSimpleProvider(psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Fatalf("unable to create provider %v", err) - } - err = provider.DefaultPodSecurityContext(pod) - if err != nil { - t.Fatalf("unable to create psc %v", err) - } + require.NoError(t, err, "unable to create provider") + err = provider.MutatePod(pod) + require.NoError(t, err, "unable to modify pod") // Creating the provider or the security context should not have mutated the psp or pod // since all the strategies were permissive if !reflect.DeepEqual(createPod(), pod) { diffs := diff.ObjectDiff(createPod(), pod) - t.Errorf("pod was mutated by DefaultPodSecurityContext. diff:\n%s", diffs) + t.Errorf("pod was mutated by MutatePod. diff:\n%s", diffs) } if !reflect.DeepEqual(createPSP(), psp) { - t.Error("psp was mutated by DefaultPodSecurityContext") + t.Error("psp was mutated by MutatePod") } } -func TestDefaultContainerSecurityContextNonmutating(t *testing.T) { +func TestMutateContainerNonmutating(t *testing.T) { untrue := false tests := []struct { security *api.SecurityContext @@ -134,7 +127,6 @@ func TestDefaultContainerSecurityContextNonmutating(t *testing.T) { Name: "psp-sa", Annotations: map[string]string{ seccomp.AllowedProfilesAnnotationKey: "*", - seccomp.DefaultProfileAnnotationKey: "foo", }, }, Spec: policy.PodSecurityPolicySpec{ @@ -162,27 +154,23 @@ func TestDefaultContainerSecurityContextNonmutating(t *testing.T) { psp := createPSP() provider, err := NewSimpleProvider(psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Fatalf("unable to create provider %v", err) - } - err = provider.DefaultContainerSecurityContext(pod, &pod.Spec.Containers[0]) - if err != nil { - t.Fatalf("unable to create container security context %v", err) - } + require.NoError(t, err, "unable to create provider") + err = provider.MutatePod(pod) + require.NoError(t, err, "unable to modify pod") // Creating the provider or the security context should not have mutated the psp or pod // since all the strategies were permissive if !reflect.DeepEqual(createPod(), pod) { diffs := diff.ObjectDiff(createPod(), pod) - t.Errorf("pod was mutated by DefaultContainerSecurityContext. diff:\n%s", diffs) + t.Errorf("pod was mutated. diff:\n%s", diffs) } if !reflect.DeepEqual(createPSP(), psp) { - t.Error("psp was mutated by DefaultContainerSecurityContext") + t.Error("psp was mutated") } } } -func TestValidatePodSecurityContextFailures(t *testing.T) { +func TestValidatePodFailures(t *testing.T) { failHostNetworkPod := defaultPod() failHostNetworkPod.Spec.SecurityContext.HostNetwork = true @@ -445,19 +433,14 @@ func TestValidatePodSecurityContextFailures(t *testing.T) { expectedError: "Flexvolume driver is not allowed to be used", }, } - for k, v := range errorCases { - provider, err := NewSimpleProvider(v.psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Fatalf("unable to create provider %v", err) - } - errs := provider.ValidatePod(v.pod) - if len(errs) == 0 { - t.Errorf("%s expected validation failure but did not receive errors", k) - continue - } - if !strings.Contains(errs[0].Error(), v.expectedError) { - t.Errorf("%s received unexpected error %v", k, errs) - } + for name, test := range errorCases { + t.Run(name, func(t *testing.T) { + provider, err := NewSimpleProvider(test.psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "unable to create provider") + errs := provider.ValidatePod(test.pod) + require.NotEmpty(t, errs, "expected validation failure but did not receive errors") + assert.Contains(t, errs[0].Error(), test.expectedError, "received unexpected error") + }) } } @@ -504,6 +487,7 @@ func TestValidateContainerFailures(t *testing.T) { }, } failSELinuxPod := defaultPod() + failSELinuxPod.Spec.SecurityContext.SELinuxOptions = &api.SELinuxOptions{Level: "foo"} failSELinuxPod.Spec.Containers[0].SecurityContext.SELinuxOptions = &api.SELinuxOptions{ Level: "bar", } @@ -619,23 +603,18 @@ func TestValidateContainerFailures(t *testing.T) { }, } - for k, v := range errorCases { - provider, err := NewSimpleProvider(v.psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Fatalf("unable to create provider %v", err) - } - errs := provider.ValidateContainer(v.pod, &v.pod.Spec.Containers[0], field.NewPath("")) - if len(errs) == 0 { - t.Errorf("%s expected validation failure but did not receive errors", k) - continue - } - if !strings.Contains(errs[0].Error(), v.expectedError) { - t.Errorf("%s received unexpected error %v\nexpected: %s", k, errs, v.expectedError) - } + for name, test := range errorCases { + t.Run(name, func(t *testing.T) { + provider, err := NewSimpleProvider(test.psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "unable to create provider") + errs := provider.ValidatePod(test.pod) + require.NotEmpty(t, errs, "expected validation failure but did not receive errors") + assert.Contains(t, errs[0].Error(), test.expectedError, "unexpected error") + }) } } -func TestValidatePodSecurityContextSuccess(t *testing.T) { +func TestValidatePodSuccess(t *testing.T) { hostNetworkPSP := defaultPSP() hostNetworkPSP.Spec.HostNetwork = true hostNetworkPod := defaultPod() @@ -908,16 +887,13 @@ func TestValidatePodSecurityContextSuccess(t *testing.T) { }, } - for k, v := range successCases { - provider, err := NewSimpleProvider(v.psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Fatalf("unable to create provider %v", err) - } - errs := provider.ValidatePod(v.pod) - if len(errs) != 0 { - t.Errorf("%s expected validation pass but received errors %v", k, errs) - continue - } + for name, test := range successCases { + t.Run(name, func(t *testing.T) { + provider, err := NewSimpleProvider(test.psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "unable to create provider") + errs := provider.ValidatePod(test.pod) + assert.Empty(t, errs, "expected validation pass but received errors") + }) } } @@ -941,6 +917,7 @@ func TestValidateContainerSuccess(t *testing.T) { }, } seLinuxPod := defaultPod() + seLinuxPod.Spec.SecurityContext.SELinuxOptions = &api.SELinuxOptions{Level: "foo"} seLinuxPod.Spec.Containers[0].SecurityContext.SELinuxOptions = &api.SELinuxOptions{ Level: "foo", } @@ -1007,6 +984,7 @@ func TestValidateContainerSuccess(t *testing.T) { seccompPod := defaultPod() seccompPod.Annotations = map[string]string{ + api.SeccompPodAnnotationKey: "foo", api.SeccompContainerAnnotationKeyPrefix + seccompPod.Spec.Containers[0].Name: "foo", } @@ -1073,16 +1051,13 @@ func TestValidateContainerSuccess(t *testing.T) { }, } - for k, v := range successCases { - provider, err := NewSimpleProvider(v.psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Fatalf("unable to create provider %v", err) - } - errs := provider.ValidateContainer(v.pod, &v.pod.Spec.Containers[0], field.NewPath("")) - if len(errs) != 0 { - t.Errorf("%s expected validation pass but received errors %v\n%s", k, errs, spew.Sdump(v.pod.ObjectMeta)) - continue - } + for name, test := range successCases { + t.Run(name, func(t *testing.T) { + provider, err := NewSimpleProvider(test.psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "unable to create provider") + errs := provider.ValidatePod(test.pod) + assert.Empty(t, errs, "expected validation pass but received errors") + }) } } @@ -1140,29 +1115,21 @@ func TestGenerateContainerSecurityContextReadOnlyRootFS(t *testing.T) { }, } - for k, v := range tests { - provider, err := NewSimpleProvider(v.psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Errorf("%s unable to create provider %v", k, err) - continue - } - err = provider.DefaultContainerSecurityContext(v.pod, &v.pod.Spec.Containers[0]) - if err != nil { - t.Errorf("%s unable to create container security context %v", k, err) - continue - } - - sc := v.pod.Spec.Containers[0].SecurityContext - if v.expected == nil && sc.ReadOnlyRootFilesystem != nil { - t.Errorf("%s expected a nil ReadOnlyRootFilesystem but got %t", k, *sc.ReadOnlyRootFilesystem) - } - if v.expected != nil && sc.ReadOnlyRootFilesystem == nil { - t.Errorf("%s expected a non nil ReadOnlyRootFilesystem but received nil", k) - } - if v.expected != nil && sc.ReadOnlyRootFilesystem != nil && (*v.expected != *sc.ReadOnlyRootFilesystem) { - t.Errorf("%s expected a non nil ReadOnlyRootFilesystem set to %t but got %t", k, *v.expected, *sc.ReadOnlyRootFilesystem) - } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + provider, err := NewSimpleProvider(test.psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "unable to create provider") + err = provider.MutatePod(test.pod) + require.NoError(t, err, "unable to mutate container") + sc := test.pod.Spec.Containers[0].SecurityContext + if test.expected == nil { + assert.Nil(t, sc.ReadOnlyRootFilesystem, "expected a nil ReadOnlyRootFilesystem") + } else { + require.NotNil(t, sc.ReadOnlyRootFilesystem, "expected a non nil ReadOnlyRootFilesystem") + assert.Equal(t, *test.expected, *sc.ReadOnlyRootFilesystem) + } + }) } } @@ -1252,55 +1219,42 @@ func TestValidateAllowedVolumes(t *testing.T) { // reflectively create the volume source fieldVal := val.Type().Field(i) - volumeSource := api.VolumeSource{} - volumeSourceVolume := reflect.New(fieldVal.Type.Elem()) + t.Run(fieldVal.Name, func(t *testing.T) { + volumeSource := api.VolumeSource{} + volumeSourceVolume := reflect.New(fieldVal.Type.Elem()) - reflect.ValueOf(&volumeSource).Elem().FieldByName(fieldVal.Name).Set(volumeSourceVolume) - volume := api.Volume{VolumeSource: volumeSource} + reflect.ValueOf(&volumeSource).Elem().FieldByName(fieldVal.Name).Set(volumeSourceVolume) + volume := api.Volume{VolumeSource: volumeSource} - // sanity check before moving on - fsType, err := psputil.GetVolumeFSType(volume) - if err != nil { - t.Errorf("error getting FSType for %s: %s", fieldVal.Name, err.Error()) - continue - } + // sanity check before moving on + fsType, err := psputil.GetVolumeFSType(volume) + require.NoError(t, err, "error getting FSType") - // add the volume to the pod - pod := defaultPod() - pod.Spec.Volumes = []api.Volume{volume} + // add the volume to the pod + pod := defaultPod() + pod.Spec.Volumes = []api.Volume{volume} - // create a PSP that allows no volumes - psp := defaultPSP() + // create a PSP that allows no volumes + psp := defaultPSP() - provider, err := NewSimpleProvider(psp, "namespace", NewSimpleStrategyFactory()) - if err != nil { - t.Errorf("error creating provider for %s: %s", fieldVal.Name, err.Error()) - continue - } + provider, err := NewSimpleProvider(psp, "namespace", NewSimpleStrategyFactory()) + require.NoError(t, err, "error creating provider") - // expect a denial for this PSP and test the error message to ensure it's related to the volumesource - errs := provider.ValidatePod(pod) - if len(errs) != 1 { - t.Errorf("expected exactly 1 error for %s but got %v", fieldVal.Name, errs) - } else { - if !strings.Contains(errs.ToAggregate().Error(), fmt.Sprintf("%s volumes are not allowed to be used", fsType)) { - t.Errorf("did not find the expected error, received: %v", errs) - } - } + // expect a denial for this PSP and test the error message to ensure it's related to the volumesource + errs := provider.ValidatePod(pod) + require.Len(t, errs, 1, "expected exactly 1 error") + assert.Contains(t, errs.ToAggregate().Error(), fmt.Sprintf("%s volumes are not allowed to be used", fsType), "did not find the expected error") - // now add the fstype directly to the psp and it should validate - psp.Spec.Volumes = []policy.FSType{fsType} - errs = provider.ValidatePod(pod) - if len(errs) != 0 { - t.Errorf("directly allowing volume expected no errors for %s but got %v", fieldVal.Name, errs) - } + // now add the fstype directly to the psp and it should validate + psp.Spec.Volumes = []policy.FSType{fsType} + errs = provider.ValidatePod(pod) + assert.Empty(t, errs, "directly allowing volume expected no errors") - // now change the psp to allow any volumes and the pod should still validate - psp.Spec.Volumes = []policy.FSType{policy.All} - errs = provider.ValidatePod(pod) - if len(errs) != 0 { - t.Errorf("wildcard volume expected no errors for %s but got %v", fieldVal.Name, errs) - } + // now change the psp to allow any volumes and the pod should still validate + psp.Spec.Volumes = []policy.FSType{policy.All} + errs = provider.ValidatePod(pod) + assert.Empty(t, errs, "wildcard volume expected no errors") + }) } } @@ -1351,10 +1305,10 @@ func TestAllowPrivilegeEscalation(t *testing.T) { provider, err := NewSimpleProvider(psp, "namespace", NewSimpleStrategyFactory()) require.NoError(t, err) - err = provider.DefaultContainerSecurityContext(pod, &pod.Spec.Containers[0]) + err = provider.MutatePod(pod) require.NoError(t, err) - errs := provider.ValidateContainer(pod, &pod.Spec.Containers[0], field.NewPath("")) + errs := provider.ValidatePod(pod) if test.expectErr { assert.NotEmpty(t, errs, "expected validation error") } else { diff --git a/pkg/security/podsecuritypolicy/types.go b/pkg/security/podsecuritypolicy/types.go index 8caba94dbaf..ea4694ae143 100644 --- a/pkg/security/podsecuritypolicy/types.go +++ b/pkg/security/podsecuritypolicy/types.go @@ -32,16 +32,12 @@ import ( // Provider provides the implementation to generate a new security // context based on constraints or validate an existing security context against constraints. type Provider interface { - // DefaultPodSecurityContext sets the default values of the required but not filled fields. - // It modifies the SecurityContext and annotations of the provided pod. - DefaultPodSecurityContext(pod *api.Pod) error - // DefaultContainerSecurityContext sets the default values of the required but not filled fields. - // It modifies the SecurityContext of the container and annotations of the pod. - DefaultContainerSecurityContext(pod *api.Pod, container *api.Container) error - // Ensure a pod is in compliance with the given constraints. + // MutatePod sets the default values of the required but not filled fields of the pod and all + // containers in the pod. + MutatePod(pod *api.Pod) error + // ValidatePod ensures a pod and all its containers are in compliance with the given constraints. + // ValidatePod MUST NOT mutate the pod. ValidatePod(pod *api.Pod) field.ErrorList - // Ensure a container's SecurityContext is in compliance with the given constraints. - ValidateContainer(pod *api.Pod, container *api.Container, containerPath *field.Path) field.ErrorList // Get the name of the PSP that this provider was initialized with. GetPSPName() string } diff --git a/plugin/pkg/admission/security/podsecuritypolicy/admission.go b/plugin/pkg/admission/security/podsecuritypolicy/admission.go index 4f28229d57d..6a0752272de 100644 --- a/plugin/pkg/admission/security/podsecuritypolicy/admission.go +++ b/plugin/pkg/admission/security/podsecuritypolicy/admission.go @@ -306,35 +306,14 @@ func (c *PodSecurityPolicyPlugin) computeSecurityContext(a admission.Attributes, func assignSecurityContext(provider psp.Provider, pod *api.Pod) field.ErrorList { errs := field.ErrorList{} - err := provider.DefaultPodSecurityContext(pod) - if err != nil { - errs = append(errs, field.Invalid(field.NewPath("spec", "securityContext"), pod.Spec.SecurityContext, err.Error())) + if err := provider.MutatePod(pod); err != nil { + // TODO(tallclair): MutatePod should return a field.ErrorList + errs = append(errs, field.Invalid(field.NewPath(""), pod, err.Error())) } errs = append(errs, provider.ValidatePod(pod)...) - for i := range pod.Spec.InitContainers { - err := provider.DefaultContainerSecurityContext(pod, &pod.Spec.InitContainers[i]) - if err != nil { - errs = append(errs, field.Invalid(field.NewPath("spec", "initContainers").Index(i).Child("securityContext"), "", err.Error())) - continue - } - errs = append(errs, provider.ValidateContainer(pod, &pod.Spec.InitContainers[i], field.NewPath("spec", "initContainers").Index(i))...) - } - - for i := range pod.Spec.Containers { - err := provider.DefaultContainerSecurityContext(pod, &pod.Spec.Containers[i]) - if err != nil { - errs = append(errs, field.Invalid(field.NewPath("spec", "containers").Index(i).Child("securityContext"), "", err.Error())) - continue - } - errs = append(errs, provider.ValidateContainer(pod, &pod.Spec.Containers[i], field.NewPath("spec", "containers").Index(i))...) - } - - if len(errs) > 0 { - return errs - } - return nil + return errs } // createProvidersFromPolicies creates providers from the constraints supplied.