diff --git a/apps/authentication/signals_handlers.py b/apps/authentication/signals_handlers.py index c0b48c61d..b894e0651 100644 --- a/apps/authentication/signals_handlers.py +++ b/apps/authentication/signals_handlers.py @@ -47,7 +47,7 @@ def on_openid_login_success(sender, user=None, request=None, **kwargs): @receiver(populate_user) def on_ldap_create_user(sender, user, ldap_user, **kwargs): - if user and user.username != 'admin': + if user and user.username not in ['admin']: user.source = user.SOURCE_LDAP user.save() diff --git a/apps/settings/api.py b/apps/settings/api.py index 22a295b68..c226c9205 100644 --- a/apps/settings/api.py +++ b/apps/settings/api.py @@ -13,10 +13,16 @@ from django.core.mail import send_mail from django.utils.translation import ugettext_lazy as _ from .models import Setting -from .utils import LDAPUtil +from .utils import ( + LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil, + LDAP_USE_CACHE_FLAGS + +) +from .tasks import sync_ldap_user_task from common.permissions import IsOrgAdmin, IsSuperUser from common.utils import get_logger from .serializers import MailTestSerializer, LDAPTestSerializer, LDAPUserSerializer +from users.models import User logger = get_logger(__file__) @@ -67,65 +73,107 @@ class LDAPTestingAPI(APIView): success_message = _("Test ldap success") @staticmethod - def get_ldap_util(serializer): - host = serializer.validated_data["AUTH_LDAP_SERVER_URI"] + def get_ldap_config(serializer): + server_uri = serializer.validated_data["AUTH_LDAP_SERVER_URI"] bind_dn = serializer.validated_data["AUTH_LDAP_BIND_DN"] password = serializer.validated_data["AUTH_LDAP_BIND_PASSWORD"] use_ssl = serializer.validated_data.get("AUTH_LDAP_START_TLS", False) search_ougroup = serializer.validated_data["AUTH_LDAP_SEARCH_OU"] search_filter = serializer.validated_data["AUTH_LDAP_SEARCH_FILTER"] attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"] - try: - attr_map = json.loads(attr_map) - except json.JSONDecodeError: - return Response({"error": "AUTH_LDAP_USER_ATTR_MAP not valid"}, status=401) - - util = LDAPUtil( - use_settings_config=False, server_uri=host, bind_dn=bind_dn, - password=password, use_ssl=use_ssl, - search_ougroup=search_ougroup, search_filter=search_filter, - attr_map=attr_map - ) - return util + config = { + 'server_uri': server_uri, + 'bind_dn': bind_dn, + 'password': password, + 'use_ssl': use_ssl, + 'search_ougroup': search_ougroup, + 'search_filter': search_filter, + 'attr_map': json.loads(attr_map), + } + return config def post(self, request): serializer = self.serializer_class(data=request.data) if not serializer.is_valid(): return Response({"error": str(serializer.errors)}, status=401) - util = self.get_ldap_util(serializer) - + attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"] try: - users = util.search_user_items() + json.loads(attr_map) + except json.JSONDecodeError: + return Response({"error": "AUTH_LDAP_USER_ATTR_MAP not valid"}, status=401) + + config = self.get_ldap_config(serializer) + util = LDAPServerUtil(config=config) + try: + users = util.search() except Exception as e: return Response({"error": str(e)}, status=401) - if len(users) > 0: - return Response({"msg": _("Match {} s users").format(len(users))}) - else: - return Response({"error": "Have user but attr mapping error"}, status=401) + return Response({"msg": _("Match {} s users").format(len(users))}) class LDAPUserListApi(generics.ListAPIView): permission_classes = (IsOrgAdmin,) serializer_class = LDAPUserSerializer + def get_queryset_from_cache(self): + search_value = self.request.query_params.get('search') + users = LDAPCacheUtil().search(search_value=search_value) + return users + + def get_queryset_from_server(self): + search_value = self.request.query_params.get('search') + users = LDAPServerUtil().search(search_value=search_value) + return users + def get_queryset(self): if hasattr(self, 'swagger_fake_view'): return [] - q = self.request.query_params.get('search') - try: - util = LDAPUtil() - extra_filter = util.construct_extra_filter(util.SEARCH_FIELD_ALL, q) - users = util.search_user_items(extra_filter) - except Exception as e: - users = [] - logger.error(e) - # 前端data_table会根据row.id对table.selected值进行操作 - for user in users: - user['id'] = user['username'] + cache_police = self.request.query_params.get('cache_police', True) + if cache_police in LDAP_USE_CACHE_FLAGS: + users = self.get_queryset_from_cache() + else: + users = self.get_queryset_from_server() return users + def list(self, request, *args, **kwargs): + cache_police = self.request.query_params.get('cache_police', True) + # 不是用缓存 + if cache_police not in LDAP_USE_CACHE_FLAGS: + return super().list(request, *args, **kwargs) + + queryset = self.get_queryset() + # 缓存有数据 + if queryset is not None: + return super().list(request, *args, **kwargs) + + sync_util = LDAPSyncUtil() + # 还没有同步任务 + if sync_util.task_no_start: + task = sync_ldap_user_task.delay() + data = {'msg': 'Cache no data, sync task {} started.'.format(task.id)} + return Response(data=data, status=409) + # 同步任务正在执行 + if sync_util.task_is_running: + data = {'msg': 'synchronization is running.'} + return Response(data=data, status=409) + # 同步任务执行结束 + if sync_util.task_is_over: + msg = sync_util.get_task_error_msg() + data = {'msg': 'Synchronization task report error: {}'.format(msg)} + return Response(data=data, status=400) + + return super().list(request, *args, **kwargs) + + @staticmethod + def processing_queryset(queryset): + db_username_list = User.objects.all().values_list('username', flat=True) + for q in queryset: + q['id'] = q['username'] + q['existing'] = q['username'] in db_username_list + return queryset + def sort_queryset(self, queryset): order_by = self.request.query_params.get('order') if not order_by: @@ -138,32 +186,41 @@ class LDAPUserListApi(generics.ListAPIView): queryset = sorted(queryset, key=lambda x: x[order_by], reverse=reverse) return queryset - def list(self, request, *args, **kwargs): - queryset = self.get_queryset() + def filter_queryset(self, queryset): + queryset = self.processing_queryset(queryset) queryset = self.sort_queryset(queryset) - page = self.paginate_queryset(queryset) - if page is not None: - return self.get_paginated_response(page) - return Response(queryset) + return queryset -class LDAPUserSyncAPI(APIView): +class LDAPUserImportAPI(APIView): permission_classes = (IsOrgAdmin,) - def post(self, request): - username_list = request.data.get('username_list', []) - - util = LDAPUtil() - try: - result = util.sync_users(username_list) - except Exception as e: - logger.error(e, exc_info=True) - return Response({'error': str(e)}, status=401) + def get_ldap_users(self): + username_list = self.request.data.get('username_list', []) + cache_police = self.request.query_params.get('cache_police', True) + if cache_police in LDAP_USE_CACHE_FLAGS: + users = LDAPCacheUtil().search(search_users=username_list) else: - msg = _("succeed: {} failed: {} total: {}").format( - result['succeed'], result['failed'], result['total'] - ) - return Response({'msg': msg}) + users = LDAPServerUtil().search(search_users=username_list) + return users + + def post(self, request): + users = self.get_ldap_users() + errors = LDAPImportUtil().perform_import(users) + if errors: + return Response({'Error': errors}, status=401) + return Response({'msg': 'Imported {} users successfully'.format(len(users))}) + + +class LDAPCacheRefreshAPI(generics.RetrieveAPIView): + + def retrieve(self, request, *args, **kwargs): + try: + LDAPSyncUtil().clear_cache() + except Exception as e: + logger.error(str(e)) + return Response(data={'msg': str(e)}, status=400) + return Response(data={'msg': 'success'}) class ReplayStorageCreateAPI(APIView): diff --git a/apps/settings/serializers.py b/apps/settings/serializers.py index eb8a61679..0e2e48fa5 100644 --- a/apps/settings/serializers.py +++ b/apps/settings/serializers.py @@ -25,6 +25,7 @@ class LDAPTestSerializer(serializers.Serializer): class LDAPUserSerializer(serializers.Serializer): id = serializers.CharField() username = serializers.CharField() + name = serializers.CharField() email = serializers.CharField() existing = serializers.BooleanField(read_only=True) diff --git a/apps/settings/tasks/__init__.py b/apps/settings/tasks/__init__.py new file mode 100644 index 000000000..87bc6198f --- /dev/null +++ b/apps/settings/tasks/__init__.py @@ -0,0 +1,4 @@ +# coding: utf-8 +# + +from .ldap import * diff --git a/apps/settings/tasks/ldap.py b/apps/settings/tasks/ldap.py new file mode 100644 index 000000000..60058e03e --- /dev/null +++ b/apps/settings/tasks/ldap.py @@ -0,0 +1,17 @@ +# coding: utf-8 +# + +from celery import shared_task + +from common.utils import get_logger +from ..utils import LDAPSyncUtil + +__all__ = ['sync_ldap_user_task'] + + +logger = get_logger(__file__) + + +@shared_task +def sync_ldap_user_task(): + LDAPSyncUtil().perform_sync() diff --git a/apps/settings/templates/settings/_ldap_list_users_modal.html b/apps/settings/templates/settings/_ldap_list_users_modal.html index dd839eb5b..cc63f72c5 100644 --- a/apps/settings/templates/settings/_ldap_list_users_modal.html +++ b/apps/settings/templates/settings/_ldap_list_users_modal.html @@ -23,6 +23,7 @@