diff --git a/seahub/base/accounts.py b/seahub/base/accounts.py index bddaf77f16..07c0d74c03 100644 --- a/seahub/base/accounts.py +++ b/seahub/base/accounts.py @@ -111,11 +111,11 @@ class UserManager(object): return self.get(email=virtual_id) - def update_role(self, email, role): + def update_role(self, email, role, is_manual_set=True): """ If user has a role, update it; or create a role for user. """ - ccnet_api.update_role_emailuser(email, role) + ccnet_api.update_role_emailuser(email, role, is_manual_set=is_manual_set) return self.get(email=email) def create_oauth_user(self, email=None, password=None, is_staff=False, is_active=False): diff --git a/seahub/utils/ccnet_db.py b/seahub/utils/ccnet_db.py index b47268a221..4cdce9265a 100644 --- a/seahub/utils/ccnet_db.py +++ b/seahub/utils/ccnet_db.py @@ -28,6 +28,11 @@ class CcnetUsers(object): self.role = kwargs.get('role') self.passwd = kwargs.get('passwd') +class CcnetUserRole(object): + + def __init__(self, **kwargs): + self.role = kwargs.get('role') + self.is_manual_set = kwargs.get('is_manual_set') class CcnetDB: @@ -211,3 +216,24 @@ class CcnetDB: cursor.execute(sql) user_count = cursor.fetchone()[0] return user_count + + def get_user_role_from_db(self, email): + + sql = f""" + SELECT `role`, `is_manual_set` FROM `{self.db_name}`.`UserRole` WHERE email = '{email}'; + """ + with connection.cursor() as cursor: + cursor.execute(sql) + row = cursor.fetchone() + if not row: + role = None + is_manual_set = False + else: + role = row[0] + is_manual_set = row[1] + + params = { + 'role': role, + 'is_manual_set': is_manual_set + } + return CcnetUserRole(**params) diff --git a/thirdpart/shibboleth/middleware.py b/thirdpart/shibboleth/middleware.py index 4ba691050f..ff5f058fae 100755 --- a/thirdpart/shibboleth/middleware.py +++ b/thirdpart/shibboleth/middleware.py @@ -19,6 +19,7 @@ from seahub.base.accounts import User from seahub.profile.models import Profile from seahub.utils.file_size import get_quota_from_string from seahub.role_permissions.utils import get_enabled_role_permissions_by_role +from seahub.utils.ccnet_db import CcnetDB # Get an instance of a logger logger = logging.getLogger(__name__) @@ -106,14 +107,20 @@ class ShibbolethRemoteUserMiddleware(RemoteUserMiddleware): # call make profile. self.make_profile(user, shib_meta) - if CUSTOM_SHIBBOLETH_GET_USER_ROLE: - user_role = custom_shibboleth_get_user_role(shib_meta) - if user_role: - ccnet_api.update_role_emailuser(user.email, user_role) - else: - user_role = self.update_user_role(user, shib_meta) + db_api = CcnetDB() + db_user_role = db_api.get_user_role_from_db(user.email) + if db_user_role.is_manual_set: + user_role = db_user_role.role + else: - user_role = self.update_user_role(user, shib_meta) + if CUSTOM_SHIBBOLETH_GET_USER_ROLE: + user_role = custom_shibboleth_get_user_role(shib_meta) + if user_role: + ccnet_api.update_role_emailuser(user.email, user_role, False) + else: + user_role = self.update_user_role(user, shib_meta, False) + else: + user_role = self.update_user_role(user, shib_meta, False) if user_role: self.update_user_quota(user, user_role) @@ -208,7 +215,7 @@ class ShibbolethRemoteUserMiddleware(RemoteUserMiddleware): return None - def update_user_role(self, user, shib_meta): + def update_user_role(self, user, shib_meta, is_manual_set): affiliation = shib_meta.get('affiliation', '') if not affiliation: return @@ -216,7 +223,7 @@ class ShibbolethRemoteUserMiddleware(RemoteUserMiddleware): for e in affiliation.split(';'): role = self._get_role_by_affiliation(e) if role: - User.objects.update_role(user.email, role) + User.objects.update_role(user.email, role, is_manual_set) return role def update_user_quota(self, user, user_role):