From 83bdfc2a571b03f32d8a7d8f21f1fbd00f8a99d5 Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Thu, 23 Apr 2026 05:33:27 +0800 Subject: [PATCH] Support for Custom URI Schemes in OAuth2 Redirect URIs (#37356) Fix #34349 By the way, remove `(ctx *APIContext) HasAPIError() ` and `(ctx *APIContext) GetErrMsg()` because they do nothing, the error handling has been done in API's middeware The existing OAuth2 tests were not quite right, refactored them together --- custom/conf/app.example.ini | 5 + models/auth/oauth2_test.go | 52 ++--- models/fixtures/oauth2_application.yml | 2 +- models/fixtures/oauth2_authorization_code.yml | 17 +- modules/setting/oauth2.go | 1 + modules/validation/binding.go | 29 --- modules/validation/binding_test.go | 1 - modules/validation/validurllist_test.go | 157 -------------- modules/web/middleware/binding.go | 145 +++++++------ routers/api/v1/misc/markup.go | 14 -- routers/api/v1/repo/migrate.go | 5 - routers/api/v1/user/app.go | 12 +- routers/web/repo/pull_review.go | 4 +- routers/web/shared/label/label.go | 2 +- services/context/api.go | 18 -- services/forms/user_form.go | 21 +- services/forms/user_form_test.go | 35 ++- templates/swagger/v1_json.tmpl | 3 + tests/integration/api_oauth2_apps_test.go | 113 ++++------ tests/integration/oauth_test.go | 204 ++++++++++-------- tests/integration/user_settings_test.go | 12 +- 21 files changed, 340 insertions(+), 512 deletions(-) delete mode 100644 modules/validation/validurllist_test.go diff --git a/custom/conf/app.example.ini b/custom/conf/app.example.ini index ef276e4da58..2c789dd1032 100644 --- a/custom/conf/app.example.ini +++ b/custom/conf/app.example.ini @@ -592,6 +592,11 @@ ENABLED = true ;; * https://github.com/git-ecosystem/git-credential-manager ;; * https://gitea.com/gitea/tea ;DEFAULT_APPLICATIONS = git-credential-oauth, git-credential-manager, tea +;; +;; By default, OAuth2 applications can only use "http" and "https" as their redirect URI schemes. +;; If you need to use other schemes (e.g. for desktop applications), you can specify them here as a comma-separated list. +;; For example: set "my-scheme, com.example.app" to support "my-scheme://..." and "com.example.app://..." redirect URIs. +;CUSTOM_SCHEMES = ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/models/auth/oauth2_test.go b/models/auth/oauth2_test.go index 88ae065652c..d72e1cb1d52 100644 --- a/models/auth/oauth2_test.go +++ b/models/auth/oauth2_test.go @@ -12,19 +12,30 @@ import ( "code.gitea.io/gitea/modules/timeutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestOAuth2AuthorizationCodeValidity(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) +func TestOAuth2AuthorizationCode(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) t.Run("GenerateSetsValidUntil", func(t *testing.T) { grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1}) expectedValidUntil := timeutil.TimeStamp(time.Now().Unix() + 600) code, err := grant.GenerateNewAuthorizationCode(t.Context(), "http://127.0.0.1/", "", "") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedValidUntil, code.ValidUntil) assert.False(t, code.IsExpired()) + assert.Equal(t, int64(1), code.ID) + + code2, err := auth_model.GetOAuth2AuthorizationByCode(t.Context(), code.Code) + require.NoError(t, err) + assert.Equal(t, code.Code, code2.Code) + assert.NoError(t, code.Invalidate(t.Context())) + + code, err = auth_model.GetOAuth2AuthorizationByCode(t.Context(), "does not exist") + require.NoError(t, err) + require.Nil(t, code) }) t.Run("Expired", func(t *testing.T) { @@ -34,13 +45,14 @@ func TestOAuth2AuthorizationCodeValidity(t *testing.T) { assert.True(t, code.IsExpired()) }) - t.Run("InvalidateTwice", func(t *testing.T) { - code, err := auth_model.GetOAuth2AuthorizationByCode(t.Context(), "authcode") - assert.NoError(t, err) - if assert.NotNil(t, code) { - assert.NoError(t, code.Invalidate(t.Context())) - assert.ErrorIs(t, code.Invalidate(t.Context()), auth_model.ErrOAuth2AuthorizationCodeInvalidated) - } + t.Run("Invalidate", func(t *testing.T) { + grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1}) + code, err := grant.GenerateNewAuthorizationCode(t.Context(), "http://127.0.0.1/", "", "") + require.NoError(t, err) + require.NotNil(t, code) + require.NoError(t, code.Invalidate(t.Context())) + unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: code.Code}) + assert.ErrorIs(t, code.Invalidate(t.Context()), auth_model.ErrOAuth2AuthorizationCodeInvalidated) }) } @@ -224,19 +236,6 @@ func TestRevokeOAuth2Grant(t *testing.T) { //////////////////// Authorization Code -func TestGetOAuth2AuthorizationByCode(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - code, err := auth_model.GetOAuth2AuthorizationByCode(t.Context(), "authcode") - assert.NoError(t, err) - assert.NotNil(t, code) - assert.Equal(t, "authcode", code.Code) - assert.Equal(t, int64(1), code.ID) - - code, err = auth_model.GetOAuth2AuthorizationByCode(t.Context(), "does not exist") - assert.NoError(t, err) - assert.Nil(t, code) -} - func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) { // test plain code := &auth_model.OAuth2AuthorizationCode{ @@ -284,13 +283,6 @@ func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) { assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String()) } -func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - code := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"}) - assert.NoError(t, code.Invalidate(t.Context())) - unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"}) -} - func TestOAuth2AuthorizationCode_TableName(t *testing.T) { assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName()) } diff --git a/models/fixtures/oauth2_application.yml b/models/fixtures/oauth2_application.yml index 5b3b00b16e8..3426ccffac4 100644 --- a/models/fixtures/oauth2_application.yml +++ b/models/fixtures/oauth2_application.yml @@ -4,7 +4,7 @@ name: "Test" client_id: "da7da3ba-9a13-4167-856f-3899de0b0138" client_secret: "$2a$10$UYRgUSgekzBp6hYe8pAdc.cgB4Gn06QRKsORUnIYTYQADs.YR/uvi" # bcrypt of "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA= - redirect_uris: '["a", "https://example.com/xyzzy"]' + redirect_uris: '["https://example.com"]' created_unix: 1546869730 updated_unix: 1546869730 confidential_client: true diff --git a/models/fixtures/oauth2_authorization_code.yml b/models/fixtures/oauth2_authorization_code.yml index 64d8b175077..01918b35eeb 100644 --- a/models/fixtures/oauth2_authorization_code.yml +++ b/models/fixtures/oauth2_authorization_code.yml @@ -1,17 +1,2 @@ -- id: 1 - grant_id: 1 - code: "authcode" - code_challenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg" # Code Verifier: N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt - code_challenge_method: "S256" - redirect_uri: "a" - valid_until: 3546869730 - -- id: 2 - grant_id: 4 - code: "authcodepublic" - code_challenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg" # Code Verifier: N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt - code_challenge_method: "S256" - redirect_uri: "http://127.0.0.1/" - valid_until: 3546869730 - +[] # DO NOT add more test data in the fixtures, test case should prepare their own test data separately and clearly diff --git a/modules/setting/oauth2.go b/modules/setting/oauth2.go index 8e0210aa518..83891387a77 100644 --- a/modules/setting/oauth2.go +++ b/modules/setting/oauth2.go @@ -99,6 +99,7 @@ var OAuth2 = struct { JWTClaimIssuer string `ini:"JWT_CLAIM_ISSUER"` MaxTokenLength int DefaultApplications []string + CustomSchemes []string }{ Enabled: true, AccessTokenExpirationTime: 3600, diff --git a/modules/validation/binding.go b/modules/validation/binding.go index 1a830ed2ebe..c9de1be96c1 100644 --- a/modules/validation/binding.go +++ b/modules/validation/binding.go @@ -13,7 +13,6 @@ import ( "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/glob" "code.gitea.io/gitea/modules/json" - "code.gitea.io/gitea/modules/util" "gitea.com/go-chi/binding" ) @@ -51,7 +50,6 @@ func (j jsonProvider) NewEncoder(writer io.Writer) binding.JSONEncoder { func AddBindingRules() { binding.JSONProvider = jsonProvider{} addGitRefNameBindingRule() - addValidURLListBindingRule() addValidURLBindingRule() addValidSiteURLBindingRule() addGlobPatternRule() @@ -80,33 +78,6 @@ func addGitRefNameBindingRule() { }) } -func addValidURLListBindingRule() { - // URL validation rule - binding.AddRule(&binding.Rule{ - IsMatch: func(rule string) bool { - return rule == "ValidUrlList" - }, - IsValid: func(errs binding.Errors, name string, val any) (bool, binding.Errors) { - str := fmt.Sprintf("%v", val) - if len(str) == 0 { - errs.Add([]string{name}, binding.ERR_URL, "Url") - return false, errs - } - - ok := true - urls := util.SplitTrimSpace(str, "\n") - for _, u := range urls { - if !IsValidURL(u) { - ok = false - errs.Add([]string{name}, binding.ERR_URL, u) - } - } - - return ok, errs - }, - }) -} - func addValidURLBindingRule() { // URL validation rule binding.AddRule(&binding.Rule{ diff --git a/modules/validation/binding_test.go b/modules/validation/binding_test.go index 0cd328f312a..d30eb1bbb1c 100644 --- a/modules/validation/binding_test.go +++ b/modules/validation/binding_test.go @@ -27,7 +27,6 @@ type ( TestForm struct { BranchName string `form:"BranchName" binding:"GitRefName"` URL string `form:"ValidUrl" binding:"ValidUrl"` - URLs string `form:"ValidUrls" binding:"ValidUrlList"` GlobPattern string `form:"GlobPattern" binding:"GlobPattern"` RegexPattern string `form:"RegexPattern" binding:"RegexPattern"` } diff --git a/modules/validation/validurllist_test.go b/modules/validation/validurllist_test.go deleted file mode 100644 index cccc570a1a8..00000000000 --- a/modules/validation/validurllist_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package validation - -import ( - "testing" - - "gitea.com/go-chi/binding" -) - -func Test_ValidURLListValidation(t *testing.T) { - AddBindingRules() - - // This is a copy of all the URL tests cases, plus additional ones to - // account for multiple URLs - urlListValidationTestCases := []validationTestCase{ - { - description: "Empty URL", - data: TestForm{ - URLs: "", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "URL without port", - data: TestForm{ - URLs: "http://test.lan/", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "URL with port", - data: TestForm{ - URLs: "http://test.lan:3000/", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "URL with IPv6 address without port", - data: TestForm{ - URLs: "http://[::1]/", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "URL with IPv6 address with port", - data: TestForm{ - URLs: "http://[::1]:3000/", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "Invalid URL", - data: TestForm{ - URLs: "http//test.lan/", - }, - expectedErrors: binding.Errors{ - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "http//test.lan/", - }, - }, - }, - { - description: "Invalid schema", - data: TestForm{ - URLs: "ftp://test.lan/", - }, - expectedErrors: binding.Errors{ - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "ftp://test.lan/", - }, - }, - }, - { - description: "Invalid port", - data: TestForm{ - URLs: "http://test.lan:3x4/", - }, - expectedErrors: binding.Errors{ - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "http://test.lan:3x4/", - }, - }, - }, - { - description: "Invalid port with IPv6 address", - data: TestForm{ - URLs: "http://[::1]:3x4/", - }, - expectedErrors: binding.Errors{ - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "http://[::1]:3x4/", - }, - }, - }, - { - description: "Multi URLs", - data: TestForm{ - URLs: "http://test.lan:3000/\nhttp://test.local/", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "Multi URLs with newline", - data: TestForm{ - URLs: "http://test.lan:3000/\nhttp://test.local/\n", - }, - expectedErrors: binding.Errors{}, - }, - { - description: "List with invalid entry", - data: TestForm{ - URLs: "http://test.lan:3000/\nhttp://[::1]:3x4/", - }, - expectedErrors: binding.Errors{ - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "http://[::1]:3x4/", - }, - }, - }, - { - description: "List with two invalid entries", - data: TestForm{ - URLs: "ftp://test.lan:3000/\nhttp://[::1]:3x4/\n", - }, - expectedErrors: binding.Errors{ - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "ftp://test.lan:3000/", - }, - binding.Error{ - FieldNames: []string{"URLs"}, - Classification: binding.ERR_URL, - Message: "http://[::1]:3x4/", - }, - }, - }, - } - - for _, testCase := range urlListValidationTestCases { - t.Run(testCase.description, func(t *testing.T) { - performValidationTest(t, testCase) - }) - } -} diff --git a/modules/web/middleware/binding.go b/modules/web/middleware/binding.go index 05047ad3bdb..988beb47c52 100644 --- a/modules/web/middleware/binding.go +++ b/modules/web/middleware/binding.go @@ -78,82 +78,97 @@ func GetInclude(field reflect.StructField) string { return getRuleBody(field, "Include(") } -// Validate validate +func ReportValidationError(errs binding.Errors, data map[string]any, fieldName, classification, errorMsg string) binding.Errors { + errs.Add([]string{fieldName}, classification, errorMsg) + + data["HasError"] = true + data["ErrorMsg"] = fieldName + ": " + errorMsg + data["Err_"+fieldName] = true + // there is already a reported validation error, so no need to generate default error messages in Validate() + data["HasErrorFormValidation"] = true + return errs +} + func Validate(errs binding.Errors, data map[string]any, f Form, l translation.Locale) binding.Errors { - if errs.Len() == 0 { + // try to restore the form's values as much as possible, + // especially for RenderWithErrDeprecated to re-render the form with errors + AssignForm(f, data) + + if errs.Len() == 0 || data["HasErrorFormValidation"] == true { return errs } + // if HasError=true, then must set default error message + // because still a lot of places use `ctx.Data["ErrorMsg"].(string)` even if the error fields can't be found data["HasError"] = true - // If the field with name errs[0].FieldNames[0] is not found in form - // somehow, some code later on will panic on Data["ErrorMsg"].(string). - // So initialize it to some default. - data["ErrorMsg"] = l.Tr("form.unknown_error") - AssignForm(f, data) + data["ErrorMsg"] = l.TrString("form.unknown_error") typ := reflect.TypeOf(f) - if typ.Kind() == reflect.Ptr { typ = typ.Elem() } - if field, ok := typ.FieldByName(errs[0].FieldNames[0]); ok { - fieldName := field.Tag.Get("form") - if fieldName != "-" { - data["Err_"+field.Name] = true - - trName := field.Tag.Get("locale") - if len(trName) == 0 { - trName = l.TrString("form." + field.Name) - } else { - trName = l.TrString(trName) - } - - switch errs[0].Classification { - case binding.ERR_REQUIRED: - data["ErrorMsg"] = trName + l.TrString("form.require_error") - case binding.ERR_ALPHA_DASH: - data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_error") - case binding.ERR_ALPHA_DASH_DOT: - data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_dot_error") - case validation.ErrGitRefName: - data["ErrorMsg"] = trName + l.TrString("form.git_ref_name_error") - case binding.ERR_SIZE: - data["ErrorMsg"] = trName + l.TrString("form.size_error", GetSize(field)) - case binding.ERR_MIN_SIZE: - data["ErrorMsg"] = trName + l.TrString("form.min_size_error", GetMinSize(field)) - case binding.ERR_MAX_SIZE: - data["ErrorMsg"] = trName + l.TrString("form.max_size_error", GetMaxSize(field)) - case binding.ERR_EMAIL: - data["ErrorMsg"] = trName + l.TrString("form.email_error") - case binding.ERR_URL: - data["ErrorMsg"] = trName + l.TrString("form.url_error", errs[0].Message) - case binding.ERR_INCLUDE: - data["ErrorMsg"] = trName + l.TrString("form.include_error", GetInclude(field)) - case validation.ErrGlobPattern: - data["ErrorMsg"] = trName + l.TrString("form.glob_pattern_error", errs[0].Message) - case validation.ErrRegexPattern: - data["ErrorMsg"] = trName + l.TrString("form.regex_pattern_error", errs[0].Message) - case validation.ErrUsername: - data["ErrorMsg"] = trName + l.TrString("form.username_error") - case validation.ErrInvalidGroupTeamMap: - data["ErrorMsg"] = trName + l.TrString("form.invalid_group_team_map_error", errs[0].Message) - case validation.ErrInvalidBadgeSlug: - data["ErrorMsg"] = trName + l.TrString("form.invalid_slug_error") - default: - msg := errs[0].Classification - if msg != "" && errs[0].Message != "" { - msg += ": " - } - - msg += errs[0].Message - if msg == "" { - msg = l.TrString("form.unknown_error") - } - data["ErrorMsg"] = trName + ": " + msg - } - return errs - } + field, fieldExists := typ.FieldByName(errs[0].FieldNames[0]) + if !fieldExists { + return errs } + + if field.Tag.Get("form") == "-" { + return errs + } + + data["Err_"+field.Name] = true + + trName := field.Tag.Get("locale") + if len(trName) == 0 { + trName = l.TrString("form." + field.Name) + } else { + trName = l.TrString(trName) + } + + switch errs[0].Classification { + case binding.ERR_REQUIRED: + data["ErrorMsg"] = trName + l.TrString("form.require_error") + case binding.ERR_ALPHA_DASH: + data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_error") + case binding.ERR_ALPHA_DASH_DOT: + data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_dot_error") + case validation.ErrGitRefName: + data["ErrorMsg"] = trName + l.TrString("form.git_ref_name_error") + case binding.ERR_SIZE: + data["ErrorMsg"] = trName + l.TrString("form.size_error", GetSize(field)) + case binding.ERR_MIN_SIZE: + data["ErrorMsg"] = trName + l.TrString("form.min_size_error", GetMinSize(field)) + case binding.ERR_MAX_SIZE: + data["ErrorMsg"] = trName + l.TrString("form.max_size_error", GetMaxSize(field)) + case binding.ERR_EMAIL: + data["ErrorMsg"] = trName + l.TrString("form.email_error") + case binding.ERR_URL: + data["ErrorMsg"] = trName + l.TrString("form.url_error", errs[0].Message) + case binding.ERR_INCLUDE: + data["ErrorMsg"] = trName + l.TrString("form.include_error", GetInclude(field)) + case validation.ErrGlobPattern: + data["ErrorMsg"] = trName + l.TrString("form.glob_pattern_error", errs[0].Message) + case validation.ErrRegexPattern: + data["ErrorMsg"] = trName + l.TrString("form.regex_pattern_error", errs[0].Message) + case validation.ErrUsername: + data["ErrorMsg"] = trName + l.TrString("form.username_error") + case validation.ErrInvalidGroupTeamMap: + data["ErrorMsg"] = trName + l.TrString("form.invalid_group_team_map_error", errs[0].Message) + case validation.ErrInvalidBadgeSlug: + data["ErrorMsg"] = trName + l.TrString("form.invalid_slug_error") + default: + msg := errs[0].Classification + if msg != "" && errs[0].Message != "" { + msg += ": " + } + + msg += errs[0].Message + if msg == "" { + msg = l.TrString("form.unknown_error") + } + data["ErrorMsg"] = trName + ": " + msg + } + return errs } diff --git a/routers/api/v1/misc/markup.go b/routers/api/v1/misc/markup.go index 909310b4c86..f7623b9105b 100644 --- a/routers/api/v1/misc/markup.go +++ b/routers/api/v1/misc/markup.go @@ -4,8 +4,6 @@ package misc import ( - "net/http" - "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/markup/markdown" api "code.gitea.io/gitea/modules/structs" @@ -36,12 +34,6 @@ func Markup(ctx *context.APIContext) { // "$ref": "#/responses/validationError" form := web.GetForm(ctx).(*api.MarkupOption) - - if ctx.HasAPIError() { - ctx.APIError(http.StatusUnprocessableEntity, ctx.GetErrMsg()) - return - } - mode := util.Iif(form.Wiki, "wiki", form.Mode) //nolint:staticcheck // form.Wiki is deprecated common.RenderMarkup(ctx.Base, ctx.Repo, mode, form.Text, form.Context, form.FilePath) } @@ -67,12 +59,6 @@ func Markdown(ctx *context.APIContext) { // "$ref": "#/responses/validationError" form := web.GetForm(ctx).(*api.MarkdownOption) - - if ctx.HasAPIError() { - ctx.APIError(http.StatusUnprocessableEntity, ctx.GetErrMsg()) - return - } - mode := util.Iif(form.Wiki, "wiki", form.Mode) //nolint:staticcheck // form.Wiki is deprecated common.RenderMarkup(ctx.Base, ctx.Repo, mode, form.Text, form.Context, "") } diff --git a/routers/api/v1/repo/migrate.go b/routers/api/v1/repo/migrate.go index 9355177fce6..dc99cf8c162 100644 --- a/routers/api/v1/repo/migrate.go +++ b/routers/api/v1/repo/migrate.go @@ -79,11 +79,6 @@ func Migrate(ctx *context.APIContext) { return } - if ctx.HasAPIError() { - ctx.APIError(http.StatusUnprocessableEntity, ctx.GetErrMsg()) - return - } - if !ctx.Doer.IsAdmin { if !repoOwner.IsOrganization() && ctx.Doer.ID != repoOwner.ID { ctx.APIError(http.StatusForbidden, "Given user is not an organization.") diff --git a/routers/api/v1/user/app.go b/routers/api/v1/user/app.go index 6f1053e7ac9..474680adec5 100644 --- a/routers/api/v1/user/app.go +++ b/routers/api/v1/user/app.go @@ -18,6 +18,7 @@ import ( "code.gitea.io/gitea/routers/api/v1/utils" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/convert" + "code.gitea.io/gitea/services/forms" ) // ListAccessTokens list all the access tokens @@ -228,7 +229,10 @@ func CreateOauth2Application(ctx *context.APIContext) { // "$ref": "#/responses/error" data := web.GetForm(ctx).(*api.CreateOAuth2ApplicationOptions) - + if invalidURI := forms.DetectInvalidOAuth2ApplicationRedirectURI(data.RedirectURIs); invalidURI != "" { + ctx.APIError(http.StatusBadRequest, "invalid redirect URI: "+invalidURI) + return + } app, err := auth_model.CreateOAuth2Application(ctx, auth_model.CreateOAuth2ApplicationOptions{ Name: data.Name, UserID: ctx.Doer.ID, @@ -382,11 +386,17 @@ func UpdateOauth2Application(ctx *context.APIContext) { // responses: // "200": // "$ref": "#/responses/OAuth2Application" + // "400": + // "$ref": "#/responses/error" // "404": // "$ref": "#/responses/notFound" appID := ctx.PathParamInt64("id") data := web.GetForm(ctx).(*api.CreateOAuth2ApplicationOptions) + if invalidURI := forms.DetectInvalidOAuth2ApplicationRedirectURI(data.RedirectURIs); invalidURI != "" { + ctx.APIError(http.StatusBadRequest, "invalid redirect URI: "+invalidURI) + return + } app, err := auth_model.UpdateOAuth2Application(ctx, auth_model.UpdateOAuth2ApplicationOptions{ Name: data.Name, diff --git a/routers/web/repo/pull_review.go b/routers/web/repo/pull_review.go index f064058221e..eb8e8fa677e 100644 --- a/routers/web/repo/pull_review.go +++ b/routers/web/repo/pull_review.go @@ -72,7 +72,7 @@ func CreateCodeComment(ctx *context.Context) { } if ctx.HasError() { - ctx.Flash.Error(ctx.Data["ErrorMsg"].(string)) + ctx.Flash.Error(ctx.GetErrMsg()) ctx.Redirect(fmt.Sprintf("%s/pulls/%d/files", ctx.Repo.RepoLink, issue.Index)) return } @@ -230,7 +230,7 @@ func SubmitReview(ctx *context.Context) { return } if ctx.HasError() { - ctx.Flash.Error(ctx.Data["ErrorMsg"].(string)) + ctx.Flash.Error(ctx.GetErrMsg()) ctx.JSONRedirect(fmt.Sprintf("%s/pulls/%d/files", ctx.Repo.RepoLink, issue.Index)) return } diff --git a/routers/web/shared/label/label.go b/routers/web/shared/label/label.go index 6968a318c47..4a3c84e32ad 100644 --- a/routers/web/shared/label/label.go +++ b/routers/web/shared/label/label.go @@ -13,7 +13,7 @@ import ( func GetLabelEditForm(ctx *context.Context) *forms.CreateLabelForm { form := web.GetForm(ctx).(*forms.CreateLabelForm) if ctx.HasError() { - ctx.JSONError(ctx.Data["ErrorMsg"].(string)) + ctx.JSONError(ctx.GetErrMsg()) return nil } var err error diff --git a/services/context/api.go b/services/context/api.go index b49bf9b42c6..7f1429a89a3 100644 --- a/services/context/api.go +++ b/services/context/api.go @@ -322,24 +322,6 @@ func RepoRefForAPI(next http.Handler) http.Handler { }) } -// HasAPIError returns true if error occurs in form validation. -func (ctx *APIContext) HasAPIError() bool { - hasErr, ok := ctx.Data["HasError"] - if !ok { - return false - } - return hasErr.(bool) -} - -// GetErrMsg returns error message in form validation. -func (ctx *APIContext) GetErrMsg() string { - msg, _ := ctx.Data["ErrorMsg"].(string) - if msg == "" { - msg = "invalid form data" - } - return msg -} - // NotFoundOrServerError use error check function to determine if the error // is about not found. It responds with 404 status code for not found error, // or error context description for logging purpose of 500 server error. diff --git a/services/forms/user_form.go b/services/forms/user_form.go index cc514a2e279..3f65e8c551f 100644 --- a/services/forms/user_form.go +++ b/services/forms/user_form.go @@ -7,9 +7,13 @@ package forms import ( "mime/multipart" "net/http" + "strings" user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/structs" + "code.gitea.io/gitea/modules/util" + "code.gitea.io/gitea/modules/validation" "code.gitea.io/gitea/modules/web/middleware" "code.gitea.io/gitea/services/context" @@ -356,14 +360,29 @@ func (f *NewAccessTokenForm) Validate(req *http.Request, errs binding.Errors) bi // EditOAuth2ApplicationForm form for editing oauth2 applications type EditOAuth2ApplicationForm struct { Name string `binding:"Required;MaxSize(255)" form:"application_name"` - RedirectURIs string `binding:"Required;ValidUrlList" form:"redirect_uris"` + RedirectURIs string `binding:"Required" form:"redirect_uris"` ConfidentialClient bool `form:"confidential_client"` SkipSecondaryAuthorization bool `form:"skip_secondary_authorization"` } +func DetectInvalidOAuth2ApplicationRedirectURI(uris []string) (invalidURL string) { + for _, u := range uris { + scheme, _, ok := strings.Cut(u, ":") + valid := ok && (validation.IsValidURL(u) || util.SliceContainsString(setting.OAuth2.CustomSchemes, scheme)) + if !valid { + return u + } + } + return "" +} + // Validate validates the fields func (f *EditOAuth2ApplicationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { ctx := context.GetValidateContext(req) + invalidURI := DetectInvalidOAuth2ApplicationRedirectURI(util.SplitTrimSpace(f.RedirectURIs, "\n")) + if invalidURI != "" { + errs = middleware.ReportValidationError(errs, ctx.Data, "RedirectURIs", binding.ERR_URL, ctx.Locale.TrString("form.url_error", invalidURI)) + } return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/user_form_test.go b/services/forms/user_form_test.go index 4246f955b3e..c082f40294b 100644 --- a/services/forms/user_form_test.go +++ b/services/forms/user_form_test.go @@ -14,15 +14,9 @@ import ( ) func TestRegisterForm_IsDomainAllowed_Empty(t *testing.T) { - oldService := setting.Service - defer func() { - setting.Service = oldService - }() - + defer test.MockVariableValue(&setting.Service)() setting.Service.EmailDomainAllowList = nil - form := RegisterForm{} - assert.True(t, form.IsEmailDomainAllowed()) } @@ -87,3 +81,30 @@ func TestRegisterForm_IsDomainAllowed_BlockedEmail(t *testing.T) { assert.Equal(t, v.valid, form.IsEmailDomainAllowed()) } } + +func TestDetectInvalidOAuth2ApplicationRedirectURI(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.CustomSchemes)() + setting.OAuth2.CustomSchemes = []string{"my-app"} + assertValid := func(t *testing.T, s string, valid bool) { + ret := DetectInvalidOAuth2ApplicationRedirectURI([]string{s}) + if valid { + assert.Empty(t, ret) + } else { + assert.Equal(t, s, ret) + } + } + assertValid(t, "my-app:", true) + assertValid(t, "my-app:/foo", true) + assertValid(t, "http://foo", true) + assertValid(t, "https://foo", true) + + assertValid(t, "my-app", false) + assertValid(t, "ftp:", false) + assertValid(t, "ftp://foo", false) + assertValid(t, "https://[invalid", false) + + ret := DetectInvalidOAuth2ApplicationRedirectURI([]string{"my-app:", "http://foo", "https://foo"}) + assert.Empty(t, ret) + ret = DetectInvalidOAuth2ApplicationRedirectURI([]string{"my-app:", "http://foo", "invalid", "https://foo"}) + assert.Equal(t, "invalid", ret) +} diff --git a/templates/swagger/v1_json.tmpl b/templates/swagger/v1_json.tmpl index 29075a8021f..2748179cac6 100644 --- a/templates/swagger/v1_json.tmpl +++ b/templates/swagger/v1_json.tmpl @@ -19354,6 +19354,9 @@ "200": { "$ref": "#/responses/OAuth2Application" }, + "400": { + "$ref": "#/responses/error" + }, "404": { "$ref": "#/responses/notFound" } diff --git a/tests/integration/api_oauth2_apps_test.go b/tests/integration/api_oauth2_apps_test.go index a0f1a2cb66b..32b732b2a83 100644 --- a/tests/integration/api_oauth2_apps_test.go +++ b/tests/integration/api_oauth2_apps_test.go @@ -11,10 +11,13 @@ import ( auth_model "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/setting" api "code.gitea.io/gitea/modules/structs" + "code.gitea.io/gitea/modules/test" "code.gitea.io/gitea/tests" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOAuth2Application(t *testing.T) { @@ -28,18 +31,17 @@ func TestOAuth2Application(t *testing.T) { func testAPICreateOAuth2Application(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) - appBody := api.CreateOAuth2ApplicationOptions{ - Name: "test-app-1", - RedirectURIs: []string{ - "http://www.google.com", - }, - ConfidentialClient: true, - } + redirectURIs := []string{"http://www.google.com", "my-app:foo"} + appBody := api.CreateOAuth2ApplicationOptions{Name: "test-app-1", RedirectURIs: redirectURIs, ConfidentialClient: true} - req := NewRequestWithJSON(t, "POST", "/api/v1/user/applications/oauth2", &appBody). - AddBasicAuth(user.Name) + // no custom scheme + req := NewRequestWithJSON(t, "POST", "/api/v1/user/applications/oauth2", &appBody).AddBasicAuth(user.Name) + MakeRequest(t, req, http.StatusBadRequest) + + // with custom scheme + defer test.MockVariableValue(&setting.OAuth2.CustomSchemes, []string{"my-app"})() + req = NewRequestWithJSON(t, "POST", "/api/v1/user/applications/oauth2", &appBody).AddBasicAuth(user.Name) resp := MakeRequest(t, req, http.StatusCreated) - createdApp := DecodeJSON(t, resp, &api.OAuth2Application{}) assert.Equal(t, appBody.Name, createdApp.Name) @@ -47,7 +49,7 @@ func testAPICreateOAuth2Application(t *testing.T) { assert.Len(t, createdApp.ClientID, 36) assert.True(t, createdApp.ConfidentialClient) assert.NotEmpty(t, createdApp.Created) - assert.Equal(t, appBody.RedirectURIs[0], createdApp.RedirectURIs[0]) + assert.Equal(t, redirectURIs, createdApp.RedirectURIs) unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{UID: user.ID, Name: createdApp.Name}) } @@ -56,21 +58,13 @@ func testAPIListOAuth2Applications(t *testing.T) { session := loginUser(t, user.Name) token := getTokenForLoggedInUser(t, session, auth_model.AccessTokenScopeReadUser) - existApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ - UID: user.ID, - Name: "test-app-1", - RedirectURIs: []string{ - "http://www.google.com", - }, - ConfidentialClient: true, - }) + existApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{UID: user.ID, Name: "test-app-1", ConfidentialClient: true}) + require.NotEmpty(t, existApp.RedirectURIs) - req := NewRequest(t, "GET", "/api/v1/user/applications/oauth2"). - AddTokenAuth(token) + req := NewRequest(t, "GET", "/api/v1/user/applications/oauth2").AddTokenAuth(token) resp := MakeRequest(t, req, http.StatusOK) - var appList api.OAuth2ApplicationList - DecodeJSON(t, resp, &appList) + appList := DecodeJSON(t, resp, api.OAuth2ApplicationList{}) expectedApp := appList[0] assert.Equal(t, expectedApp.Name, existApp.Name) @@ -78,7 +72,7 @@ func testAPIListOAuth2Applications(t *testing.T) { assert.Equal(t, expectedApp.ConfidentialClient, existApp.ConfidentialClient) assert.Len(t, expectedApp.ClientID, 36) assert.Empty(t, expectedApp.ClientSecret) - assert.Equal(t, existApp.RedirectURIs[0], expectedApp.RedirectURIs[0]) + assert.Equal(t, expectedApp.RedirectURIs, existApp.RedirectURIs) unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) } @@ -87,21 +81,16 @@ func testAPIDeleteOAuth2Application(t *testing.T) { session := loginUser(t, user.Name) token := getTokenForLoggedInUser(t, session, auth_model.AccessTokenScopeWriteUser) - oldApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ - UID: user.ID, - Name: "test-app-1", - }) + oldApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{UID: user.ID, Name: "test-app-1"}) urlStr := fmt.Sprintf("/api/v1/user/applications/oauth2/%d", oldApp.ID) - req := NewRequest(t, "DELETE", urlStr). - AddTokenAuth(token) + req := NewRequest(t, "DELETE", urlStr).AddTokenAuth(token) MakeRequest(t, req, http.StatusNoContent) unittest.AssertNotExistsBean(t, &auth_model.OAuth2Application{UID: oldApp.UID, Name: oldApp.Name}) // Delete again will return not found - req = NewRequest(t, "DELETE", urlStr). - AddTokenAuth(token) + req = NewRequest(t, "DELETE", urlStr).AddTokenAuth(token) MakeRequest(t, req, http.StatusNotFound) } @@ -110,65 +99,41 @@ func testAPIGetOAuth2Application(t *testing.T) { session := loginUser(t, user.Name) token := getTokenForLoggedInUser(t, session, auth_model.AccessTokenScopeReadUser) - existApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ - UID: user.ID, - Name: "test-app-1", - RedirectURIs: []string{ - "http://www.google.com", - }, - ConfidentialClient: true, - }) + existApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{UID: user.ID, Name: "test-app-1", ConfidentialClient: true}) + require.NotEmpty(t, existApp.RedirectURIs) - req := NewRequest(t, "GET", fmt.Sprintf("/api/v1/user/applications/oauth2/%d", existApp.ID)). - AddTokenAuth(token) + req := NewRequest(t, "GET", fmt.Sprintf("/api/v1/user/applications/oauth2/%d", existApp.ID)).AddTokenAuth(token) resp := MakeRequest(t, req, http.StatusOK) - - var app api.OAuth2Application - DecodeJSON(t, resp, &app) - expectedApp := app + expectedApp := DecodeJSON(t, resp, &api.OAuth2Application{}) assert.Equal(t, expectedApp.Name, existApp.Name) assert.Equal(t, expectedApp.ClientID, existApp.ClientID) assert.Equal(t, expectedApp.ConfidentialClient, existApp.ConfidentialClient) assert.Len(t, expectedApp.ClientID, 36) assert.Empty(t, expectedApp.ClientSecret) - assert.Len(t, expectedApp.RedirectURIs, 1) - assert.Equal(t, expectedApp.RedirectURIs[0], existApp.RedirectURIs[0]) + assert.Equal(t, expectedApp.RedirectURIs, existApp.RedirectURIs) unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) } func testAPIUpdateOAuth2Application(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) - existApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ - UID: user.ID, - Name: "test-app-1", - RedirectURIs: []string{ - "http://www.google.com", - }, - }) - - appBody := api.CreateOAuth2ApplicationOptions{ - Name: "test-app-1", - RedirectURIs: []string{ - "http://www.google.com/", - "http://www.github.com/", - }, - ConfidentialClient: true, - } - + existApp := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{UID: user.ID, Name: "test-app-1"}) + redirectURIs := []string{"https://www.google.com", "my-app:foo"} + appBody := api.CreateOAuth2ApplicationOptions{Name: "test-app-1", RedirectURIs: redirectURIs, ConfidentialClient: true} urlStr := fmt.Sprintf("/api/v1/user/applications/oauth2/%d", existApp.ID) - req := NewRequestWithJSON(t, "PATCH", urlStr, &appBody). - AddBasicAuth(user.Name) + + // no custom scheme + req := NewRequestWithJSON(t, "PATCH", urlStr, &appBody).AddBasicAuth(user.Name) + MakeRequest(t, req, http.StatusBadRequest) + + // with custom scheme + defer test.MockVariableValue(&setting.OAuth2.CustomSchemes, []string{"my-app"})() + req = NewRequestWithJSON(t, "PATCH", urlStr, &appBody).AddBasicAuth(user.Name) resp := MakeRequest(t, req, http.StatusOK) - var app api.OAuth2Application - DecodeJSON(t, resp, &app) - expectedApp := app - - assert.Len(t, expectedApp.RedirectURIs, 2) - assert.Equal(t, expectedApp.RedirectURIs[0], appBody.RedirectURIs[0]) - assert.Equal(t, expectedApp.RedirectURIs[1], appBody.RedirectURIs[1]) + expectedApp := DecodeJSON(t, resp, &api.OAuth2Application{}) + assert.Equal(t, expectedApp.RedirectURIs, appBody.RedirectURIs) assert.Equal(t, expectedApp.ConfidentialClient, appBody.ConfidentialClient) unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) } diff --git a/tests/integration/oauth_test.go b/tests/integration/oauth_test.go index 2b10f5f61d2..7090c4c2389 100644 --- a/tests/integration/oauth_test.go +++ b/tests/integration/oauth_test.go @@ -23,6 +23,7 @@ import ( "code.gitea.io/gitea/modules/setting" api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/test" + "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/auth/source/oauth2" "code.gitea.io/gitea/services/oauth2_provider" @@ -35,17 +36,57 @@ import ( "github.com/stretchr/testify/require" ) -func TestOAuth2Provider(t *testing.T) { +func testOAuth2PrepareTestCode(t *testing.T) { + require.NoError(t, db.TruncateBeans(t.Context(), &auth_model.OAuth2AuthorizationCode{})) + err := db.Insert(t.Context(), &auth_model.OAuth2AuthorizationCode{ + GrantID: 1, + Code: "authcode", + CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", // Code Verifier: N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt + CodeChallengeMethod: "S256", + RedirectURI: "https://example.com", + ValidUntil: timeutil.TimeStampNow() + 86400, + }, &auth_model.OAuth2AuthorizationCode{ + GrantID: 4, + Code: "authcodepublic", + CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", //# Code Verifier: N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt + CodeChallengeMethod: "S256", + RedirectURI: "http://127.0.0.1/", + ValidUntil: timeutil.TimeStampNow() + 86400, + }) + require.NoError(t, err) +} + +func TestOAuth2(t *testing.T) { defer tests.PrepareTestEnv(t)() - t.Run("AuthorizeNoClientID", testAuthorizeNoClientID) - t.Run("AuthorizeUnregisteredRedirect", testAuthorizeUnregisteredRedirect) - t.Run("AuthorizeUnsupportedResponseType", testAuthorizeUnsupportedResponseType) - t.Run("AuthorizeUnsupportedCodeChallengeMethod", testAuthorizeUnsupportedCodeChallengeMethod) - t.Run("AuthorizeLoginRedirect", testAuthorizeLoginRedirect) - - t.Run("OAuth2WellKnown", testOAuth2WellKnown) - t.Run("OAuthSourceSpecialChars", testOAuthSourceSpecialChars) + t.Run("Provider", func(t *testing.T) { + t.Run("AuthorizeNoClientID", testAuthorizeNoClientID) + t.Run("AuthorizeUnregisteredRedirect", testAuthorizeUnregisteredRedirect) + t.Run("AuthorizeUnsupportedResponseType", testAuthorizeUnsupportedResponseType) + t.Run("AuthorizeUnsupportedCodeChallengeMethod", testAuthorizeUnsupportedCodeChallengeMethod) + t.Run("AuthorizeLoginRedirect", testAuthorizeLoginRedirect) + t.Run("AuthorizeShow", testAuthorizeShow) + t.Run("AuthorizeGrantS256RequiresVerifier", testAuthorizeGrantS256RequiresVerifier) + t.Run("AuthorizeRedirectWithExistingGrant", testAuthorizeRedirectWithExistingGrant) + t.Run("AuthorizePKCERequiredForPublicClient", testAuthorizePKCERequiredForPublicClient) + t.Run("AccessTokenExchange", testAccessTokenExchange) + t.Run("AccessTokenExchangeWithPublicClient", testAccessTokenExchangeWithPublicClient) + t.Run("AccessTokenExchangeJSON", testAccessTokenExchangeJSON) + t.Run("AccessTokenExchangeWithoutPKCE", testAccessTokenExchangeWithoutPKCE) + t.Run("AccessTokenExchangeWithInvalidCredentials", testAccessTokenExchangeWithInvalidCredentials) + t.Run("AccessTokenExchangeWithBasicAuth", testAccessTokenExchangeWithBasicAuth) + t.Run("RefreshTokenInvalidation", testRefreshTokenInvalidation) + t.Run("OAuthIntrospection", testOAuthIntrospection) + t.Run("OAuthGrantScopesReadUserFailRepos", testOAuthGrantScopesReadUserFailRepos) + t.Run("OAuthGrantScopesReadRepositoryFailOrganization", testOAuthGrantScopesReadRepositoryFailOrganization) + t.Run("OAuthGrantScopesClaimPublicOnlyGroups", testOAuthGrantScopesClaimPublicOnlyGroups) + t.Run("OAuthGrantScopesClaimAllGroups", testOAuthGrantScopesClaimAllGroups) + t.Run("OAuth2WellKnown", testOAuth2WellKnown) + }) + t.Run("Client", func(t *testing.T) { + t.Run("OAuthSourceSpecialChars", testOAuthSourceSpecialChars) + t.Run("SignInOauthCallbackSyncSSHKeys", testSignInOauthCallbackSyncSSHKeys) + }) // TODO: move more tests as sub-tests here, avoid unnecessary PrepareTestEnv } @@ -64,7 +105,7 @@ func testAuthorizeUnregisteredRedirect(t *testing.T) { } func testAuthorizeUnsupportedResponseType(t *testing.T) { - req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=a&response_type=UNEXPECTED&state=thestate") + req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=https://example.com&response_type=UNEXPECTED&state=thestate") ctx := loginUser(t, "user1") resp := ctx.MakeRequest(t, req, http.StatusSeeOther) u, err := resp.Result().Location() @@ -74,7 +115,7 @@ func testAuthorizeUnsupportedResponseType(t *testing.T) { } func testAuthorizeUnsupportedCodeChallengeMethod(t *testing.T) { - req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=a&response_type=code&state=thestate&code_challenge_method=UNEXPECTED") + req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=https://example.com&response_type=code&state=thestate&code_challenge_method=UNEXPECTED") ctx := loginUser(t, "user1") resp := ctx.MakeRequest(t, req, http.StatusSeeOther) u, err := resp.Result().Location() @@ -88,9 +129,8 @@ func testAuthorizeLoginRedirect(t *testing.T) { assert.Contains(t, MakeRequest(t, req, http.StatusSeeOther).Body.String(), "/user/login") } -func TestAuthorizeShow(t *testing.T) { - defer tests.PrepareTestEnv(t)() - req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=a&response_type=code&state=thestate") +func testAuthorizeShow(t *testing.T) { + req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=https://example.com&response_type=code&state=thestate") ctx := loginUser(t, "user4") resp := ctx.MakeRequest(t, req, http.StatusOK) @@ -98,11 +138,10 @@ func TestAuthorizeShow(t *testing.T) { AssertHTMLElement(t, htmlDoc, "#authorize-app", true) } -func TestAuthorizeGrantS256RequiresVerifier(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAuthorizeGrantS256RequiresVerifier(t *testing.T) { ctx := loginUser(t, "user4") codeChallenge := "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg" - req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=a&response_type=code&state=thestate&code_challenge_method=S256&code_challenge="+url.QueryEscape(codeChallenge)) + req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=https://example.com&response_type=code&state=thestate&code_challenge_method=S256&code_challenge="+url.QueryEscape(codeChallenge)) resp := ctx.MakeRequest(t, req, http.StatusOK) htmlDoc := NewHTMLParser(t, resp.Body) @@ -113,7 +152,7 @@ func TestAuthorizeGrantS256RequiresVerifier(t *testing.T) { "state": "thestate", "scope": "", "nonce": "", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "granted": "true", }) grantResp := ctx.MakeRequest(t, grantReq, http.StatusSeeOther) @@ -126,7 +165,7 @@ func TestAuthorizeGrantS256RequiresVerifier(t *testing.T) { "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": code, }) accessResp := MakeRequest(t, accessReq, http.StatusBadRequest) @@ -136,9 +175,8 @@ func TestAuthorizeGrantS256RequiresVerifier(t *testing.T) { assert.Equal(t, "failed PKCE code challenge", parsedError.ErrorDescription) } -func TestAuthorizeRedirectWithExistingGrant(t *testing.T) { - defer tests.PrepareTestEnv(t)() - req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=https%3A%2F%2Fexample.com%2Fxyzzy&response_type=code&state=thestate") +func testAuthorizeRedirectWithExistingGrant(t *testing.T) { + req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=da7da3ba-9a13-4167-856f-3899de0b0138&redirect_uri=https://example.com/&response_type=code&state=thestate") ctx := loginUser(t, "user1") resp := ctx.MakeRequest(t, req, http.StatusSeeOther) u, err := resp.Result().Location() @@ -146,11 +184,11 @@ func TestAuthorizeRedirectWithExistingGrant(t *testing.T) { assert.Equal(t, "thestate", u.Query().Get("state")) assert.Greaterf(t, len(u.Query().Get("code")), 30, "authorization code '%s' should be longer then 30", u.Query().Get("code")) u.RawQuery = "" - assert.Equal(t, "https://example.com/xyzzy", u.String()) + assert.Equal(t, "https://example.com/", u.String()) } -func TestAuthorizePKCERequiredForPublicClient(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAuthorizePKCERequiredForPublicClient(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequest(t, "GET", "/login/oauth/authorize?client_id=ce5a1322-42a7-11ed-b878-0242ac120002&redirect_uri=http%3A%2F%2F127.0.0.1&response_type=code&state=thestate") ctx := loginUser(t, "user1") resp := ctx.MakeRequest(t, req, http.StatusSeeOther) @@ -160,13 +198,13 @@ func TestAuthorizePKCERequiredForPublicClient(t *testing.T) { assert.Equal(t, "PKCE is required for public clients", u.Query().Get("error_description")) } -func TestAccessTokenExchange(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAccessTokenExchange(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -184,8 +222,8 @@ func TestAccessTokenExchange(t *testing.T) { assert.Greater(t, len(parsed.RefreshToken), 10) } -func TestAccessTokenExchangeWithPublicClient(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAccessTokenExchangeWithPublicClient(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "ce5a1322-42a7-11ed-b878-0242ac120002", @@ -207,13 +245,13 @@ func TestAccessTokenExchangeWithPublicClient(t *testing.T) { assert.Greater(t, len(parsed.RefreshToken), 10) } -func TestAccessTokenExchangeJSON(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAccessTokenExchangeJSON(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithJSON(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -231,13 +269,13 @@ func TestAccessTokenExchangeJSON(t *testing.T) { assert.Greater(t, len(parsed.RefreshToken), 10) } -func TestAccessTokenExchangeWithoutPKCE(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAccessTokenExchangeWithoutPKCE(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", }) resp := MakeRequest(t, req, http.StatusBadRequest) @@ -247,14 +285,14 @@ func TestAccessTokenExchangeWithoutPKCE(t *testing.T) { assert.Equal(t, "failed PKCE code challenge", parsedError.ErrorDescription) } -func TestAccessTokenExchangeWithInvalidCredentials(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAccessTokenExchangeWithInvalidCredentials(t *testing.T) { + testOAuth2PrepareTestCode(t) // invalid client id req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "???", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -269,7 +307,7 @@ func TestAccessTokenExchangeWithInvalidCredentials(t *testing.T) { "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "???", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -299,7 +337,7 @@ func TestAccessTokenExchangeWithInvalidCredentials(t *testing.T) { "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "???", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -314,7 +352,7 @@ func TestAccessTokenExchangeWithInvalidCredentials(t *testing.T) { "grant_type": "???", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -325,11 +363,11 @@ func TestAccessTokenExchangeWithInvalidCredentials(t *testing.T) { assert.Equal(t, "Only refresh_token or authorization_code grant type is supported", parsedError.ErrorDescription) } -func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testAccessTokenExchangeWithBasicAuth(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -350,7 +388,7 @@ func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { // use wrong client_secret req = NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -364,7 +402,7 @@ func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { // missing header req = NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -377,7 +415,7 @@ func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { // client_id inconsistent with Authorization header req = NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "client_id": "inconsistent", }) @@ -391,7 +429,7 @@ func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { // client_secret inconsistent with Authorization header req = NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "client_secret": "inconsistent", }) @@ -403,13 +441,13 @@ func TestAccessTokenExchangeWithBasicAuth(t *testing.T) { assert.Equal(t, "client_secret in request body inconsistent with Authorization header", parsedError.ErrorDescription) } -func TestRefreshTokenInvalidation(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testRefreshTokenInvalidation(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -431,7 +469,7 @@ func TestRefreshTokenInvalidation(t *testing.T) { "grant_type": "refresh_token", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", // omit secret - "redirect_uri": "a", + "redirect_uri": "https://example.com", "refresh_token": parsed.RefreshToken, }) resp = MakeRequest(t, req, http.StatusBadRequest) @@ -444,7 +482,7 @@ func TestRefreshTokenInvalidation(t *testing.T) { "grant_type": "refresh_token", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "refresh_token": "UNEXPECTED", }) resp = MakeRequest(t, req, http.StatusBadRequest) @@ -457,7 +495,7 @@ func TestRefreshTokenInvalidation(t *testing.T) { "grant_type": "refresh_token", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "refresh_token": parsed.RefreshToken, }) @@ -484,13 +522,13 @@ func TestRefreshTokenInvalidation(t *testing.T) { assert.Equal(t, "token was already used", parsedError.ErrorDescription) } -func TestOAuthIntrospection(t *testing.T) { - defer tests.PrepareTestEnv(t)() +func testOAuthIntrospection(t *testing.T) { + testOAuth2PrepareTestCode(t) req := NewRequestWithValues(t, "POST", "/login/oauth/access_token", map[string]string{ "grant_type": "authorization_code", "client_id": "da7da3ba-9a13-4167-856f-3899de0b0138", "client_secret": "4MK8Na6R55smdCY0WuCCumZ6hjRPnGY5saWVRHHjJiA=", - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": "authcode", "code_verifier": "N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt", }) @@ -542,14 +580,12 @@ func TestOAuthIntrospection(t *testing.T) { assert.Contains(t, resp.Body.String(), "no valid authorization") } -func TestOAuth_GrantScopesReadUserFailRepos(t *testing.T) { - defer tests.PrepareTestEnv(t)() - +func testOAuthGrantScopesReadUserFailRepos(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) appBody := api.CreateOAuth2ApplicationOptions{ Name: "oauth-provider-scopes-test", RedirectURIs: []string{ - "a", + "https://example.com", }, ConfidentialClient: true, } @@ -573,7 +609,7 @@ func TestOAuth_GrantScopesReadUserFailRepos(t *testing.T) { ctx := loginUser(t, user.Name) - authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=a&response_type=code&state=thestate", app.ClientID) + authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=https://example.com&response_type=code&state=thestate", app.ClientID) authorizeReq := NewRequest(t, "GET", authorizeURL) authorizeResp := ctx.MakeRequest(t, authorizeReq, http.StatusSeeOther) @@ -583,7 +619,7 @@ func TestOAuth_GrantScopesReadUserFailRepos(t *testing.T) { "grant_type": "authorization_code", "client_id": app.ClientID, "client_secret": app.ClientSecret, - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": authcode, }) accessTokenResp := ctx.MakeRequest(t, accessTokenReq, 200) @@ -622,14 +658,12 @@ func TestOAuth_GrantScopesReadUserFailRepos(t *testing.T) { assert.Contains(t, errorParsed.Message, "token does not have at least one of required scope(s), required=[read:repository]") } -func TestOAuth_GrantScopesReadRepositoryFailOrganization(t *testing.T) { - defer tests.PrepareTestEnv(t)() - +func testOAuthGrantScopesReadRepositoryFailOrganization(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) appBody := api.CreateOAuth2ApplicationOptions{ Name: "oauth-provider-scopes-test", RedirectURIs: []string{ - "a", + "https://example.com", }, ConfidentialClient: true, } @@ -653,7 +687,7 @@ func TestOAuth_GrantScopesReadRepositoryFailOrganization(t *testing.T) { ctx := loginUser(t, user.Name) - authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=a&response_type=code&state=thestate", app.ClientID) + authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=https://example.com&response_type=code&state=thestate", app.ClientID) authorizeReq := NewRequest(t, "GET", authorizeURL) authorizeResp := ctx.MakeRequest(t, authorizeReq, http.StatusSeeOther) @@ -662,7 +696,7 @@ func TestOAuth_GrantScopesReadRepositoryFailOrganization(t *testing.T) { "grant_type": "authorization_code", "client_id": app.ClientID, "client_secret": app.ClientSecret, - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": authcode, }) accessTokenResp := ctx.MakeRequest(t, accessTokenReq, http.StatusOK) @@ -760,15 +794,13 @@ func TestOAuth_GrantScopesReadRepositoryFailOrganization(t *testing.T) { assert.Contains(t, errorParsed.Message, "token does not have at least one of required scope(s), required=[read:user read:organization]") } -func TestOAuth_GrantScopesClaimPublicOnlyGroups(t *testing.T) { - defer tests.PrepareTestEnv(t)() - +func testOAuthGrantScopesClaimPublicOnlyGroups(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: "user2"}) appBody := api.CreateOAuth2ApplicationOptions{ Name: "oauth-provider-scopes-test", RedirectURIs: []string{ - "a", + "https://example.com", }, ConfidentialClient: true, } @@ -792,7 +824,7 @@ func TestOAuth_GrantScopesClaimPublicOnlyGroups(t *testing.T) { ctx := loginUser(t, user.Name) - authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=a&response_type=code&state=thestate", app.ClientID) + authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=https://example.com&response_type=code&state=thestate", app.ClientID) authorizeReq := NewRequest(t, "GET", authorizeURL) authorizeResp := ctx.MakeRequest(t, authorizeReq, http.StatusSeeOther) @@ -802,7 +834,7 @@ func TestOAuth_GrantScopesClaimPublicOnlyGroups(t *testing.T) { "grant_type": "authorization_code", "client_id": app.ClientID, "client_secret": app.ClientSecret, - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": authcode, }) accessTokenResp := ctx.MakeRequest(t, accessTokenReq, http.StatusOK) @@ -860,15 +892,13 @@ func TestOAuth_GrantScopesClaimPublicOnlyGroups(t *testing.T) { } } -func TestOAuth_GrantScopesClaimAllGroups(t *testing.T) { - defer tests.PrepareTestEnv(t)() - +func testOAuthGrantScopesClaimAllGroups(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: "user2"}) appBody := api.CreateOAuth2ApplicationOptions{ Name: "oauth-provider-scopes-test", RedirectURIs: []string{ - "a", + "https://example.com", }, ConfidentialClient: true, } @@ -892,7 +922,7 @@ func TestOAuth_GrantScopesClaimAllGroups(t *testing.T) { ctx := loginUser(t, user.Name) - authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=a&response_type=code&state=thestate", app.ClientID) + authorizeURL := fmt.Sprintf("/login/oauth/authorize?client_id=%s&redirect_uri=https://example.com&response_type=code&state=thestate", app.ClientID) authorizeReq := NewRequest(t, "GET", authorizeURL) authorizeResp := ctx.MakeRequest(t, authorizeReq, http.StatusSeeOther) @@ -902,7 +932,7 @@ func TestOAuth_GrantScopesClaimAllGroups(t *testing.T) { "grant_type": "authorization_code", "client_id": app.ClientID, "client_secret": app.ClientSecret, - "redirect_uri": "a", + "redirect_uri": "https://example.com", "code": authcode, }) accessTokenResp := ctx.MakeRequest(t, accessTokenReq, http.StatusOK) @@ -998,7 +1028,7 @@ func addOAuth2Source(t *testing.T, authName string, cfg oauth2.Source) { require.NoError(t, err) } -func createMockServer() *httptest.Server { +func createOAuth2MockProvider() *httptest.Server { var mockServer *httptest.Server mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -1017,10 +1047,8 @@ func createMockServer() *httptest.Server { return mockServer } -func TestSignInOauthCallbackSyncSSHKeys(t *testing.T) { - defer tests.PrepareTestEnv(t)() - - mockServer := createMockServer() +func testSignInOauthCallbackSyncSSHKeys(t *testing.T) { + mockServer := createOAuth2MockProvider() defer mockServer.Close() ctx := t.Context() @@ -1100,7 +1128,7 @@ func TestSignInOauthCallbackSyncSSHKeys(t *testing.T) { // Checks if an OAuth provider with spaces within the name does work, // with the encoding of its names in the URL (PR#37327) func testOAuthSourceSpecialChars(t *testing.T) { - mockServer := createMockServer() + mockServer := createOAuth2MockProvider() defer mockServer.Close() addOAuth2Source(t, "test space", oauth2.Source{ diff --git a/tests/integration/user_settings_test.go b/tests/integration/user_settings_test.go index b3527dd467a..b0e23b6e7ff 100644 --- a/tests/integration/user_settings_test.go +++ b/tests/integration/user_settings_test.go @@ -10,6 +10,7 @@ import ( "code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/test" "code.gitea.io/gitea/tests" "github.com/stretchr/testify/assert" @@ -311,18 +312,25 @@ func TestUserSettingsApplications(t *testing.T) { resp := session.MakeRequest(t, req, http.StatusOK) doc := NewHTMLParser(t, resp.Body) msg := strings.TrimSpace(doc.Find(".ui.message.flash-message").Text()) - assert.Equal(t, `form.RedirectURIs"ftp://127.0.0.1" is not a valid URL.`, msg) + assert.Equal(t, `RedirectURIs: "ftp://127.0.0.1" is not a valid URL.`, msg) }) t.Run("OK", func(t *testing.T) { defer tests.PrintCurrentTest(t)() - + defer test.MockVariableValue(&setting.OAuth2.CustomSchemes, []string{"my-app"})() req := NewRequestWithValues(t, "POST", "/user/settings/applications/oauth2/2", map[string]string{ "application_name": "Test native app", "redirect_uris": "http://127.0.0.1", "confidential_client": "false", }) session.MakeRequest(t, req, http.StatusSeeOther) + + req = NewRequestWithValues(t, "POST", "/user/settings/applications/oauth2/2", map[string]string{ + "application_name": "Test native app", + "redirect_uris": "my-app://127.0.0.1", + "confidential_client": "false", + }) + session.MakeRequest(t, req, http.StatusSeeOther) }) }) })