mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-09-02 07:55:16 +00:00
feat: 支持saml2协议的单点登录,合并代码 (#7347)
* fix: 支持saml2协议的单点登录 * feat: 支持saml2协议的单点登录,合并代码 Co-authored-by: jiangweidong <weidong.jiang@fit2cloud.com>
This commit is contained in:
3
apps/authentication/backends/saml2/__init__.py
Normal file
3
apps/authentication/backends/saml2/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from .backends import *
|
67
apps/authentication/backends/saml2/backends.py
Normal file
67
apps/authentication/backends/saml2/backends.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.backends import ModelBackend
|
||||
from django.db import transaction
|
||||
|
||||
from common.utils import get_logger
|
||||
from authentication.errors import reason_choices, reason_user_invalid
|
||||
from .signals import (
|
||||
saml2_user_authenticated, saml2_user_authentication_failed,
|
||||
saml2_create_or_update_user
|
||||
)
|
||||
|
||||
__all__ = ['SAML2Backend']
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class SAML2Backend(ModelBackend):
|
||||
@staticmethod
|
||||
def user_can_authenticate(user):
|
||||
is_valid = getattr(user, 'is_valid', None)
|
||||
return is_valid or is_valid is None
|
||||
|
||||
@transaction.atomic
|
||||
def get_or_create_from_saml_data(self, request, **saml_user_data):
|
||||
log_prompt = "Get or Create user [SAML2Backend]: {}"
|
||||
logger.debug(log_prompt.format('start'))
|
||||
|
||||
user, created = get_user_model().objects.get_or_create(
|
||||
username=saml_user_data['username'], defaults=saml_user_data
|
||||
)
|
||||
logger.debug(log_prompt.format("user: {}|created: {}".format(user, created)))
|
||||
|
||||
logger.debug(log_prompt.format("Send signal => saml2 create or update user"))
|
||||
saml2_create_or_update_user.send(
|
||||
sender=self, request=request, user=user, created=created, attrs=saml_user_data
|
||||
)
|
||||
return user, created
|
||||
|
||||
def authenticate(self, request, saml_user_data=None, **kwargs):
|
||||
log_prompt = "Process authenticate [SAML2AuthCodeBackend]: {}"
|
||||
logger.debug(log_prompt.format('Start'))
|
||||
if saml_user_data is None:
|
||||
logger.debug(log_prompt.format('saml_user_data is missing'))
|
||||
return None
|
||||
|
||||
username = saml_user_data.get('username')
|
||||
if not username:
|
||||
logger.debug(log_prompt.format('username is missing'))
|
||||
return None
|
||||
|
||||
user, created = self.get_or_create_from_saml_data(request, **saml_user_data)
|
||||
|
||||
if self.user_can_authenticate(user):
|
||||
logger.debug(log_prompt.format('SAML2 user login success'))
|
||||
saml2_user_authenticated.send(
|
||||
sender=self, request=request, user=user, created=created
|
||||
)
|
||||
return user
|
||||
else:
|
||||
logger.debug(log_prompt.format('SAML2 user login failed'))
|
||||
saml2_user_authentication_failed.send(
|
||||
sender=self, request=request, username=username,
|
||||
reason=reason_choices.get(reason_user_invalid)
|
||||
)
|
||||
return None
|
12
apps/authentication/backends/saml2/settings.py
Normal file
12
apps/authentication/backends/saml2/settings.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from django.conf import settings
|
||||
from onelogin.saml2.settings import OneLogin_Saml2_Settings
|
||||
|
||||
|
||||
class JmsSaml2Settings(OneLogin_Saml2_Settings):
|
||||
def get_sp_key(self):
|
||||
key = getattr(settings, 'SAML2_SP_KEY_CONTENT', '')
|
||||
return key
|
||||
|
||||
def get_sp_cert(self):
|
||||
cert = getattr(settings, 'SAML2_SP_CERT_CONTENT', '')
|
||||
return cert
|
6
apps/authentication/backends/saml2/signals.py
Normal file
6
apps/authentication/backends/saml2/signals.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.dispatch import Signal
|
||||
|
||||
|
||||
saml2_create_or_update_user = Signal(providing_args=('user', 'created', 'request', 'attrs'))
|
||||
saml2_user_authenticated = Signal(providing_args=('user', 'created', 'request'))
|
||||
saml2_user_authentication_failed = Signal(providing_args=('request', 'username', 'reason'))
|
13
apps/authentication/backends/saml2/urls.py
Normal file
13
apps/authentication/backends/saml2/urls.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'),
|
||||
path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'),
|
||||
path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'),
|
||||
path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'),
|
||||
]
|
269
apps/authentication/backends/saml2/views.py
Normal file
269
apps/authentication/backends/saml2/views.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from django.views import View
|
||||
from django.contrib import auth as auth
|
||||
from django.urls import reverse
|
||||
from django.conf import settings
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.http import HttpResponseRedirect, HttpResponse, HttpResponseServerError
|
||||
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth
|
||||
from onelogin.saml2.errors import OneLogin_Saml2_Error
|
||||
from onelogin.saml2.idp_metadata_parser import (
|
||||
OneLogin_Saml2_IdPMetadataParser as IdPMetadataParse,
|
||||
dict_deep_merge
|
||||
)
|
||||
|
||||
from .settings import JmsSaml2Settings
|
||||
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class PrepareRequestMixin:
|
||||
@staticmethod
|
||||
def prepare_django_request(request):
|
||||
result = {
|
||||
'https': 'on' if request.is_secure() else 'off',
|
||||
'http_host': request.META['HTTP_HOST'],
|
||||
'script_name': request.META['PATH_INFO'],
|
||||
'get_data': request.GET.copy(),
|
||||
'post_data': request.POST.copy()
|
||||
}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_idp_settings():
|
||||
idp_metadata_xml = settings.SAML2_IDP_METADATA_XML
|
||||
idp_metadata_url = settings.SAML2_IDP_METADATA_URL
|
||||
logger.debug('Start getting IDP configuration')
|
||||
|
||||
try:
|
||||
xml_idp_settings = IdPMetadataParse.parse(idp_metadata_xml)
|
||||
except Exception as err:
|
||||
xml_idp_settings = None
|
||||
logger.warning('Failed to get IDP metadata XML settings, error: %s', str(err))
|
||||
|
||||
try:
|
||||
url_idp_settings = IdPMetadataParse.parse_remote(
|
||||
idp_metadata_url, timeout=20
|
||||
)
|
||||
except Exception as err:
|
||||
url_idp_settings = None
|
||||
logger.warning('Failed to get IDP metadata URL settings, error: %s', str(err))
|
||||
|
||||
idp_settings = url_idp_settings or xml_idp_settings
|
||||
|
||||
if idp_settings is None:
|
||||
msg = 'Unable to resolve IDP settings. '
|
||||
tip = 'Please contact your administrator to check system settings,' \
|
||||
'or login using other methods.'
|
||||
logger.error(msg)
|
||||
raise OneLogin_Saml2_Error(msg + tip, OneLogin_Saml2_Error.SETTINGS_INVALID)
|
||||
|
||||
logger.debug('IDP settings obtained successfully')
|
||||
return idp_settings
|
||||
|
||||
@staticmethod
|
||||
def get_attribute_consuming_service():
|
||||
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES
|
||||
name_prefix = settings.SITE_URL
|
||||
if attr_mapping and isinstance(attr_mapping, dict):
|
||||
attr_list = [
|
||||
{
|
||||
"name": '{}/{}'.format(name_prefix, sp_key),
|
||||
"friendlyName": idp_key, "isRequired": True
|
||||
}
|
||||
for idp_key, sp_key in attr_mapping.items()
|
||||
]
|
||||
request_attribute_template = {
|
||||
"attributeConsumingService": {
|
||||
"isDefault": False,
|
||||
"serviceName": "JumpServer",
|
||||
"serviceDescription": "JumpServer",
|
||||
"requestedAttributes": attr_list
|
||||
}
|
||||
}
|
||||
return request_attribute_template
|
||||
else:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_advanced_settings():
|
||||
other_settings = {}
|
||||
other_settings_path = settings.SAML2_OTHER_SETTINGS_PATH
|
||||
if os.path.exists(other_settings_path):
|
||||
with open(other_settings_path, 'r') as json_data:
|
||||
try:
|
||||
other_settings = json.loads(json_data.read())
|
||||
except Exception as error:
|
||||
logger.error('Get other settings error: %s', error)
|
||||
|
||||
default = {
|
||||
"organization": {
|
||||
"en": {
|
||||
"name": "JumpServer",
|
||||
"displayname": "JumpServer",
|
||||
"url": "https://jumpserver.org/"
|
||||
}
|
||||
}
|
||||
}
|
||||
default.update(other_settings)
|
||||
return default
|
||||
|
||||
def get_sp_settings(self):
|
||||
sp_host = settings.SITE_URL
|
||||
attrs = self.get_attribute_consuming_service()
|
||||
sp_settings = {
|
||||
'sp': {
|
||||
'entityId': f"{sp_host}{reverse('authentication:saml2:saml2-login')}",
|
||||
'assertionConsumerService': {
|
||||
'url': f"{sp_host}{reverse('authentication:saml2:saml2-callback')}",
|
||||
},
|
||||
'singleLogoutService': {
|
||||
'url': f"{sp_host}{reverse('authentication:saml2:saml2-logout')}"
|
||||
}
|
||||
}
|
||||
}
|
||||
sp_settings['sp'].update(attrs)
|
||||
advanced_settings = self.get_advanced_settings()
|
||||
sp_settings.update(advanced_settings)
|
||||
return sp_settings
|
||||
|
||||
def get_saml2_settings(self):
|
||||
sp_settings = self.get_sp_settings()
|
||||
idp_settings = self.get_idp_settings()
|
||||
saml2_settings = dict_deep_merge(sp_settings, idp_settings)
|
||||
return saml2_settings
|
||||
|
||||
def init_saml_auth(self, request):
|
||||
request = self.prepare_django_request(request)
|
||||
_settings = self.get_saml2_settings()
|
||||
saml_instance = OneLogin_Saml2_Auth(
|
||||
request, old_settings=_settings, custom_base_path=settings.SAML_FOLDER
|
||||
)
|
||||
return saml_instance
|
||||
|
||||
@staticmethod
|
||||
def value_to_str(attr):
|
||||
if isinstance(attr, str):
|
||||
return attr
|
||||
elif isinstance(attr, list) and len(attr) > 0:
|
||||
return str(attr[0])
|
||||
|
||||
def get_attributes(self, saml_instance):
|
||||
user_attrs = {}
|
||||
real_key_index = len(settings.SITE_URL) + 1
|
||||
attrs = saml_instance.get_attributes()
|
||||
|
||||
for attr, value in attrs.items():
|
||||
attr = attr[real_key_index:]
|
||||
user_attrs[attr] = self.value_to_str(value)
|
||||
return user_attrs
|
||||
|
||||
|
||||
class Saml2AuthRequestView(View, PrepareRequestMixin):
|
||||
|
||||
def get(self, request):
|
||||
log_prompt = "Process GET requests [SAML2AuthRequestView]: {}"
|
||||
logger.debug(log_prompt.format('Start'))
|
||||
|
||||
try:
|
||||
saml_instance = self.init_saml_auth(request)
|
||||
except OneLogin_Saml2_Error as error:
|
||||
logger.error(log_prompt.format('Init saml auth error: %s' % error))
|
||||
return HttpResponse(error, status=412)
|
||||
|
||||
next_url = settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT
|
||||
url = saml_instance.login(return_to=next_url)
|
||||
logger.debug(log_prompt.format('Redirect login url'))
|
||||
return HttpResponseRedirect(url)
|
||||
|
||||
|
||||
class Saml2EndSessionView(View, PrepareRequestMixin):
|
||||
http_method_names = ['get', 'post', ]
|
||||
|
||||
def get(self, request):
|
||||
log_prompt = "Process GET requests [SAML2EndSessionView]: {}"
|
||||
logger.debug(log_prompt.format('Start'))
|
||||
return self.post(request)
|
||||
|
||||
def post(self, request):
|
||||
log_prompt = "Process POST requests [SAML2EndSessionView]: {}"
|
||||
logger.debug(log_prompt.format('Start'))
|
||||
|
||||
logout_url = settings.LOGOUT_REDIRECT_URL or '/'
|
||||
|
||||
if request.user.is_authenticated:
|
||||
logger.debug(log_prompt.format('Log out the current user: {}'.format(request.user)))
|
||||
auth.logout(request)
|
||||
|
||||
if settings.SAML2_LOGOUT_COMPLETELY:
|
||||
saml_instance = self.init_saml_auth(request)
|
||||
logger.debug(log_prompt.format('Log out IDP user session synchronously'))
|
||||
return HttpResponseRedirect(saml_instance.logout())
|
||||
|
||||
logger.debug(log_prompt.format('Redirect logout url'))
|
||||
return HttpResponseRedirect(logout_url)
|
||||
|
||||
|
||||
class Saml2AuthCallbackView(View, PrepareRequestMixin):
|
||||
|
||||
def post(self, request):
|
||||
log_prompt = "Process POST requests [SAML2AuthCallbackView]: {}"
|
||||
post_data = request.POST
|
||||
|
||||
try:
|
||||
saml_instance = self.init_saml_auth(request)
|
||||
except OneLogin_Saml2_Error as error:
|
||||
logger.error(log_prompt.format('Init saml auth error: %s' % error))
|
||||
return HttpResponse(error, status=412)
|
||||
|
||||
request_id = None
|
||||
if 'AuthNRequestID' in request.session:
|
||||
request_id = request.session['AuthNRequestID']
|
||||
|
||||
logger.debug(log_prompt.format('Process saml response'))
|
||||
saml_instance.process_response(request_id=request_id)
|
||||
errors = saml_instance.get_errors()
|
||||
|
||||
if not errors:
|
||||
if 'AuthNRequestID' in request.session:
|
||||
del request.session['AuthNRequestID']
|
||||
|
||||
logger.debug(log_prompt.format('Process authenticate'))
|
||||
saml_user_data = self.get_attributes(saml_instance)
|
||||
user = auth.authenticate(request=request, saml_user_data=saml_user_data)
|
||||
if user and user.is_valid:
|
||||
logger.debug(log_prompt.format('Login: {}'.format(user)))
|
||||
auth.login(self.request, user)
|
||||
|
||||
logger.debug(log_prompt.format('Redirect'))
|
||||
next_url = saml_instance.redirect_to(post_data.get('RelayState', '/'))
|
||||
return HttpResponseRedirect(next_url)
|
||||
logger.error(log_prompt.format('Saml response has error: %s' % str(errors)))
|
||||
return HttpResponseRedirect(settings.AUTH_SAML2_AUTHENTICATION_FAILURE_REDIRECT_URI)
|
||||
|
||||
@csrf_exempt
|
||||
def dispatch(self, *args, **kwargs):
|
||||
return super().dispatch(*args, **kwargs)
|
||||
|
||||
|
||||
class Saml2AuthMetadataView(View, PrepareRequestMixin):
|
||||
|
||||
def get(self, _):
|
||||
saml_settings = self.get_sp_settings()
|
||||
saml_settings = JmsSaml2Settings(
|
||||
settings=saml_settings, sp_validation_only=True,
|
||||
custom_base_path=settings.SAML_FOLDER
|
||||
)
|
||||
metadata = saml_settings.get_sp_metadata()
|
||||
errors = saml_settings.validate_metadata(metadata)
|
||||
|
||||
if len(errors) == 0:
|
||||
resp = HttpResponse(content=metadata, content_type='text/xml')
|
||||
else:
|
||||
resp = HttpResponseServerError(content=', '.join(errors))
|
||||
return resp
|
Reference in New Issue
Block a user