registry/auth: pass request to AccessController

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider
2023-10-24 14:08:04 -04:00
parent 9157226e7b
commit 49e22cbf3e
8 changed files with 23 additions and 47 deletions

View File

@@ -13,7 +13,6 @@ import (
"os"
"strings"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3"
)
@@ -292,7 +291,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized handles checking whether the given request is authorized
// for actions on resources described by the given access items.
func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) {
func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (context.Context, error) {
challenge := &authChallenge{
realm: ac.realm,
autoRedirect: ac.autoRedirect,
@@ -300,11 +299,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
accessSet: newAccessSet(accessItems...),
}
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ")
if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") {
challenge.err = ErrTokenRequired
@@ -338,7 +332,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
}
}
ctx = auth.WithResources(ctx, claims.resources())
ctx := auth.WithResources(req.Context(), claims.resources())
return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil
}

View File

@@ -18,7 +18,6 @@ import (
"testing"
"time"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
@@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) {
Action: "baz",
}
ctx := dcontext.WithRequest(dcontext.Background(), req)
authCtx, err := accessController.Authorized(ctx, testAccess)
authCtx, err := accessController.Authorized(req, testAccess)
challenge, ok := err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
authCtx, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
authCtx, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@@ -564,7 +562,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
authCtx, err = accessController.Authorized(req, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}
@@ -594,7 +592,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
_, err = accessController.Authorized(ctx, testAccess)
_, err = accessController.Authorized(req, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}