feat: 优化SAML2生成的metadata文件内容及属性映射

This commit is contained in:
jiangweidong 2022-04-21 15:18:17 +08:00 committed by 老广
parent 9804ca5dd0
commit 3a3f7eaf71

View File

@ -74,16 +74,28 @@ class PrepareRequestMixin:
return idp_settings return idp_settings
@staticmethod @staticmethod
def get_attribute_consuming_service(): def get_request_attributes():
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES attr_mapping = settings.SAML2_RENAME_ATTRIBUTES or {}
if attr_mapping and isinstance(attr_mapping, dict): attr_map_reverse = {v: k for k, v in attr_mapping.items()}
attr_list = [ need_attrs = (
{ ('username', 'username', True),
"name": sp_key, ('email', 'email', True),
"friendlyName": idp_key, "isRequired": True ('name', 'name', False),
} ('phone', 'phone', False),
for idp_key, sp_key in attr_mapping.items() ('comment', 'comment', False),
] )
attr_list = []
for name, friend_name, is_required in need_attrs:
rename_name = attr_map_reverse.get(friend_name)
name = rename_name if rename_name else name
attr_list.append({
"name": name, "isRequired": is_required,
"friendlyName": friend_name,
})
return attr_list
def get_attribute_consuming_service(self):
attr_list = self.get_request_attributes()
request_attribute_template = { request_attribute_template = {
"attributeConsumingService": { "attributeConsumingService": {
"isDefault": False, "isDefault": False,
@ -93,8 +105,6 @@ class PrepareRequestMixin:
} }
} }
return request_attribute_template return request_attribute_template
else:
return {}
@staticmethod @staticmethod
def get_advanced_settings(): def get_advanced_settings():
@ -167,11 +177,14 @@ class PrepareRequestMixin:
def get_attributes(self, saml_instance): def get_attributes(self, saml_instance):
user_attrs = {} user_attrs = {}
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES
attrs = saml_instance.get_attributes() attrs = saml_instance.get_attributes()
valid_attrs = ['username', 'name', 'email', 'comment', 'phone'] valid_attrs = ['username', 'name', 'email', 'comment', 'phone']
for attr, value in attrs.items(): for attr, value in attrs.items():
attr = attr.rsplit('/', 1)[-1] attr = attr.rsplit('/', 1)[-1]
if attr_mapping and attr_mapping.get(attr):
attr = attr_mapping.get(attr)
if attr not in valid_attrs: if attr not in valid_attrs:
continue continue
user_attrs[attr] = self.value_to_str(value) user_attrs[attr] = self.value_to_str(value)