diff --git a/apps/authentication/backends/oidc/backends.py b/apps/authentication/backends/oidc/backends.py index f29bf95e5..7586eb479 100644 --- a/apps/authentication/backends/oidc/backends.py +++ b/apps/authentication/backends/oidc/backends.py @@ -8,27 +8,26 @@ """ import base64 -import requests -from rest_framework.exceptions import ParseError +import requests +from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.core.exceptions import SuspiciousOperation from django.db import transaction from django.urls import reverse -from django.conf import settings +from rest_framework.exceptions import ParseError -from common.utils import get_logger +from authentication.signals import user_auth_success, user_auth_failed from authentication.utils import build_absolute_uri_for_oidc +from common.utils import get_logger from users.utils import construct_user_email - -from ..base import JMSBaseAuthBackend -from .utils import validate_and_return_id_token from .decorator import ssl_verification from .signals import ( openid_create_or_update_user ) -from authentication.signals import user_auth_success, user_auth_failed +from .utils import validate_and_return_id_token +from ..base import JMSBaseAuthBackend logger = get_logger(__file__) @@ -55,16 +54,17 @@ class UserMixin: logger.debug(log_prompt.format(user_attrs)) username = user_attrs.get('username') - name = user_attrs.get('name') + groups = user_attrs.pop('groups', None) user, created = get_user_model().objects.get_or_create( username=username, defaults=user_attrs ) + user_attrs['groups'] = groups logger.debug(log_prompt.format("user: {}|created: {}".format(user, created))) logger.debug(log_prompt.format("Send signal => openid create or update user")) openid_create_or_update_user.send( - sender=self.__class__, request=request, user=user, created=created, - name=name, username=username, email=email + sender=self.__class__, request=request, user=user, + created=created, attrs=user_attrs, ) return user, created @@ -269,7 +269,8 @@ class OIDCAuthPasswordBackend(OIDCBaseBackend): # Calls the token endpoint. logger.debug(log_prompt.format('Call the token endpoint')) - token_response = requests.post(settings.AUTH_OPENID_PROVIDER_TOKEN_ENDPOINT, data=token_payload, timeout=request_timeout) + token_response = requests.post(settings.AUTH_OPENID_PROVIDER_TOKEN_ENDPOINT, data=token_payload, + timeout=request_timeout) try: token_response.raise_for_status() token_response_data = token_response.json() diff --git a/apps/authentication/backends/saml2/backends.py b/apps/authentication/backends/saml2/backends.py index 570557c1d..ac2aa7bb7 100644 --- a/apps/authentication/backends/saml2/backends.py +++ b/apps/authentication/backends/saml2/backends.py @@ -27,9 +27,13 @@ class SAML2Backend(JMSModelBackend): log_prompt = "Get or Create user [SAML2Backend]: {}" logger.debug(log_prompt.format('start')) + groups = saml_user_data.pop('groups', None) + user, created = get_user_model().objects.get_or_create( username=saml_user_data['username'], defaults=saml_user_data ) + + saml_user_data['groups'] = groups logger.debug(log_prompt.format("user: {}|created: {}".format(user, created))) logger.debug(log_prompt.format("Send signal => saml2 create or update user")) diff --git a/apps/authentication/backends/saml2/views.py b/apps/authentication/backends/saml2/views.py index 5a866cc47..22a5a9a68 100644 --- a/apps/authentication/backends/saml2/views.py +++ b/apps/authentication/backends/saml2/views.py @@ -87,6 +87,7 @@ class PrepareRequestMixin: ('name', 'name', False), ('phone', 'phone', False), ('comment', 'comment', False), + ('groups', 'groups', False), ) attr_list = [] for name, friend_name, is_required in need_attrs: @@ -185,7 +186,7 @@ class PrepareRequestMixin: user_attrs = {} attr_mapping = settings.SAML2_RENAME_ATTRIBUTES attrs = saml_instance.get_attributes() - valid_attrs = ['username', 'name', 'email', 'comment', 'phone'] + valid_attrs = ['username', 'name', 'email', 'comment', 'phone', 'groups'] for attr, value in attrs.items(): attr = attr.rsplit('/', 1)[-1] diff --git a/apps/users/signal_handlers.py b/apps/users/signal_handlers.py index 5bc116adf..c7be331aa 100644 --- a/apps/users/signal_handlers.py +++ b/apps/users/signal_handlers.py @@ -21,11 +21,13 @@ from common.signals import django_ready from common.utils import get_logger from jumpserver.utils import get_current_request from ops.celery.decorator import register_as_period_task +from orgs.models import Organization +from orgs.utils import tmp_to_root_org from rbac.builtin import BuiltinRole from rbac.const import Scope from rbac.models import RoleBinding from settings.signals import setting_changed -from .models import User, UserPasswordHistory +from .models import User, UserPasswordHistory, UserGroup from .signals import post_user_create logger = get_logger(__file__) @@ -50,7 +52,9 @@ def user_authenticated_handle(user, created, source, attrs=None, **kwargs): if created: user.source = source user.save() - bind_user_to_org_role(user) + org_ids = bind_user_to_org_role(user) + group_names = attrs.get('groups') + bind_user_to_group(org_ids, group_names, user) if not attrs: return @@ -146,7 +150,7 @@ def radius_create_user(sender, user, **kwargs): @receiver(openid_create_or_update_user) -def on_openid_create_or_update_user(sender, request, user, created, name, username, email, **kwargs): +def on_openid_create_or_update_user(sender, request, user, created, attrs, **kwargs): if not check_only_allow_exist_user_auth(created): return @@ -157,7 +161,13 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna ) user.source = User.Source.openid.value user.save() - bind_user_to_org_role(user) + org_ids = bind_user_to_org_role(user) + group_names = attrs.get('groups') + bind_user_to_group(org_ids, group_names, user) + + name = attrs.get('name') + username = attrs.get('username') + email = attrs.get('email') if not created and settings.AUTH_OPENID_ALWAYS_UPDATE_USER: logger.debug( @@ -225,3 +235,37 @@ def bind_user_to_org_role(user): ] RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True) + return org_ids + + +def bind_user_to_group(org_ids, group_names, user): + if not isinstance(group_names, list): + return + + org_ids = org_ids or [Organization.DEFAULT_ID] + + with tmp_to_root_org(): + existing_groups = UserGroup.objects.filter(org_id__in=org_ids).values_list('org_id', 'name') + + org_groups_map = {} + for org_id, group_name in existing_groups: + org_groups_map.setdefault(org_id, []).append(group_name) + + groups_to_create = [] + for org_id in org_ids: + existing_group_names = set(org_groups_map.get(org_id, [])) + new_group_names = set(group_names) - existing_group_names + groups_to_create.extend( + UserGroup(org_id=org_id, name=name) for name in new_group_names + ) + + UserGroup.objects.bulk_create(groups_to_create) + + user_groups = UserGroup.objects.filter(org_id__in=org_ids, name__in=group_names) + + user_group_links = [ + User.groups.through(user_id=user.id, usergroup_id=group.id) + for group in user_groups + ] + if user_group_links: + User.groups.through.objects.bulk_create(user_group_links)