mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-09-05 01:11:57 +00:00
[Update] 重构 LDAP/AD 同步功能,添加缓存机制
This commit is contained in:
@@ -13,10 +13,16 @@ from django.core.mail import send_mail
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from .models import Setting
|
||||
from .utils import LDAPUtil
|
||||
from .utils import (
|
||||
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
|
||||
LDAP_USE_CACHE_FLAGS
|
||||
|
||||
)
|
||||
from .tasks import sync_ldap_user_task
|
||||
from common.permissions import IsOrgAdmin, IsSuperUser
|
||||
from common.utils import get_logger
|
||||
from .serializers import MailTestSerializer, LDAPTestSerializer, LDAPUserSerializer
|
||||
from users.models import User
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
@@ -67,65 +73,107 @@ class LDAPTestingAPI(APIView):
|
||||
success_message = _("Test ldap success")
|
||||
|
||||
@staticmethod
|
||||
def get_ldap_util(serializer):
|
||||
host = serializer.validated_data["AUTH_LDAP_SERVER_URI"]
|
||||
def get_ldap_config(serializer):
|
||||
server_uri = serializer.validated_data["AUTH_LDAP_SERVER_URI"]
|
||||
bind_dn = serializer.validated_data["AUTH_LDAP_BIND_DN"]
|
||||
password = serializer.validated_data["AUTH_LDAP_BIND_PASSWORD"]
|
||||
use_ssl = serializer.validated_data.get("AUTH_LDAP_START_TLS", False)
|
||||
search_ougroup = serializer.validated_data["AUTH_LDAP_SEARCH_OU"]
|
||||
search_filter = serializer.validated_data["AUTH_LDAP_SEARCH_FILTER"]
|
||||
attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"]
|
||||
try:
|
||||
attr_map = json.loads(attr_map)
|
||||
except json.JSONDecodeError:
|
||||
return Response({"error": "AUTH_LDAP_USER_ATTR_MAP not valid"}, status=401)
|
||||
|
||||
util = LDAPUtil(
|
||||
use_settings_config=False, server_uri=host, bind_dn=bind_dn,
|
||||
password=password, use_ssl=use_ssl,
|
||||
search_ougroup=search_ougroup, search_filter=search_filter,
|
||||
attr_map=attr_map
|
||||
)
|
||||
return util
|
||||
config = {
|
||||
'server_uri': server_uri,
|
||||
'bind_dn': bind_dn,
|
||||
'password': password,
|
||||
'use_ssl': use_ssl,
|
||||
'search_ougroup': search_ougroup,
|
||||
'search_filter': search_filter,
|
||||
'attr_map': json.loads(attr_map),
|
||||
}
|
||||
return config
|
||||
|
||||
def post(self, request):
|
||||
serializer = self.serializer_class(data=request.data)
|
||||
if not serializer.is_valid():
|
||||
return Response({"error": str(serializer.errors)}, status=401)
|
||||
|
||||
util = self.get_ldap_util(serializer)
|
||||
|
||||
attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"]
|
||||
try:
|
||||
users = util.search_user_items()
|
||||
json.loads(attr_map)
|
||||
except json.JSONDecodeError:
|
||||
return Response({"error": "AUTH_LDAP_USER_ATTR_MAP not valid"}, status=401)
|
||||
|
||||
config = self.get_ldap_config(serializer)
|
||||
util = LDAPServerUtil(config=config)
|
||||
try:
|
||||
users = util.search()
|
||||
except Exception as e:
|
||||
return Response({"error": str(e)}, status=401)
|
||||
|
||||
if len(users) > 0:
|
||||
return Response({"msg": _("Match {} s users").format(len(users))})
|
||||
else:
|
||||
return Response({"error": "Have user but attr mapping error"}, status=401)
|
||||
return Response({"msg": _("Match {} s users").format(len(users))})
|
||||
|
||||
|
||||
class LDAPUserListApi(generics.ListAPIView):
|
||||
permission_classes = (IsOrgAdmin,)
|
||||
serializer_class = LDAPUserSerializer
|
||||
|
||||
def get_queryset_from_cache(self):
|
||||
search_value = self.request.query_params.get('search')
|
||||
users = LDAPCacheUtil().search(search_value=search_value)
|
||||
return users
|
||||
|
||||
def get_queryset_from_server(self):
|
||||
search_value = self.request.query_params.get('search')
|
||||
users = LDAPServerUtil().search(search_value=search_value)
|
||||
return users
|
||||
|
||||
def get_queryset(self):
|
||||
if hasattr(self, 'swagger_fake_view'):
|
||||
return []
|
||||
q = self.request.query_params.get('search')
|
||||
try:
|
||||
util = LDAPUtil()
|
||||
extra_filter = util.construct_extra_filter(util.SEARCH_FIELD_ALL, q)
|
||||
users = util.search_user_items(extra_filter)
|
||||
except Exception as e:
|
||||
users = []
|
||||
logger.error(e)
|
||||
# 前端data_table会根据row.id对table.selected值进行操作
|
||||
for user in users:
|
||||
user['id'] = user['username']
|
||||
cache_police = self.request.query_params.get('cache_police', True)
|
||||
if cache_police in LDAP_USE_CACHE_FLAGS:
|
||||
users = self.get_queryset_from_cache()
|
||||
else:
|
||||
users = self.get_queryset_from_server()
|
||||
return users
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
cache_police = self.request.query_params.get('cache_police', True)
|
||||
# 不是用缓存
|
||||
if cache_police not in LDAP_USE_CACHE_FLAGS:
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
queryset = self.get_queryset()
|
||||
# 缓存有数据
|
||||
if queryset is not None:
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
sync_util = LDAPSyncUtil()
|
||||
# 还没有同步任务
|
||||
if sync_util.task_no_start:
|
||||
task = sync_ldap_user_task.delay()
|
||||
data = {'msg': 'Cache no data, sync task {} started.'.format(task.id)}
|
||||
return Response(data=data, status=409)
|
||||
# 同步任务正在执行
|
||||
if sync_util.task_is_running:
|
||||
data = {'msg': 'synchronization is running.'}
|
||||
return Response(data=data, status=409)
|
||||
# 同步任务执行结束
|
||||
if sync_util.task_is_over:
|
||||
msg = sync_util.get_task_error_msg()
|
||||
data = {'msg': 'Synchronization task report error: {}'.format(msg)}
|
||||
return Response(data=data, status=400)
|
||||
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def processing_queryset(queryset):
|
||||
db_username_list = User.objects.all().values_list('username', flat=True)
|
||||
for q in queryset:
|
||||
q['id'] = q['username']
|
||||
q['existing'] = q['username'] in db_username_list
|
||||
return queryset
|
||||
|
||||
def sort_queryset(self, queryset):
|
||||
order_by = self.request.query_params.get('order')
|
||||
if not order_by:
|
||||
@@ -138,32 +186,41 @@ class LDAPUserListApi(generics.ListAPIView):
|
||||
queryset = sorted(queryset, key=lambda x: x[order_by], reverse=reverse)
|
||||
return queryset
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = self.get_queryset()
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = self.processing_queryset(queryset)
|
||||
queryset = self.sort_queryset(queryset)
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
return self.get_paginated_response(page)
|
||||
return Response(queryset)
|
||||
return queryset
|
||||
|
||||
|
||||
class LDAPUserSyncAPI(APIView):
|
||||
class LDAPUserImportAPI(APIView):
|
||||
permission_classes = (IsOrgAdmin,)
|
||||
|
||||
def post(self, request):
|
||||
username_list = request.data.get('username_list', [])
|
||||
|
||||
util = LDAPUtil()
|
||||
try:
|
||||
result = util.sync_users(username_list)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
return Response({'error': str(e)}, status=401)
|
||||
def get_ldap_users(self):
|
||||
username_list = self.request.data.get('username_list', [])
|
||||
cache_police = self.request.query_params.get('cache_police', True)
|
||||
if cache_police in LDAP_USE_CACHE_FLAGS:
|
||||
users = LDAPCacheUtil().search(search_users=username_list)
|
||||
else:
|
||||
msg = _("succeed: {} failed: {} total: {}").format(
|
||||
result['succeed'], result['failed'], result['total']
|
||||
)
|
||||
return Response({'msg': msg})
|
||||
users = LDAPServerUtil().search(search_users=username_list)
|
||||
return users
|
||||
|
||||
def post(self, request):
|
||||
users = self.get_ldap_users()
|
||||
errors = LDAPImportUtil().perform_import(users)
|
||||
if errors:
|
||||
return Response({'Error': errors}, status=401)
|
||||
return Response({'msg': 'Imported {} users successfully'.format(len(users))})
|
||||
|
||||
|
||||
class LDAPCacheRefreshAPI(generics.RetrieveAPIView):
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
try:
|
||||
LDAPSyncUtil().clear_cache()
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
return Response(data={'msg': str(e)}, status=400)
|
||||
return Response(data={'msg': 'success'})
|
||||
|
||||
|
||||
class ReplayStorageCreateAPI(APIView):
|
||||
|
Reference in New Issue
Block a user