diff --git a/staging/src/k8s.io/apiserver/pkg/audit/context.go b/staging/src/k8s.io/apiserver/pkg/audit/context.go index eada7add9a8..1167f67f210 100644 --- a/staging/src/k8s.io/apiserver/pkg/audit/context.go +++ b/staging/src/k8s.io/apiserver/pkg/audit/context.go @@ -18,9 +18,11 @@ package audit import ( "context" + "sync" auditinternal "k8s.io/apiserver/pkg/apis/audit" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/klog/v2" ) // The key type is unexported to prevent collisions @@ -37,6 +39,9 @@ const ( // auditKey is the context key for storing the audit event that is being // captured and the evaluated policy that applies to the given request. auditKey + + // auditAnnotationsMutexKey is the context key for the audit annotations mutex. + auditAnnotationsMutexKey ) // annotations = *[]annotation instead of a map to preserve order of insertions @@ -54,6 +59,7 @@ func WithAuditAnnotations(parent context.Context) context.Context { if _, ok := parent.Value(auditAnnotationsKey).(*[]annotation); ok { return parent } + parent = withAuditAnnotationsMutex(parent) var annotations []annotation // avoid allocations until we actually need it return genericapirequest.WithValue(parent, auditAnnotationsKey, &annotations) @@ -67,6 +73,15 @@ func WithAuditAnnotations(parent context.Context) context.Context { // Handlers that are unaware of their position in the overall request flow should // prefer AddAuditAnnotation over LogAnnotation to avoid dropping annotations. func AddAuditAnnotation(ctx context.Context, key, value string) { + mutex, ok := auditAnnotationsMutex(ctx) + if !ok { + klog.Errorf("Attempted to add audit annotation from unsupported request chain: %q=%q", key, value) + return + } + + mutex.Lock() + defer mutex.Unlock() + // use the audit event directly if we have it if ae := AuditEventFrom(ctx); ae != nil { LogAnnotation(ae, key, value) @@ -83,19 +98,31 @@ func AddAuditAnnotation(ctx context.Context, key, value string) { // This is private to prevent reads/write to the slice from outside of this package. // The audit event should be directly read to get access to the annotations. -func auditAnnotationsFrom(ctx context.Context) []annotation { - annotations, ok := ctx.Value(auditAnnotationsKey).(*[]annotation) +func addAuditAnnotationsFrom(ctx context.Context, ev *auditinternal.Event) { + mutex, ok := auditAnnotationsMutex(ctx) if !ok { - return nil // adding audit annotation is not supported at this call site + klog.Errorf("Attempted to copy audit annotations from unsupported request chain") + return } - return *annotations + mutex.Lock() + defer mutex.Unlock() + + annotations, ok := ctx.Value(auditAnnotationsKey).(*[]annotation) + if !ok { + return // no annotations to copy + } + + for _, kv := range *annotations { + LogAnnotation(ev, kv.key, kv.value) + } } // WithAuditContext returns a new context that stores the pair of the audit // configuration object that applies to the given request and // the audit event that is going to be written to the API audit log. func WithAuditContext(parent context.Context, ev *AuditContext) context.Context { + parent = withAuditAnnotationsMutex(parent) return genericapirequest.WithValue(parent, auditKey, ev) } @@ -114,3 +141,18 @@ func AuditContextFrom(ctx context.Context) *AuditContext { ev, _ := ctx.Value(auditKey).(*AuditContext) return ev } + +// WithAuditAnnotationMutex adds a mutex for guarding context.AddAuditAnnotation. +func withAuditAnnotationsMutex(parent context.Context) context.Context { + if _, ok := parent.Value(auditAnnotationsMutexKey).(*sync.Mutex); ok { + return parent + } + var mutex sync.Mutex + return genericapirequest.WithValue(parent, auditAnnotationsMutexKey, &mutex) +} + +// AuditAnnotationsMutex returns the audit annotations mutex from the context. +func auditAnnotationsMutex(ctx context.Context) (*sync.Mutex, bool) { + mutex, ok := ctx.Value(auditAnnotationsMutexKey).(*sync.Mutex) + return mutex, ok +} diff --git a/staging/src/k8s.io/apiserver/pkg/audit/context_test.go b/staging/src/k8s.io/apiserver/pkg/audit/context_test.go new file mode 100644 index 00000000000..a593f90fe48 --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/audit/context_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package audit + +import ( + "context" + "fmt" + "sync" + "testing" + + auditinternal "k8s.io/apiserver/pkg/apis/audit" + + "github.com/stretchr/testify/assert" +) + +func TestAddAuditAnnotation(t *testing.T) { + const ( + annotationKeyTemplate = "test-annotation-%d" + annotationValue = "test-annotation-value" + numAnnotations = 10 + ) + + expectAnnotations := func(t *testing.T, annotations map[string]string) { + assert.Len(t, annotations, numAnnotations) + } + noopValidator := func(_ *testing.T, _ context.Context) {} + preEventValidator := func(t *testing.T, ctx context.Context) { + ev := auditinternal.Event{ + Level: auditinternal.LevelMetadata, + } + addAuditAnnotationsFrom(ctx, &ev) + expectAnnotations(t, ev.Annotations) + } + postEventValidator := func(t *testing.T, ctx context.Context) { + ev := AuditEventFrom(ctx) + expectAnnotations(t, ev.Annotations) + } + postEventEmptyValidator := func(t *testing.T, ctx context.Context) { + ev := AuditEventFrom(ctx) + assert.Empty(t, ev.Annotations) + } + + tests := []struct { + description string + ctx context.Context + validator func(t *testing.T, ctx context.Context) + }{{ + description: "no audit", + ctx: context.Background(), + validator: noopValidator, + }, { + description: "no annotations context", + ctx: WithAuditContext(context.Background(), newAuditContext(auditinternal.LevelMetadata)), + validator: postEventValidator, + }, { + description: "no audit context", + ctx: WithAuditAnnotations(context.Background()), + validator: preEventValidator, + }, { + description: "both contexts metadata level", + ctx: WithAuditContext(WithAuditAnnotations(context.Background()), newAuditContext(auditinternal.LevelMetadata)), + validator: postEventValidator, + }, { + description: "both contexts none level", + ctx: WithAuditContext(WithAuditAnnotations(context.Background()), newAuditContext(auditinternal.LevelNone)), + validator: postEventEmptyValidator, + }} + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(numAnnotations) + for i := 0; i < numAnnotations; i++ { + go func(i int) { + AddAuditAnnotation(test.ctx, fmt.Sprintf(annotationKeyTemplate, i), annotationValue) + wg.Done() + }(i) + } + wg.Wait() + + test.validator(t, test.ctx) + }) + } +} + +func newAuditContext(l auditinternal.Level) *AuditContext { + return &AuditContext{ + Event: &auditinternal.Event{ + Level: l, + }, + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/audit/request.go b/staging/src/k8s.io/apiserver/pkg/audit/request.go index f48566576bc..19bb9993a39 100644 --- a/staging/src/k8s.io/apiserver/pkg/audit/request.go +++ b/staging/src/k8s.io/apiserver/pkg/audit/request.go @@ -87,9 +87,7 @@ func NewEventFromRequest(req *http.Request, requestReceivedTimestamp time.Time, } } - for _, kv := range auditAnnotationsFrom(req.Context()) { - LogAnnotation(ev, kv.key, kv.value) - } + addAuditAnnotationsFrom(req.Context(), ev) return ev, nil }