diff --git a/routers/common/middleware.go b/routers/common/middleware.go index 603e18e7bd6..b750e85ad8b 100644 --- a/routers/common/middleware.go +++ b/routers/common/middleware.go @@ -147,6 +147,9 @@ func MustInitSessioner() func(next http.Handler) http.Handler { Secure: setting.SessionConfig.Secure, SameSite: setting.SessionConfig.SameSite, Domain: setting.SessionConfig.Domain, + + // in the future, if websocket is used, the websocket handler should manage its own session sync (release) + IgnoreReleaseForWebSocket: true, }) if err != nil { log.Fatal("common.Sessioner failed: %v", err) diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index 18218e92daf..4c8a379ad0b 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -118,7 +118,7 @@ func autoSignIn(ctx *context.Context) (bool, error) { ctx.SetSiteCookie(setting.CookieRememberName, nt.ID+":"+token, setting.LogInRememberDays*timeutil.Day) - if err := updateSession(ctx, nil, map[string]any{ + if err := regenerateSession(ctx, nil, map[string]any{ session.KeyUID: u.ID, session.KeyUname: u.Name, session.KeyUserHasTwoFactorAuth: userHasTwoFactorAuth, @@ -357,7 +357,7 @@ func SignInPost(ctx *context.Context) { // User will need to use WebAuthn, save data updates["totpEnrolled"] = u.ID } - if err := updateSession(ctx, nil, updates); err != nil { + if err := regenerateSession(ctx, nil, updates); err != nil { ctx.ServerError("UserSignIn: Unable to update session", err) return } @@ -398,7 +398,7 @@ func handleSignInFull(ctx *context.Context, u *user_model.User, remember bool) { return } - if err := updateSession(ctx, []string{ + if err := regenerateSession(ctx, []string{ // Delete the openid, 2fa and link_account data "openid_verified_uri", "openid_signin_remember", @@ -884,7 +884,7 @@ func handleAccountActivation(ctx *context.Context, user *user_model.User) { log.Trace("User activated: %s", user.Name) - if err := updateSession(ctx, nil, map[string]any{ + if err := regenerateSession(ctx, nil, map[string]any{ "uid": user.ID, "uname": user.Name, }); err != nil { @@ -936,7 +936,7 @@ func ActivateEmail(ctx *context.Context) { ctx.Redirect(setting.AppSubURL + "/user/settings/account") } -func updateSession(ctx *context.Context, deletes []string, updates map[string]any) error { +func regenerateSession(ctx *context.Context, deletes []string, updates map[string]any) error { if _, err := session.RegenerateSession(ctx.Resp, ctx.Req); err != nil { return fmt.Errorf("regenerate session: %w", err) } diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 1885cc5fdf7..e2adac1bdde 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -164,7 +164,12 @@ func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData return } - if err := updateSession(ctx, nil, map[string]any{ + if err := Oauth2SetLinkAccountData(ctx, *linkAccountData); err != nil { + ctx.ServerError("Oauth2SetLinkAccountData", err) + return + } + + if err := regenerateSession(ctx, nil, map[string]any{ // User needs to use 2FA, save data and redirect to 2FA page. "twofaUid": u.ID, "twofaRemember": remember, diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 7099aa55dd2..bf4c877069d 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -285,9 +285,7 @@ func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData { } func Oauth2SetLinkAccountData(ctx *context.Context, linkAccountData LinkAccountData) error { - return updateSession(ctx, nil, map[string]any{ - "linkAccountData": linkAccountData, - }) + return ctx.Session.Set("linkAccountData", linkAccountData) } func showLinkingLogin(ctx *context.Context, authSourceID int64, gothUser goth.User) { @@ -409,7 +407,7 @@ func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_m return } - if err := updateSession(ctx, nil, map[string]any{ + if err := regenerateSession(ctx, nil, map[string]any{ session.KeyUID: u.ID, session.KeyUname: u.Name, session.KeyUserHasTwoFactorAuth: userHasTwoFactorAuth, @@ -434,7 +432,7 @@ func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_m } } - if err := updateSession(ctx, nil, map[string]any{ + if err := regenerateSession(ctx, nil, map[string]any{ // User needs to use 2FA, save data and redirect to 2FA page. "twofaUid": u.ID, "twofaRemember": false, diff --git a/routers/web/auth/openid.go b/routers/web/auth/openid.go index 7d67b2fa3f7..9492462afa9 100644 --- a/routers/web/auth/openid.go +++ b/routers/web/auth/openid.go @@ -213,7 +213,7 @@ func signInOpenIDVerify(ctx *context.Context) { if u != nil { nickname = u.LowerName } - if err := updateSession(ctx, nil, map[string]any{ + if err := regenerateSession(ctx, nil, map[string]any{ "openid_verified_uri": id, "openid_determined_email": email, "openid_determined_username": nickname, diff --git a/tests/integration/auth_oauth2_test.go b/tests/integration/auth_oauth2_test.go index 7978c3bd048..d81be14272d 100644 --- a/tests/integration/auth_oauth2_test.go +++ b/tests/integration/auth_oauth2_test.go @@ -22,6 +22,7 @@ import ( "gitea.dev/services/auth/source/oauth2" "gitea.dev/tests" + "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "xorm.io/builder" @@ -489,3 +490,73 @@ func TestOAuth2GroupClaimsManualLinking(t *testing.T) { }) } } + +// TestOAuth2AutoLinkWithTwoFactor verifies that automatic account linking completes +// after the user passes local 2FA when an OIDC identity matches an existing account. +func TestOAuth2AutoLinkWithTwoFactor(t *testing.T) { + defer tests.PrepareTestEnv(t)() + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + defer test.MockVariableValue(&setting.OAuth2Client.AccountLinking, setting.OAuth2AccountLinkingAuto)() + defer test.MockVariableValue(&setting.OAuth2Client.Username, setting.OAuth2UsernameEmail)() + + const ( + sourceName = "test-oauth-auto-link-2fa" + sub = "oidc-auto-link-2fa-sub" + email = "oidc-auto-link-2fa@example.com" + userName = "oidc-auto-link-2fa" + ) + + srv := newFakeOIDCServer(t, FakeOIDCConfig{Sub: sub, Email: email, Name: "OIDC Auto Link 2FA"}) + addOAuth2Source(t, sourceName, oauth2.Source{ + Provider: "openidConnect", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + OpenIDConnectAutoDiscoveryURL: srv.URL + "/.well-known/openid-configuration", + }) + authSource, err := auth_model.GetActiveOAuth2SourceByAuthName(t.Context(), sourceName) + require.NoError(t, err) + + localUser := &user_model.User{Name: userName, Email: email} + require.NoError(t, user_model.CreateUser(t.Context(), localUser, &user_model.Meta{})) + + otpKey, err := totp.Generate(totp.GenerateOpts{ + SecretSize: 40, + Issuer: "gitea-test", + AccountName: localUser.Name, + }) + require.NoError(t, err) + + tfa := &auth_model.TwoFactor{UID: localUser.ID} + require.NoError(t, tfa.SetSecret(otpKey.Secret())) + require.NoError(t, auth_model.NewTwoFactor(t.Context(), tfa)) + + unittest.AssertNotExistsBean(t, &user_model.ExternalLoginUser{ExternalID: sub, LoginSourceID: authSource.ID}, unittest.OrderBy("external_id ASC")) + + session := emptyTestSession(t) + resp := session.MakeRequest(t, NewRequest(t, "GET", "/user/oauth2/"+sourceName), http.StatusTemporaryRedirect) + + location := resp.Header().Get("Location") + u, err := url.Parse(location) + require.NoError(t, err) + state := u.Query().Get("state") + require.NotEmpty(t, state) + + callbackURL := fmt.Sprintf("/user/oauth2/%s/callback?code=test-code&state=%s", sourceName, url.QueryEscape(state)) + resp = session.MakeRequest(t, NewRequest(t, "GET", callbackURL), http.StatusSeeOther) + assert.Contains(t, resp.Header().Get("Location"), "/user/two_factor") + + session.MakeRequest(t, NewRequest(t, "GET", "/user/two_factor"), http.StatusOK) + + passcode, err := totp.GenerateCode(otpKey.Secret(), time.Now()) + require.NoError(t, err) + + req := NewRequestWithValues(t, "POST", "/user/two_factor", map[string]string{ + "passcode": passcode, + }) + session.MakeRequest(t, req, http.StatusSeeOther) + + externalLink := unittest.AssertExistsAndLoadBean(t, &user_model.ExternalLoginUser{ExternalID: sub, LoginSourceID: authSource.ID}, unittest.OrderBy("external_id ASC")) + assert.Equal(t, localUser.ID, externalLink.UserID) + + session.MakeRequest(t, NewRequest(t, "GET", "/user/settings"), http.StatusOK) +}