feat: LDAP HA

This commit is contained in:
wangruidong
2024-09-04 15:49:59 +08:00
committed by Bryan
parent 512e727ac6
commit c2784c44ad
13 changed files with 572 additions and 260 deletions

View File

@@ -3,16 +3,17 @@
import json
import asyncio
from asgiref.sync import sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from django.core.cache import cache
from django.conf import settings
from django.utils.translation import gettext_lazy as _, activate
from django.utils import translation
from urllib.parse import parse_qs
from common.db.utils import close_old_connections
from common.utils import get_logger
from settings.serializers import (
LDAPHATestConfigSerializer,
LDAPTestConfigSerializer,
LDAPTestLoginSerializer
)
@@ -101,8 +102,12 @@ class ToolsWebsocket(AsyncJsonWebsocketConsumer):
class LdapWebsocket(AsyncJsonWebsocketConsumer):
category: str
async def connect(self):
user = self.scope["user"]
query = parse_qs(self.scope['query_string'].decode())
self.category = query.get('category', ['ldap'])[0]
if user.is_authenticated:
await self.accept()
else:
@@ -125,30 +130,21 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
await self.close()
close_old_connections()
@staticmethod
def get_ldap_config(serializer):
server_uri = serializer.validated_data["AUTH_LDAP_SERVER_URI"]
bind_dn = serializer.validated_data["AUTH_LDAP_BIND_DN"]
password = serializer.validated_data["AUTH_LDAP_BIND_PASSWORD"]
use_ssl = serializer.validated_data.get("AUTH_LDAP_START_TLS", False)
search_ou = serializer.validated_data["AUTH_LDAP_SEARCH_OU"]
search_filter = serializer.validated_data["AUTH_LDAP_SEARCH_FILTER"]
attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"]
auth_ldap = serializer.validated_data.get('AUTH_LDAP', False)
if not password:
password = settings.AUTH_LDAP_BIND_PASSWORD
def get_ldap_config(self, serializer):
prefix = 'AUTH_LDAP_' if self.category == 'ldap' else 'AUTH_LDAP_HA_'
config = {
'server_uri': server_uri,
'bind_dn': bind_dn,
'password': password,
'use_ssl': use_ssl,
'search_ou': search_ou,
'search_filter': search_filter,
'attr_map': attr_map,
'auth_ldap': auth_ldap
'server_uri': serializer.validated_data.get(f"{prefix}SERVER_URI"),
'bind_dn': serializer.validated_data.get(f"{prefix}BIND_DN"),
'password': (serializer.validated_data.get(f"{prefix}BIND_PASSWORD") or
getattr(settings, f"{prefix}BIND_PASSWORD")),
'use_ssl': serializer.validated_data.get(f"{prefix}START_TLS", False),
'search_ou': serializer.validated_data.get(f"{prefix}SEARCH_OU"),
'search_filter': serializer.validated_data.get(f"{prefix}SEARCH_FILTER"),
'attr_map': serializer.validated_data.get(f"{prefix}USER_ATTR_MAP"),
'auth_ldap': serializer.validated_data.get(f"{prefix.rstrip('_')}", False)
}
return config
@staticmethod
@@ -160,7 +156,10 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
cache.set(task_key, TASK_STATUS_IS_OVER, ttl)
def run_testing_config(self, data):
serializer = LDAPTestConfigSerializer(data=data)
if self.category == 'ldap':
serializer = LDAPTestConfigSerializer(data=data)
else:
serializer = LDAPHATestConfigSerializer(data=data)
if not serializer.is_valid():
self.send_msg(msg=f'error: {str(serializer.errors)}')
config = self.get_ldap_config(serializer)
@@ -175,14 +174,13 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
self.send_msg(msg=f'error: {str(serializer.errors)}')
username = serializer.validated_data['username']
password = serializer.validated_data['password']
ok, msg = LDAPTestUtil().test_login(username, password)
ok, msg = LDAPTestUtil(category=self.category).test_login(username, password)
return ok, msg
@staticmethod
def run_sync_user(data):
sync_util = LDAPSyncUtil()
def run_sync_user(self, data):
sync_util = LDAPSyncUtil(category=self.category)
sync_util.clear_cache()
sync_ldap_user()
sync_ldap_user(category=self.category)
msg = sync_util.get_task_error_msg()
ok = False if msg else True
return ok, msg
@@ -215,7 +213,7 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
return ok, msg
def set_users_status(self, import_users, errors):
util = LDAPCacheUtil()
util = LDAPCacheUtil(category=self.category)
all_users = util.get_users()
import_usernames = [u['username'] for u in import_users]
errors_mapper = {k: v for err in errors for k, v in err.items()}
@@ -225,7 +223,7 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
user['status'] = {'error': errors_mapper[username]}
elif username in import_usernames:
user['status'] = ImportStatus.ok
LDAPCacheUtil().set_users(all_users)
LDAPCacheUtil(category=self.category).set_users(all_users)
@staticmethod
def get_orgs(org_ids):
@@ -235,12 +233,11 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
orgs = [current_org]
return orgs
@staticmethod
def get_ldap_users(username_list, cache_police):
def get_ldap_users(self, username_list, cache_police):
if '*' in username_list:
users = LDAPServerUtil().search()
users = LDAPServerUtil(category=self.category).search()
elif cache_police in LDAP_USE_CACHE_FLAGS:
users = LDAPCacheUtil().search(search_users=username_list)
users = LDAPCacheUtil(category=self.category).search(search_users=username_list)
else:
users = LDAPServerUtil().search(search_users=username_list)
users = LDAPServerUtil(category=self.category).search(search_users=username_list)
return users