diff --git a/apps/assets/api/accounts.py b/apps/assets/api/accounts.py index 778916e64..e05bee3d2 100644 --- a/apps/assets/api/accounts.py +++ b/apps/assets/api/accounts.py @@ -64,8 +64,8 @@ class AccountViewSet(OrgBulkModelViewSet): permission_classes = (IsOrgAdmin,) def get_queryset(self): - queryset = super().get_queryset()\ - .annotate(ip=F('asset__ip'))\ + queryset = super().get_queryset() \ + .annotate(ip=F('asset__ip')) \ .annotate(hostname=F('asset__hostname')) return queryset @@ -110,4 +110,5 @@ class AccountTaskCreateAPI(CreateAPIView): def get_exception_handler(self): def handler(e, context): return Response({"error": str(e)}, status=400) + return handler diff --git a/apps/authentication/backends/cas/__init__.py b/apps/authentication/backends/cas/__init__.py index bf0101c81..bbdbdb814 100644 --- a/apps/authentication/backends/cas/__init__.py +++ b/apps/authentication/backends/cas/__init__.py @@ -1,4 +1,3 @@ # -*- coding: utf-8 -*- # from .backends import * -from .callback import * diff --git a/apps/authentication/backends/cas/callback.py b/apps/authentication/backends/cas/callback.py deleted file mode 100644 index 64201e607..000000000 --- a/apps/authentication/backends/cas/callback.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# -from django.contrib.auth import get_user_model - - -User = get_user_model() - - -def cas_callback(response): - username = response['username'] - user, user_created = User.objects.get_or_create(username=username) - profile, created = user.get_profile() - - profile.role = response['attributes']['role'] - profile.birth_date = response['attributes']['birth_date'] - profile.save() diff --git a/apps/authentication/middleware.py b/apps/authentication/middleware.py new file mode 100644 index 000000000..59eabff75 --- /dev/null +++ b/apps/authentication/middleware.py @@ -0,0 +1,14 @@ +from django.shortcuts import redirect + + +class MFAMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + response = self.get_response(request) + if request.path.find('/auth/login/otp/') > -1: + return response + if request.session.get('auth_mfa_required'): + return redirect('authentication:login-otp') + return response diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index dd68cd483..30e5d63cd 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -315,6 +315,7 @@ class AuthMixin: self.request.session['auth_mfa'] = 1 self.request.session['auth_mfa_time'] = time.time() self.request.session['auth_mfa_type'] = 'otp' + self.request.session['auth_mfa_required'] = '' def check_mfa_is_block(self, username, ip, raise_exception=True): if MFABlockUtils(username, ip).is_block(): @@ -391,7 +392,6 @@ class AuthMixin: def clear_auth_mark(self): self.request.session['auth_password'] = '' self.request.session['auth_user_id'] = '' - self.request.session['auth_mfa'] = '' self.request.session['auth_confirm'] = '' self.request.session['auth_ticket_id'] = '' diff --git a/apps/authentication/signals_handlers.py b/apps/authentication/signals_handlers.py index 8e353ddf6..c6c1db680 100644 --- a/apps/authentication/signals_handlers.py +++ b/apps/authentication/signals_handlers.py @@ -13,6 +13,10 @@ from .signals import post_auth_success, post_auth_failed @receiver(user_logged_in) def on_user_auth_login_success(sender, user, request, **kwargs): + # 开启了 MFA,且没有校验过 + if user.mfa_enabled and not request.session.get('auth_mfa'): + request.session['auth_mfa_required'] = 1 + if settings.USER_LOGIN_SINGLE_MACHINE_ENABLED: user_id = 'single_machine_login_' + str(user.id) session_key = cache.get(user_id) diff --git a/apps/jumpserver/settings/base.py b/apps/jumpserver/settings/base.py index a3b2c7cf6..fd8feeaad 100644 --- a/apps/jumpserver/settings/base.py +++ b/apps/jumpserver/settings/base.py @@ -87,6 +87,7 @@ MIDDLEWARE = [ 'orgs.middleware.OrgMiddleware', 'authentication.backends.oidc.middleware.OIDCRefreshIDTokenMiddleware', 'authentication.backends.cas.middleware.CASMiddleware', + 'authentication.middleware.MFAMiddleware', 'simple_history.middleware.HistoryRequestMiddleware', ]