Move context package internal

Our context package predates the establishment of current best practices
regarding context usage and it shows. It encourages bad practices such
as using contexts to propagate non-request-scoped values like the
application version and using string-typed keys for context values. Move
the package internal to remove it from the API surface of
distribution/v3@v3.0.0 so we are free to iterate on it without being
constrained by compatibility.

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider
2023-10-24 13:16:58 -04:00
parent 6c694cbcf6
commit d0f5aa670b
61 changed files with 151 additions and 151 deletions

View File

@@ -17,7 +17,7 @@ import (
"time"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/distribution/distribution/v3/registry/api/errcode"
@@ -108,7 +108,7 @@ func TestBlobServeBlob(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@@ -157,7 +157,7 @@ func TestBlobServeBlobHEAD(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@@ -250,7 +250,7 @@ func TestBlobResume(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -307,7 +307,7 @@ func TestBlobDelete(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -327,7 +327,7 @@ func TestBlobFetch(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@@ -382,7 +382,7 @@ func TestBlobExistsNoContentLength(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -406,7 +406,7 @@ func TestBlobExists(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@@ -512,7 +512,7 @@ func TestBlobUploadChunked(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -622,7 +622,7 @@ func TestBlobUploadMonolithic(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -728,7 +728,7 @@ func TestBlobUploadMonolithicDockerUploadUUIDFromURL(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -833,7 +833,7 @@ func TestBlobUploadMonolithicNoDockerUploadUUID(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -891,7 +891,7 @@ func TestBlobMount(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -1066,7 +1066,7 @@ func checkEqualManifest(m1, m2 *ocischema.DeserializedManifest) error {
}
func TestOCIManifestFetch(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo")
m1, dgst, pl := newRandomOCIManifest(t, 6)
var m testutil.RequestResponseMap
@@ -1149,7 +1149,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -1171,7 +1171,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
}
func TestManifestFetchWithAccept(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo")
_, dgst, _ := newRandomOCIManifest(t, 6)
headers := make(chan []string, 1)
@@ -1258,7 +1258,7 @@ func TestManifestDelete(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
@@ -1315,7 +1315,7 @@ func TestManifestPut(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
@@ -1372,7 +1372,7 @@ func TestManifestTags(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
tagService := r.Tags(ctx)
tags, err := tagService.All(ctx)
@@ -1423,7 +1423,7 @@ func TestTagDelete(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ts := r.Tags(ctx)
if err := ts.Untag(ctx, tag); err != nil {
@@ -1460,7 +1460,7 @@ func TestObtainsErrorForMissingTag(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -1487,7 +1487,7 @@ func TestObtainsManifestForTagWithoutHeaders(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@@ -1566,7 +1566,7 @@ func TestManifestTagsPaginated(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
tagService := r.Tags(ctx)
tags, err := tagService.All(ctx)
@@ -1614,7 +1614,7 @@ func TestManifestUnauthorized(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
@@ -1652,7 +1652,7 @@ func TestCatalog(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
numFilled, err := r.Repositories(ctx, entries, "")
if err != io.EOF {
t.Fatal(err)
@@ -1684,7 +1684,7 @@ func TestCatalogInParts(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
numFilled, err := r.Repositories(ctx, entries, "")
if err != nil {
t.Fatal(err)

View File

@@ -0,0 +1,73 @@
package dcontext
import (
"context"
"sync"
"github.com/google/uuid"
)
// instanceContext is a context that provides only an instance id. It is
// provided as the main background context.
type instanceContext struct {
context.Context
id string // id of context, logged as "instance.id"
once sync.Once // once protect generation of the id
}
func (ic *instanceContext) Value(key interface{}) interface{} {
if key == "instance.id" {
ic.once.Do(func() {
// We want to lazy initialize the UUID such that we don't
// call a random generator from the package initialization
// code. For various reasons random could not be available
// https://github.com/distribution/distribution/issues/782
ic.id = uuid.NewString()
})
return ic.id
}
return ic.Context.Value(key)
}
var background = &instanceContext{
Context: context.Background(),
}
// Background returns a non-nil, empty Context. The background context
// provides a single key, "instance.id" that is globally unique to the
// process.
func Background() context.Context {
return background
}
// stringMapContext is a simple context implementation that checks a map for a
// key, falling back to a parent if not present.
type stringMapContext struct {
context.Context
m map[string]interface{}
}
// WithValues returns a context that proxies lookups through a map. Only
// supports string keys.
func WithValues(ctx context.Context, m map[string]interface{}) context.Context {
mo := make(map[string]interface{}, len(m)) // make our own copy.
for k, v := range m {
mo[k] = v
}
return stringMapContext{
Context: ctx,
m: mo,
}
}
func (smc stringMapContext) Value(key interface{}) interface{} {
if ks, ok := key.(string); ok {
if v, ok := smc.m[ks]; ok {
return v
}
}
return smc.Context.Value(key)
}

88
internal/dcontext/doc.go Normal file
View File

@@ -0,0 +1,88 @@
// Package dcontext provides several utilities for working with
// Go's context in http requests. Primarily, the focus is on logging relevant
// request information but this package is not limited to that purpose.
//
// The easiest way to get started is to get the background context:
//
// ctx := dcontext.Background()
//
// The returned context should be passed around your application and be the
// root of all other context instances. If the application has a version, this
// line should be called before anything else:
//
// ctx := dcontext.WithVersion(dcontext.Background(), version)
//
// The above will store the version in the context and will be available to
// the logger.
//
// # Logging
//
// The most useful aspect of this package is GetLogger. This function takes
// any context.Context interface and returns the current logger from the
// context. Canonical usage looks like this:
//
// GetLogger(ctx).Infof("something interesting happened")
//
// GetLogger also takes optional key arguments. The keys will be looked up in
// the context and reported with the logger. The following example would
// return a logger that prints the version with each log message:
//
// ctx := context.WithValue(dcontext.Background(), "version", version)
// GetLogger(ctx, "version").Infof("this log message has a version field")
//
// The above would print out a log message like this:
//
// INFO[0000] this log message has a version field version=v2.0.0-alpha.2.m
//
// When used with WithLogger, we gain the ability to decorate the context with
// loggers that have information from disparate parts of the call stack.
// Following from the version example, we can build a new context with the
// configured logger such that we always print the version field:
//
// ctx = WithLogger(ctx, GetLogger(ctx, "version"))
//
// Since the logger has been pushed to the context, we can now get the version
// field for free with our log messages. Future calls to GetLogger on the new
// context will have the version field:
//
// GetLogger(ctx).Infof("this log message has a version field")
//
// This becomes more powerful when we start stacking loggers. Let's say we
// have the version logger from above but also want a request id. Using the
// context above, in our request scoped function, we place another logger in
// the context:
//
// ctx = context.WithValue(ctx, "http.request.id", "unique id") // called when building request context
// ctx = WithLogger(ctx, GetLogger(ctx, "http.request.id"))
//
// When GetLogger is called on the new context, "http.request.id" will be
// included as a logger field, along with the original "version" field:
//
// INFO[0000] this log message has a version field http.request.id=unique id version=v2.0.0-alpha.2.m
//
// Note that this only affects the new context, the previous context, with the
// version field, can be used independently. Put another way, the new logger,
// added to the request context, is unique to that context and can have
// request scoped variables.
//
// # HTTP Requests
//
// This package also contains several methods for working with http requests.
// The concepts are very similar to those described above. We simply place the
// request in the context using WithRequest. This makes the request variables
// available. GetRequestLogger can then be called to get request specific
// variables in a log line:
//
// ctx = WithRequest(ctx, req)
// GetRequestLogger(ctx).Infof("request variables")
//
// Like above, if we want to include the request data in all log messages in
// the context, we push the logger to a new context and use that one:
//
// ctx = WithLogger(ctx, GetRequestLogger(ctx))
//
// The concept is fairly powerful and ensures that calls throughout the stack
// can be traced in log messages. Using the fields like "http.request.id", one
// can analyze call flow for a particular request with a simple grep of the
// logs.
package dcontext

309
internal/dcontext/http.go Normal file
View File

@@ -0,0 +1,309 @@
package dcontext
import (
"context"
"errors"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
// Common errors used with this package.
var (
ErrNoRequestContext = errors.New("no http request in context")
ErrNoResponseWriterContext = errors.New("no http response in context")
)
func parseIP(ipStr string) net.IP {
ip := net.ParseIP(ipStr)
if ip == nil {
log.Warnf("invalid remote IP address: %q", ipStr)
}
return ip
}
// RemoteAddr extracts the remote address of the request, taking into
// account proxy headers.
func RemoteAddr(r *http.Request) string {
if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
remoteAddr, _, _ := strings.Cut(prior, ",")
remoteAddr = strings.Trim(remoteAddr, " ")
if parseIP(remoteAddr) != nil {
return remoteAddr
}
}
// X-Real-Ip is less supported, but worth checking in the
// absence of X-Forwarded-For
if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
if parseIP(realIP) != nil {
return realIP
}
}
return r.RemoteAddr
}
// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
addr := RemoteAddr(r)
// Try parsing it as "IP:port"
if ip, _, err := net.SplitHostPort(addr); err == nil {
return ip
}
return addr
}
// WithRequest places the request on the context. The context of the request
// is assigned a unique id, available at "http.request.id". The request itself
// is available at "http.request". Other common attributes are available under
// the prefix "http.request.". If a request is already present on the context,
// this method will panic.
func WithRequest(ctx context.Context, r *http.Request) context.Context {
if ctx.Value("http.request") != nil {
// NOTE(stevvooe): This needs to be considered a programming error. It
// is unlikely that we'd want to have more than one request in
// context.
panic("only one request per context")
}
return &httpRequestContext{
Context: ctx,
startedAt: time.Now(),
id: uuid.NewString(),
r: r,
}
}
// GetRequest returns the http request in the given context. Returns
// ErrNoRequestContext if the context does not have an http request associated
// with it.
func GetRequest(ctx context.Context) (*http.Request, error) {
if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok {
return r, nil
}
return nil, ErrNoRequestContext
}
// GetRequestID attempts to resolve the current request id, if possible. An
// error is return if it is not available on the context.
func GetRequestID(ctx context.Context) string {
return GetStringValue(ctx, "http.request.id")
}
// WithResponseWriter returns a new context and response writer that makes
// interesting response statistics available within the context.
func WithResponseWriter(ctx context.Context, w http.ResponseWriter) (context.Context, http.ResponseWriter) {
irw := instrumentedResponseWriter{
ResponseWriter: w,
Context: ctx,
}
return &irw, &irw
}
// GetResponseWriter returns the http.ResponseWriter from the provided
// context. If not present, ErrNoResponseWriterContext is returned. The
// returned instance provides instrumentation in the context.
func GetResponseWriter(ctx context.Context) (http.ResponseWriter, error) {
v := ctx.Value("http.response")
rw, ok := v.(http.ResponseWriter)
if !ok || rw == nil {
return nil, ErrNoResponseWriterContext
}
return rw, nil
}
// getVarsFromRequest let's us change request vars implementation for testing
// and maybe future changes.
var getVarsFromRequest = mux.Vars
// WithVars extracts gorilla/mux vars and makes them available on the returned
// context. Variables are available at keys with the prefix "vars.". For
// example, if looking for the variable "name", it can be accessed as
// "vars.name". Implementations that are accessing values need not know that
// the underlying context is implemented with gorilla/mux vars.
func WithVars(ctx context.Context, r *http.Request) context.Context {
return &muxVarsContext{
Context: ctx,
vars: getVarsFromRequest(r),
}
}
// GetRequestLogger returns a logger that contains fields from the request in
// the current context. If the request is not available in the context, no
// fields will display. Request loggers can safely be pushed onto the context.
func GetRequestLogger(ctx context.Context) Logger {
return GetLogger(ctx,
"http.request.id",
"http.request.method",
"http.request.host",
"http.request.uri",
"http.request.referer",
"http.request.useragent",
"http.request.remoteaddr",
"http.request.contenttype")
}
// GetResponseLogger reads the current response stats and builds a logger.
// Because the values are read at call time, pushing a logger returned from
// this function on the context will lead to missing or invalid data. Only
// call this at the end of a request, after the response has been written.
func GetResponseLogger(ctx context.Context) Logger {
l := getLogrusLogger(ctx,
"http.response.written",
"http.response.status",
"http.response.contenttype")
duration := Since(ctx, "http.request.startedat")
if duration > 0 {
l = l.WithField("http.response.duration", duration.String())
}
return l
}
// httpRequestContext makes information about a request available to context.
type httpRequestContext struct {
context.Context
startedAt time.Time
id string
r *http.Request
}
// Value returns a keyed element of the request for use in the context. To get
// the request itself, query "request". For other components, access them as
// "request.<component>". For example, r.RequestURI
func (ctx *httpRequestContext) Value(key interface{}) interface{} {
if keyStr, ok := key.(string); ok {
switch keyStr {
case "http.request":
return ctx.r
case "http.request.uri":
return ctx.r.RequestURI
case "http.request.remoteaddr":
return RemoteAddr(ctx.r)
case "http.request.method":
return ctx.r.Method
case "http.request.host":
return ctx.r.Host
case "http.request.referer":
referer := ctx.r.Referer()
if referer != "" {
return referer
}
case "http.request.useragent":
return ctx.r.UserAgent()
case "http.request.id":
return ctx.id
case "http.request.startedat":
return ctx.startedAt
case "http.request.contenttype":
if ct := ctx.r.Header.Get("Content-Type"); ct != "" {
return ct
}
default:
// no match; fall back to standard behavior below
}
}
return ctx.Context.Value(key)
}
type muxVarsContext struct {
context.Context
vars map[string]string
}
func (ctx *muxVarsContext) Value(key interface{}) interface{} {
if keyStr, ok := key.(string); ok {
if keyStr == "vars" {
return ctx.vars
}
// TODO(thaJeztah): this considers "vars.FOO" and "FOO" to be equal.
// We need to check if that's intentional (could be a bug).
if v, ok := ctx.vars[strings.TrimPrefix(keyStr, "vars.")]; ok {
return v
}
}
return ctx.Context.Value(key)
}
// instrumentedResponseWriter provides response writer information in a
// context. This variant is only used in the case where CloseNotifier is not
// implemented by the parent ResponseWriter.
type instrumentedResponseWriter struct {
http.ResponseWriter
context.Context
mu sync.Mutex
status int
written int64
}
func (irw *instrumentedResponseWriter) Write(p []byte) (n int, err error) {
n, err = irw.ResponseWriter.Write(p)
irw.mu.Lock()
irw.written += int64(n)
// Guess the likely status if not set.
if irw.status == 0 {
irw.status = http.StatusOK
}
irw.mu.Unlock()
return
}
func (irw *instrumentedResponseWriter) WriteHeader(status int) {
irw.ResponseWriter.WriteHeader(status)
irw.mu.Lock()
irw.status = status
irw.mu.Unlock()
}
func (irw *instrumentedResponseWriter) Flush() {
if flusher, ok := irw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (irw *instrumentedResponseWriter) Value(key interface{}) interface{} {
if keyStr, ok := key.(string); ok {
switch keyStr {
case "http.response":
return irw
case "http.response.written":
irw.mu.Lock()
defer irw.mu.Unlock()
return irw.written
case "http.response.status":
irw.mu.Lock()
defer irw.mu.Unlock()
return irw.status
case "http.response.contenttype":
if ct := irw.Header().Get("Content-Type"); ct != "" {
return ct
}
default:
// no match; fall back to standard behavior below
}
}
return irw.Context.Value(key)
}

View File

@@ -0,0 +1,288 @@
package dcontext
import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"reflect"
"testing"
"time"
)
func TestWithRequest(t *testing.T) {
var req http.Request
start := time.Now()
req.Method = http.MethodGet
req.Host = "example.com"
req.RequestURI = "/test-test"
req.Header = make(http.Header)
req.Header.Set("Referer", "foo.com/referer")
req.Header.Set("User-Agent", "test/0.1")
ctx := WithRequest(Background(), &req)
for _, tc := range []struct {
key string
expected interface{}
}{
{
key: "http.request",
expected: &req,
},
{
key: "http.request.id",
},
{
key: "http.request.method",
expected: req.Method,
},
{
key: "http.request.host",
expected: req.Host,
},
{
key: "http.request.uri",
expected: req.RequestURI,
},
{
key: "http.request.referer",
expected: req.Referer(),
},
{
key: "http.request.useragent",
expected: req.UserAgent(),
},
{
key: "http.request.remoteaddr",
expected: req.RemoteAddr,
},
{
key: "http.request.startedat",
},
} {
v := ctx.Value(tc.key)
if v == nil {
t.Fatalf("value not found for %q", tc.key)
}
if tc.expected != nil && v != tc.expected {
t.Fatalf("%s: %v != %v", tc.key, v, tc.expected)
}
// Key specific checks!
switch tc.key {
case "http.request.id":
if _, ok := v.(string); !ok {
t.Fatalf("request id not a string: %v", v)
}
case "http.request.startedat":
vt, ok := v.(time.Time)
if !ok {
t.Fatalf("value not a time: %v", v)
}
now := time.Now()
if vt.After(now) {
t.Fatalf("time generated too late: %v > %v", vt, now)
}
if vt.Before(start) {
t.Fatalf("time generated too early: %v < %v", vt, start)
}
}
}
}
type testResponseWriter struct {
flushed bool
status int
written int64
header http.Header
}
func (trw *testResponseWriter) Header() http.Header {
if trw.header == nil {
trw.header = make(http.Header)
}
return trw.header
}
func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
if trw.status == 0 {
trw.status = http.StatusOK
}
n = len(p)
trw.written += int64(n)
return
}
func (trw *testResponseWriter) WriteHeader(status int) {
trw.status = status
}
func (trw *testResponseWriter) Flush() {
trw.flushed = true
}
func TestWithResponseWriter(t *testing.T) {
trw := testResponseWriter{}
ctx, rw := WithResponseWriter(Background(), &trw)
if ctx.Value("http.response") != rw {
t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
}
grw, err := GetResponseWriter(ctx)
if err != nil {
t.Fatalf("error getting response writer: %v", err)
}
if grw != rw {
t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
}
if ctx.Value("http.response.status") != 0 {
t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
}
if n, err := rw.Write(make([]byte, 1024)); err != nil {
t.Fatalf("unexpected error writing: %v", err)
} else if n != 1024 {
t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
}
if ctx.Value("http.response.status") != http.StatusOK {
t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
}
if ctx.Value("http.response.written") != int64(1024) {
t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
}
// Make sure flush propagates
rw.(http.Flusher).Flush()
if !trw.flushed {
t.Fatalf("response writer not flushed")
}
// Write another status and make sure context is correct. This normally
// wouldn't work except for in this contrived testcase.
rw.WriteHeader(http.StatusBadRequest)
if ctx.Value("http.response.status") != http.StatusBadRequest {
t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
}
}
func TestWithVars(t *testing.T) {
var req http.Request
vars := map[string]string{
"foo": "asdf",
"bar": "qwer",
}
getVarsFromRequest = func(r *http.Request) map[string]string {
if r != &req {
t.Fatalf("unexpected request: %v != %v", r, req)
}
return vars
}
ctx := WithVars(Background(), &req)
for _, tc := range []struct {
key string
expected interface{}
}{
{
key: "vars",
expected: vars,
},
{
key: "vars.foo",
expected: "asdf",
},
{
key: "vars.bar",
expected: "qwer",
},
} {
v := ctx.Value(tc.key)
if !reflect.DeepEqual(v, tc.expected) {
t.Fatalf("%q: %v != %v", tc.key, v, tc.expected)
}
}
}
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
// just contains the IP address, it is different enough for testing.
func TestRemoteAddr(t *testing.T) {
var expectedRemote string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if r.RemoteAddr == expectedRemote {
t.Errorf("Unexpected matching remote addresses")
}
actualRemote := RemoteAddr(r)
if expectedRemote != actualRemote {
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
}
w.WriteHeader(200)
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxy := httputil.NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(proxy)
defer frontend.Close()
// X-Forwarded-For set by proxy
expectedRemote = "127.0.0.1"
proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// RemoteAddr in X-Real-Ip
getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedRemote = "1.2.3.4"
getReq.Header["X-Real-ip"] = []string{expectedRemote}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Valid X-Real-Ip and invalid X-Forwarded-For
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
}

139
internal/dcontext/logger.go Normal file
View File

@@ -0,0 +1,139 @@
package dcontext
import (
"context"
"fmt"
"runtime"
"sync"
"github.com/sirupsen/logrus"
)
var (
defaultLogger *logrus.Entry = logrus.StandardLogger().WithField("go.version", runtime.Version())
defaultLoggerMu sync.RWMutex
)
// Logger provides a leveled-logging interface.
type Logger interface {
// standard logger methods
Print(args ...interface{})
Printf(format string, args ...interface{})
Println(args ...interface{})
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
Fatalln(args ...interface{})
Panic(args ...interface{})
Panicf(format string, args ...interface{})
Panicln(args ...interface{})
// Leveled methods, from logrus
Debug(args ...interface{})
Debugf(format string, args ...interface{})
Debugln(args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
Errorln(args ...interface{})
Info(args ...interface{})
Infof(format string, args ...interface{})
Infoln(args ...interface{})
Warn(args ...interface{})
Warnf(format string, args ...interface{})
Warnln(args ...interface{})
WithError(err error) *logrus.Entry
}
type loggerKey struct{}
// WithLogger creates a new context with provided logger.
func WithLogger(ctx context.Context, logger Logger) context.Context {
return context.WithValue(ctx, loggerKey{}, logger)
}
// GetLoggerWithField returns a logger instance with the specified field key
// and value without affecting the context. Extra specified keys will be
// resolved from the context.
func GetLoggerWithField(ctx context.Context, key, value interface{}, keys ...interface{}) Logger {
return getLogrusLogger(ctx, keys...).WithField(fmt.Sprint(key), value)
}
// GetLoggerWithFields returns a logger instance with the specified fields
// without affecting the context. Extra specified keys will be resolved from
// the context.
func GetLoggerWithFields(ctx context.Context, fields map[interface{}]interface{}, keys ...interface{}) Logger {
// must convert from interface{} -> interface{} to string -> interface{} for logrus.
lfields := make(logrus.Fields, len(fields))
for key, value := range fields {
lfields[fmt.Sprint(key)] = value
}
return getLogrusLogger(ctx, keys...).WithFields(lfields)
}
// GetLogger returns the logger from the current context, if present. If one
// or more keys are provided, they will be resolved on the context and
// included in the logger. While context.Value takes an interface, any key
// argument passed to GetLogger will be passed to fmt.Sprint when expanded as
// a logging key field. If context keys are integer constants, for example,
// its recommended that a String method is implemented.
func GetLogger(ctx context.Context, keys ...interface{}) Logger {
return getLogrusLogger(ctx, keys...)
}
// SetDefaultLogger sets the default logger upon which to base new loggers.
func SetDefaultLogger(logger Logger) {
entry, ok := logger.(*logrus.Entry)
if !ok {
return
}
defaultLoggerMu.Lock()
defaultLogger = entry
defaultLoggerMu.Unlock()
}
// GetLogrusLogger returns the logrus logger for the context. If one more keys
// are provided, they will be resolved on the context and included in the
// logger. Only use this function if specific logrus functionality is
// required.
func getLogrusLogger(ctx context.Context, keys ...interface{}) *logrus.Entry {
var logger *logrus.Entry
// Get a logger, if it is present.
loggerInterface := ctx.Value(loggerKey{})
if loggerInterface != nil {
if lgr, ok := loggerInterface.(*logrus.Entry); ok {
logger = lgr
}
}
if logger == nil {
fields := logrus.Fields{}
// Fill in the instance id, if we have it.
instanceID := ctx.Value("instance.id")
if instanceID != nil {
fields["instance.id"] = instanceID
}
defaultLoggerMu.RLock()
logger = defaultLogger.WithFields(fields)
defaultLoggerMu.RUnlock()
}
fields := logrus.Fields{}
for _, key := range keys {
v := ctx.Value(key)
if v != nil {
fields[fmt.Sprint(key)] = v
}
}
return logger.WithFields(fields)
}

105
internal/dcontext/trace.go Normal file
View File

@@ -0,0 +1,105 @@
package dcontext
import (
"context"
"runtime"
"time"
"github.com/google/uuid"
)
// WithTrace allocates a traced timing span in a new context. This allows a
// caller to track the time between calling WithTrace and the returned done
// function. When the done function is called, a log message is emitted with a
// "trace.duration" field, corresponding to the elapsed time and a
// "trace.func" field, corresponding to the function that called WithTrace.
//
// The logging keys "trace.id" and "trace.parent.id" are provided to implement
// dapper-like tracing. This function should be complemented with a WithSpan
// method that could be used for tracing distributed RPC calls.
//
// The main benefit of this function is to post-process log messages or
// intercept them in a hook to provide timing data. Trace ids and parent ids
// can also be linked to provide call tracing, if so required.
//
// Here is an example of the usage:
//
// func timedOperation(ctx Context) {
// ctx, done := WithTrace(ctx)
// defer done("this will be the log message")
// // ... function body ...
// }
//
// If the function ran for roughly 1s, such a usage would emit a log message
// as follows:
//
// INFO[0001] this will be the log message trace.duration=1.004575763s trace.func=github.com/distribution/distribution/context.traceOperation trace.id=<id> ...
//
// Notice that the function name is automatically resolved, along with the
// package and a trace id is emitted that can be linked with parent ids.
func WithTrace(ctx context.Context) (context.Context, func(format string, a ...interface{})) {
if ctx == nil {
ctx = Background()
}
pc, file, line, _ := runtime.Caller(1)
f := runtime.FuncForPC(pc)
ctx = &traced{
Context: ctx,
id: uuid.NewString(),
start: time.Now(),
parent: GetStringValue(ctx, "trace.id"),
fnname: f.Name(),
file: file,
line: line,
}
return ctx, func(format string, a ...interface{}) {
GetLogger(ctx,
"trace.duration",
"trace.id",
"trace.parent.id",
"trace.func",
"trace.file",
"trace.line").
Debugf(format, a...)
}
}
// traced represents a context that is traced for function call timing. It
// also provides fast lookup for the various attributes that are available on
// the trace.
type traced struct {
context.Context
id string
parent string
start time.Time
fnname string
file string
line int
}
func (ts *traced) Value(key interface{}) interface{} {
switch key {
case "trace.start":
return ts.start
case "trace.duration":
return time.Since(ts.start)
case "trace.id":
return ts.id
case "trace.parent.id":
if ts.parent == "" {
return nil // must return nil to signal no parent.
}
return ts.parent
case "trace.func":
return ts.fnname
case "trace.file":
return ts.file
case "trace.line":
return ts.line
}
return ts.Context.Value(key)
}

View File

@@ -0,0 +1,103 @@
package dcontext
import (
"runtime"
"testing"
"time"
)
// TestWithTrace ensures that tracing has the expected values in the context.
func TestWithTrace(t *testing.T) {
t.Parallel()
pc, file, _, _ := runtime.Caller(0) // get current caller.
f := runtime.FuncForPC(pc)
base := []valueTestCase{
{
key: "trace.id",
notnilorempty: true,
},
{
key: "trace.file",
expected: file,
notnilorempty: true,
},
{
key: "trace.line",
notnilorempty: true,
},
{
key: "trace.start",
notnilorempty: true,
},
}
ctx, done := WithTrace(Background())
defer done("this will be emitted at end of test")
tests := append(base, valueTestCase{
key: "trace.func",
expected: f.Name(),
})
for _, tc := range tests {
tc := tc
t.Run(tc.key, func(t *testing.T) {
t.Parallel()
v := ctx.Value(tc.key)
if tc.notnilorempty {
if v == nil || v == "" {
t.Fatalf("value was nil or empty: %#v", v)
}
return
}
if v != tc.expected {
t.Fatalf("unexpected value: %v != %v", v, tc.expected)
}
})
}
tracedFn := func() {
parentID := ctx.Value("trace.id") // ensure the parent trace id is correct.
pc, _, _, _ := runtime.Caller(0) // get current caller.
f := runtime.FuncForPC(pc)
ctx, done := WithTrace(ctx)
defer done("this should be subordinate to the other trace")
time.Sleep(time.Second)
tests := append(base, valueTestCase{
key: "trace.func",
expected: f.Name(),
}, valueTestCase{
key: "trace.parent.id",
expected: parentID,
})
for _, tc := range tests {
tc := tc
t.Run(tc.key, func(t *testing.T) {
t.Parallel()
v := ctx.Value(tc.key)
if tc.notnilorempty {
if v == nil || v == "" {
t.Fatalf("value was nil or empty: %#v", v)
}
return
}
if v != tc.expected {
t.Fatalf("unexpected value: %v != %v", v, tc.expected)
}
})
}
}
tracedFn()
time.Sleep(time.Second)
}
type valueTestCase struct {
key string
expected interface{}
notnilorempty bool // just check not empty/not nil
}

25
internal/dcontext/util.go Normal file
View File

@@ -0,0 +1,25 @@
package dcontext
import (
"context"
"time"
)
// Since looks up key, which should be a time.Time, and returns the duration
// since that time. If the key is not found, the value returned will be zero.
// This is helpful when inferring metrics related to context execution times.
func Since(ctx context.Context, key interface{}) time.Duration {
if startedAt, ok := ctx.Value(key).(time.Time); ok {
return time.Since(startedAt)
}
return 0
}
// GetStringValue returns a string value from the context. The empty string
// will be returned if not found.
func GetStringValue(ctx context.Context, key interface{}) (value string) {
if valuev, ok := ctx.Value(key).(string); ok {
value = valuev
}
return value
}

View File

@@ -0,0 +1,22 @@
package dcontext
import "context"
type versionKey struct{}
func (versionKey) String() string { return "version" }
// WithVersion stores the application version in the context. The new context
// gets a logger to ensure log messages are marked with the application
// version.
func WithVersion(ctx context.Context, version string) context.Context {
ctx = context.WithValue(ctx, versionKey{}, version)
// push a new logger onto the stack
return WithLogger(ctx, GetLogger(ctx, versionKey{}))
}
// GetVersion returns the application version from the context. An empty
// string may returned if the version was not set on the context.
func GetVersion(ctx context.Context) string {
return GetStringValue(ctx, versionKey{})
}

View File

@@ -0,0 +1,19 @@
package dcontext
import "testing"
func TestVersionContext(t *testing.T) {
ctx := Background()
if GetVersion(ctx) != "" {
t.Fatalf("context should not yet have a version")
}
expected := "2.1-whatever"
ctx = WithVersion(ctx, expected)
version := GetVersion(ctx)
if version != expected {
t.Fatalf("version was not set: %q != %q", version, expected)
}
}