pref: 优化MFA (#7153)

* perf: 优化mfa 和登录

* perf: stash

* stash

* pref: 基本完成

* perf: remove init function

* perf: 优化命名

* perf: 优化backends

* perf: 基本完成优化

* perf: 修复首页登录时没有 toastr 的问题

Co-authored-by: ibuler <ibuler@qq.com>
Co-authored-by: Jiangjie.Bai <32935519+BaiJiangJie@users.noreply.github.com>
This commit is contained in:
fit2bot
2021-11-10 11:30:48 +08:00
committed by GitHub
parent bac974b4f2
commit 17303c0550
44 changed files with 1373 additions and 977 deletions

View File

@@ -1,24 +1,26 @@
# -*- coding: utf-8 -*-
#
import inspect
from django.utils.http import urlencode
from functools import partial
import time
from typing import Callable
from django.utils.http import urlencode
from django.core.cache import cache
from django.conf import settings
from django.urls import reverse_lazy
from django.contrib import auth
from django.utils.translation import ugettext as _
from rest_framework.request import Request
from django.contrib.auth import (
BACKEND_SESSION_KEY, _get_backends,
PermissionDenied, user_login_failed, _clean_credentials
)
from django.shortcuts import reverse, redirect
from django.shortcuts import reverse, redirect, get_object_or_404
from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get, FlashMessageUtil
from acls.models import LoginACL
from users.models import User, MFAType
from users.models import User
from users.utils import LoginBlockUtil, MFABlockUtils
from . import errors
from .utils import rsa_decrypt, gen_key_pair
@@ -32,8 +34,7 @@ def check_backend_can_auth(username, backend_path, allowed_auth_backends):
if allowed_auth_backends is not None and backend_path not in allowed_auth_backends:
logger.debug('Skip user auth backend: {}, {} not in'.format(
username, backend_path, ','.join(allowed_auth_backends)
)
)
))
return False
return True
@@ -109,17 +110,18 @@ class PasswordEncryptionViewMixin:
def decrypt_passwd(self, raw_passwd):
# 获取解密密钥,对密码进行解密
rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY)
if rsa_private_key is not None:
try:
return rsa_decrypt(raw_passwd, rsa_private_key)
except Exception as e:
logger.error(e, exc_info=True)
logger.error(
f'Decrypt password failed: password[{raw_passwd}] '
f'rsa_private_key[{rsa_private_key}]'
)
return None
return raw_passwd
if rsa_private_key is None:
return raw_passwd
try:
return rsa_decrypt(raw_passwd, rsa_private_key)
except Exception as e:
logger.error(e, exc_info=True)
logger.error(
f'Decrypt password failed: password[{raw_passwd}] '
f'rsa_private_key[{rsa_private_key}]'
)
return None
def get_request_ip(self):
ip = ''
@@ -132,7 +134,7 @@ class PasswordEncryptionViewMixin:
# 生成加解密密钥对public_key传递给前端private_key存入session中供解密使用
rsa_public_key = self.request.session.get(RSA_PUBLIC_KEY)
rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY)
if not all((rsa_private_key, rsa_public_key)):
if not all([rsa_private_key, rsa_public_key]):
rsa_private_key, rsa_public_key = gen_key_pair()
rsa_public_key = rsa_public_key.replace('\n', '\\n')
self.request.session[RSA_PRIVATE_KEY] = rsa_private_key
@@ -144,49 +146,9 @@ class PasswordEncryptionViewMixin:
return super().get_context_data(**kwargs)
class AuthMixin(PasswordEncryptionViewMixin):
request = None
partial_credential_error = None
key_prefix_captcha = "_LOGIN_INVALID_{}"
def get_user_from_session(self):
if self.request.session.is_empty():
raise errors.SessionEmptyError()
if all((self.request.user,
not self.request.user.is_anonymous,
BACKEND_SESSION_KEY in self.request.session)):
user = self.request.user
user.backend = self.request.session[BACKEND_SESSION_KEY]
return user
user_id = self.request.session.get('user_id')
if not user_id:
user = None
else:
user = get_object_or_none(User, pk=user_id)
if not user:
raise errors.SessionEmptyError()
user.backend = self.request.session.get("auth_backend")
return user
def _check_is_block(self, username, raise_exception=True):
ip = self.get_request_ip()
if LoginBlockUtil(username, ip).is_block():
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockLoginError(username=username, ip=ip)
if raise_exception:
raise errors.BlockLoginError(username=username, ip=ip)
else:
return exception
def check_is_block(self, raise_exception=True):
if hasattr(self.request, 'data'):
username = self.request.data.get("username")
else:
username = self.request.POST.get("username")
self._check_is_block(username, raise_exception)
class CommonMixin(PasswordEncryptionViewMixin):
request: Request
get_request_ip: Callable
def raise_credential_error(self, error):
raise self.partial_credential_error(error=error)
@@ -197,6 +159,31 @@ class AuthMixin(PasswordEncryptionViewMixin):
ip=ip, request=request
)
def get_user_from_session(self):
if self.request.session.is_empty():
raise errors.SessionEmptyError()
if all([
self.request.user,
not self.request.user.is_anonymous,
BACKEND_SESSION_KEY in self.request.session
]):
user = self.request.user
user.backend = self.request.session[BACKEND_SESSION_KEY]
return user
user_id = self.request.session.get('user_id')
auth_password = self.request.session.get('auth_password')
auth_expired_at = self.request.session.get('auth_password_expired_at')
auth_expired = auth_expired_at < time.time() if auth_expired_at else False
if not user_id or not auth_password or auth_expired:
raise errors.SessionEmptyError()
user = get_object_or_404(User, pk=user_id)
user.backend = self.request.session.get("auth_backend")
return user
def get_auth_data(self, decrypt_passwd=False):
request = self.request
if hasattr(request, 'data'):
@@ -214,6 +201,31 @@ class AuthMixin(PasswordEncryptionViewMixin):
password = password + challenge.strip()
return username, password, public_key, ip, auto_login
class AuthPreCheckMixin:
request: Request
get_request_ip: Callable
raise_credential_error: Callable
def _check_is_block(self, username, raise_exception=True):
ip = self.get_request_ip()
is_block = LoginBlockUtil(username, ip).is_block()
if not is_block:
return
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockLoginError(username=username, ip=ip)
if raise_exception:
raise errors.BlockLoginError(username=username, ip=ip)
else:
return exception
def check_is_block(self, raise_exception=True):
if hasattr(self.request, 'data'):
username = self.request.data.get("username")
else:
username = self.request.POST.get("username")
self._check_is_block(username, raise_exception)
def _check_only_allow_exists_user_auth(self, username):
# 仅允许预先存在的用户认证
if not settings.ONLY_ALLOW_EXIST_USER_AUTH:
@@ -224,105 +236,92 @@ class AuthMixin(PasswordEncryptionViewMixin):
logger.error(f"Only allow exist user auth, login failed: {username}")
self.raise_credential_error(errors.reason_user_not_exist)
def _check_auth_user_is_valid(self, username, password, public_key):
user = authenticate(self.request, username=username, password=password, public_key=public_key)
if not user:
self.raise_credential_error(errors.reason_password_failed)
elif user.is_expired:
self.raise_credential_error(errors.reason_user_expired)
elif not user.is_active:
self.raise_credential_error(errors.reason_user_inactive)
return user
def _check_login_mfa_login_if_need(self, user):
class MFAMixin:
request: Request
get_user_from_session: Callable
get_request_ip: Callable
def _check_login_page_mfa_if_need(self, user):
if not settings.SECURITY_MFA_IN_LOGIN_PAGE:
return
request = self.request
if hasattr(request, 'data'):
data = request.data
else:
data = request.POST
data = request.data if hasattr(request, 'data') else request.POST
code = data.get('code')
mfa_type = data.get('mfa_type')
if settings.SECURITY_MFA_IN_LOGIN_PAGE and mfa_type:
if not code:
if mfa_type == MFAType.OTP and bool(user.otp_secret_key):
raise errors.OTPCodeRequiredError
elif mfa_type == MFAType.SMS_CODE:
raise errors.SMSCodeRequiredError
self.check_user_mfa(code, mfa_type, user=user)
mfa_type = data.get('mfa_type', 'otp')
if not code:
raise errors.MFACodeRequiredError
self._do_check_user_mfa(code, mfa_type, user=user)
def _check_login_acl(self, user, ip):
# ACL 限制用户登录
is_allowed, limit_type = LoginACL.allow_user_to_login(user, ip)
if not is_allowed:
if limit_type == 'ip':
raise errors.LoginIPNotAllowed(username=user.username, request=self.request)
elif limit_type == 'time':
raise errors.TimePeriodNotAllowed(username=user.username, request=self.request)
def check_user_mfa_if_need(self, user):
if self.request.session.get('auth_mfa'):
return
if not user.mfa_enabled:
return
def set_login_failed_mark(self):
active_mfa_mapper = user.active_mfa_backends_mapper
if not active_mfa_mapper:
url = reverse('authentication:user-otp-enable-start')
raise errors.MFAUnsetError(user, self.request, url)
raise errors.MFARequiredError(mfa_types=tuple(active_mfa_mapper.keys()))
def mark_mfa_ok(self, mfa_type):
self.request.session['auth_mfa'] = 1
self.request.session['auth_mfa_time'] = time.time()
self.request.session['auth_mfa_required'] = 0
self.request.session['auth_mfa_type'] = mfa_type
def clean_mfa_mark(self):
keys = ['auth_mfa', 'auth_mfa_time', 'auth_mfa_required', 'auth_mfa_type']
for k in keys:
self.request.session.pop(k, '')
def check_mfa_is_block(self, username, ip, raise_exception=True):
blocked = MFABlockUtils(username, ip).is_block()
if not blocked:
return
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockMFAError(username=username, request=self.request, ip=ip)
if raise_exception:
raise exception
else:
return exception
def _do_check_user_mfa(self, code, mfa_type, user=None):
user = user if user else self.get_user_from_session()
if not user.mfa_enabled:
return
# 监测 MFA 是不是屏蔽了
ip = self.get_request_ip()
cache.set(self.key_prefix_captcha.format(ip), 1, 3600)
self.check_mfa_is_block(user.username, ip)
def set_passwd_verify_on_session(self, user: User):
self.request.session['user_id'] = str(user.id)
self.request.session['auth_password'] = 1
self.request.session['auth_password_expired_at'] = time.time() + settings.AUTH_EXPIRED_SECONDS
ok = False
mfa_backend = user.get_mfa_backend_by_type(mfa_type)
if mfa_backend:
ok, msg = mfa_backend.check_code(code)
else:
msg = _('The MFA type({}) is not supported'.format(mfa_type))
def check_is_need_captcha(self):
# 最近有登录失败时需要填写验证码
ip = get_request_ip(self.request)
need = cache.get(self.key_prefix_captcha.format(ip))
return need
if ok:
self.mark_mfa_ok(mfa_type)
return
def check_user_auth(self, decrypt_passwd=False):
self.check_is_block()
username, password, public_key, ip, auto_login = self.get_auth_data(decrypt_passwd)
raise errors.MFAFailedError(
username=user.username,
request=self.request,
ip=ip, mfa_type=mfa_type,
error=msg
)
self._check_only_allow_exists_user_auth(username)
user = self._check_auth_user_is_valid(username, password, public_key)
# 校验login-acl规则
self._check_login_acl(user, ip)
self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password)
self._check_passwd_need_update(user)
@staticmethod
def get_user_mfa_context(user=None):
mfa_backends = User.get_user_mfa_backends(user)
return {'mfa_backends': mfa_backends}
# 校验login-mfa, 如果登录页面上显示 mfa 的话
self._check_login_mfa_login_if_need(user)
LoginBlockUtil(username, ip).clean_failed_count()
request = self.request
request.session['auth_password'] = 1
request.session['user_id'] = str(user.id)
request.session['auto_login'] = auto_login
request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL)
return user
def _check_is_local_user(self, user: User):
if user.source != User.Source.local:
raise self.raise_credential_error(error=errors.only_local_users_are_allowed)
def check_oauth2_auth(self, user: User, auth_backend):
ip = self.get_request_ip()
request = self.request
self._set_partial_credential_error(user.username, ip, request)
if user.is_expired:
self.raise_credential_error(errors.reason_user_expired)
elif not user.is_active:
self.raise_credential_error(errors.reason_user_inactive)
self._check_is_block(user.username)
self._check_login_acl(user, ip)
LoginBlockUtil(user.username, ip).clean_failed_count()
MFABlockUtils(user.username, ip).clean_failed_count()
request.session['auth_password'] = 1
request.session['user_id'] = str(user.id)
request.session['auth_backend'] = auth_backend
return user
class AuthPostCheckMixin:
@classmethod
def generate_reset_password_url_with_flash_msg(cls, user, message):
reset_passwd_url = reverse('authentication:reset-password')
@@ -344,14 +343,14 @@ class AuthMixin(PasswordEncryptionViewMixin):
if user.is_superuser and password == 'admin':
message = _('Your password is too simple, please change it for security')
url = cls.generate_reset_password_url_with_flash_msg(user, message=message)
raise errors.PasswdTooSimple(url)
raise errors.PasswordTooSimple(url)
@classmethod
def _check_passwd_need_update(cls, user: User):
if user.need_update_password:
message = _('You should to change your password before login')
url = cls.generate_reset_password_url_with_flash_msg(user, message)
raise errors.PasswdNeedUpdate(url)
raise errors.PasswordNeedUpdate(url)
@classmethod
def _check_password_require_reset_or_not(cls, user: User):
@@ -360,76 +359,20 @@ class AuthMixin(PasswordEncryptionViewMixin):
url = cls.generate_reset_password_url_with_flash_msg(user, message)
raise errors.PasswordRequireResetError(url)
def check_user_auth_if_need(self, decrypt_passwd=False):
request = self.request
if request.session.get('auth_password') and \
request.session.get('user_id'):
user = self.get_user_from_session()
if user:
return user
return self.check_user_auth(decrypt_passwd=decrypt_passwd)
def check_user_mfa_if_need(self, user):
if self.request.session.get('auth_mfa'):
class AuthACLMixin:
request: Request
get_request_ip: Callable
def _check_login_acl(self, user, ip):
# ACL 限制用户登录
is_allowed, limit_type = LoginACL.allow_user_to_login(user, ip)
if is_allowed:
return
if settings.OTP_IN_RADIUS:
return
if not user.mfa_enabled:
return
unset, url = user.mfa_enabled_but_not_set()
if unset:
raise errors.MFAUnsetError(user, self.request, url)
raise errors.MFARequiredError(mfa_types=user.get_supported_mfa_types())
def mark_mfa_ok(self, mfa_type=MFAType.OTP):
self.request.session['auth_mfa'] = 1
self.request.session['auth_mfa_time'] = time.time()
self.request.session['auth_mfa_required'] = ''
self.request.session['auth_mfa_type'] = mfa_type
def clean_mfa_mark(self):
self.request.session['auth_mfa'] = ''
self.request.session['auth_mfa_time'] = ''
self.request.session['auth_mfa_required'] = ''
self.request.session['auth_mfa_type'] = ''
def check_mfa_is_block(self, username, ip, raise_exception=True):
blocked = MFABlockUtils(username, ip).is_block()
if not blocked:
return
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockMFAError(username=username, request=self.request, ip=ip)
if raise_exception:
raise exception
else:
return exception
def check_user_mfa(self, code, mfa_type=MFAType.OTP, user=None):
user = user if user else self.get_user_from_session()
if not user.mfa_enabled:
return
if not bool(user.phone) and mfa_type == MFAType.SMS_CODE:
raise errors.UserPhoneNotSet
if not bool(user.otp_secret_key) and mfa_type == MFAType.OTP:
self.set_passwd_verify_on_session(user)
raise errors.OTPBindRequiredError(reverse_lazy('authentication:user-otp-enable-bind'))
ip = self.get_request_ip()
self.check_mfa_is_block(user.username, ip)
ok = user.check_mfa(code, mfa_type=mfa_type)
if ok:
self.mark_mfa_ok()
return
raise errors.MFAFailedError(
username=user.username,
request=self.request,
ip=ip, mfa_type=mfa_type,
)
if limit_type == 'ip':
raise errors.LoginIPNotAllowed(username=user.username, request=self.request)
elif limit_type == 'time':
raise errors.TimePeriodNotAllowed(username=user.username, request=self.request)
def get_ticket(self):
from tickets.models import Ticket
@@ -480,11 +423,99 @@ class AuthMixin(PasswordEncryptionViewMixin):
self.get_ticket_or_create(confirm_setting)
self.check_user_login_confirm()
class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPostCheckMixin):
request = None
partial_credential_error = None
key_prefix_captcha = "_LOGIN_INVALID_{}"
def _check_auth_user_is_valid(self, username, password, public_key):
user = authenticate(
self.request, username=username,
password=password, public_key=public_key
)
if not user:
self.raise_credential_error(errors.reason_password_failed)
elif user.is_expired:
self.raise_credential_error(errors.reason_user_expired)
elif not user.is_active:
self.raise_credential_error(errors.reason_user_inactive)
return user
def set_login_failed_mark(self):
ip = self.get_request_ip()
cache.set(self.key_prefix_captcha.format(ip), 1, 3600)
def check_is_need_captcha(self):
# 最近有登录失败时需要填写验证码
ip = get_request_ip(self.request)
need = cache.get(self.key_prefix_captcha.format(ip))
return need
def check_user_auth(self, decrypt_passwd=False):
# pre check
self.check_is_block()
username, password, public_key, ip, auto_login = self.get_auth_data(decrypt_passwd)
self._check_only_allow_exists_user_auth(username)
# check auth
user = self._check_auth_user_is_valid(username, password, public_key)
# 校验login-acl规则
self._check_login_acl(user, ip)
# post check
self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password)
self._check_passwd_need_update(user)
# 校验login-mfa, 如果登录页面上显示 mfa 的话
self._check_login_page_mfa_if_need(user)
# 标记密码验证成功
self.mark_password_ok(user=user, auto_login=auto_login)
LoginBlockUtil(user.username, ip).clean_failed_count()
return user
def mark_password_ok(self, user, auto_login=False):
request = self.request
request.session['auth_password'] = 1
request.session['auth_password_expired_at'] = time.time() + settings.AUTH_EXPIRED_SECONDS
request.session['user_id'] = str(user.id)
request.session['auto_login'] = auto_login
request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL)
def check_oauth2_auth(self, user: User, auth_backend):
ip = self.get_request_ip()
request = self.request
self._set_partial_credential_error(user.username, ip, request)
if user.is_expired:
self.raise_credential_error(errors.reason_user_expired)
elif not user.is_active:
self.raise_credential_error(errors.reason_user_inactive)
self._check_is_block(user.username)
self._check_login_acl(user, ip)
LoginBlockUtil(user.username, ip).clean_failed_count()
MFABlockUtils(user.username, ip).clean_failed_count()
self.mark_password_ok(user, False)
return user
def check_user_auth_if_need(self, decrypt_passwd=False):
request = self.request
if not request.session.get('auth_password'):
return self.check_user_auth(decrypt_passwd=decrypt_passwd)
return self.get_user_from_session()
def clear_auth_mark(self):
self.request.session['auth_password'] = ''
self.request.session['auth_user_id'] = ''
self.request.session['auth_confirm'] = ''
self.request.session['auth_ticket_id'] = ''
keys = ['auth_password', 'user_id', 'auth_confirm', 'auth_ticket_id']
for k in keys:
self.request.session.pop(k, '')
def send_auth_signal(self, success=True, user=None, username='', reason=''):
if success:
@@ -503,31 +534,3 @@ class AuthMixin(PasswordEncryptionViewMixin):
if args:
guard_url = "%s?%s" % (guard_url, args)
return redirect(guard_url)
@staticmethod
def get_user_mfa_methods(user=None):
otp_enabled = user.otp_secret_key if user else True
# 没有用户时,或者有用户并且有电话配置
sms_enabled = any([user and user.phone, not user]) \
and settings.SMS_ENABLED and settings.XPACK_ENABLED
methods = [
{
'name': 'otp',
'label': 'MFA',
'enable': otp_enabled,
'selected': False,
},
{
'name': 'sms',
'label': _('SMS'),
'enable': sms_enabled,
'selected': False,
},
]
for item in methods:
if item['enable']:
item['selected'] = True
break
return methods