# -*- coding: utf-8 -*- # import inspect import threading import time import uuid from functools import partial from typing import Callable from werkzeug.local import Local from django.conf import settings from django.contrib import auth from django.contrib.auth import ( BACKEND_SESSION_KEY, load_backend, PermissionDenied, user_login_failed, _clean_credentials, ) from django.contrib.auth import get_user_model from django.core.cache import cache from django.core.exceptions import ImproperlyConfigured from django.db.models import Q from django.shortcuts import reverse, redirect, get_object_or_404 from django.utils.http import urlencode from django.utils.translation import gettext as _ from rest_framework.request import Request from acls.models import LoginACL from apps.jumpserver.settings.auth import AUTHENTICATION_BACKENDS_THIRD_PARTY from common.utils import get_request_ip_or_data, get_request_ip, get_logger, bulk_get, FlashMessageUtil from users.models import User from users.utils import LoginBlockUtil, MFABlockUtils, LoginIpBlockUtil from . import errors from .signals import post_auth_success, post_auth_failed logger = get_logger(__name__) # 模块级别的线程上下文,用于 authenticate 函数中标记当前线程 _auth_thread_context = Local() # 保存 Django 原始的 get_or_create 方法(在模块加载时保存一次) def _save_original_get_or_create(): """保存 Django 原始的 get_or_create 方法""" from django.contrib.auth import get_user_model as get_user_model_func UserModel = get_user_model_func() return UserModel.objects.get_or_create _django_original_get_or_create = _save_original_get_or_create() class OnlyAllowExistUserAuthError(Exception): pass def _authenticate_context(func): """ 装饰器:管理 authenticate 函数的执行上下文 功能: 1. 执行前: - 在线程本地存储中标记当前正在执行 authenticate - 临时替换 UserModel.objects.get_or_create 方法 2. 执行后: - 清理线程本地存储标记 - 恢复 get_or_create 为 Django 原始方法 作用: - 确保 get_or_create 行为仅在 authenticate 生命周期内生效 - 支持 ONLY_ALLOW_EXIST_USER_AUTH 配置的线程安全实现 - 防止跨请求或跨线程的状态污染 """ from functools import wraps @wraps(func) def wrapper(request=None, **credentials): from django.contrib.auth import get_user_model UserModel = get_user_model() def custom_get_or_create(*args, **kwargs): create_username = kwargs.get('username') logger.debug(f"get_or_create: thread_id={threading.get_ident()}, username={create_username}") # 如果当前线程正在执行 authenticate 且仅允许已存在用户认证,则提前判断用户是否存在 if ( getattr(_auth_thread_context, 'in_authenticate', False) and settings.ONLY_ALLOW_EXIST_USER_AUTH ): try: UserModel.objects.get(username=create_username) except UserModel.DoesNotExist: raise OnlyAllowExistUserAuthError # 调用 Django 原始方法(已是绑定方法,直接传参) return _django_original_get_or_create(*args, **kwargs) try: # 执行前:设置线程上下文和 monkey-patch setattr(_auth_thread_context, 'in_authenticate', True) UserModel.objects.get_or_create = custom_get_or_create # 执行原函数 return func(request, **credentials) finally: # 执行后:清理线程上下文和恢复原始方法 try: if hasattr(_auth_thread_context, 'in_authenticate'): delattr(_auth_thread_context, 'in_authenticate') except Exception: pass try: UserModel.objects.get_or_create = _django_original_get_or_create except Exception: pass return wrapper def _get_backends(return_tuples=False): backends = [] for backend_path in settings.AUTHENTICATION_BACKENDS: backend = load_backend(backend_path) # 检查 backend 是否启用 if not backend.is_enabled(): continue backends.append((backend, backend_path) if return_tuples else backend) if not backends: raise ImproperlyConfigured( 'No authentication backends have been defined. Does ' 'AUTHENTICATION_BACKENDS contain anything?' ) return backends auth._get_backends = _get_backends @_authenticate_context def authenticate(request=None, **credentials): """ If the given credentials are valid, return a User object. """ temp_user = None username = credentials.get('username') for backend, backend_path in _get_backends(return_tuples=True): # 检查用户名是否允许认证 (预先检查,不浪费认证时间) logger.info('Try using auth backend: {}'.format(str(backend))) if not backend.username_allow_authenticate(username): continue # 原生 backend_signature = inspect.signature(backend.authenticate) try: backend_signature.bind(request, **credentials) except TypeError: # This backend doesn't accept these credentials as arguments. Try the next one. continue try: user = backend.authenticate(request, **credentials) except PermissionDenied: # This backend says to stop in our tracks - this user should not be allowed in at all. break except OnlyAllowExistUserAuthError: if request: request.error_message = _( '''The administrator has enabled "Only allow existing users to log in", and the current user is not in the user list. Please contact the administrator.''' ) continue if user is None: continue if not user.is_valid: temp_user = user temp_user.backend = backend_path if request: request.error_message = _('User is invalid') return temp_user # 检查用户是否允许认证 if not backend.user_allow_authenticate(user): temp_user = user temp_user.backend = backend_path continue # Annotate the user object with the path of the backend. user.backend = backend_path return user else: if temp_user is not None: source_display = temp_user.source_display if request: request.error_message = _( ''' The administrator has enabled 'Only allow login from user source'. The current user source is {}. Please contact the administrator. ''' ).format(source_display) return temp_user # The credentials supplied are invalid to all backends, fire signal user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request) auth.authenticate = authenticate class CommonMixin: request: Request _ip = '' def get_request_ip(self): if not self._ip: self._ip = get_request_ip_or_data(self.request) return self._ip def raise_credential_error(self, error): raise self.partial_credential_error(error=error) def _set_partial_credential_error(self, username, ip, request): self.partial_credential_error = partial( errors.CredentialError, username=username, ip=ip, request=request ) def get_user_from_session(self): if self.request.session.is_empty(): raise errors.SessionEmptyError() if all([ self.request.user, not self.request.user.is_anonymous, BACKEND_SESSION_KEY in self.request.session ]): user = self.request.user user.backend = self.request.session[BACKEND_SESSION_KEY] return user user_id = self.request.session.get('user_id') auth_ok = self.request.session.get('auth_password') auth_expired_at = self.request.session.get('auth_password_expired_at') auth_expired = auth_expired_at < time.time() if auth_expired_at else False if not user_id or not auth_ok or auth_expired: raise errors.SessionEmptyError() user = get_object_or_404(User, pk=user_id) user.backend = self.request.session.get("auth_backend") return user def get_auth_data(self, data): request = self.request items = ['username', 'password', 'challenge', 'public_key', 'auto_login'] username, password, challenge, public_key, auto_login = bulk_get(data, items, default='') ip = self.get_request_ip() self._set_partial_credential_error(username=username, ip=ip, request=request) password = password + challenge.strip() return username, password, public_key, ip, auto_login class AuthPreCheckMixin: request: Request get_request_ip: Callable raise_credential_error: Callable def _check_is_block(self, username, raise_exception=True): ip = self.get_request_ip() if LoginIpBlockUtil(ip).is_block(): raise errors.BlockGlobalIpLoginError(username=username, ip=ip) is_block = LoginBlockUtil(username, ip).is_block() if not is_block: return logger.warning('Ip was blocked' + ': ' + username + ':' + ip) exception = errors.BlockLoginError(username=username, ip=ip, request=self.request) if raise_exception: raise exception else: return exception def check_is_block(self, raise_exception=True): if hasattr(self.request, 'data'): username = self.request.data.get("username") else: username = self.request.POST.get("username") self._check_is_block(username, raise_exception) def _check_only_allow_exists_user_auth(self, username): # 仅允许预先存在的用户认证 if not settings.ONLY_ALLOW_EXIST_USER_AUTH: return q = Q(username=username) | Q(email=username) exist = User.objects.filter(q).exists() if not exist: logger.error(f"Only allow exist user auth, login failed: {username}") self.raise_credential_error(errors.reason_user_not_exist) class MFAMixin: request: Request get_user_from_session: Callable get_request_ip: Callable def _check_if_no_active_mfa(self, user): active_mfa_mapper = user.active_mfa_backends_mapper if not active_mfa_mapper: set_url = reverse('authentication:user-otp-enable-start') raise errors.MFAUnsetError(set_url, user, self.request) def _check_login_page_mfa_if_need(self, user): if not settings.SECURITY_MFA_IN_LOGIN_PAGE: return if not user.active_mfa_backends: return request = self.request data = request.data if hasattr(request, 'data') else request.POST code = data.get('code') mfa_type = data.get('mfa_type', 'otp') if not code: return self._do_check_user_mfa(code, mfa_type, user=user) def check_user_mfa_if_need(self, user): # 扫码登录的认证方式会执行该函数检查 mfa,跳转登录认证方式则通过ThirdPartyLoginMiddleware中间件检验 mfa if not settings.SECURITY_MFA_AUTH_ENABLED_FOR_THIRD_PARTY and \ self.request.session.get('auth_backend') in AUTHENTICATION_BACKENDS_THIRD_PARTY: return if self.request.session.get('auth_mfa') and \ self.request.session.get('auth_mfa_username') == user.username: return if not user.mfa_enabled: return active_mfa_names = user.active_mfa_backends_mapper.keys() raise errors.MFARequiredError(mfa_types=tuple(active_mfa_names)) def mark_mfa_ok(self, mfa_type, user): self.request.session['auth_mfa'] = 1 self.request.session['auth_mfa_username'] = user.username self.request.session['auth_mfa_time'] = time.time() self.request.session['auth_mfa_required'] = 0 self.request.session['auth_mfa_type'] = mfa_type MFABlockUtils(user.username, self.get_request_ip()).clean_failed_count() def clean_mfa_mark(self): keys = ['auth_mfa', 'auth_mfa_time', 'auth_mfa_required', 'auth_mfa_type', 'auth_mfa_username'] for k in keys: self.request.session.pop(k, '') def check_mfa_is_block(self, username, ip, raise_exception=True): blocked = MFABlockUtils(username, ip).is_block() if not blocked: return logger.warning('Ip was blocked' + ': ' + username + ':' + ip) exception = errors.BlockMFAError(username=username, request=self.request, ip=ip) if raise_exception: raise exception else: return exception def _do_check_user_mfa(self, code, mfa_type, user=None): user = user if user else self.get_user_from_session() if not user.mfa_enabled: return # 监测 MFA 是不是屏蔽了 ip = self.get_request_ip() self.check_mfa_is_block(user.username, ip) ok = False mfa_backend = user.get_mfa_backend_by_type(mfa_type) backend_error = _('The MFA type ({}) is not enabled') if not mfa_backend: msg = backend_error.format(mfa_type) elif not mfa_backend.is_active(): msg = backend_error.format(mfa_backend.display_name) else: mfa_backend.set_request(self.request) ok, msg = mfa_backend.check_code(code) if ok: self.mark_mfa_ok(mfa_type, user) return raise errors.MFAFailedError( username=user.username, request=self.request, ip=ip, mfa_type=mfa_type, error=msg ) @staticmethod def get_user_mfa_context(user=None): mfa_backends = User.get_user_mfa_backends(user) return {'mfa_backends': mfa_backends} @staticmethod def incr_mfa_failed_time(username, ip): util = MFABlockUtils(username, ip) util.incr_failed_count() class AuthPostCheckMixin: @classmethod def generate_reset_password_url_with_flash_msg(cls, user, message): reset_passwd_url = reverse('authentication:reset-password') query_str = urlencode({ 'token': user.generate_reset_token() }) reset_passwd_url = f'{reset_passwd_url}?{query_str}' message_data = { 'title': _('Please change your password'), 'message': message, 'interval': 3, 'redirect_url': reset_passwd_url, } return FlashMessageUtil.gen_message_url(message_data) @classmethod def _check_passwd_is_too_simple(cls, user: User, password): if not user.is_auth_backend_model(): return if user.check_passwd_too_simple(password) or user.check_leak_password(password): message = _('Your password is too simple, please change it for security') url = cls.generate_reset_password_url_with_flash_msg(user, message=message) raise errors.PasswordTooSimple(url) @classmethod def _check_passwd_need_update(cls, user: User): if not user.is_auth_backend_model(): return if user.check_need_update_password(): message = _('You should to change your password before login') url = cls.generate_reset_password_url_with_flash_msg(user, message) raise errors.PasswordNeedUpdate(url) @classmethod def _check_password_require_reset_or_not(cls, user: User): if not user.is_auth_backend_model(): return if user.password_has_expired: message = _('Your password has expired, please reset before logging in') url = cls.generate_reset_password_url_with_flash_msg(user, message) raise errors.PasswordRequireResetError(url) class AuthACLMixin: request: Request get_request_ip: Callable def _check_login_acl(self, user, ip): # ACL 限制用户登录 acl = LoginACL.get_match_rule_acls(user, ip) if not acl: return if acl.is_action(LoginACL.ActionChoices.accept): return if acl.is_action(LoginACL.ActionChoices.reject): raise errors.LoginACLNotAllowed(user.username, request=self.request) if acl.is_action(acl.ActionChoices.review): self.request.session['auth_confirm_required'] = '1' self.request.session['auth_acl_id'] = str(acl.id) return if acl.is_action(acl.ActionChoices.notice): self.request.session['auth_notice_required'] = '1' self.request.session['auth_acl_id'] = str(acl.id) def _check_third_party_login_acl(self): request = self.request error_message = getattr(request, 'error_message', None) if not error_message: return raise ValueError(error_message) def check_user_login_confirm_if_need(self, user): if not self.request.session.get("auth_confirm_required"): return acl_id = self.request.session.get('auth_acl_id') logger.debug('Login confirm acl id: {}'.format(acl_id)) if not acl_id: return acl = LoginACL.get_user_acls(user).filter(id=acl_id).first() if not acl: return if not acl.is_action(acl.ActionChoices.review): return self.get_ticket_or_create(acl, user) self.check_user_login_confirm() def get_ticket_or_create(self, acl, user): ticket = self.get_ticket() if not ticket or ticket.is_state(ticket.State.closed): ticket = acl.create_confirm_ticket(self.request, user) self.request.session['auth_ticket_id'] = str(ticket.id) return ticket def check_user_login_confirm(self): ticket = self.get_ticket() if not ticket: raise errors.LoginConfirmOtherError('', "Not found", '') elif ticket.is_state(ticket.State.approved): self.request.session["auth_confirm_required"] = '' return elif ticket.is_status(ticket.Status.open): raise errors.LoginConfirmWaitError(ticket.id) else: # rejected, closed ticket_id = ticket.id status = ticket.get_state_display() username = ticket.applicant.username raise errors.LoginConfirmOtherError(ticket_id, status, username) def get_ticket(self): from tickets.models import ApplyLoginTicket ticket_id = self.request.session.get("auth_ticket_id") logger.debug('Login confirm ticket id: {}'.format(ticket_id)) if not ticket_id: ticket = None else: ticket = ApplyLoginTicket.all().filter(id=ticket_id).first() return ticket class AuthFaceMixin: request: Request @staticmethod def _get_face_cache_key(token): from authentication.const import FACE_CONTEXT_CACHE_KEY_PREFIX return f"{FACE_CONTEXT_CACHE_KEY_PREFIX}_{token}" @staticmethod def _is_context_finished(context): return context.get('is_finished', False) @staticmethod def _is_context_success(context): return context.get('success', False) def create_face_verify_context(self, data=None): token = uuid.uuid4().hex context_data = { "action": "mfa", "token": token, "user_id": self.request.user.id, "is_finished": False } if data: context_data.update(data) cache_key = self._get_face_cache_key(token) from .const import FACE_CONTEXT_CACHE_TTL, FACE_SESSION_KEY cache.set(cache_key, context_data, FACE_CONTEXT_CACHE_TTL) self.request.session[FACE_SESSION_KEY] = token return token def get_face_token_from_session(self): from authentication.const import FACE_SESSION_KEY token = self.request.session.get(FACE_SESSION_KEY) if not token: raise ValueError("Face recognition token is missing from the session.") return token def get_face_verify_context(self): token = self.get_face_token_from_session() cache_key = self._get_face_cache_key(token) context = cache.get(cache_key) if not context: raise ValueError(f"Face recognition context does not exist for token: {token}") return context def get_face_code(self): context = self.get_face_verify_context() if not self._is_context_finished(context): raise RuntimeError("Face recognition is not yet completed.") if not self._is_context_success(context): msg = context.get('error_message', '') raise RuntimeError(msg) face_code = context.get('face_code') if not face_code: raise ValueError("Face code is missing from the context.") return face_code class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, AuthFaceMixin, MFAMixin, AuthPostCheckMixin, ): request = None partial_credential_error = None key_prefix_captcha = "_LOGIN_INVALID_{}" def _check_auth_user_is_valid(self, username, password, public_key): credentials = {'username': username} if password: credentials['password'] = password if public_key: credentials['public_key'] = public_key user = authenticate(self.request, **credentials) if not user: self.raise_credential_error(errors.reason_password_failed) self.request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL) if user.is_expired: self.raise_credential_error(errors.reason_user_expired) elif not user.is_active: self.raise_credential_error(errors.reason_user_inactive) return user def set_login_failed_mark(self): ip = self.get_request_ip() cache.set(self.key_prefix_captcha.format(ip), 1, 3600) def check_is_need_captcha(self): # 最近有登录失败时需要填写验证码 ip = get_request_ip(self.request) need = cache.get(self.key_prefix_captcha.format(ip)) return need def check_user_auth(self, valid_data=None): # pre check self.check_is_block() username, password, public_key, ip, auto_login = self.get_auth_data(valid_data) self._check_only_allow_exists_user_auth(username) # check auth user = self._check_auth_user_is_valid(username, password, public_key) # 校验login-acl规则 self._check_login_acl(user, ip) # post check self._check_password_require_reset_or_not(user) self._check_passwd_is_too_simple(user, password) self._check_passwd_need_update(user) user.cache_login_password_if_need(password) # 校验login-mfa, 如果登录页面上显示 mfa 的话 self._check_login_page_mfa_if_need(user) # 标记密码验证成功 self.mark_password_ok(user=user, auto_login=auto_login) LoginBlockUtil(user.username, ip).clean_failed_count() LoginIpBlockUtil(ip).clean_block_if_need() return user def mark_password_ok(self, user, auto_login=False, auth_backend=None): request = self.request request.session['auth_password'] = 1 request.session['auth_password_expired_at'] = time.time() + settings.AUTH_EXPIRED_SECONDS request.session['user_id'] = str(user.id) request.session['auto_login'] = auto_login if not auth_backend: auth_backend = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL) request.session['auth_backend'] = auth_backend def check_oauth2_auth(self, user: User, auth_backend): ip = self.get_request_ip() request = self.request self._set_partial_credential_error(user.username, ip, request) if user.is_expired: self.raise_credential_error(errors.reason_user_expired) elif not user.is_active: self.raise_credential_error(errors.reason_user_inactive) self._check_is_block(user.username) self._check_login_acl(user, ip) LoginBlockUtil(user.username, ip).clean_failed_count() LoginIpBlockUtil(ip).clean_block_if_need() MFABlockUtils(user.username, ip).clean_failed_count() self.mark_password_ok(user, False, auth_backend) return user def get_user_or_auth(self, valid_data): request = self.request if request.session.get('auth_password'): return self.get_user_from_session() else: return self.check_user_auth(valid_data) def clear_auth_mark(self): keys = [ 'auth_password', 'user_id', 'auth_confirm_required', 'auth_notice_required', 'auth_ticket_id', 'auth_acl_id', 'user_session_id', 'user_log_id', 'can_send_notifications' ] for k in keys: self.request.session.pop(k, '') def send_auth_signal(self, success=True, user=None, username='', reason=''): if success: post_auth_success.send( sender=self.__class__, user=user, request=self.request ) else: post_auth_failed.send( sender=self.__class__, username=username, request=self.request, reason=reason ) def redirect_to_guard_view(self): guard_url = reverse('authentication:login-guard') args = self.request.META.get('QUERY_STRING', '') if args: guard_url = "%s?%s" % (guard_url, args) return redirect(guard_url)