[Update] 重构 LDAP/AD 同步功能,添加缓存机制

This commit is contained in:
BaiJiangJie
2019-11-11 16:41:32 +08:00
parent cf65e9bfe3
commit 596e5a6dd1
13 changed files with 569 additions and 317 deletions

View File

@@ -13,10 +13,16 @@ from django.core.mail import send_mail
from django.utils.translation import ugettext_lazy as _
from .models import Setting
from .utils import LDAPUtil
from .utils import (
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
LDAP_USE_CACHE_FLAGS
)
from .tasks import sync_ldap_user_task
from common.permissions import IsOrgAdmin, IsSuperUser
from common.utils import get_logger
from .serializers import MailTestSerializer, LDAPTestSerializer, LDAPUserSerializer
from users.models import User
logger = get_logger(__file__)
@@ -67,65 +73,107 @@ class LDAPTestingAPI(APIView):
success_message = _("Test ldap success")
@staticmethod
def get_ldap_util(serializer):
host = serializer.validated_data["AUTH_LDAP_SERVER_URI"]
def get_ldap_config(serializer):
server_uri = serializer.validated_data["AUTH_LDAP_SERVER_URI"]
bind_dn = serializer.validated_data["AUTH_LDAP_BIND_DN"]
password = serializer.validated_data["AUTH_LDAP_BIND_PASSWORD"]
use_ssl = serializer.validated_data.get("AUTH_LDAP_START_TLS", False)
search_ougroup = serializer.validated_data["AUTH_LDAP_SEARCH_OU"]
search_filter = serializer.validated_data["AUTH_LDAP_SEARCH_FILTER"]
attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"]
try:
attr_map = json.loads(attr_map)
except json.JSONDecodeError:
return Response({"error": "AUTH_LDAP_USER_ATTR_MAP not valid"}, status=401)
util = LDAPUtil(
use_settings_config=False, server_uri=host, bind_dn=bind_dn,
password=password, use_ssl=use_ssl,
search_ougroup=search_ougroup, search_filter=search_filter,
attr_map=attr_map
)
return util
config = {
'server_uri': server_uri,
'bind_dn': bind_dn,
'password': password,
'use_ssl': use_ssl,
'search_ougroup': search_ougroup,
'search_filter': search_filter,
'attr_map': json.loads(attr_map),
}
return config
def post(self, request):
serializer = self.serializer_class(data=request.data)
if not serializer.is_valid():
return Response({"error": str(serializer.errors)}, status=401)
util = self.get_ldap_util(serializer)
attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"]
try:
users = util.search_user_items()
json.loads(attr_map)
except json.JSONDecodeError:
return Response({"error": "AUTH_LDAP_USER_ATTR_MAP not valid"}, status=401)
config = self.get_ldap_config(serializer)
util = LDAPServerUtil(config=config)
try:
users = util.search()
except Exception as e:
return Response({"error": str(e)}, status=401)
if len(users) > 0:
return Response({"msg": _("Match {} s users").format(len(users))})
else:
return Response({"error": "Have user but attr mapping error"}, status=401)
return Response({"msg": _("Match {} s users").format(len(users))})
class LDAPUserListApi(generics.ListAPIView):
permission_classes = (IsOrgAdmin,)
serializer_class = LDAPUserSerializer
def get_queryset_from_cache(self):
search_value = self.request.query_params.get('search')
users = LDAPCacheUtil().search(search_value=search_value)
return users
def get_queryset_from_server(self):
search_value = self.request.query_params.get('search')
users = LDAPServerUtil().search(search_value=search_value)
return users
def get_queryset(self):
if hasattr(self, 'swagger_fake_view'):
return []
q = self.request.query_params.get('search')
try:
util = LDAPUtil()
extra_filter = util.construct_extra_filter(util.SEARCH_FIELD_ALL, q)
users = util.search_user_items(extra_filter)
except Exception as e:
users = []
logger.error(e)
# 前端data_table会根据row.id对table.selected值进行操作
for user in users:
user['id'] = user['username']
cache_police = self.request.query_params.get('cache_police', True)
if cache_police in LDAP_USE_CACHE_FLAGS:
users = self.get_queryset_from_cache()
else:
users = self.get_queryset_from_server()
return users
def list(self, request, *args, **kwargs):
cache_police = self.request.query_params.get('cache_police', True)
# 不是用缓存
if cache_police not in LDAP_USE_CACHE_FLAGS:
return super().list(request, *args, **kwargs)
queryset = self.get_queryset()
# 缓存有数据
if queryset is not None:
return super().list(request, *args, **kwargs)
sync_util = LDAPSyncUtil()
# 还没有同步任务
if sync_util.task_no_start:
task = sync_ldap_user_task.delay()
data = {'msg': 'Cache no data, sync task {} started.'.format(task.id)}
return Response(data=data, status=409)
# 同步任务正在执行
if sync_util.task_is_running:
data = {'msg': 'synchronization is running.'}
return Response(data=data, status=409)
# 同步任务执行结束
if sync_util.task_is_over:
msg = sync_util.get_task_error_msg()
data = {'msg': 'Synchronization task report error: {}'.format(msg)}
return Response(data=data, status=400)
return super().list(request, *args, **kwargs)
@staticmethod
def processing_queryset(queryset):
db_username_list = User.objects.all().values_list('username', flat=True)
for q in queryset:
q['id'] = q['username']
q['existing'] = q['username'] in db_username_list
return queryset
def sort_queryset(self, queryset):
order_by = self.request.query_params.get('order')
if not order_by:
@@ -138,32 +186,41 @@ class LDAPUserListApi(generics.ListAPIView):
queryset = sorted(queryset, key=lambda x: x[order_by], reverse=reverse)
return queryset
def list(self, request, *args, **kwargs):
queryset = self.get_queryset()
def filter_queryset(self, queryset):
queryset = self.processing_queryset(queryset)
queryset = self.sort_queryset(queryset)
page = self.paginate_queryset(queryset)
if page is not None:
return self.get_paginated_response(page)
return Response(queryset)
return queryset
class LDAPUserSyncAPI(APIView):
class LDAPUserImportAPI(APIView):
permission_classes = (IsOrgAdmin,)
def post(self, request):
username_list = request.data.get('username_list', [])
util = LDAPUtil()
try:
result = util.sync_users(username_list)
except Exception as e:
logger.error(e, exc_info=True)
return Response({'error': str(e)}, status=401)
def get_ldap_users(self):
username_list = self.request.data.get('username_list', [])
cache_police = self.request.query_params.get('cache_police', True)
if cache_police in LDAP_USE_CACHE_FLAGS:
users = LDAPCacheUtil().search(search_users=username_list)
else:
msg = _("succeed: {} failed: {} total: {}").format(
result['succeed'], result['failed'], result['total']
)
return Response({'msg': msg})
users = LDAPServerUtil().search(search_users=username_list)
return users
def post(self, request):
users = self.get_ldap_users()
errors = LDAPImportUtil().perform_import(users)
if errors:
return Response({'Error': errors}, status=401)
return Response({'msg': 'Imported {} users successfully'.format(len(users))})
class LDAPCacheRefreshAPI(generics.RetrieveAPIView):
def retrieve(self, request, *args, **kwargs):
try:
LDAPSyncUtil().clear_cache()
except Exception as e:
logger.error(str(e))
return Response(data={'msg': str(e)}, status=400)
return Response(data={'msg': 'success'})
class ReplayStorageCreateAPI(APIView):