Files
jumpserver/apps/jumpserver/middleware.py
2026-03-29 21:25:26 +08:00

258 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ~*~ coding: utf-8 ~*~
import json
import os
import re
import time
from urllib.parse import urlparse, quote
import pytz
from django.conf import settings
from django.core.exceptions import MiddlewareNotUsed
from django.db.utils import OperationalError
from django.middleware.csrf import CsrfViewMiddleware
from django.http.response import HttpResponseForbidden, JsonResponse
from django.shortcuts import HttpResponse
from django.shortcuts import redirect
from django.urls import reverse
from django.utils import timezone
from rest_framework import status
from .utils import set_current_request
from common.utils.common import text_hmac_sha256
IGNORE_CSRF_CHECK = '*' in os.getenv("DOMAINS", "").split(',')
class TimezoneMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
tzname = request.META.get('HTTP_X_TZ')
if not tzname or tzname == 'undefined':
return self.get_response(request)
try:
tz = pytz.timezone(tzname)
timezone.activate(tz)
except pytz.UnknownTimeZoneError:
pass
response = self.get_response(request)
return response
class DemoMiddleware:
DEMO_MODE_ENABLED = os.environ.get("DEMO_MODE", "") in ("1", "ok", "True")
SAFE_URL_PATTERN = re.compile(
r'^/users/login|'
r'^/api/terminal/v1/.*|'
r'^/api/terminal/.*|'
r'^/api/users/v1/auth/|'
r'^/api/users/v1/profile/'
)
SAFE_METHOD = ("GET", "HEAD")
def __init__(self, get_response):
self.get_response = get_response
if self.DEMO_MODE_ENABLED:
print("Demo mode enabled, reject unsafe method and url")
raise MiddlewareNotUsed
def __call__(self, request):
if self.DEMO_MODE_ENABLED and request.method not in self.SAFE_METHOD \
and not self.SAFE_URL_PATTERN.match(request.path):
return HttpResponse("Demo mode, only safe request accepted", status=403)
else:
response = self.get_response(request)
return response
class RequestMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
set_current_request(request)
response = self.get_response(request)
return response
class RefererCheckMiddleware:
def __init__(self, get_response):
if not settings.REFERER_CHECK_ENABLED:
raise MiddlewareNotUsed
self.get_response = get_response
self.http_pattern = re.compile('https?://')
def check_referer(self, request):
referer = request.META.get('HTTP_REFERER', '')
referer = self.http_pattern.sub('', referer)
if not referer:
return True
remote_host = request.get_host()
return referer.startswith(remote_host)
def __call__(self, request):
match = self.check_referer(request)
if not match:
return HttpResponseForbidden('CSRF CHECK ERROR')
response = self.get_response(request)
return response
class SQLCountMiddleware:
def __init__(self, get_response):
self.get_response = get_response
if not settings.DEBUG_DEV:
raise MiddlewareNotUsed
def __call__(self, request):
from django.db import connection
response = self.get_response(request)
response['X-JMS-SQL-COUNT'] = len(connection.queries) - 2
return response
class StartMiddleware:
def __init__(self, get_response):
self.get_response = get_response
if not settings.DEBUG_DEV:
raise MiddlewareNotUsed
def __call__(self, request):
request._s_time_start = time.time()
response = self.get_response(request)
request._s_time_end = time.time()
if request.path == '/api/health/':
data = response.data
data['pre_middleware_time'] = request._e_time_start - request._s_time_start
data['api_time'] = request._e_time_end - request._e_time_start
data['post_middleware_time'] = request._s_time_end - request._e_time_end
response.content = json.dumps(data)
response.headers['Content-Length'] = str(len(response.content))
return response
return response
class EndMiddleware:
def __init__(self, get_response):
self.get_response = get_response
if not settings.DEBUG_DEV:
raise MiddlewareNotUsed
def __call__(self, request):
request._e_time_start = time.time()
response = self.get_response(request)
request._e_time_end = time.time()
return response
def process_exception(self, request, exception):
if isinstance(exception, OperationalError):
return JsonResponse({
'error': 'Database OperationalError: ' + str(exception),
'message': 'Database operation failed, please try again later.',
'code': 'DB_ERROR'
}, status=status.HTTP_503_SERVICE_UNAVAILABLE)
return None
class SafeRedirectMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
if not (300 <= response.status_code < 400):
return response
if (
request.resolver_match and
request.resolver_match.namespace.startswith('authentication') and
not request.resolver_match.namespace.startswith('authentication:oauth2-provider')
):
# 认证相关的路由跳过验证 /core/auth/...,
# 但 oauth2-provider 除外, 因为它会重定向到第三方客户端, 希望给出更友好的提示
return response
location = response.get('Location')
if not location:
return response
parsed = urlparse(location)
if parsed.scheme and parsed.netloc:
target_host = parsed.netloc
if target_host in [*settings.ALLOWED_HOSTS]:
return response
target_host, target_port = self._split_host_port(parsed.netloc)
origin_host, origin_port = self._split_host_port(request.get_host())
if target_host != origin_host:
safe_redirect_url = '%s?%s' % (reverse('redirect-confirm'), f'next={quote(location)}')
return redirect(safe_redirect_url)
return response
@staticmethod
def _split_host_port(netloc):
if ':' in netloc:
host, port = netloc.split(':', 1)
return host, port
return netloc, '80'
class CsrfCheckMiddleware(CsrfViewMiddleware):
def _origin_verified(self, request):
if IGNORE_CSRF_CHECK:
request._dont_enforce_csrf_checks = True
return True
return super()._origin_verified(request)
class HmacSignAuthMiddleware:
"""
在响应中写入客户端可读会话状态 Cookiejms_session_sign
供边缘代理、网关或安全设备(含 WAF基于 Cookie 做访问策略,不特指某一种产品。
取值约定(均为非空,便于写规则):
- 已登录:<hex_hmac>:<username>|<session_id>HMAC 与 text_hmac_sha256 一致(消息会先 strip/lower
- 有会话 Cookie 但未认证expired含会话过期、登出后会话仍存在、或仅匿名会话等
- 请求未带会话 Cookieunauth首次访问等
"""
SIGN_COOKIE_NAME = 'jms_session_sign'
MARKER_UNAUTH = 'unauth'
MARKER_EXPIRED = 'expired'
def __init__(self, get_response):
self.get_response = get_response
enabled = os.getenv("HMAC_SIGN_AUTH_ENABLED", "").lower() in ("1", "true", "yes")
hmac_sign_key = os.getenv("HMAC_SIGN_KEY", "")
if not enabled or not hmac_sign_key:
raise MiddlewareNotUsed
self.hmac_sign_key = hmac_sign_key
def __call__(self, request):
response = self.get_response(request)
return self._set_session_sign_cookie(request, response)
def _set_session_sign_cookie(self, request, response):
session_cookie_name = settings.SESSION_COOKIE_NAME
has_session_cookie = bool(request.COOKIES.get(session_cookie_name))
if request.user.is_authenticated:
session_id = request.session.session_key
if not session_id:
value = self.MARKER_EXPIRED
else:
username = request.user.username
sign_data = f'{username}|{session_id}'
signature = text_hmac_sha256(sign_data, self.hmac_sign_key)
value = f'{signature}:{sign_data}'
elif has_session_cookie:
value = self.MARKER_EXPIRED
else:
value = self.MARKER_UNAUTH
response.set_cookie(
self.SIGN_COOKIE_NAME,
value,
)
return response