perf: Support SAML2, OIDC user authentication services, mapping user group field information

This commit is contained in:
feng 2024-08-29 19:23:04 +08:00 committed by feng626
parent 1068662ab1
commit c545e2a3aa
4 changed files with 67 additions and 17 deletions

View File

@ -8,27 +8,26 @@
""" """
import base64 import base64
import requests
from rest_framework.exceptions import ParseError import requests
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend from django.contrib.auth.backends import ModelBackend
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
from django.db import transaction from django.db import transaction
from django.urls import reverse from django.urls import reverse
from django.conf import settings from rest_framework.exceptions import ParseError
from common.utils import get_logger from authentication.signals import user_auth_success, user_auth_failed
from authentication.utils import build_absolute_uri_for_oidc from authentication.utils import build_absolute_uri_for_oidc
from common.utils import get_logger
from users.utils import construct_user_email from users.utils import construct_user_email
from ..base import JMSBaseAuthBackend
from .utils import validate_and_return_id_token
from .decorator import ssl_verification from .decorator import ssl_verification
from .signals import ( from .signals import (
openid_create_or_update_user openid_create_or_update_user
) )
from authentication.signals import user_auth_success, user_auth_failed from .utils import validate_and_return_id_token
from ..base import JMSBaseAuthBackend
logger = get_logger(__file__) logger = get_logger(__file__)
@ -55,16 +54,17 @@ class UserMixin:
logger.debug(log_prompt.format(user_attrs)) logger.debug(log_prompt.format(user_attrs))
username = user_attrs.get('username') username = user_attrs.get('username')
name = user_attrs.get('name') groups = user_attrs.pop('groups', None)
user, created = get_user_model().objects.get_or_create( user, created = get_user_model().objects.get_or_create(
username=username, defaults=user_attrs username=username, defaults=user_attrs
) )
user_attrs['groups'] = groups
logger.debug(log_prompt.format("user: {}|created: {}".format(user, created))) logger.debug(log_prompt.format("user: {}|created: {}".format(user, created)))
logger.debug(log_prompt.format("Send signal => openid create or update user")) logger.debug(log_prompt.format("Send signal => openid create or update user"))
openid_create_or_update_user.send( openid_create_or_update_user.send(
sender=self.__class__, request=request, user=user, created=created, sender=self.__class__, request=request, user=user,
name=name, username=username, email=email created=created, attrs=user_attrs,
) )
return user, created return user, created
@ -269,7 +269,8 @@ class OIDCAuthPasswordBackend(OIDCBaseBackend):
# Calls the token endpoint. # Calls the token endpoint.
logger.debug(log_prompt.format('Call the token endpoint')) logger.debug(log_prompt.format('Call the token endpoint'))
token_response = requests.post(settings.AUTH_OPENID_PROVIDER_TOKEN_ENDPOINT, data=token_payload, timeout=request_timeout) token_response = requests.post(settings.AUTH_OPENID_PROVIDER_TOKEN_ENDPOINT, data=token_payload,
timeout=request_timeout)
try: try:
token_response.raise_for_status() token_response.raise_for_status()
token_response_data = token_response.json() token_response_data = token_response.json()

View File

@ -27,9 +27,13 @@ class SAML2Backend(JMSModelBackend):
log_prompt = "Get or Create user [SAML2Backend]: {}" log_prompt = "Get or Create user [SAML2Backend]: {}"
logger.debug(log_prompt.format('start')) logger.debug(log_prompt.format('start'))
groups = saml_user_data.pop('groups', None)
user, created = get_user_model().objects.get_or_create( user, created = get_user_model().objects.get_or_create(
username=saml_user_data['username'], defaults=saml_user_data username=saml_user_data['username'], defaults=saml_user_data
) )
saml_user_data['groups'] = groups
logger.debug(log_prompt.format("user: {}|created: {}".format(user, created))) logger.debug(log_prompt.format("user: {}|created: {}".format(user, created)))
logger.debug(log_prompt.format("Send signal => saml2 create or update user")) logger.debug(log_prompt.format("Send signal => saml2 create or update user"))

View File

@ -87,6 +87,7 @@ class PrepareRequestMixin:
('name', 'name', False), ('name', 'name', False),
('phone', 'phone', False), ('phone', 'phone', False),
('comment', 'comment', False), ('comment', 'comment', False),
('groups', 'groups', False),
) )
attr_list = [] attr_list = []
for name, friend_name, is_required in need_attrs: for name, friend_name, is_required in need_attrs:
@ -185,7 +186,7 @@ class PrepareRequestMixin:
user_attrs = {} user_attrs = {}
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES attr_mapping = settings.SAML2_RENAME_ATTRIBUTES
attrs = saml_instance.get_attributes() attrs = saml_instance.get_attributes()
valid_attrs = ['username', 'name', 'email', 'comment', 'phone'] valid_attrs = ['username', 'name', 'email', 'comment', 'phone', 'groups']
for attr, value in attrs.items(): for attr, value in attrs.items():
attr = attr.rsplit('/', 1)[-1] attr = attr.rsplit('/', 1)[-1]

View File

@ -21,11 +21,13 @@ from common.signals import django_ready
from common.utils import get_logger from common.utils import get_logger
from jumpserver.utils import get_current_request from jumpserver.utils import get_current_request
from ops.celery.decorator import register_as_period_task from ops.celery.decorator import register_as_period_task
from orgs.models import Organization
from orgs.utils import tmp_to_root_org
from rbac.builtin import BuiltinRole from rbac.builtin import BuiltinRole
from rbac.const import Scope from rbac.const import Scope
from rbac.models import RoleBinding from rbac.models import RoleBinding
from settings.signals import setting_changed from settings.signals import setting_changed
from .models import User, UserPasswordHistory from .models import User, UserPasswordHistory, UserGroup
from .signals import post_user_create from .signals import post_user_create
logger = get_logger(__file__) logger = get_logger(__file__)
@ -50,7 +52,9 @@ def user_authenticated_handle(user, created, source, attrs=None, **kwargs):
if created: if created:
user.source = source user.source = source
user.save() user.save()
bind_user_to_org_role(user) org_ids = bind_user_to_org_role(user)
group_names = attrs.get('groups')
bind_user_to_group(org_ids, group_names, user)
if not attrs: if not attrs:
return return
@ -146,7 +150,7 @@ def radius_create_user(sender, user, **kwargs):
@receiver(openid_create_or_update_user) @receiver(openid_create_or_update_user)
def on_openid_create_or_update_user(sender, request, user, created, name, username, email, **kwargs): def on_openid_create_or_update_user(sender, request, user, created, attrs, **kwargs):
if not check_only_allow_exist_user_auth(created): if not check_only_allow_exist_user_auth(created):
return return
@ -157,7 +161,13 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
) )
user.source = User.Source.openid.value user.source = User.Source.openid.value
user.save() user.save()
bind_user_to_org_role(user) org_ids = bind_user_to_org_role(user)
group_names = attrs.get('groups')
bind_user_to_group(org_ids, group_names, user)
name = attrs.get('name')
username = attrs.get('username')
email = attrs.get('email')
if not created and settings.AUTH_OPENID_ALWAYS_UPDATE_USER: if not created and settings.AUTH_OPENID_ALWAYS_UPDATE_USER:
logger.debug( logger.debug(
@ -225,3 +235,37 @@ def bind_user_to_org_role(user):
] ]
RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True) RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)
return org_ids
def bind_user_to_group(org_ids, group_names, user):
if not isinstance(group_names, list):
return
org_ids = org_ids or [Organization.DEFAULT_ID]
with tmp_to_root_org():
existing_groups = UserGroup.objects.filter(org_id__in=org_ids).values_list('org_id', 'name')
org_groups_map = {}
for org_id, group_name in existing_groups:
org_groups_map.setdefault(org_id, []).append(group_name)
groups_to_create = []
for org_id in org_ids:
existing_group_names = set(org_groups_map.get(org_id, []))
new_group_names = set(group_names) - existing_group_names
groups_to_create.extend(
UserGroup(org_id=org_id, name=name) for name in new_group_names
)
UserGroup.objects.bulk_create(groups_to_create)
user_groups = UserGroup.objects.filter(org_id__in=org_ids, name__in=group_names)
user_group_links = [
User.groups.through(user_id=user.id, usergroup_id=group.id)
for group in user_groups
]
if user_group_links:
User.groups.through.objects.bulk_create(user_group_links)