Compare commits

..

5 Commits

Author SHA1 Message Date
ibuler
4def7bc5ec fix: 修复 close connection 的问题 2021-11-25 18:29:39 +08:00
ibuler
d401a44317 fix: 修复重置 mfa 的提示问题 2021-11-25 17:50:17 +08:00
xinwen
7e793a6e0a fix: 按资产ip搜索数据不全 2021-11-25 17:29:28 +08:00
ibuler
b61559d078 perf: 去掉登录页面更好 2021-11-25 16:04:25 +08:00
ibuler
0e8260a37c fix: 修复 oidc cas 登录时跳转问题
perf: 优化一波,容易debug

perf: 还原回来的世界
2021-11-25 15:00:38 +08:00
13 changed files with 77 additions and 65 deletions

View File

@@ -84,10 +84,10 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe() subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
msgs = subscribe.listen() msgs = subscribe.listen()
# 开始之前关闭连接因为server端可能关闭了连接而 client 还在 CONN_MAX_AGE 中 # 开始之前关闭连接因为server端可能关闭了连接而 client 还在 CONN_MAX_AGE 中
close_old_connections()
for message in msgs: for message in msgs:
if message["type"] != "message": if message["type"] != "message":
continue continue
close_old_connections()
org_id = message['data'].decode() org_id = message['data'].decode()
root_org_id = Organization.ROOT_ID root_org_id = Organization.ROOT_ID
Node.expire_node_all_asset_ids_mapping_from_memory(org_id) Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
@@ -96,6 +96,7 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
"Expire node assets id mapping from memory of org={}, pid={}" "Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid()) "".format(str(org_id), os.getpid())
) )
close_old_connections()
except Exception as e: except Exception as e:
logger.exception(f'subscribe_node_assets_mapping_expire: {e}') logger.exception(f'subscribe_node_assets_mapping_expire: {e}')
Node.expire_all_orgs_node_all_asset_ids_mapping_from_memory() Node.expire_all_orgs_node_all_asset_ids_mapping_from_memory()

View File

@@ -1,17 +1,36 @@
from django.shortcuts import redirect from django.shortcuts import redirect, reverse
from django.http import HttpResponse
class MFAMiddleware: class MFAMiddleware:
"""
这个 中间件 是用来全局拦截开启了 MFA 却没有认证的,如 OIDC, CAS使用第三方库做的登录直接 login 了,
所以只能在 Middleware 中控制
"""
def __init__(self, get_response): def __init__(self, get_response):
self.get_response = get_response self.get_response = get_response
def __call__(self, request): def __call__(self, request):
response = self.get_response(request) response = self.get_response(request)
# 没有校验
if not request.session.get('auth_mfa_required'):
return response
# 没有认证过,证明不是从 第三方 来的
if request.user.is_anonymous:
return response
white_urls = ['login/mfa', 'mfa/select', 'jsi18n/', '/static/'] # 这个是 mfa 登录页需要的请求, 也得放出来, 用户其实已经在 CAS/OIDC 中完成登录了
white_urls = [
'login/mfa', 'mfa/select', 'jsi18n/', '/static/',
'/profile/otp', '/logout/',
]
for url in white_urls: for url in white_urls:
if request.path.find(url) > -1: if request.path.find(url) > -1:
return response return response
if request.session.get('auth_mfa_required'):
return redirect('authentication:login-mfa') # 因为使用 CAS/OIDC 登录的,不小心去了别的页面就回不来了
return response if request.path.find('users/profile') > -1:
return HttpResponse('', status=401)
url = reverse('authentication:login-mfa') + '?_=middleware'
return redirect(url)

View File

@@ -257,7 +257,8 @@ class MFAMixin:
def _check_login_page_mfa_if_need(self, user): def _check_login_page_mfa_if_need(self, user):
if not settings.SECURITY_MFA_IN_LOGIN_PAGE: if not settings.SECURITY_MFA_IN_LOGIN_PAGE:
return return
self._check_if_no_active_mfa(user) if not user.active_mfa_backends:
return
request = self.request request = self.request
data = request.data if hasattr(request, 'data') else request.POST data = request.data if hasattr(request, 'data') else request.POST
@@ -274,10 +275,8 @@ class MFAMixin:
if not user.mfa_enabled: if not user.mfa_enabled:
return return
self._check_if_no_active_mfa(user) active_mfa_names = user.active_mfa_backends_mapper.keys()
raise errors.MFARequiredError(mfa_types=tuple(active_mfa_names))
active_mfa_mapper = user.active_mfa_backends_mapper
raise errors.MFARequiredError(mfa_types=tuple(active_mfa_mapper.keys()))
def mark_mfa_ok(self, mfa_type): def mark_mfa_ok(self, mfa_type):
self.request.session['auth_mfa'] = 1 self.request.session['auth_mfa'] = 1

View File

@@ -3,6 +3,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.views.generic.edit import FormView from django.views.generic.edit import FormView
from django.shortcuts import redirect
from common.utils import get_logger from common.utils import get_logger
from .. import forms, errors, mixins from .. import forms, errors, mixins
@@ -19,9 +20,15 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
try: try:
self.get_user_from_session() user = self.get_user_from_session()
except errors.SessionEmptyError: except errors.SessionEmptyError:
return redirect_to_guard_view() return redirect_to_guard_view('session_empty')
try:
self._check_if_no_active_mfa(user)
except errors.MFAUnsetError as e:
return redirect(e.url + '?_=login_mfa')
return super().get(*args, **kwargs) return super().get(*args, **kwargs)
def form_valid(self, form): def form_valid(self, form):
@@ -30,17 +37,17 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
try: try:
self._do_check_user_mfa(code, mfa_type) self._do_check_user_mfa(code, mfa_type)
return redirect_to_guard_view() return redirect_to_guard_view('mfa_ok')
except (errors.MFAFailedError, errors.BlockMFAError) as e: except (errors.MFAFailedError, errors.BlockMFAError) as e:
form.add_error('code', e.msg) form.add_error('code', e.msg)
return super().form_invalid(form) return super().form_invalid(form)
except errors.SessionEmptyError: except errors.SessionEmptyError:
return redirect_to_guard_view() return redirect_to_guard_view('session_empty')
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return redirect_to_guard_view() return redirect_to_guard_view('unexpect')
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
user = self.get_user_from_session() user = self.get_user_from_session()

View File

@@ -3,6 +3,6 @@
from django.shortcuts import reverse, redirect from django.shortcuts import reverse, redirect
def redirect_to_guard_view(): def redirect_to_guard_view(comment=''):
continue_url = reverse('authentication:login-guard') continue_url = reverse('authentication:login-guard') + '?_=' + comment
return redirect(continue_url) return redirect(continue_url)

View File

@@ -52,11 +52,10 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
try: try:
msgs = self.chan.listen() msgs = self.chan.listen()
# 开始之前关闭连接因为server端可能关闭了连接而 client 还在 CONN_MAX_AGE 中 # 开始之前关闭连接因为server端可能关闭了连接而 client 还在 CONN_MAX_AGE 中
close_old_connections()
for message in msgs: for message in msgs:
if message['type'] != 'message': if message['type'] != 'message':
continue continue
close_old_connections()
try: try:
msg = json.loads(message['data'].decode()) msg = json.loads(message['data'].decode())
except json.JSONDecoder as e: except json.JSONDecoder as e:
@@ -70,6 +69,7 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
logger.debug('Message users: {}'.format(users)) logger.debug('Message users: {}'.format(users))
if user_id in users: if user_id in users:
self.send_unread_msg_count() self.send_unread_msg_count()
close_old_connections()
except ConnectionError: except ConnectionError:
logger.error('Redis chan closed') logger.error('Redis chan closed')
finally: finally:

View File

@@ -170,13 +170,13 @@ class AssetPermissionFilter(PermissionBaseFilter):
return queryset return queryset
if not assets: if not assets:
return queryset.none() return queryset.none()
asset = assets.first() assetids = list(assets.values_list('id', flat=True))
if not is_query_all: if not is_query_all:
queryset = queryset.filter(assets=asset) queryset = queryset.filter(assets__in=assetids)
return queryset return queryset
inherit_all_nodekeys = set() inherit_all_nodekeys = set()
inherit_nodekeys = asset.nodes.values_list('key', flat=True) inherit_nodekeys = set(assets.values_list('nodes__key', flat=True))
for key in inherit_nodekeys: for key in inherit_nodekeys:
ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True) ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True)
@@ -185,8 +185,8 @@ class AssetPermissionFilter(PermissionBaseFilter):
inherit_all_nodeids = Node.objects.filter(key__in=inherit_all_nodekeys).values_list('id', flat=True) inherit_all_nodeids = Node.objects.filter(key__in=inherit_all_nodekeys).values_list('id', flat=True)
inherit_all_nodeids = list(inherit_all_nodeids) inherit_all_nodeids = list(inherit_all_nodeids)
qs1 = queryset.filter(assets=asset).distinct() qs1 = queryset.filter(assets__in=assetids).distinct()
qs2 = queryset.filter(nodes__id__in=inherit_all_nodeids).distinct() qs2 = queryset.filter(nodes__in=inherit_all_nodeids).distinct()
qs = UnionQuerySet(qs1, qs2) qs = UnionQuerySet(qs1, qs2)
return qs return qs

View File

@@ -86,17 +86,17 @@ def subscribe_settings_change(sender, **kwargs):
sub = setting_pub_sub.subscribe() sub = setting_pub_sub.subscribe()
msgs = sub.listen() msgs = sub.listen()
# 开始之前关闭连接因为server端可能关闭了连接而 client 还在 CONN_MAX_AGE 中 # 开始之前关闭连接因为server端可能关闭了连接而 client 还在 CONN_MAX_AGE 中
close_old_connections()
for msg in msgs: for msg in msgs:
if msg["type"] != "message": if msg["type"] != "message":
continue continue
close_old_connections()
item = msg['data'].decode() item = msg['data'].decode()
logger.debug("Found setting change: {}".format(str(item))) logger.debug("Found setting change: {}".format(str(item)))
Setting.refresh_item(item) Setting.refresh_item(item)
close_old_connections()
except Exception as e: except Exception as e:
logger.exception(f'subscribe_settings_change: {e}') logger.exception(f'subscribe_settings_change: {e}')
Setting.refresh_all_settings() Setting.refresh_all_settings()
finally:
close_old_connections() close_old_connections()
t = threading.Thread(target=keep_subscribe_settings_change) t = threading.Thread(target=keep_subscribe_settings_change)

View File

@@ -136,6 +136,10 @@ article ul li:last-child{
border-radius: 6px; border-radius: 6px;
color: white; color: white;
} }
.next:hover {
color: white;
}
/*绑定TOTP*/ /*绑定TOTP*/
/*版权信息*/ /*版权信息*/

View File

@@ -207,7 +207,7 @@ class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView):
user = self.get_object() if kwargs.get('pk') else request.user user = self.get_object() if kwargs.get('pk') else request.user
if user == request.user: if user == request.user:
msg = _("Could not reset self otp, use profile reset instead") msg = _("Could not reset self otp, use profile reset instead")
return Response({"error": msg}, status=401) return Response({"error": msg}, status=400)
backends = user.active_mfa_backends_mapper backends = user.active_mfa_backends_mapper
for backend in backends.values(): for backend in backends.values():

View File

@@ -7,36 +7,8 @@
{% endblock %} {% endblock %}
{% block content %} {% block content %}
<div class="verify">{% trans 'Please enter the password of' %}&nbsp;{% trans 'account' %}&nbsp;<span>{{ user.username }}</span>&nbsp;{% trans 'to complete the binding operation' %}</div> <hr style="width: 500px; margin: 10px auto auto;">
<hr style="width: 500px; margin: auto; margin-top: 10px;"> <a type="submit" class="next" href="{% url 'authentication:user-otp-enable-install-app' %}" >
<form id="verify-form" class="" role="form" method="post" action=""> {% trans 'Next' %}
{% csrf_token %} </a>
<div class="form-input">
<input id="password" type="password" class="" placeholder="{% trans 'Password' %}" required="" autofocus="autofocus">
<input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}">
</div>
<button type="submit" class="next" onclick="doVerify();return false;">{% trans 'Next' %}</button>
{% if 'password' in form.errors %}
<p class="red-fonts">{{ form.password.errors.as_text }}</p>
{% endif %}
</form>
<script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script>
<script>
function encryptLoginPassword(password, rsaPublicKey) {
var jsencrypt = new JSEncrypt(); //加密对象
jsencrypt.setPublicKey(rsaPublicKey); // 设置密钥
return jsencrypt.encrypt(password); //加密
}
function doVerify() {
//公钥加密
var rsaPublicKey = "{{ rsa_public_key }}"
var password = $('#password').val(); //明文密码
var passwordEncrypted = encryptLoginPassword(password, rsaPublicKey)
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#verify-form').submit();//post提交
}
$(document).ready(function () {
})
</script>
{% endblock %} {% endblock %}

View File

@@ -16,12 +16,12 @@
<div id="qr_code"></div> <div id="qr_code"></div>
<div style="display: block; margin: 0">Secret: {{ otp_secret_key }}</div> <div style="display: block; margin: 0">Secret: {{ otp_secret_key }}</div>
<form class="" role="form" method="post" action=""> <form id="bind-form" class="" role="form" method="post" action="">
{% csrf_token %} {% csrf_token %}
<div class="form-input"> <div class="form-input">
<input type="text" class="" name="otp_code" placeholder="{% trans 'Six figures' %}" required=""> <input type="text" class="" name="otp_code" placeholder="{% trans 'Six figures' %}" required="">
</div> </div>
<button type="submit" class="next">{% trans 'Next' %}</button> <a type="submit" class="next button" onclick="submitForm()">{% trans 'Next' %}</a>
{% if 'otp_code' in form.errors %} {% if 'otp_code' in form.errors %}
<p style="color: #ed5565">{{ form.otp_code.errors.as_text }}</p> <p style="color: #ed5565">{{ form.otp_code.errors.as_text }}</p>
{% endif %} {% endif %}
@@ -33,6 +33,10 @@
$('.change-color li:eq(1) i').css('color', '#1ab394'); $('.change-color li:eq(1) i').css('color', '#1ab394');
$('.change-color li:eq(2) i').css('color', '#1ab394'); $('.change-color li:eq(2) i').css('color', '#1ab394');
function submitForm() {
$('#bind-form').submit()
}
$(document).ready(function() { $(document).ready(function() {
// 生成用户绑定otp的二维码 // 生成用户绑定otp的二维码
var qrcode = new QRCode(document.getElementById('qr_code'), { var qrcode = new QRCode(document.getElementById('qr_code'), {

View File

@@ -6,10 +6,12 @@ from django.utils.translation import ugettext as _
from django.views.generic.base import TemplateView from django.views.generic.base import TemplateView
from django.views.generic.edit import FormView from django.views.generic.edit import FormView
from django.contrib.auth import logout as auth_logout from django.contrib.auth import logout as auth_logout
from django.shortcuts import redirect
from django.http.response import HttpResponseRedirect from django.http.response import HttpResponseRedirect
from authentication.mixins import AuthMixin from authentication.mixins import AuthMixin
from authentication.mfa import MFAOtp, otp_failed_msg from authentication.mfa import MFAOtp, otp_failed_msg
from authentication.errors import SessionEmptyError
from common.utils import get_logger, FlashMessageUtil from common.utils import get_logger, FlashMessageUtil
from common.mixins.views import PermissionsMixin from common.mixins.views import PermissionsMixin
from common.permissions import IsValidUser from common.permissions import IsValidUser
@@ -30,11 +32,15 @@ __all__ = [
logger = get_logger(__name__) logger = get_logger(__name__)
class UserOtpEnableStartView(UserVerifyPasswordView): class UserOtpEnableStartView(AuthMixin, TemplateView):
template_name = 'users/user_otp_check_password.html' template_name = 'users/user_otp_check_password.html'
def get_success_url(self): def get(self, request, *args, **kwargs):
return reverse('authentication:user-otp-enable-install-app') try:
self.get_user_from_session()
except SessionEmptyError:
return redirect('authentication:login') + '?_=otp_enable_start'
return super().get(request, *args, **kwargs)
class UserOtpEnableInstallAppView(TemplateView): class UserOtpEnableInstallAppView(TemplateView):