[StructuredAuthenticationConfig] wire request context to claim resolver

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
This commit is contained in:
Anish Ramasekar 2023-06-28 20:37:40 +00:00
parent 4036b6fb41
commit 150f732c7e
No known key found for this signature in database
GPG Key ID: F1F7F3518F1ECB0C
2 changed files with 14 additions and 11 deletions

View File

@ -464,7 +464,7 @@ func (r *claimResolver) Verifier(iss string) (*oidc.IDTokenVerifier, error) {
// },
// },
// }
func (r *claimResolver) expand(c claims) error {
func (r *claimResolver) expand(ctx context.Context, c claims) error {
const (
// The claim containing a map of endpoint references per claim.
// OIDC Connect Core 1.0, section 5.6.2.
@ -516,14 +516,14 @@ func (r *claimResolver) expand(c claims) error {
// This is maybe an aggregated claim (ep.JWT != "").
return nil
}
return r.resolve(ep, c)
return r.resolve(ctx, ep, c)
}
// resolve requests distributed claims from all endpoints passed in,
// and inserts the lookup results into allClaims.
func (r *claimResolver) resolve(endpoint endpoint, allClaims claims) error {
func (r *claimResolver) resolve(ctx context.Context, endpoint endpoint, allClaims claims) error {
// TODO: cache resolved claims.
jwt, err := getClaimJWT(r.client, endpoint.URL, endpoint.AccessToken)
jwt, err := getClaimJWT(ctx, r.client, endpoint.URL, endpoint.AccessToken)
if err != nil {
return fmt.Errorf("while getting distributed claim %q: %v", r.claim, err)
}
@ -535,7 +535,7 @@ func (r *claimResolver) resolve(endpoint endpoint, allClaims claims) error {
if err != nil {
return fmt.Errorf("verifying untrusted issuer %v failed: %v", untrustedIss, err)
}
t, err := v.Verify(context.Background(), jwt)
t, err := v.Verify(ctx, jwt)
if err != nil {
return fmt.Errorf("verify distributed claim token: %v", err)
}
@ -571,7 +571,7 @@ func (a *Authenticator) AuthenticateToken(ctx context.Context, token string) (*a
return nil, false, fmt.Errorf("oidc: parse claims: %v", err)
}
if a.resolver != nil {
if err := a.resolver.expand(c); err != nil {
if err := a.resolver.expand(ctx, c); err != nil {
return nil, false, fmt.Errorf("oidc: could not expand distributed claims: %v", err)
}
}
@ -645,10 +645,7 @@ func (a *Authenticator) AuthenticateToken(ctx context.Context, token string) (*a
// token as bearer token. If the access token is "", the authorization header
// will not be set.
// TODO: Allow passing in JSON hints to the IDP.
func getClaimJWT(client *http.Client, url, accessToken string) (string, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
func getClaimJWT(ctx context.Context, client *http.Client, url, accessToken string) (string, error) {
// TODO: Allow passing request body with configurable information.
req, err := http.NewRequest("GET", url, nil)
if err != nil {

View File

@ -296,7 +296,7 @@ func (c *claimsTest) run(t *testing.T) {
t.Fatalf("serialize token: %v", err)
}
got, ok, err := a.AuthenticateToken(context.Background(), token)
got, ok, err := a.AuthenticateToken(testContext(t), token)
expectErr := len(c.wantErr) > 0
@ -1581,3 +1581,9 @@ type errTransport string
func (e errTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("%s", e)
}
func testContext(t *testing.T) context.Context {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
return ctx
}