Merge pull request #111095 from tallclair/audit-single-context

Audit single context
This commit is contained in:
Kubernetes Prow Robot 2022-10-26 23:56:32 -07:00 committed by GitHub
commit 6e31c6531f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 286 additions and 472 deletions

View File

@ -143,7 +143,9 @@ func TestWithAudit(t *testing.T) {
for tcName, tc := range testCases { for tcName, tc := range testCases {
var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles} var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles}
ae := &auditinternal.Event{Level: auditinternal.LevelMetadata} ae := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx := audit.WithAuditContext(context.Background(), &audit.AuditContext{Event: ae}) ctx := audit.WithAuditContext(context.Background())
ac := audit.AuditContextFrom(ctx)
ac.Event = ae
auditHandler := WithAudit(handler) auditHandler := WithAudit(handler)
a := attributes() a := attributes()
@ -184,8 +186,9 @@ func TestWithAuditConcurrency(t *testing.T) {
"plugin.example.com/qux": "qux", "plugin.example.com/qux": "qux",
} }
var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true} var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true}
ae := &auditinternal.Event{Level: auditinternal.LevelMetadata} ctx := audit.WithAuditContext(context.Background())
ctx := audit.WithAuditContext(context.Background(), &audit.AuditContext{Event: ae}) ac := audit.AuditContextFrom(ctx)
ac.Event = &auditinternal.Event{Level: auditinternal.LevelMetadata}
auditHandler := WithAudit(handler) auditHandler := WithAudit(handler)
a := attributes() a := attributes()

View File

@ -20,6 +20,7 @@ import (
"context" "context"
"sync" "sync"
"k8s.io/apimachinery/pkg/types"
auditinternal "k8s.io/apiserver/pkg/apis/audit" auditinternal "k8s.io/apiserver/pkg/apis/audit"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request" genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2" "k8s.io/klog/v2"
@ -28,38 +29,31 @@ import (
// The key type is unexported to prevent collisions // The key type is unexported to prevent collisions
type key int type key int
const ( // auditKey is the context key for storing the audit context that is being
// auditAnnotationsKey is the context key for the audit annotations. // captured and the evaluated policy that applies to the given request.
// TODO: consolidate all audit info under the AuditContext, rather than storing 3 separate keys. const auditKey key = iota
auditAnnotationsKey key = iota
// auditKey is the context key for storing the audit event that is being // AuditContext holds the information for constructing the audit events for the current request.
// captured and the evaluated policy that applies to the given request. type AuditContext struct {
auditKey // RequestAuditConfig is the audit configuration that applies to the request
RequestAuditConfig RequestAuditConfig
// auditAnnotationsMutexKey is the context key for the audit annotations mutex. // Event is the audit Event object that is being captured to be written in
auditAnnotationsMutexKey // the API audit log. It is set to nil when the request is not being audited.
) Event *auditinternal.Event
// annotations = *[]annotation instead of a map to preserve order of insertions // annotations holds audit annotations that are recorded before the event has been initialized.
type annotation struct { // This is represented as a slice rather than a map to preserve order.
key, value string annotations []annotation
// annotationMutex guards annotations AND event.Annotations
annotationMutex sync.Mutex
// auditID is the Audit ID associated with this request.
auditID types.UID
} }
// WithAuditAnnotations returns a new context that can store audit annotations type annotation struct {
// via the AddAuditAnnotation function. This function is meant to be called from key, value string
// an early request handler to allow all later layers to set audit annotations.
// This is required to support flows where handlers that come before WithAudit
// (such as WithAuthentication) wish to set audit annotations.
func WithAuditAnnotations(parent context.Context) context.Context {
// this should never really happen, but prevent double registration of this slice
if _, ok := parent.Value(auditAnnotationsKey).(*[]annotation); ok {
return parent
}
parent = withAuditAnnotationsMutex(parent)
var annotations []annotation // avoid allocations until we actually need it
return genericapirequest.WithValue(parent, auditAnnotationsKey, &annotations)
} }
// AddAuditAnnotation sets the audit annotation for the given key, value pair. // AddAuditAnnotation sets the audit annotation for the given key, value pair.
@ -70,102 +64,79 @@ func WithAuditAnnotations(parent context.Context) context.Context {
// Handlers that are unaware of their position in the overall request flow should // Handlers that are unaware of their position in the overall request flow should
// prefer AddAuditAnnotation over LogAnnotation to avoid dropping annotations. // prefer AddAuditAnnotation over LogAnnotation to avoid dropping annotations.
func AddAuditAnnotation(ctx context.Context, key, value string) { func AddAuditAnnotation(ctx context.Context, key, value string) {
mutex, ok := auditAnnotationsMutex(ctx) ac := AuditContextFrom(ctx)
if !ok { if ac == nil {
// auditing is not enabled // auditing is not enabled
return return
} }
mutex.Lock() ac.annotationMutex.Lock()
defer mutex.Unlock() defer ac.annotationMutex.Unlock()
ae := AuditEventFrom(ctx) addAuditAnnotationLocked(ac, key, value)
var ctxAnnotations *[]annotation
if ae == nil {
ctxAnnotations, _ = ctx.Value(auditAnnotationsKey).(*[]annotation)
}
addAuditAnnotationLocked(ae, ctxAnnotations, key, value)
} }
// AddAuditAnnotations is a bulk version of AddAuditAnnotation. Refer to AddAuditAnnotation for // AddAuditAnnotations is a bulk version of AddAuditAnnotation. Refer to AddAuditAnnotation for
// restrictions on when this can be called. // restrictions on when this can be called.
// keysAndValues are the key-value pairs to add, and must have an even number of items. // keysAndValues are the key-value pairs to add, and must have an even number of items.
func AddAuditAnnotations(ctx context.Context, keysAndValues ...string) { func AddAuditAnnotations(ctx context.Context, keysAndValues ...string) {
mutex, ok := auditAnnotationsMutex(ctx) ac := AuditContextFrom(ctx)
if !ok { if ac == nil {
// auditing is not enabled // auditing is not enabled
return return
} }
mutex.Lock() ac.annotationMutex.Lock()
defer mutex.Unlock() defer ac.annotationMutex.Unlock()
ae := AuditEventFrom(ctx)
var ctxAnnotations *[]annotation
if ae == nil {
ctxAnnotations, _ = ctx.Value(auditAnnotationsKey).(*[]annotation)
}
if len(keysAndValues)%2 != 0 { if len(keysAndValues)%2 != 0 {
klog.Errorf("Dropping mismatched audit annotation %q", keysAndValues[len(keysAndValues)-1]) klog.Errorf("Dropping mismatched audit annotation %q", keysAndValues[len(keysAndValues)-1])
} }
for i := 0; i < len(keysAndValues); i += 2 { for i := 0; i < len(keysAndValues); i += 2 {
addAuditAnnotationLocked(ae, ctxAnnotations, keysAndValues[i], keysAndValues[i+1]) addAuditAnnotationLocked(ac, keysAndValues[i], keysAndValues[i+1])
} }
} }
// AddAuditAnnotationsMap is a bulk version of AddAuditAnnotation. Refer to AddAuditAnnotation for // AddAuditAnnotationsMap is a bulk version of AddAuditAnnotation. Refer to AddAuditAnnotation for
// restrictions on when this can be called. // restrictions on when this can be called.
func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) { func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) {
mutex, ok := auditAnnotationsMutex(ctx) ac := AuditContextFrom(ctx)
if !ok { if ac == nil {
// auditing is not enabled // auditing is not enabled
return return
} }
mutex.Lock() ac.annotationMutex.Lock()
defer mutex.Unlock() defer ac.annotationMutex.Unlock()
ae := AuditEventFrom(ctx)
var ctxAnnotations *[]annotation
if ae == nil {
ctxAnnotations, _ = ctx.Value(auditAnnotationsKey).(*[]annotation)
}
for k, v := range annotations { for k, v := range annotations {
addAuditAnnotationLocked(ae, ctxAnnotations, k, v) addAuditAnnotationLocked(ac, k, v)
} }
} }
// addAuditAnnotationLocked is the shared code for recording an audit annotation. This method should // addAuditAnnotationLocked is the shared code for recording an audit annotation. This method should
// only be called while the auditAnnotationsMutex is locked. // only be called while the auditAnnotationsMutex is locked.
func addAuditAnnotationLocked(ae *auditinternal.Event, annotations *[]annotation, key, value string) { func addAuditAnnotationLocked(ac *AuditContext, key, value string) {
if ae != nil { if ac.Event != nil {
logAnnotation(ae, key, value) logAnnotation(ac.Event, key, value)
} else if annotations != nil { } else {
*annotations = append(*annotations, annotation{key: key, value: value}) ac.annotations = append(ac.annotations, annotation{key: key, value: value})
} }
} }
// This is private to prevent reads/write to the slice from outside of this package. // This is private to prevent reads/write to the slice from outside of this package.
// The audit event should be directly read to get access to the annotations. // The audit event should be directly read to get access to the annotations.
func addAuditAnnotationsFrom(ctx context.Context, ev *auditinternal.Event) { func addAuditAnnotationsFrom(ctx context.Context, ev *auditinternal.Event) {
mutex, ok := auditAnnotationsMutex(ctx) ac := AuditContextFrom(ctx)
if !ok { if ac == nil {
// auditing is not enabled // auditing is not enabled
return return
} }
mutex.Lock() ac.annotationMutex.Lock()
defer mutex.Unlock() defer ac.annotationMutex.Unlock()
annotations, ok := ctx.Value(auditAnnotationsKey).(*[]annotation) for _, kv := range ac.annotations {
if !ok {
return // no annotations to copy
}
for _, kv := range *annotations {
logAnnotation(ev, kv.key, kv.value) logAnnotation(ev, kv.key, kv.value)
} }
} }
@ -185,12 +156,13 @@ func logAnnotation(ae *auditinternal.Event, key, value string) {
ae.Annotations[key] = value ae.Annotations[key] = value
} }
// WithAuditContext returns a new context that stores the pair of the audit // WithAuditContext returns a new context that stores the AuditContext.
// configuration object that applies to the given request and func WithAuditContext(parent context.Context) context.Context {
// the audit event that is going to be written to the API audit log. if AuditContextFrom(parent) != nil {
func WithAuditContext(parent context.Context, ev *AuditContext) context.Context { return parent // Avoid double registering.
parent = withAuditAnnotationsMutex(parent) }
return genericapirequest.WithValue(parent, auditKey, ev)
return genericapirequest.WithValue(parent, auditKey, &AuditContext{})
} }
// AuditEventFrom returns the audit event struct on the ctx // AuditEventFrom returns the audit event struct on the ctx
@ -209,17 +181,46 @@ func AuditContextFrom(ctx context.Context) *AuditContext {
return ev return ev
} }
// WithAuditAnnotationMutex adds a mutex for guarding context.AddAuditAnnotation. // WithAuditID sets the AuditID on the AuditContext. The AuditContext must already be present in the
func withAuditAnnotationsMutex(parent context.Context) context.Context { // request context. If the specified auditID is empty, no value is set.
if _, ok := parent.Value(auditAnnotationsMutexKey).(*sync.Mutex); ok { func WithAuditID(ctx context.Context, auditID types.UID) {
return parent if auditID == "" {
return
}
ac := AuditContextFrom(ctx)
if ac == nil {
return
}
ac.auditID = auditID
if ac.Event != nil {
ac.Event.AuditID = auditID
} }
var mutex sync.Mutex
return genericapirequest.WithValue(parent, auditAnnotationsMutexKey, &mutex)
} }
// AuditAnnotationsMutex returns the audit annotations mutex from the context. // AuditIDFrom returns the value of the audit ID from the request context.
func auditAnnotationsMutex(ctx context.Context) (*sync.Mutex, bool) { func AuditIDFrom(ctx context.Context) (types.UID, bool) {
mutex, ok := ctx.Value(auditAnnotationsMutexKey).(*sync.Mutex) if ac := AuditContextFrom(ctx); ac != nil {
return mutex, ok return ac.auditID, ac.auditID != ""
}
return "", false
}
// GetAuditIDTruncated returns the audit ID (truncated) from the request context.
// If the length of the Audit-ID value exceeds the limit, we truncate it to keep
// the first N (maxAuditIDLength) characters.
// This is intended to be used in logging only.
func GetAuditIDTruncated(ctx context.Context) string {
auditID, ok := AuditIDFrom(ctx)
if !ok {
return ""
}
// if the user has specified a very long audit ID then we will use the first N characters
// Note: assuming Audit-ID header is in ASCII
const maxAuditIDLength = 64
if len(auditID) > maxAuditIDLength {
auditID = auditID[:maxAuditIDLength]
}
return string(auditID)
} }

View File

@ -63,20 +63,16 @@ func TestAddAuditAnnotation(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
validator: noopValidator, validator: noopValidator,
}, { }, {
description: "no annotations context", description: "empty audit context",
ctx: WithAuditContext(context.Background(), newAuditContext(auditinternal.LevelMetadata)), ctx: WithAuditContext(context.Background()),
validator: postEventValidator,
}, {
description: "no audit context",
ctx: WithAuditAnnotations(context.Background()),
validator: preEventValidator, validator: preEventValidator,
}, { }, {
description: "both contexts metadata level", description: "with metadata level",
ctx: WithAuditContext(WithAuditAnnotations(context.Background()), newAuditContext(auditinternal.LevelMetadata)), ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelMetadata),
validator: postEventValidator, validator: postEventValidator,
}, { }, {
description: "both contexts none level", description: "with none level",
ctx: WithAuditContext(WithAuditAnnotations(context.Background()), newAuditContext(auditinternal.LevelNone)), ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelNone),
validator: postEventEmptyValidator, validator: postEventEmptyValidator,
}} }}
@ -111,10 +107,11 @@ func TestLogAnnotation(t *testing.T) {
assert.Equal(t, "", ev.Annotations["qux"], "audit annotation should not be overwritten.") assert.Equal(t, "", ev.Annotations["qux"], "audit annotation should not be overwritten.")
} }
func newAuditContext(l auditinternal.Level) *AuditContext { func withAuditContextAndLevel(ctx context.Context, l auditinternal.Level) context.Context {
return &AuditContext{ ctx = WithAuditContext(ctx)
Event: &auditinternal.Event{ ac := AuditContextFrom(ctx)
Level: l, ac.Event = &auditinternal.Event{
}, Level: l,
} }
return ctx
} }

View File

@ -21,18 +21,6 @@ import (
"k8s.io/apiserver/pkg/authorization/authorizer" "k8s.io/apiserver/pkg/authorization/authorizer"
) )
// AuditContext is a pair of the audit configuration object that applies to
// a given request and the audit Event object that is being captured.
// It's a convenient placeholder to store both these objects in the request context.
type AuditContext struct {
// RequestAuditConfig is the audit configuration that applies to the request
RequestAuditConfig RequestAuditConfig
// Event is the audit Event object that is being captured to be written in
// the API audit log. It is set to nil when the request is not being audited.
Event *audit.Event
}
// RequestAuditConfig is the evaluated audit configuration that is applicable to // RequestAuditConfig is the evaluated audit configuration that is applicable to
// a given request. PolicyRuleEvaluator evaluates the audit policy against the // a given request. PolicyRuleEvaluator evaluates the audit policy against the
// authorizer attributes and returns a RequestAuditConfig that applies to the request. // authorizer attributes and returns a RequestAuditConfig that applies to the request.

View File

@ -33,7 +33,6 @@ import (
auditinternal "k8s.io/apiserver/pkg/apis/audit" auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/authorization/authorizer" "k8s.io/apiserver/pkg/authorization/authorizer"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2" "k8s.io/klog/v2"
"github.com/google/uuid" "github.com/google/uuid"
@ -53,7 +52,7 @@ func NewEventFromRequest(req *http.Request, requestReceivedTimestamp time.Time,
Level: level, Level: level,
} }
auditID, found := request.AuditIDFrom(req.Context()) auditID, found := AuditIDFrom(req.Context())
if !found { if !found {
auditID = types.UID(uuid.New().String()) auditID = types.UID(uuid.New().String())
} }

View File

@ -200,7 +200,9 @@ func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, toke
// since this is shared work between multiple requests, we have no way of knowing if any // since this is shared work between multiple requests, we have no way of knowing if any
// particular request supports audit annotations. thus we always attempt to record them. // particular request supports audit annotations. thus we always attempt to record them.
ev := &auditinternal.Event{Level: auditinternal.LevelMetadata} ev := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx = audit.WithAuditContext(ctx, &audit.AuditContext{Event: ev}) ctx = audit.WithAuditContext(ctx)
ac := audit.AuditContextFrom(ctx)
ac.Event = ev
record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token) record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token)
record.annotations = ev.Annotations record.annotations = ev.Annotations

View File

@ -300,26 +300,10 @@ func TestCachedAuditAnnotations(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
// exercise both ways of tracking audit annotations ctx := withAudit(context.Background())
r := mathrand.New(mathrand.NewSource(mathrand.Int63()))
randomChoice := r.Int()%2 == 0
ctx := context.Background()
if randomChoice {
ctx = audit.WithAuditAnnotations(ctx)
} else {
ctx = audit.WithAuditContext(ctx, &audit.AuditContext{
Event: &auditinternal.Event{Level: auditinternal.LevelMetadata},
})
}
_, _, _ = a.AuthenticateToken(ctx, "token") _, _, _ = a.AuthenticateToken(ctx, "token")
if randomChoice { allAnnotations <- audit.AuditEventFrom(ctx).Annotations
allAnnotations <- extractAnnotations(ctx)
} else {
allAnnotations <- audit.AuditEventFrom(ctx).Annotations
}
}() }()
} }
@ -354,9 +338,9 @@ func TestCachedAuditAnnotations(t *testing.T) {
allAnnotations := make([]map[string]string, 0, 10) allAnnotations := make([]map[string]string, 0, 10)
for i := 0; i < cap(allAnnotations); i++ { for i := 0; i < cap(allAnnotations); i++ {
ctx := audit.WithAuditAnnotations(context.Background()) ctx := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx, "token") _, _, _ = a.AuthenticateToken(ctx, "token")
allAnnotations = append(allAnnotations, extractAnnotations(ctx)) allAnnotations = append(allAnnotations, audit.AuditEventFrom(ctx).Annotations)
} }
if len(allAnnotations) != cap(allAnnotations) { if len(allAnnotations) != cap(allAnnotations) {
@ -381,16 +365,16 @@ func TestCachedAuditAnnotations(t *testing.T) {
return snorlax, true, nil return snorlax, true, nil
}), false, time.Minute, 0) }), false, time.Minute, 0)
ctx1 := audit.WithAuditAnnotations(context.Background()) ctx1 := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx1, "token1") _, _, _ = a.AuthenticateToken(ctx1, "token1")
annotations1 := extractAnnotations(ctx1) annotations1 := audit.AuditEventFrom(ctx1).Annotations
// guarantee different now times // guarantee different now times
time.Sleep(time.Second) time.Sleep(time.Second)
ctx2 := audit.WithAuditAnnotations(context.Background()) ctx2 := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx2, "token2") _, _, _ = a.AuthenticateToken(ctx2, "token2")
annotations2 := extractAnnotations(ctx2) annotations2 := audit.AuditEventFrom(ctx2).Annotations
if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok { if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 1: %v", annotations1) t.Errorf("invalid annotations 1: %v", annotations1)
@ -405,18 +389,6 @@ func TestCachedAuditAnnotations(t *testing.T) {
}) })
} }
func extractAnnotations(ctx context.Context) map[string]string {
annotationsSlice := reflect.ValueOf(ctx).Elem().FieldByName("val").Elem().Elem()
annotations := map[string]string{}
for i := 0; i < annotationsSlice.Len(); i++ {
annotation := annotationsSlice.Index(i)
key := annotation.FieldByName("key").String()
val := annotation.FieldByName("value").String()
annotations[key] = val
}
return annotations
}
func BenchmarkCachedTokenAuthenticator(b *testing.B) { func BenchmarkCachedTokenAuthenticator(b *testing.B) {
tokenCount := []int{100, 500, 2500, 12500, 62500} tokenCount := []int{100, 500, 2500, 12500, 62500}
threadCount := []int{1, 16, 256} threadCount := []int{1, 16, 256}
@ -566,3 +538,12 @@ func (s *singleBenchmark) bench(b *testing.B) {
b.ReportMetric(float64(lookups)/float64(b.N), "lookups/op") b.ReportMetric(float64(lookups)/float64(b.N), "lookups/op")
} }
// Add a test version of the audit context with a pre-populated event for easy annotation
// extraction.
func withAudit(ctx context.Context) context.Context {
ctx = audit.WithAuditContext(ctx)
ac := audit.AuditContextFrom(ctx)
ac.Event = &auditinternal.Event{Level: auditinternal.LevelMetadata}
return ctx
}

View File

@ -290,6 +290,7 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
handler := genericapifilters.WithAudit(mux, auditSink, fakeRuleEvaluator, longRunningCheck) handler := genericapifilters.WithAudit(mux, auditSink, fakeRuleEvaluator, longRunningCheck)
handler = genericapifilters.WithRequestDeadline(handler, auditSink, fakeRuleEvaluator, longRunningCheck, codecs, 60*time.Second) handler = genericapifilters.WithRequestDeadline(handler, auditSink, fakeRuleEvaluator, longRunningCheck, codecs, 60*time.Second)
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver()) handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver())
handler = genericapifilters.WithAuditInit(handler)
return &defaultAPIServer{handler, container} return &defaultAPIServer{handler, container}
} }

View File

@ -286,81 +286,97 @@ func TestAudit(t *testing.T) {
}, },
}, },
} { } {
sink := &fakeAuditSink{} t.Run(test.desc, func(t *testing.T) {
handler := handleInternal(map[string]rest.Storage{ sink := &fakeAuditSink{}
"simple": &SimpleRESTStorage{ handler := handleInternal(map[string]rest.Storage{
list: []genericapitesting.Simple{ "simple": &SimpleRESTStorage{
{ list: []genericapitesting.Simple{
ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "other"}, {
Other: "foo", ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "other"},
Other: "foo",
},
{
ObjectMeta: metav1.ObjectMeta{Name: "b", Namespace: "other"},
Other: "foo",
},
}, },
{ item: genericapitesting.Simple{
ObjectMeta: metav1.ObjectMeta{Name: "b", Namespace: "other"}, ObjectMeta: metav1.ObjectMeta{Name: "c", Namespace: "other", UID: "uid"},
Other: "foo", Other: "foo",
}, },
}, },
item: genericapitesting.Simple{ }, admissionControl, sink)
ObjectMeta: metav1.ObjectMeta{Name: "c", Namespace: "other", UID: "uid"},
Other: "foo",
},
},
}, admissionControl, sink)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{Timeout: 2 * time.Second} client := http.Client{Timeout: 2 * time.Second}
req, err := test.req(server.URL) req, err := test.req(server.URL)
if err != nil { if err != nil {
t.Errorf("[%s] error creating the request: %v", test.desc, err) t.Errorf("[%s] error creating the request: %v", test.desc, err)
} }
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
response, err := client.Do(req) response, err := client.Do(req)
if err != nil { if err != nil {
t.Errorf("[%s] error: %v", test.desc, err) t.Errorf("[%s] error: %v", test.desc, err)
} }
if response.StatusCode != test.code { if response.StatusCode != test.code {
t.Errorf("[%s] expected http code %d, got %#v", test.desc, test.code, response) t.Errorf("[%s] expected http code %d, got %#v", test.desc, test.code, response)
} }
// close body because the handler might block in Flush, unable to send the remaining event. // close body because the handler might block in Flush, unable to send the remaining event.
response.Body.Close() response.Body.Close()
// wait for events to arrive, at least the given number in the test // wait for events to arrive, at least the given number in the test
events := []*auditinternal.Event{} events := []*auditinternal.Event{}
err = wait.Poll(50*time.Millisecond, wait.ForeverTestTimeout, wait.ConditionFunc(func() (done bool, err error) { err = wait.Poll(50*time.Millisecond, testTimeout(t), wait.ConditionFunc(func() (done bool, err error) {
events = sink.Events() events = sink.Events()
return len(events) >= test.events, nil return len(events) >= test.events, nil
})) }))
if err != nil { if err != nil {
t.Errorf("[%s] timeout waiting for events", test.desc) t.Errorf("[%s] timeout waiting for events", test.desc)
} }
if got := len(events); got != test.events { if got := len(events); got != test.events {
t.Errorf("[%s] expected %d audit events, got %d", test.desc, test.events, got) t.Errorf("[%s] expected %d audit events, got %d", test.desc, test.events, got)
} else { } else {
for i, check := range test.checks { for i, check := range test.checks {
err := check(events) err := check(events)
if err != nil { if err != nil {
t.Errorf("[%s,%d] %v", test.desc, i, err) t.Errorf("[%s,%d] %v", test.desc, i, err)
}
}
if err := requestUserAgentMatches(userAgent)(events); err != nil {
t.Errorf("[%s] %v", test.desc, err)
} }
} }
if err := requestUserAgentMatches(userAgent)(events); err != nil { if len(events) > 0 {
t.Errorf("[%s] %v", test.desc, err) status := events[len(events)-1].ResponseStatus
if status == nil {
t.Errorf("[%s] expected non-nil ResponseStatus in last event", test.desc)
} else if int(status.Code) != test.code {
t.Errorf("[%s] expected ResponseStatus.Code=%d, got %d", test.desc, test.code, status.Code)
}
} }
} })
if len(events) > 0 {
status := events[len(events)-1].ResponseStatus
if status == nil {
t.Errorf("[%s] expected non-nil ResponseStatus in last event", test.desc)
} else if int(status.Code) != test.code {
t.Errorf("[%s] expected ResponseStatus.Code=%d, got %d", test.desc, test.code, status.Code)
}
}
} }
} }
// testTimeout returns the minimimum of the "ForeverTestTimeout" and the testing deadline (with
// cleanup time).
func testTimeout(t *testing.T) time.Duration {
defaultTimeout := wait.ForeverTestTimeout
const cleanupTime = 5 * time.Second
if deadline, ok := t.Deadline(); ok {
maxTimeout := time.Until(deadline) - cleanupTime
if maxTimeout < defaultTimeout {
return maxTimeout
}
}
return defaultTimeout
}

View File

@ -44,23 +44,21 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva
return handler return handler
} }
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
auditContext, err := evaluatePolicyAndCreateAuditEvent(req, policy) ac, err := evaluatePolicyAndCreateAuditEvent(req, policy)
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event")) responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return return
} }
ev := auditContext.Event if ac == nil || ac.Event == nil {
if ev == nil || req.Context() == nil {
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
return return
} }
ev := ac.Event
req = req.WithContext(audit.WithAuditContext(req.Context(), auditContext))
ctx := req.Context() ctx := req.Context()
omitStages := auditContext.RequestAuditConfig.OmitStages omitStages := ac.RequestAuditConfig.OmitStages
ev.Stage = auditinternal.StageRequestReceived ev.Stage = auditinternal.StageRequestReceived
if processed := processAuditEvent(ctx, sink, ev, omitStages); !processed { if processed := processAuditEvent(ctx, sink, ev, omitStages); !processed {
@ -124,19 +122,23 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva
// - error if anything bad happened // - error if anything bad happened
func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator) (*audit.AuditContext, error) { func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator) (*audit.AuditContext, error) {
ctx := req.Context() ctx := req.Context()
ac := audit.AuditContextFrom(ctx)
if ac == nil {
// Auditing not enabled.
return nil, nil
}
attribs, err := GetAuthorizerAttributes(ctx) attribs, err := GetAuthorizerAttributes(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to GetAuthorizerAttributes: %v", err) return ac, fmt.Errorf("failed to GetAuthorizerAttributes: %v", err)
} }
ls := policy.EvaluatePolicyRule(attribs) ls := policy.EvaluatePolicyRule(attribs)
audit.ObservePolicyLevel(ctx, ls.Level) audit.ObservePolicyLevel(ctx, ls.Level)
ac.RequestAuditConfig = ls.RequestAuditConfig
if ls.Level == auditinternal.LevelNone { if ls.Level == auditinternal.LevelNone {
// Don't audit. // Don't audit.
return &audit.AuditContext{ return ac, nil
RequestAuditConfig: ls.RequestAuditConfig,
}, nil
} }
requestReceivedTimestamp, ok := request.ReceivedTimestampFrom(ctx) requestReceivedTimestamp, ok := request.ReceivedTimestampFrom(ctx)
@ -148,10 +150,9 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul
return nil, fmt.Errorf("failed to complete audit event from request: %v", err) return nil, fmt.Errorf("failed to complete audit event from request: %v", err)
} }
return &audit.AuditContext{ ac.Event = ev
RequestAuditConfig: ls.RequestAuditConfig,
Event: ev, return ac, nil
}, nil
} }
// writeLatencyToAnnotation writes the latency incurred in different // writeLatencyToAnnotation writes the latency incurred in different

View File

@ -1,38 +0,0 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"net/http"
"k8s.io/apiserver/pkg/audit"
)
// WithAuditAnnotations decorates a http.Handler with a []{key, value} that is merged
// with the audit.Event.Annotations map. This allows layers that run before WithAudit
// (such as authentication) to assert annotations.
// If sink or audit policy is nil, no decoration takes place.
func WithAuditAnnotations(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEvaluator) http.Handler {
// no need to wrap if auditing is disabled
if sink == nil || policy == nil {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = req.WithContext(audit.WithAuditAnnotations(req.Context()))
handler.ServeHTTP(w, req)
})
}

View File

@ -21,28 +21,25 @@ import (
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
auditinternal "k8s.io/apiserver/pkg/apis/audit" auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/audit"
"github.com/google/uuid" "github.com/google/uuid"
) )
// WithAuditID attaches the Audit-ID associated with a request to the context. // WithAuditInit initializes the audit context and attaches the Audit-ID associated with a request.
// //
// a. If the caller does not specify a value for Audit-ID in the request header, we generate a new audit ID // a. If the caller does not specify a value for Audit-ID in the request header, we generate a new audit ID
// b. We echo the Audit-ID value to the caller via the response Header 'Audit-ID'. // b. We echo the Audit-ID value to the caller via the response Header 'Audit-ID'.
func WithAuditID(handler http.Handler) http.Handler { func WithAuditInit(handler http.Handler) http.Handler {
return withAuditID(handler, func() string { return withAuditInit(handler, func() string {
return uuid.New().String() return uuid.New().String()
}) })
} }
func withAuditID(handler http.Handler, newAuditIDFunc func() string) http.Handler { func withAuditInit(handler http.Handler, newAuditIDFunc func() string) http.Handler {
if newAuditIDFunc == nil {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := audit.WithAuditContext(r.Context())
r = r.WithContext(ctx)
auditID := r.Header.Get(auditinternal.HeaderAuditID) auditID := r.Header.Get(auditinternal.HeaderAuditID)
if len(auditID) == 0 { if len(auditID) == 0 {
@ -50,7 +47,7 @@ func withAuditID(handler http.Handler, newAuditIDFunc func() string) http.Handle
} }
// Note: we save the user specified value of the Audit-ID header as is, no truncation is performed. // Note: we save the user specified value of the Audit-ID header as is, no truncation is performed.
r = r.WithContext(request.WithAuditID(ctx, types.UID(auditID))) audit.WithAuditID(ctx, types.UID(auditID))
// We echo the Audit-ID in to the response header. // We echo the Audit-ID in to the response header.
// It's not guaranteed Audit-ID http header is sent for all requests. // It's not guaranteed Audit-ID http header is sent for all requests.

View File

@ -23,7 +23,7 @@ import (
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/audit"
) )
func TestWithAuditID(t *testing.T) { func TestWithAuditID(t *testing.T) {
@ -72,15 +72,15 @@ func TestWithAuditID(t *testing.T) {
innerHandlerCallCount++ innerHandlerCallCount++
// does the inner handler see the audit ID? // does the inner handler see the audit ID?
v, ok := request.AuditIDFrom(req.Context()) v, ok := audit.AuditIDFrom(req.Context())
found = ok found = ok
auditIDGot = string(v) auditIDGot = string(v)
}) })
wrapped := WithAuditID(handler) wrapped := WithAuditInit(handler)
if test.newAuditIDFunc != nil { if test.newAuditIDFunc != nil {
wrapped = withAuditID(handler, test.newAuditIDFunc) wrapped = withAuditInit(handler, test.newAuditIDFunc)
} }
testRequest, err := http.NewRequest(http.MethodGet, "/api/v1/namespaces", nil) testRequest, err := http.NewRequest(http.MethodGet, "/api/v1/namespaces", nil)

View File

@ -676,7 +676,7 @@ func TestAudit(t *testing.T) {
// simplified long-running check // simplified long-running check
return ri.Verb == "watch" return ri.Verb == "watch"
}) })
handler = WithAuditID(handler) handler = WithAuditInit(handler)
req, _ := http.NewRequest(test.verb, test.path, nil) req, _ := http.NewRequest(test.verb, test.path, nil)
req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil) req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil)
@ -812,7 +812,7 @@ func TestAuditIDHttpHeader(t *testing.T) {
}) })
fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(test.level, nil) fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(test.level, nil)
handler = WithAudit(handler, sink, fakeRuleEvaluator, nil) handler = WithAudit(handler, sink, fakeRuleEvaluator, nil)
handler = WithAuditID(handler) handler = WithAuditInit(handler)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil) req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1" req.RemoteAddr = "127.0.0.1"
@ -843,14 +843,13 @@ func TestAuditIDHttpHeader(t *testing.T) {
} }
func withTestContext(req *http.Request, user user.Info, ae *auditinternal.Event) *http.Request { func withTestContext(req *http.Request, user user.Info, ae *auditinternal.Event) *http.Request {
ctx := req.Context() ctx := audit.WithAuditContext(req.Context())
if user != nil { if user != nil {
ctx = request.WithUser(ctx, user) ctx = request.WithUser(ctx, user)
} }
if ae != nil { if ae != nil {
ctx = audit.WithAuditContext(ctx, &audit.AuditContext{ ac := audit.AuditContextFrom(ctx)
Event: ae, ac.Event = ae
})
} }
if info, err := newTestRequestInfoResolver().NewRequestInfo(req); err == nil { if info, err := newTestRequestInfoResolver().NewRequestInfo(req); err == nil {
ctx = request.WithRequestInfo(ctx, info) ctx = request.WithRequestInfo(ctx, info)

View File

@ -36,26 +36,24 @@ func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink,
return failedHandler return failedHandler
} }
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
a, err := evaluatePolicyAndCreateAuditEvent(req, policy) ac, err := evaluatePolicyAndCreateAuditEvent(req, policy)
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event")) responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return return
} }
ev := a.Event if ac == nil || ac.Event == nil {
if ev == nil {
failedHandler.ServeHTTP(w, req) failedHandler.ServeHTTP(w, req)
return return
} }
ev := ac.Event
req = req.WithContext(audit.WithAuditContext(req.Context(), a))
ev.ResponseStatus = &metav1.Status{} ev.ResponseStatus = &metav1.Status{}
ev.ResponseStatus.Message = getAuthMethods(req) ev.ResponseStatus.Message = getAuthMethods(req)
ev.Stage = auditinternal.StageResponseStarted ev.Stage = auditinternal.StageResponseStarted
rw := decorateResponseWriter(req.Context(), w, ev, sink, a.RequestAuditConfig.OmitStages) rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages)
failedHandler.ServeHTTP(rw, req) failedHandler.ServeHTTP(rw, req)
}) })
} }

View File

@ -108,20 +108,18 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta
return failedHandler return failedHandler
} }
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
a, err := evaluatePolicyAndCreateAuditEvent(req, policy) ac, err := evaluatePolicyAndCreateAuditEvent(req, policy)
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event")) responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return return
} }
ev := a.Event if ac == nil || ac.Event == nil {
if ev == nil {
failedHandler.ServeHTTP(w, req) failedHandler.ServeHTTP(w, req)
return return
} }
ev := ac.Event
req = req.WithContext(audit.WithAuditContext(req.Context(), a))
ev.ResponseStatus = &metav1.Status{} ev.ResponseStatus = &metav1.Status{}
ev.Stage = auditinternal.StageResponseStarted ev.Stage = auditinternal.StageResponseStarted
@ -129,7 +127,7 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta
ev.ResponseStatus.Message = statusErr.Error() ev.ResponseStatus.Message = statusErr.Error()
} }
rw := decorateResponseWriter(req.Context(), w, ev, sink, a.RequestAuditConfig.OmitStages) rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages)
failedHandler.ServeHTTP(rw, req) failedHandler.ServeHTTP(rw, req)
}) })
} }

View File

@ -253,7 +253,7 @@ func TestWithRequestDeadlineWithClock(t *testing.T) {
} }
} }
func TestWithRequestDeadlineWithFailedRequestIsAudited(t *testing.T) { func TestWithRequestDeadlineWithInvalidTimeoutIsAudited(t *testing.T) {
var handlerInvoked bool var handlerInvoked bool
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
handlerInvoked = true handlerInvoked = true
@ -386,10 +386,7 @@ func TestWithFailedRequestAudit(t *testing.T) {
withAudit := withFailedRequestAudit(errorHandler, test.statusErr, fakeSink, fakeRuleEvaluator) withAudit := withFailedRequestAudit(errorHandler, test.statusErr, fakeSink, fakeRuleEvaluator)
w := httptest.NewRecorder() w := httptest.NewRecorder()
testRequest, err := http.NewRequest(http.MethodGet, "/apis/v1/namespaces/default/pods", nil) testRequest := newRequest(t, "/apis/v1/namespaces/default/pods")
if err != nil {
t.Fatalf("failed to create new http testRequest - %v", err)
}
info := request.RequestInfo{} info := request.RequestInfo{}
testRequest = testRequest.WithContext(request.WithRequestInfo(testRequest.Context(), &info)) testRequest = testRequest.WithContext(request.WithRequestInfo(testRequest.Context(), &info))
@ -446,8 +443,8 @@ func newRequest(t *testing.T, requestURL string) *http.Request {
if err != nil { if err != nil {
t.Fatalf("failed to create new http request - %v", err) t.Fatalf("failed to create new http request - %v", err)
} }
ctx := audit.WithAuditContext(req.Context())
return req return req.WithContext(ctx)
} }
func message(err error) string { func message(err error) string {

View File

@ -60,11 +60,11 @@ func (s *mockCodecs) EncoderForVersion(encoder runtime.Encoder, gv runtime.Group
func TestDeleteResourceAuditLogRequestObject(t *testing.T) { func TestDeleteResourceAuditLogRequestObject(t *testing.T) {
ctx := audit.WithAuditContext(context.TODO(), &audit.AuditContext{ ctx := audit.WithAuditContext(context.TODO())
Event: &auditapis.Event{ ac := audit.AuditContextFrom(ctx)
Level: auditapis.LevelRequestResponse, ac.Event = &auditapis.Event{
}, Level: auditapis.LevelRequestResponse,
}) }
policy := metav1.DeletePropagationBackground policy := metav1.DeletePropagationBackground
deleteOption := &metav1.DeleteOptions{ deleteOption := &metav1.DeleteOptions{

View File

@ -20,7 +20,7 @@ import (
"net/http" "net/http"
utilnet "k8s.io/apimachinery/pkg/util/net" utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/audit"
) )
const ( const (
@ -83,7 +83,7 @@ type lazyAuditID struct {
func (lazy *lazyAuditID) String() string { func (lazy *lazyAuditID) String() string {
if lazy.req != nil { if lazy.req != nil {
return request.GetAuditIDTruncated(lazy.req.Context()) return audit.GetAuditIDTruncated(lazy.req.Context())
} }
return "unknown" return "unknown"

View File

@ -88,7 +88,7 @@ func StreamObject(statusCode int, gv schema.GroupVersion, s runtime.NegotiatedSe
// a client and the feature gate for APIResponseCompression is enabled. // a client and the feature gate for APIResponseCompression is enabled.
func SerializeObject(mediaType string, encoder runtime.Encoder, hw http.ResponseWriter, req *http.Request, statusCode int, object runtime.Object) { func SerializeObject(mediaType string, encoder runtime.Encoder, hw http.ResponseWriter, req *http.Request, statusCode int, object runtime.Object) {
trace := utiltrace.New("SerializeObject", trace := utiltrace.New("SerializeObject",
utiltrace.Field{"audit-id", request.GetAuditIDTruncated(req.Context())}, utiltrace.Field{"audit-id", audit.GetAuditIDTruncated(req.Context())},
utiltrace.Field{"method", req.Method}, utiltrace.Field{"method", req.Method},
utiltrace.Field{"url", req.URL.Path}, utiltrace.Field{"url", req.URL.Path},
utiltrace.Field{"protocol", req.Proto}, utiltrace.Field{"protocol", req.Proto},

View File

@ -1,65 +0,0 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package request
import (
"context"
"k8s.io/apimachinery/pkg/types"
)
type auditIDKeyType int
// auditIDKey is the key to associate the Audit-ID value of a request.
const auditIDKey auditIDKeyType = iota
// WithAuditID returns a copy of the parent context into which the Audit-ID
// associated with the request is set.
//
// If the specified auditID is empty, no value is set and the parent context is returned as is.
func WithAuditID(parent context.Context, auditID types.UID) context.Context {
if auditID == "" {
return parent
}
return WithValue(parent, auditIDKey, auditID)
}
// AuditIDFrom returns the value of the audit ID from the request context.
func AuditIDFrom(ctx context.Context) (types.UID, bool) {
auditID, ok := ctx.Value(auditIDKey).(types.UID)
return auditID, ok
}
// GetAuditIDTruncated returns the audit ID (truncated) from the request context.
// If the length of the Audit-ID value exceeds the limit, we truncate it to keep
// the first N (maxAuditIDLength) characters.
// This is intended to be used in logging only.
func GetAuditIDTruncated(ctx context.Context) string {
auditID, ok := AuditIDFrom(ctx)
if !ok {
return ""
}
// if the user has specified a very long audit ID then we will use the first N characters
// Note: assuming Audit-ID header is in ASCII
const maxAuditIDLength = 64
if len(auditID) > maxAuditIDLength {
auditID = auditID[0:maxAuditIDLength]
}
return string(auditID)
}

View File

@ -1,68 +0,0 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package request
import (
"context"
"testing"
"k8s.io/apimachinery/pkg/types"
)
func TestAuditIDFrom(t *testing.T) {
tests := []struct {
name string
auditID string
auditIDExpected string
expected bool
}{
{
name: "empty audit ID",
auditID: "",
auditIDExpected: "",
expected: false,
},
{
name: "non empty audit ID",
auditID: "foo-bar-baz",
auditIDExpected: "foo-bar-baz",
expected: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parent := context.TODO()
ctx := WithAuditID(parent, types.UID(test.auditID))
// for an empty audit ID we don't expect a copy of the parent context.
if len(test.auditID) == 0 && parent != ctx {
t.Error("expected no copy of the parent context with an empty audit ID")
}
value, ok := AuditIDFrom(ctx)
if test.expected != ok {
t.Errorf("expected AuditIDFrom to return: %t, but got: %t", test.expected, ok)
}
auditIDGot := string(value)
if test.auditIDExpected != auditIDGot {
t.Errorf("expected audit ID: %q, but got: %q", test.auditIDExpected, auditIDGot)
}
})
}
}

View File

@ -854,7 +854,6 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 { if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 {
handler = genericfilters.WithProbabilisticGoaway(handler, c.GoawayChance) handler = genericfilters.WithProbabilisticGoaway(handler, c.GoawayChance)
} }
handler = genericapifilters.WithAuditAnnotations(handler, c.AuditBackend, c.AuditPolicyRuleEvaluator)
handler = genericapifilters.WithWarningRecorder(handler) handler = genericapifilters.WithWarningRecorder(handler)
handler = genericapifilters.WithCacheControl(handler) handler = genericapifilters.WithCacheControl(handler)
handler = genericfilters.WithHSTS(handler, c.HSTSDirectives) handler = genericfilters.WithHSTS(handler, c.HSTSDirectives)
@ -870,7 +869,7 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
handler = genericapifilters.WithRequestReceivedTimestamp(handler) handler = genericapifilters.WithRequestReceivedTimestamp(handler)
handler = genericapifilters.WithMuxAndDiscoveryComplete(handler, c.lifecycleSignals.MuxAndDiscoveryComplete.Signaled()) handler = genericapifilters.WithMuxAndDiscoveryComplete(handler, c.lifecycleSignals.MuxAndDiscoveryComplete.Signaled())
handler = genericfilters.WithPanicRecovery(handler, c.RequestInfoResolver) handler = genericfilters.WithPanicRecovery(handler, c.RequestInfoResolver)
handler = genericapifilters.WithAuditID(handler) handler = genericapifilters.WithAuditInit(handler)
return handler return handler
} }

View File

@ -306,7 +306,7 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) {
h := DefaultBuildHandlerChain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := DefaultBuildHandlerChain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// confirm this is a no-op // confirm this is a no-op
if r.Context() != audit.WithAuditAnnotations(r.Context()) { if r.Context() != audit.WithAuditContext(r.Context()) {
t.Error("unexpected double wrapping of context") t.Error("unexpected double wrapping of context")
} }

View File

@ -1197,7 +1197,7 @@ func newHandlerChain(t *testing.T, handler http.Handler, filter utilflowcontrol.
handler = apifilters.WithRequestDeadline(handler, nil, nil, longRunningRequestCheck, nil, requestTimeout) handler = apifilters.WithRequestDeadline(handler, nil, nil, longRunningRequestCheck, nil, requestTimeout)
handler = apifilters.WithRequestInfo(handler, requestInfoFactory) handler = apifilters.WithRequestInfo(handler, requestInfoFactory)
handler = WithPanicRecovery(handler, requestInfoFactory) handler = WithPanicRecovery(handler, requestInfoFactory)
handler = apifilters.WithAuditID(handler) handler = apifilters.WithAuditInit(handler)
return handler return handler
} }

View File

@ -21,6 +21,7 @@ import (
"net/http" "net/http"
"k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/metrics" "k8s.io/apiserver/pkg/endpoints/metrics"
"k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/server/httplog" "k8s.io/apiserver/pkg/server/httplog"
@ -50,11 +51,11 @@ func WithPanicRecovery(handler http.Handler, resolver request.RequestInfoResolve
// This call can have different handlers, but the default chain rate limits. Call it after the metrics are updated // This call can have different handlers, but the default chain rate limits. Call it after the metrics are updated
// in case the rate limit delays it. If you outrun the rate for this one timed out requests, something has gone // in case the rate limit delays it. If you outrun the rate for this one timed out requests, something has gone
// seriously wrong with your server, but generally having a logging signal for timeouts is useful. // seriously wrong with your server, but generally having a logging signal for timeouts is useful.
runtime.HandleError(fmt.Errorf("timeout or abort while handling: method=%v URI=%q audit-ID=%q", req.Method, req.RequestURI, request.GetAuditIDTruncated(req.Context()))) runtime.HandleError(fmt.Errorf("timeout or abort while handling: method=%v URI=%q audit-ID=%q", req.Method, req.RequestURI, audit.GetAuditIDTruncated(req.Context())))
return return
} }
http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError) http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError)
klog.ErrorS(nil, "apiserver panic'd", "method", req.Method, "URI", req.RequestURI, "audit-ID", request.GetAuditIDTruncated(req.Context())) klog.ErrorS(nil, "apiserver panic'd", "method", req.Method, "URI", req.RequestURI, "audit-ID", audit.GetAuditIDTruncated(req.Context()))
}) })
} }

View File

@ -27,6 +27,7 @@ import (
"sync" "sync"
"time" "time"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/metrics" "k8s.io/apiserver/pkg/endpoints/metrics"
"k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/endpoints/responsewriter" "k8s.io/apiserver/pkg/endpoints/responsewriter"
@ -241,7 +242,7 @@ func SetStacktracePredicate(ctx context.Context, pred StacktracePred) {
// Log is intended to be called once at the end of your request handler, via defer // Log is intended to be called once at the end of your request handler, via defer
func (rl *respLogger) Log() { func (rl *respLogger) Log() {
latency := time.Since(rl.startTime) latency := time.Since(rl.startTime)
auditID := request.GetAuditIDTruncated(rl.req.Context()) auditID := audit.GetAuditIDTruncated(rl.req.Context())
verb := rl.req.Method verb := rl.req.Method
if requestInfo, ok := request.RequestInfoFrom(rl.req.Context()); ok { if requestInfo, ok := request.RequestInfoFrom(rl.req.Context()); ok {

View File

@ -35,7 +35,7 @@ import (
utilruntime "k8s.io/apimachinery/pkg/util/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/watch" "k8s.io/apimachinery/pkg/watch"
endpointsrequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/features" "k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/storage" "k8s.io/apiserver/pkg/storage"
"k8s.io/apiserver/pkg/storage/cacher/metrics" "k8s.io/apiserver/pkg/storage/cacher/metrics"
@ -670,7 +670,7 @@ func (c *Cacher) GetList(ctx context.Context, key string, opts storage.ListOptio
} }
trace := utiltrace.New("cacher list", trace := utiltrace.New("cacher list",
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)}, utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "type", Value: c.groupResource.String()}) utiltrace.Field{Key: "type", Value: c.groupResource.String()})
defer trace.LogIfLong(500 * time.Millisecond) defer trace.LogIfLong(500 * time.Millisecond)

View File

@ -37,7 +37,7 @@ import (
"k8s.io/apimachinery/pkg/conversion" "k8s.io/apimachinery/pkg/conversion"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch" "k8s.io/apimachinery/pkg/watch"
endpointsrequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/features" "k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/storage" "k8s.io/apiserver/pkg/storage"
"k8s.io/apiserver/pkg/storage/etcd3/metrics" "k8s.io/apiserver/pkg/storage/etcd3/metrics"
@ -153,7 +153,7 @@ func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, ou
// Create implements storage.Interface.Create. // Create implements storage.Interface.Create.
func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, ttl uint64) error { func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, ttl uint64) error {
trace := utiltrace.New("Create etcd3", trace := utiltrace.New("Create etcd3",
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)}, utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "key", Value: key}, utiltrace.Field{Key: "key", Value: key},
utiltrace.Field{Key: "type", Value: getTypeName(obj)}, utiltrace.Field{Key: "type", Value: getTypeName(obj)},
utiltrace.Field{Key: "resource", Value: s.groupResourceString}, utiltrace.Field{Key: "resource", Value: s.groupResourceString},
@ -332,7 +332,7 @@ func (s *store) GuaranteedUpdate(
ctx context.Context, key string, destination runtime.Object, ignoreNotFound bool, ctx context.Context, key string, destination runtime.Object, ignoreNotFound bool,
preconditions *storage.Preconditions, tryUpdate storage.UpdateFunc, cachedExistingObject runtime.Object) error { preconditions *storage.Preconditions, tryUpdate storage.UpdateFunc, cachedExistingObject runtime.Object) error {
trace := utiltrace.New("GuaranteedUpdate etcd3", trace := utiltrace.New("GuaranteedUpdate etcd3",
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)}, utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "key", Value: key}, utiltrace.Field{Key: "key", Value: key},
utiltrace.Field{Key: "type", Value: getTypeName(destination)}, utiltrace.Field{Key: "type", Value: getTypeName(destination)},
utiltrace.Field{Key: "resource", Value: s.groupResourceString}) utiltrace.Field{Key: "resource", Value: s.groupResourceString})
@ -529,7 +529,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption
match := opts.ResourceVersionMatch match := opts.ResourceVersionMatch
pred := opts.Predicate pred := opts.Predicate
trace := utiltrace.New(fmt.Sprintf("List(recursive=%v) etcd3", recursive), trace := utiltrace.New(fmt.Sprintf("List(recursive=%v) etcd3", recursive),
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)}, utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "key", Value: key}, utiltrace.Field{Key: "key", Value: key},
utiltrace.Field{Key: "resourceVersion", Value: resourceVersion}, utiltrace.Field{Key: "resourceVersion", Value: resourceVersion},
utiltrace.Field{Key: "resourceVersionMatch", Value: match}, utiltrace.Field{Key: "resourceVersionMatch", Value: match},

View File

@ -245,8 +245,9 @@ func TestCheckForHostnameError(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create an http request: %v", err) t.Fatalf("failed to create an http request: %v", err)
} }
auditCtx := &audit.AuditContext{Event: &auditapi.Event{Level: auditapi.LevelMetadata}} req = req.WithContext(audit.WithAuditContext(req.Context()))
req = req.WithContext(audit.WithAuditContext(req.Context(), auditCtx)) auditCtx := audit.AuditContextFrom(req.Context())
auditCtx.Event = &auditapi.Event{Level: auditapi.LevelMetadata}
_, err = client.Transport.RoundTrip(req) _, err = client.Transport.RoundTrip(req)
@ -387,8 +388,9 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create an http request: %v", err) t.Fatalf("failed to create an http request: %v", err)
} }
auditCtx := &audit.AuditContext{Event: &auditapi.Event{Level: auditapi.LevelMetadata}} req = req.WithContext(audit.WithAuditContext(req.Context()))
req = req.WithContext(audit.WithAuditContext(req.Context(), auditCtx)) auditCtx := audit.AuditContextFrom(req.Context())
auditCtx.Event = &auditapi.Event{Level: auditapi.LevelMetadata}
// can't use tlsServer.Client() as it contains the server certificate // can't use tlsServer.Client() as it contains the server certificate
// in tls.Config.Certificates. The signatures are, however, only checked // in tls.Config.Certificates. The signatures are, however, only checked

View File

@ -29,6 +29,7 @@ import (
utilnet "k8s.io/apimachinery/pkg/util/net" utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/proxy" "k8s.io/apimachinery/pkg/util/proxy"
auditinternal "k8s.io/apiserver/pkg/apis/audit" auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
endpointmetrics "k8s.io/apiserver/pkg/endpoints/metrics" endpointmetrics "k8s.io/apiserver/pkg/endpoints/metrics"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request" genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
@ -206,7 +207,7 @@ func newRequestForProxy(location *url.URL, req *http.Request) (*http.Request, co
// If the original request has an audit ID, let's make sure we propagate this // If the original request has an audit ID, let's make sure we propagate this
// to the aggregated server. // to the aggregated server.
if auditID, found := genericapirequest.AuditIDFrom(req.Context()); found { if auditID, found := audit.AuditIDFrom(req.Context()); found {
newReq.Header.Set(auditinternal.HeaderAuditID, string(auditID)) newReq.Header.Set(auditinternal.HeaderAuditID, string(auditID))
} }

View File

@ -34,6 +34,7 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/server/dynamiccertificates" "k8s.io/apiserver/pkg/server/dynamiccertificates"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
@ -790,7 +791,9 @@ func TestNewRequestForProxyWithAuditID(t *testing.T) {
req = req.WithContext(genericapirequest.WithRequestInfo(req.Context(), &genericapirequest.RequestInfo{Path: req.URL.Path})) req = req.WithContext(genericapirequest.WithRequestInfo(req.Context(), &genericapirequest.RequestInfo{Path: req.URL.Path}))
if len(test.auditID) > 0 { if len(test.auditID) > 0 {
req = req.WithContext(genericapirequest.WithAuditID(req.Context(), types.UID(test.auditID))) ctx := audit.WithAuditContext(req.Context())
audit.WithAuditID(ctx, types.UID(test.auditID))
req = req.WithContext(ctx)
} }
newReq, _ := newRequestForProxy(req.URL, req) newReq, _ := newRequestForProxy(req.URL, req)