Merge pull request #111095 from tallclair/audit-single-context

Audit single context
This commit is contained in:
Kubernetes Prow Robot 2022-10-26 23:56:32 -07:00 committed by GitHub
commit 6e31c6531f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 286 additions and 472 deletions

View File

@ -143,7 +143,9 @@ func TestWithAudit(t *testing.T) {
for tcName, tc := range testCases {
var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles}
ae := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx := audit.WithAuditContext(context.Background(), &audit.AuditContext{Event: ae})
ctx := audit.WithAuditContext(context.Background())
ac := audit.AuditContextFrom(ctx)
ac.Event = ae
auditHandler := WithAudit(handler)
a := attributes()
@ -184,8 +186,9 @@ func TestWithAuditConcurrency(t *testing.T) {
"plugin.example.com/qux": "qux",
}
var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true}
ae := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx := audit.WithAuditContext(context.Background(), &audit.AuditContext{Event: ae})
ctx := audit.WithAuditContext(context.Background())
ac := audit.AuditContextFrom(ctx)
ac.Event = &auditinternal.Event{Level: auditinternal.LevelMetadata}
auditHandler := WithAudit(handler)
a := attributes()

View File

@ -20,6 +20,7 @@ import (
"context"
"sync"
"k8s.io/apimachinery/pkg/types"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2"
@ -28,38 +29,31 @@ import (
// The key type is unexported to prevent collisions
type key int
const (
// auditAnnotationsKey is the context key for the audit annotations.
// TODO: consolidate all audit info under the AuditContext, rather than storing 3 separate keys.
auditAnnotationsKey key = iota
// auditKey is the context key for storing the audit context that is being
// captured and the evaluated policy that applies to the given request.
const auditKey key = iota
// auditKey is the context key for storing the audit event that is being
// captured and the evaluated policy that applies to the given request.
auditKey
// AuditContext holds the information for constructing the audit events for the current request.
type AuditContext struct {
// RequestAuditConfig is the audit configuration that applies to the request
RequestAuditConfig RequestAuditConfig
// auditAnnotationsMutexKey is the context key for the audit annotations mutex.
auditAnnotationsMutexKey
)
// Event is the audit Event object that is being captured to be written in
// the API audit log. It is set to nil when the request is not being audited.
Event *auditinternal.Event
// annotations = *[]annotation instead of a map to preserve order of insertions
type annotation struct {
key, value string
// annotations holds audit annotations that are recorded before the event has been initialized.
// This is represented as a slice rather than a map to preserve order.
annotations []annotation
// annotationMutex guards annotations AND event.Annotations
annotationMutex sync.Mutex
// auditID is the Audit ID associated with this request.
auditID types.UID
}
// WithAuditAnnotations returns a new context that can store audit annotations
// via the AddAuditAnnotation function. This function is meant to be called from
// an early request handler to allow all later layers to set audit annotations.
// This is required to support flows where handlers that come before WithAudit
// (such as WithAuthentication) wish to set audit annotations.
func WithAuditAnnotations(parent context.Context) context.Context {
// this should never really happen, but prevent double registration of this slice
if _, ok := parent.Value(auditAnnotationsKey).(*[]annotation); ok {
return parent
}
parent = withAuditAnnotationsMutex(parent)
var annotations []annotation // avoid allocations until we actually need it
return genericapirequest.WithValue(parent, auditAnnotationsKey, &annotations)
type annotation struct {
key, value string
}
// AddAuditAnnotation sets the audit annotation for the given key, value pair.
@ -70,102 +64,79 @@ func WithAuditAnnotations(parent context.Context) context.Context {
// Handlers that are unaware of their position in the overall request flow should
// prefer AddAuditAnnotation over LogAnnotation to avoid dropping annotations.
func AddAuditAnnotation(ctx context.Context, key, value string) {
mutex, ok := auditAnnotationsMutex(ctx)
if !ok {
ac := AuditContextFrom(ctx)
if ac == nil {
// auditing is not enabled
return
}
mutex.Lock()
defer mutex.Unlock()
ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
ae := AuditEventFrom(ctx)
var ctxAnnotations *[]annotation
if ae == nil {
ctxAnnotations, _ = ctx.Value(auditAnnotationsKey).(*[]annotation)
}
addAuditAnnotationLocked(ae, ctxAnnotations, key, value)
addAuditAnnotationLocked(ac, key, value)
}
// AddAuditAnnotations is a bulk version of AddAuditAnnotation. Refer to AddAuditAnnotation for
// restrictions on when this can be called.
// keysAndValues are the key-value pairs to add, and must have an even number of items.
func AddAuditAnnotations(ctx context.Context, keysAndValues ...string) {
mutex, ok := auditAnnotationsMutex(ctx)
if !ok {
ac := AuditContextFrom(ctx)
if ac == nil {
// auditing is not enabled
return
}
mutex.Lock()
defer mutex.Unlock()
ae := AuditEventFrom(ctx)
var ctxAnnotations *[]annotation
if ae == nil {
ctxAnnotations, _ = ctx.Value(auditAnnotationsKey).(*[]annotation)
}
ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
if len(keysAndValues)%2 != 0 {
klog.Errorf("Dropping mismatched audit annotation %q", keysAndValues[len(keysAndValues)-1])
}
for i := 0; i < len(keysAndValues); i += 2 {
addAuditAnnotationLocked(ae, ctxAnnotations, keysAndValues[i], keysAndValues[i+1])
addAuditAnnotationLocked(ac, keysAndValues[i], keysAndValues[i+1])
}
}
// AddAuditAnnotationsMap is a bulk version of AddAuditAnnotation. Refer to AddAuditAnnotation for
// restrictions on when this can be called.
func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) {
mutex, ok := auditAnnotationsMutex(ctx)
if !ok {
ac := AuditContextFrom(ctx)
if ac == nil {
// auditing is not enabled
return
}
mutex.Lock()
defer mutex.Unlock()
ae := AuditEventFrom(ctx)
var ctxAnnotations *[]annotation
if ae == nil {
ctxAnnotations, _ = ctx.Value(auditAnnotationsKey).(*[]annotation)
}
ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
for k, v := range annotations {
addAuditAnnotationLocked(ae, ctxAnnotations, k, v)
addAuditAnnotationLocked(ac, k, v)
}
}
// addAuditAnnotationLocked is the shared code for recording an audit annotation. This method should
// only be called while the auditAnnotationsMutex is locked.
func addAuditAnnotationLocked(ae *auditinternal.Event, annotations *[]annotation, key, value string) {
if ae != nil {
logAnnotation(ae, key, value)
} else if annotations != nil {
*annotations = append(*annotations, annotation{key: key, value: value})
func addAuditAnnotationLocked(ac *AuditContext, key, value string) {
if ac.Event != nil {
logAnnotation(ac.Event, key, value)
} else {
ac.annotations = append(ac.annotations, annotation{key: key, value: value})
}
}
// This is private to prevent reads/write to the slice from outside of this package.
// The audit event should be directly read to get access to the annotations.
func addAuditAnnotationsFrom(ctx context.Context, ev *auditinternal.Event) {
mutex, ok := auditAnnotationsMutex(ctx)
if !ok {
ac := AuditContextFrom(ctx)
if ac == nil {
// auditing is not enabled
return
}
mutex.Lock()
defer mutex.Unlock()
ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
annotations, ok := ctx.Value(auditAnnotationsKey).(*[]annotation)
if !ok {
return // no annotations to copy
}
for _, kv := range *annotations {
for _, kv := range ac.annotations {
logAnnotation(ev, kv.key, kv.value)
}
}
@ -185,12 +156,13 @@ func logAnnotation(ae *auditinternal.Event, key, value string) {
ae.Annotations[key] = value
}
// WithAuditContext returns a new context that stores the pair of the audit
// configuration object that applies to the given request and
// the audit event that is going to be written to the API audit log.
func WithAuditContext(parent context.Context, ev *AuditContext) context.Context {
parent = withAuditAnnotationsMutex(parent)
return genericapirequest.WithValue(parent, auditKey, ev)
// WithAuditContext returns a new context that stores the AuditContext.
func WithAuditContext(parent context.Context) context.Context {
if AuditContextFrom(parent) != nil {
return parent // Avoid double registering.
}
return genericapirequest.WithValue(parent, auditKey, &AuditContext{})
}
// AuditEventFrom returns the audit event struct on the ctx
@ -209,17 +181,46 @@ func AuditContextFrom(ctx context.Context) *AuditContext {
return ev
}
// WithAuditAnnotationMutex adds a mutex for guarding context.AddAuditAnnotation.
func withAuditAnnotationsMutex(parent context.Context) context.Context {
if _, ok := parent.Value(auditAnnotationsMutexKey).(*sync.Mutex); ok {
return parent
// WithAuditID sets the AuditID on the AuditContext. The AuditContext must already be present in the
// request context. If the specified auditID is empty, no value is set.
func WithAuditID(ctx context.Context, auditID types.UID) {
if auditID == "" {
return
}
ac := AuditContextFrom(ctx)
if ac == nil {
return
}
ac.auditID = auditID
if ac.Event != nil {
ac.Event.AuditID = auditID
}
var mutex sync.Mutex
return genericapirequest.WithValue(parent, auditAnnotationsMutexKey, &mutex)
}
// AuditAnnotationsMutex returns the audit annotations mutex from the context.
func auditAnnotationsMutex(ctx context.Context) (*sync.Mutex, bool) {
mutex, ok := ctx.Value(auditAnnotationsMutexKey).(*sync.Mutex)
return mutex, ok
// AuditIDFrom returns the value of the audit ID from the request context.
func AuditIDFrom(ctx context.Context) (types.UID, bool) {
if ac := AuditContextFrom(ctx); ac != nil {
return ac.auditID, ac.auditID != ""
}
return "", false
}
// GetAuditIDTruncated returns the audit ID (truncated) from the request context.
// If the length of the Audit-ID value exceeds the limit, we truncate it to keep
// the first N (maxAuditIDLength) characters.
// This is intended to be used in logging only.
func GetAuditIDTruncated(ctx context.Context) string {
auditID, ok := AuditIDFrom(ctx)
if !ok {
return ""
}
// if the user has specified a very long audit ID then we will use the first N characters
// Note: assuming Audit-ID header is in ASCII
const maxAuditIDLength = 64
if len(auditID) > maxAuditIDLength {
auditID = auditID[:maxAuditIDLength]
}
return string(auditID)
}

View File

@ -63,20 +63,16 @@ func TestAddAuditAnnotation(t *testing.T) {
ctx: context.Background(),
validator: noopValidator,
}, {
description: "no annotations context",
ctx: WithAuditContext(context.Background(), newAuditContext(auditinternal.LevelMetadata)),
validator: postEventValidator,
}, {
description: "no audit context",
ctx: WithAuditAnnotations(context.Background()),
description: "empty audit context",
ctx: WithAuditContext(context.Background()),
validator: preEventValidator,
}, {
description: "both contexts metadata level",
ctx: WithAuditContext(WithAuditAnnotations(context.Background()), newAuditContext(auditinternal.LevelMetadata)),
description: "with metadata level",
ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelMetadata),
validator: postEventValidator,
}, {
description: "both contexts none level",
ctx: WithAuditContext(WithAuditAnnotations(context.Background()), newAuditContext(auditinternal.LevelNone)),
description: "with none level",
ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelNone),
validator: postEventEmptyValidator,
}}
@ -111,10 +107,11 @@ func TestLogAnnotation(t *testing.T) {
assert.Equal(t, "", ev.Annotations["qux"], "audit annotation should not be overwritten.")
}
func newAuditContext(l auditinternal.Level) *AuditContext {
return &AuditContext{
Event: &auditinternal.Event{
Level: l,
},
func withAuditContextAndLevel(ctx context.Context, l auditinternal.Level) context.Context {
ctx = WithAuditContext(ctx)
ac := AuditContextFrom(ctx)
ac.Event = &auditinternal.Event{
Level: l,
}
return ctx
}

View File

@ -21,18 +21,6 @@ import (
"k8s.io/apiserver/pkg/authorization/authorizer"
)
// AuditContext is a pair of the audit configuration object that applies to
// a given request and the audit Event object that is being captured.
// It's a convenient placeholder to store both these objects in the request context.
type AuditContext struct {
// RequestAuditConfig is the audit configuration that applies to the request
RequestAuditConfig RequestAuditConfig
// Event is the audit Event object that is being captured to be written in
// the API audit log. It is set to nil when the request is not being audited.
Event *audit.Event
}
// RequestAuditConfig is the evaluated audit configuration that is applicable to
// a given request. PolicyRuleEvaluator evaluates the audit policy against the
// authorizer attributes and returns a RequestAuditConfig that applies to the request.

View File

@ -33,7 +33,6 @@ import (
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/authorization/authorizer"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2"
"github.com/google/uuid"
@ -53,7 +52,7 @@ func NewEventFromRequest(req *http.Request, requestReceivedTimestamp time.Time,
Level: level,
}
auditID, found := request.AuditIDFrom(req.Context())
auditID, found := AuditIDFrom(req.Context())
if !found {
auditID = types.UID(uuid.New().String())
}

View File

@ -200,7 +200,9 @@ func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, toke
// since this is shared work between multiple requests, we have no way of knowing if any
// particular request supports audit annotations. thus we always attempt to record them.
ev := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx = audit.WithAuditContext(ctx, &audit.AuditContext{Event: ev})
ctx = audit.WithAuditContext(ctx)
ac := audit.AuditContextFrom(ctx)
ac.Event = ev
record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token)
record.annotations = ev.Annotations

View File

@ -300,26 +300,10 @@ func TestCachedAuditAnnotations(t *testing.T) {
go func() {
defer wg.Done()
// exercise both ways of tracking audit annotations
r := mathrand.New(mathrand.NewSource(mathrand.Int63()))
randomChoice := r.Int()%2 == 0
ctx := context.Background()
if randomChoice {
ctx = audit.WithAuditAnnotations(ctx)
} else {
ctx = audit.WithAuditContext(ctx, &audit.AuditContext{
Event: &auditinternal.Event{Level: auditinternal.LevelMetadata},
})
}
ctx := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx, "token")
if randomChoice {
allAnnotations <- extractAnnotations(ctx)
} else {
allAnnotations <- audit.AuditEventFrom(ctx).Annotations
}
allAnnotations <- audit.AuditEventFrom(ctx).Annotations
}()
}
@ -354,9 +338,9 @@ func TestCachedAuditAnnotations(t *testing.T) {
allAnnotations := make([]map[string]string, 0, 10)
for i := 0; i < cap(allAnnotations); i++ {
ctx := audit.WithAuditAnnotations(context.Background())
ctx := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx, "token")
allAnnotations = append(allAnnotations, extractAnnotations(ctx))
allAnnotations = append(allAnnotations, audit.AuditEventFrom(ctx).Annotations)
}
if len(allAnnotations) != cap(allAnnotations) {
@ -381,16 +365,16 @@ func TestCachedAuditAnnotations(t *testing.T) {
return snorlax, true, nil
}), false, time.Minute, 0)
ctx1 := audit.WithAuditAnnotations(context.Background())
ctx1 := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx1, "token1")
annotations1 := extractAnnotations(ctx1)
annotations1 := audit.AuditEventFrom(ctx1).Annotations
// guarantee different now times
time.Sleep(time.Second)
ctx2 := audit.WithAuditAnnotations(context.Background())
ctx2 := withAudit(context.Background())
_, _, _ = a.AuthenticateToken(ctx2, "token2")
annotations2 := extractAnnotations(ctx2)
annotations2 := audit.AuditEventFrom(ctx2).Annotations
if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 1: %v", annotations1)
@ -405,18 +389,6 @@ func TestCachedAuditAnnotations(t *testing.T) {
})
}
func extractAnnotations(ctx context.Context) map[string]string {
annotationsSlice := reflect.ValueOf(ctx).Elem().FieldByName("val").Elem().Elem()
annotations := map[string]string{}
for i := 0; i < annotationsSlice.Len(); i++ {
annotation := annotationsSlice.Index(i)
key := annotation.FieldByName("key").String()
val := annotation.FieldByName("value").String()
annotations[key] = val
}
return annotations
}
func BenchmarkCachedTokenAuthenticator(b *testing.B) {
tokenCount := []int{100, 500, 2500, 12500, 62500}
threadCount := []int{1, 16, 256}
@ -566,3 +538,12 @@ func (s *singleBenchmark) bench(b *testing.B) {
b.ReportMetric(float64(lookups)/float64(b.N), "lookups/op")
}
// Add a test version of the audit context with a pre-populated event for easy annotation
// extraction.
func withAudit(ctx context.Context) context.Context {
ctx = audit.WithAuditContext(ctx)
ac := audit.AuditContextFrom(ctx)
ac.Event = &auditinternal.Event{Level: auditinternal.LevelMetadata}
return ctx
}

View File

@ -290,6 +290,7 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
handler := genericapifilters.WithAudit(mux, auditSink, fakeRuleEvaluator, longRunningCheck)
handler = genericapifilters.WithRequestDeadline(handler, auditSink, fakeRuleEvaluator, longRunningCheck, codecs, 60*time.Second)
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver())
handler = genericapifilters.WithAuditInit(handler)
return &defaultAPIServer{handler, container}
}

View File

@ -286,81 +286,97 @@ func TestAudit(t *testing.T) {
},
},
} {
sink := &fakeAuditSink{}
handler := handleInternal(map[string]rest.Storage{
"simple": &SimpleRESTStorage{
list: []genericapitesting.Simple{
{
ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "other"},
Other: "foo",
t.Run(test.desc, func(t *testing.T) {
sink := &fakeAuditSink{}
handler := handleInternal(map[string]rest.Storage{
"simple": &SimpleRESTStorage{
list: []genericapitesting.Simple{
{
ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "other"},
Other: "foo",
},
{
ObjectMeta: metav1.ObjectMeta{Name: "b", Namespace: "other"},
Other: "foo",
},
},
{
ObjectMeta: metav1.ObjectMeta{Name: "b", Namespace: "other"},
item: genericapitesting.Simple{
ObjectMeta: metav1.ObjectMeta{Name: "c", Namespace: "other", UID: "uid"},
Other: "foo",
},
},
item: genericapitesting.Simple{
ObjectMeta: metav1.ObjectMeta{Name: "c", Namespace: "other", UID: "uid"},
Other: "foo",
},
},
}, admissionControl, sink)
}, admissionControl, sink)
server := httptest.NewServer(handler)
defer server.Close()
client := http.Client{Timeout: 2 * time.Second}
server := httptest.NewServer(handler)
defer server.Close()
client := http.Client{Timeout: 2 * time.Second}
req, err := test.req(server.URL)
if err != nil {
t.Errorf("[%s] error creating the request: %v", test.desc, err)
}
req, err := test.req(server.URL)
if err != nil {
t.Errorf("[%s] error creating the request: %v", test.desc, err)
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("User-Agent", userAgent)
response, err := client.Do(req)
if err != nil {
t.Errorf("[%s] error: %v", test.desc, err)
}
response, err := client.Do(req)
if err != nil {
t.Errorf("[%s] error: %v", test.desc, err)
}
if response.StatusCode != test.code {
t.Errorf("[%s] expected http code %d, got %#v", test.desc, test.code, response)
}
if response.StatusCode != test.code {
t.Errorf("[%s] expected http code %d, got %#v", test.desc, test.code, response)
}
// close body because the handler might block in Flush, unable to send the remaining event.
response.Body.Close()
// close body because the handler might block in Flush, unable to send the remaining event.
response.Body.Close()
// wait for events to arrive, at least the given number in the test
events := []*auditinternal.Event{}
err = wait.Poll(50*time.Millisecond, wait.ForeverTestTimeout, wait.ConditionFunc(func() (done bool, err error) {
events = sink.Events()
return len(events) >= test.events, nil
}))
if err != nil {
t.Errorf("[%s] timeout waiting for events", test.desc)
}
// wait for events to arrive, at least the given number in the test
events := []*auditinternal.Event{}
err = wait.Poll(50*time.Millisecond, testTimeout(t), wait.ConditionFunc(func() (done bool, err error) {
events = sink.Events()
return len(events) >= test.events, nil
}))
if err != nil {
t.Errorf("[%s] timeout waiting for events", test.desc)
}
if got := len(events); got != test.events {
t.Errorf("[%s] expected %d audit events, got %d", test.desc, test.events, got)
} else {
for i, check := range test.checks {
err := check(events)
if err != nil {
t.Errorf("[%s,%d] %v", test.desc, i, err)
if got := len(events); got != test.events {
t.Errorf("[%s] expected %d audit events, got %d", test.desc, test.events, got)
} else {
for i, check := range test.checks {
err := check(events)
if err != nil {
t.Errorf("[%s,%d] %v", test.desc, i, err)
}
}
if err := requestUserAgentMatches(userAgent)(events); err != nil {
t.Errorf("[%s] %v", test.desc, err)
}
}
if err := requestUserAgentMatches(userAgent)(events); err != nil {
t.Errorf("[%s] %v", test.desc, err)
if len(events) > 0 {
status := events[len(events)-1].ResponseStatus
if status == nil {
t.Errorf("[%s] expected non-nil ResponseStatus in last event", test.desc)
} else if int(status.Code) != test.code {
t.Errorf("[%s] expected ResponseStatus.Code=%d, got %d", test.desc, test.code, status.Code)
}
}
}
if len(events) > 0 {
status := events[len(events)-1].ResponseStatus
if status == nil {
t.Errorf("[%s] expected non-nil ResponseStatus in last event", test.desc)
} else if int(status.Code) != test.code {
t.Errorf("[%s] expected ResponseStatus.Code=%d, got %d", test.desc, test.code, status.Code)
}
}
})
}
}
// testTimeout returns the minimimum of the "ForeverTestTimeout" and the testing deadline (with
// cleanup time).
func testTimeout(t *testing.T) time.Duration {
defaultTimeout := wait.ForeverTestTimeout
const cleanupTime = 5 * time.Second
if deadline, ok := t.Deadline(); ok {
maxTimeout := time.Until(deadline) - cleanupTime
if maxTimeout < defaultTimeout {
return maxTimeout
}
}
return defaultTimeout
}

View File

@ -44,23 +44,21 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
auditContext, err := evaluatePolicyAndCreateAuditEvent(req, policy)
ac, err := evaluatePolicyAndCreateAuditEvent(req, policy)
if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return
}
ev := auditContext.Event
if ev == nil || req.Context() == nil {
if ac == nil || ac.Event == nil {
handler.ServeHTTP(w, req)
return
}
req = req.WithContext(audit.WithAuditContext(req.Context(), auditContext))
ev := ac.Event
ctx := req.Context()
omitStages := auditContext.RequestAuditConfig.OmitStages
omitStages := ac.RequestAuditConfig.OmitStages
ev.Stage = auditinternal.StageRequestReceived
if processed := processAuditEvent(ctx, sink, ev, omitStages); !processed {
@ -124,19 +122,23 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva
// - error if anything bad happened
func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator) (*audit.AuditContext, error) {
ctx := req.Context()
ac := audit.AuditContextFrom(ctx)
if ac == nil {
// Auditing not enabled.
return nil, nil
}
attribs, err := GetAuthorizerAttributes(ctx)
if err != nil {
return nil, fmt.Errorf("failed to GetAuthorizerAttributes: %v", err)
return ac, fmt.Errorf("failed to GetAuthorizerAttributes: %v", err)
}
ls := policy.EvaluatePolicyRule(attribs)
audit.ObservePolicyLevel(ctx, ls.Level)
ac.RequestAuditConfig = ls.RequestAuditConfig
if ls.Level == auditinternal.LevelNone {
// Don't audit.
return &audit.AuditContext{
RequestAuditConfig: ls.RequestAuditConfig,
}, nil
return ac, nil
}
requestReceivedTimestamp, ok := request.ReceivedTimestampFrom(ctx)
@ -148,10 +150,9 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul
return nil, fmt.Errorf("failed to complete audit event from request: %v", err)
}
return &audit.AuditContext{
RequestAuditConfig: ls.RequestAuditConfig,
Event: ev,
}, nil
ac.Event = ev
return ac, nil
}
// writeLatencyToAnnotation writes the latency incurred in different

View File

@ -1,38 +0,0 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"net/http"
"k8s.io/apiserver/pkg/audit"
)
// WithAuditAnnotations decorates a http.Handler with a []{key, value} that is merged
// with the audit.Event.Annotations map. This allows layers that run before WithAudit
// (such as authentication) to assert annotations.
// If sink or audit policy is nil, no decoration takes place.
func WithAuditAnnotations(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEvaluator) http.Handler {
// no need to wrap if auditing is disabled
if sink == nil || policy == nil {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = req.WithContext(audit.WithAuditAnnotations(req.Context()))
handler.ServeHTTP(w, req)
})
}

View File

@ -21,28 +21,25 @@ import (
"k8s.io/apimachinery/pkg/types"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/audit"
"github.com/google/uuid"
)
// WithAuditID attaches the Audit-ID associated with a request to the context.
// WithAuditInit initializes the audit context and attaches the Audit-ID associated with a request.
//
// a. If the caller does not specify a value for Audit-ID in the request header, we generate a new audit ID
// b. We echo the Audit-ID value to the caller via the response Header 'Audit-ID'.
func WithAuditID(handler http.Handler) http.Handler {
return withAuditID(handler, func() string {
func WithAuditInit(handler http.Handler) http.Handler {
return withAuditInit(handler, func() string {
return uuid.New().String()
})
}
func withAuditID(handler http.Handler, newAuditIDFunc func() string) http.Handler {
if newAuditIDFunc == nil {
return handler
}
func withAuditInit(handler http.Handler, newAuditIDFunc func() string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx := audit.WithAuditContext(r.Context())
r = r.WithContext(ctx)
auditID := r.Header.Get(auditinternal.HeaderAuditID)
if len(auditID) == 0 {
@ -50,7 +47,7 @@ func withAuditID(handler http.Handler, newAuditIDFunc func() string) http.Handle
}
// Note: we save the user specified value of the Audit-ID header as is, no truncation is performed.
r = r.WithContext(request.WithAuditID(ctx, types.UID(auditID)))
audit.WithAuditID(ctx, types.UID(auditID))
// We echo the Audit-ID in to the response header.
// It's not guaranteed Audit-ID http header is sent for all requests.

View File

@ -23,7 +23,7 @@ import (
"testing"
"github.com/google/uuid"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/audit"
)
func TestWithAuditID(t *testing.T) {
@ -72,15 +72,15 @@ func TestWithAuditID(t *testing.T) {
innerHandlerCallCount++
// does the inner handler see the audit ID?
v, ok := request.AuditIDFrom(req.Context())
v, ok := audit.AuditIDFrom(req.Context())
found = ok
auditIDGot = string(v)
})
wrapped := WithAuditID(handler)
wrapped := WithAuditInit(handler)
if test.newAuditIDFunc != nil {
wrapped = withAuditID(handler, test.newAuditIDFunc)
wrapped = withAuditInit(handler, test.newAuditIDFunc)
}
testRequest, err := http.NewRequest(http.MethodGet, "/api/v1/namespaces", nil)

View File

@ -676,7 +676,7 @@ func TestAudit(t *testing.T) {
// simplified long-running check
return ri.Verb == "watch"
})
handler = WithAuditID(handler)
handler = WithAuditInit(handler)
req, _ := http.NewRequest(test.verb, test.path, nil)
req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil)
@ -812,7 +812,7 @@ func TestAuditIDHttpHeader(t *testing.T) {
})
fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(test.level, nil)
handler = WithAudit(handler, sink, fakeRuleEvaluator, nil)
handler = WithAuditID(handler)
handler = WithAuditInit(handler)
req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil)
req.RemoteAddr = "127.0.0.1"
@ -843,14 +843,13 @@ func TestAuditIDHttpHeader(t *testing.T) {
}
func withTestContext(req *http.Request, user user.Info, ae *auditinternal.Event) *http.Request {
ctx := req.Context()
ctx := audit.WithAuditContext(req.Context())
if user != nil {
ctx = request.WithUser(ctx, user)
}
if ae != nil {
ctx = audit.WithAuditContext(ctx, &audit.AuditContext{
Event: ae,
})
ac := audit.AuditContextFrom(ctx)
ac.Event = ae
}
if info, err := newTestRequestInfoResolver().NewRequestInfo(req); err == nil {
ctx = request.WithRequestInfo(ctx, info)

View File

@ -36,26 +36,24 @@ func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink,
return failedHandler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
a, err := evaluatePolicyAndCreateAuditEvent(req, policy)
ac, err := evaluatePolicyAndCreateAuditEvent(req, policy)
if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return
}
ev := a.Event
if ev == nil {
if ac == nil || ac.Event == nil {
failedHandler.ServeHTTP(w, req)
return
}
req = req.WithContext(audit.WithAuditContext(req.Context(), a))
ev := ac.Event
ev.ResponseStatus = &metav1.Status{}
ev.ResponseStatus.Message = getAuthMethods(req)
ev.Stage = auditinternal.StageResponseStarted
rw := decorateResponseWriter(req.Context(), w, ev, sink, a.RequestAuditConfig.OmitStages)
rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages)
failedHandler.ServeHTTP(rw, req)
})
}

View File

@ -108,20 +108,18 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta
return failedHandler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
a, err := evaluatePolicyAndCreateAuditEvent(req, policy)
ac, err := evaluatePolicyAndCreateAuditEvent(req, policy)
if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return
}
ev := a.Event
if ev == nil {
if ac == nil || ac.Event == nil {
failedHandler.ServeHTTP(w, req)
return
}
req = req.WithContext(audit.WithAuditContext(req.Context(), a))
ev := ac.Event
ev.ResponseStatus = &metav1.Status{}
ev.Stage = auditinternal.StageResponseStarted
@ -129,7 +127,7 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta
ev.ResponseStatus.Message = statusErr.Error()
}
rw := decorateResponseWriter(req.Context(), w, ev, sink, a.RequestAuditConfig.OmitStages)
rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages)
failedHandler.ServeHTTP(rw, req)
})
}

View File

@ -253,7 +253,7 @@ func TestWithRequestDeadlineWithClock(t *testing.T) {
}
}
func TestWithRequestDeadlineWithFailedRequestIsAudited(t *testing.T) {
func TestWithRequestDeadlineWithInvalidTimeoutIsAudited(t *testing.T) {
var handlerInvoked bool
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
handlerInvoked = true
@ -386,10 +386,7 @@ func TestWithFailedRequestAudit(t *testing.T) {
withAudit := withFailedRequestAudit(errorHandler, test.statusErr, fakeSink, fakeRuleEvaluator)
w := httptest.NewRecorder()
testRequest, err := http.NewRequest(http.MethodGet, "/apis/v1/namespaces/default/pods", nil)
if err != nil {
t.Fatalf("failed to create new http testRequest - %v", err)
}
testRequest := newRequest(t, "/apis/v1/namespaces/default/pods")
info := request.RequestInfo{}
testRequest = testRequest.WithContext(request.WithRequestInfo(testRequest.Context(), &info))
@ -446,8 +443,8 @@ func newRequest(t *testing.T, requestURL string) *http.Request {
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}
return req
ctx := audit.WithAuditContext(req.Context())
return req.WithContext(ctx)
}
func message(err error) string {

View File

@ -60,11 +60,11 @@ func (s *mockCodecs) EncoderForVersion(encoder runtime.Encoder, gv runtime.Group
func TestDeleteResourceAuditLogRequestObject(t *testing.T) {
ctx := audit.WithAuditContext(context.TODO(), &audit.AuditContext{
Event: &auditapis.Event{
Level: auditapis.LevelRequestResponse,
},
})
ctx := audit.WithAuditContext(context.TODO())
ac := audit.AuditContextFrom(ctx)
ac.Event = &auditapis.Event{
Level: auditapis.LevelRequestResponse,
}
policy := metav1.DeletePropagationBackground
deleteOption := &metav1.DeleteOptions{

View File

@ -20,7 +20,7 @@ import (
"net/http"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/audit"
)
const (
@ -83,7 +83,7 @@ type lazyAuditID struct {
func (lazy *lazyAuditID) String() string {
if lazy.req != nil {
return request.GetAuditIDTruncated(lazy.req.Context())
return audit.GetAuditIDTruncated(lazy.req.Context())
}
return "unknown"

View File

@ -88,7 +88,7 @@ func StreamObject(statusCode int, gv schema.GroupVersion, s runtime.NegotiatedSe
// a client and the feature gate for APIResponseCompression is enabled.
func SerializeObject(mediaType string, encoder runtime.Encoder, hw http.ResponseWriter, req *http.Request, statusCode int, object runtime.Object) {
trace := utiltrace.New("SerializeObject",
utiltrace.Field{"audit-id", request.GetAuditIDTruncated(req.Context())},
utiltrace.Field{"audit-id", audit.GetAuditIDTruncated(req.Context())},
utiltrace.Field{"method", req.Method},
utiltrace.Field{"url", req.URL.Path},
utiltrace.Field{"protocol", req.Proto},

View File

@ -1,65 +0,0 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package request
import (
"context"
"k8s.io/apimachinery/pkg/types"
)
type auditIDKeyType int
// auditIDKey is the key to associate the Audit-ID value of a request.
const auditIDKey auditIDKeyType = iota
// WithAuditID returns a copy of the parent context into which the Audit-ID
// associated with the request is set.
//
// If the specified auditID is empty, no value is set and the parent context is returned as is.
func WithAuditID(parent context.Context, auditID types.UID) context.Context {
if auditID == "" {
return parent
}
return WithValue(parent, auditIDKey, auditID)
}
// AuditIDFrom returns the value of the audit ID from the request context.
func AuditIDFrom(ctx context.Context) (types.UID, bool) {
auditID, ok := ctx.Value(auditIDKey).(types.UID)
return auditID, ok
}
// GetAuditIDTruncated returns the audit ID (truncated) from the request context.
// If the length of the Audit-ID value exceeds the limit, we truncate it to keep
// the first N (maxAuditIDLength) characters.
// This is intended to be used in logging only.
func GetAuditIDTruncated(ctx context.Context) string {
auditID, ok := AuditIDFrom(ctx)
if !ok {
return ""
}
// if the user has specified a very long audit ID then we will use the first N characters
// Note: assuming Audit-ID header is in ASCII
const maxAuditIDLength = 64
if len(auditID) > maxAuditIDLength {
auditID = auditID[0:maxAuditIDLength]
}
return string(auditID)
}

View File

@ -1,68 +0,0 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package request
import (
"context"
"testing"
"k8s.io/apimachinery/pkg/types"
)
func TestAuditIDFrom(t *testing.T) {
tests := []struct {
name string
auditID string
auditIDExpected string
expected bool
}{
{
name: "empty audit ID",
auditID: "",
auditIDExpected: "",
expected: false,
},
{
name: "non empty audit ID",
auditID: "foo-bar-baz",
auditIDExpected: "foo-bar-baz",
expected: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parent := context.TODO()
ctx := WithAuditID(parent, types.UID(test.auditID))
// for an empty audit ID we don't expect a copy of the parent context.
if len(test.auditID) == 0 && parent != ctx {
t.Error("expected no copy of the parent context with an empty audit ID")
}
value, ok := AuditIDFrom(ctx)
if test.expected != ok {
t.Errorf("expected AuditIDFrom to return: %t, but got: %t", test.expected, ok)
}
auditIDGot := string(value)
if test.auditIDExpected != auditIDGot {
t.Errorf("expected audit ID: %q, but got: %q", test.auditIDExpected, auditIDGot)
}
})
}
}

View File

@ -854,7 +854,6 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 {
handler = genericfilters.WithProbabilisticGoaway(handler, c.GoawayChance)
}
handler = genericapifilters.WithAuditAnnotations(handler, c.AuditBackend, c.AuditPolicyRuleEvaluator)
handler = genericapifilters.WithWarningRecorder(handler)
handler = genericapifilters.WithCacheControl(handler)
handler = genericfilters.WithHSTS(handler, c.HSTSDirectives)
@ -870,7 +869,7 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
handler = genericapifilters.WithRequestReceivedTimestamp(handler)
handler = genericapifilters.WithMuxAndDiscoveryComplete(handler, c.lifecycleSignals.MuxAndDiscoveryComplete.Signaled())
handler = genericfilters.WithPanicRecovery(handler, c.RequestInfoResolver)
handler = genericapifilters.WithAuditID(handler)
handler = genericapifilters.WithAuditInit(handler)
return handler
}

View File

@ -306,7 +306,7 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) {
h := DefaultBuildHandlerChain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// confirm this is a no-op
if r.Context() != audit.WithAuditAnnotations(r.Context()) {
if r.Context() != audit.WithAuditContext(r.Context()) {
t.Error("unexpected double wrapping of context")
}

View File

@ -1197,7 +1197,7 @@ func newHandlerChain(t *testing.T, handler http.Handler, filter utilflowcontrol.
handler = apifilters.WithRequestDeadline(handler, nil, nil, longRunningRequestCheck, nil, requestTimeout)
handler = apifilters.WithRequestInfo(handler, requestInfoFactory)
handler = WithPanicRecovery(handler, requestInfoFactory)
handler = apifilters.WithAuditID(handler)
handler = apifilters.WithAuditInit(handler)
return handler
}

View File

@ -21,6 +21,7 @@ import (
"net/http"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/metrics"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/server/httplog"
@ -50,11 +51,11 @@ func WithPanicRecovery(handler http.Handler, resolver request.RequestInfoResolve
// This call can have different handlers, but the default chain rate limits. Call it after the metrics are updated
// in case the rate limit delays it. If you outrun the rate for this one timed out requests, something has gone
// seriously wrong with your server, but generally having a logging signal for timeouts is useful.
runtime.HandleError(fmt.Errorf("timeout or abort while handling: method=%v URI=%q audit-ID=%q", req.Method, req.RequestURI, request.GetAuditIDTruncated(req.Context())))
runtime.HandleError(fmt.Errorf("timeout or abort while handling: method=%v URI=%q audit-ID=%q", req.Method, req.RequestURI, audit.GetAuditIDTruncated(req.Context())))
return
}
http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError)
klog.ErrorS(nil, "apiserver panic'd", "method", req.Method, "URI", req.RequestURI, "audit-ID", request.GetAuditIDTruncated(req.Context()))
klog.ErrorS(nil, "apiserver panic'd", "method", req.Method, "URI", req.RequestURI, "audit-ID", audit.GetAuditIDTruncated(req.Context()))
})
}

View File

@ -27,6 +27,7 @@ import (
"sync"
"time"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/metrics"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/endpoints/responsewriter"
@ -241,7 +242,7 @@ func SetStacktracePredicate(ctx context.Context, pred StacktracePred) {
// Log is intended to be called once at the end of your request handler, via defer
func (rl *respLogger) Log() {
latency := time.Since(rl.startTime)
auditID := request.GetAuditIDTruncated(rl.req.Context())
auditID := audit.GetAuditIDTruncated(rl.req.Context())
verb := rl.req.Method
if requestInfo, ok := request.RequestInfoFrom(rl.req.Context()); ok {

View File

@ -35,7 +35,7 @@ import (
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/watch"
endpointsrequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/storage"
"k8s.io/apiserver/pkg/storage/cacher/metrics"
@ -670,7 +670,7 @@ func (c *Cacher) GetList(ctx context.Context, key string, opts storage.ListOptio
}
trace := utiltrace.New("cacher list",
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "type", Value: c.groupResource.String()})
defer trace.LogIfLong(500 * time.Millisecond)

View File

@ -37,7 +37,7 @@ import (
"k8s.io/apimachinery/pkg/conversion"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
endpointsrequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/storage"
"k8s.io/apiserver/pkg/storage/etcd3/metrics"
@ -153,7 +153,7 @@ func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, ou
// Create implements storage.Interface.Create.
func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, ttl uint64) error {
trace := utiltrace.New("Create etcd3",
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "key", Value: key},
utiltrace.Field{Key: "type", Value: getTypeName(obj)},
utiltrace.Field{Key: "resource", Value: s.groupResourceString},
@ -332,7 +332,7 @@ func (s *store) GuaranteedUpdate(
ctx context.Context, key string, destination runtime.Object, ignoreNotFound bool,
preconditions *storage.Preconditions, tryUpdate storage.UpdateFunc, cachedExistingObject runtime.Object) error {
trace := utiltrace.New("GuaranteedUpdate etcd3",
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "key", Value: key},
utiltrace.Field{Key: "type", Value: getTypeName(destination)},
utiltrace.Field{Key: "resource", Value: s.groupResourceString})
@ -529,7 +529,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption
match := opts.ResourceVersionMatch
pred := opts.Predicate
trace := utiltrace.New(fmt.Sprintf("List(recursive=%v) etcd3", recursive),
utiltrace.Field{Key: "audit-id", Value: endpointsrequest.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "audit-id", Value: audit.GetAuditIDTruncated(ctx)},
utiltrace.Field{Key: "key", Value: key},
utiltrace.Field{Key: "resourceVersion", Value: resourceVersion},
utiltrace.Field{Key: "resourceVersionMatch", Value: match},

View File

@ -245,8 +245,9 @@ func TestCheckForHostnameError(t *testing.T) {
if err != nil {
t.Fatalf("failed to create an http request: %v", err)
}
auditCtx := &audit.AuditContext{Event: &auditapi.Event{Level: auditapi.LevelMetadata}}
req = req.WithContext(audit.WithAuditContext(req.Context(), auditCtx))
req = req.WithContext(audit.WithAuditContext(req.Context()))
auditCtx := audit.AuditContextFrom(req.Context())
auditCtx.Event = &auditapi.Event{Level: auditapi.LevelMetadata}
_, err = client.Transport.RoundTrip(req)
@ -387,8 +388,9 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) {
if err != nil {
t.Fatalf("failed to create an http request: %v", err)
}
auditCtx := &audit.AuditContext{Event: &auditapi.Event{Level: auditapi.LevelMetadata}}
req = req.WithContext(audit.WithAuditContext(req.Context(), auditCtx))
req = req.WithContext(audit.WithAuditContext(req.Context()))
auditCtx := audit.AuditContextFrom(req.Context())
auditCtx.Event = &auditapi.Event{Level: auditapi.LevelMetadata}
// can't use tlsServer.Client() as it contains the server certificate
// in tls.Config.Certificates. The signatures are, however, only checked

View File

@ -29,6 +29,7 @@ import (
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/proxy"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
endpointmetrics "k8s.io/apiserver/pkg/endpoints/metrics"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
@ -206,7 +207,7 @@ func newRequestForProxy(location *url.URL, req *http.Request) (*http.Request, co
// If the original request has an audit ID, let's make sure we propagate this
// to the aggregated server.
if auditID, found := genericapirequest.AuditIDFrom(req.Context()); found {
if auditID, found := audit.AuditIDFrom(req.Context()); found {
newReq.Header.Set(auditinternal.HeaderAuditID, string(auditID))
}

View File

@ -34,6 +34,7 @@ import (
"sync/atomic"
"testing"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/server/dynamiccertificates"
"golang.org/x/net/websocket"
@ -790,7 +791,9 @@ func TestNewRequestForProxyWithAuditID(t *testing.T) {
req = req.WithContext(genericapirequest.WithRequestInfo(req.Context(), &genericapirequest.RequestInfo{Path: req.URL.Path}))
if len(test.auditID) > 0 {
req = req.WithContext(genericapirequest.WithAuditID(req.Context(), types.UID(test.auditID)))
ctx := audit.WithAuditContext(req.Context())
audit.WithAuditID(ctx, types.UID(test.auditID))
req = req.WithContext(ctx)
}
newReq, _ := newRequestForProxy(req.URL, req)