diff --git a/seahub/adfs_auth/views.py b/seahub/adfs_auth/views.py index 65e99ce767..58d8009bec 100644 --- a/seahub/adfs_auth/views.py +++ b/seahub/adfs_auth/views.py @@ -154,7 +154,16 @@ def auth_complete(request): token = get_token_v2( request, request.user.username, platform, device_id, device_name, client_version, platform_version) - elif all(['shib_' + key not in request.GET for key in keys]): + elif all(['shib_' + key in request.session for key in keys]): + platform = request.session['shib_platform'] + device_id = request.session['shib_device_id'] + device_name = request.session['shib_device_name'] + client_version = request.session['shib_client_version'] + platform_version = request.session['shib_platform_version'] + token = get_token_v2( + request, request.user.username, platform, device_id, + device_name, client_version, platform_version) + else: token = get_token_v1(request.user.username) resp = HttpResponseRedirect(reverse('libraries')) diff --git a/seahub/api2/utils.py b/seahub/api2/utils.py index 9e35cce8f8..e4f197bd28 100644 --- a/seahub/api2/utils.py +++ b/seahub/api2/utils.py @@ -229,6 +229,17 @@ def get_api_token(request, keys=None, key_prefix='shib_'): token = get_token_v2(request, request.user.username, platform, device_id, device_name, client_version, platform_version) + + elif all([key in request.session for key in keys]): + platform = request.session['%splatform' % key_prefix] + device_id = request.session['%sdevice_id' % key_prefix] + device_name = request.session['%sdevice_name' % key_prefix] + client_version = request.session['%sclient_version' % key_prefix] + platform_version = request.session['%splatform_version' % key_prefix] + token = get_token_v2( + request, request.user.username, platform, device_id, + device_name, client_version, platform_version) + else: token = get_token_v1(request.user.username) diff --git a/seahub/views/sso.py b/seahub/views/sso.py index 132e187788..9132971a42 100644 --- a/seahub/views/sso.py +++ b/seahub/views/sso.py @@ -86,6 +86,14 @@ def jwt_sso(request): def shib_login(request): # client platform args used to create api v2 token + keys = ('platform', 'device_id', 'device_name', 'client_version', 'platform_version') + if all(['shib_' + key in request.GET for key in keys]): + request.session['shib_platform'] = request.GET['shib_platform'] + request.session['shib_device_id'] = request.GET['shib_device_id'] + request.session['shib_device_name'] = request.GET['shib_device_name'] + request.session['shib_client_version'] = request.GET['shib_client_version'] + request.session['shib_platform_version'] = request.GET['shib_platform_version'] + next_page = request.GET.get(REDIRECT_FIELD_NAME, '') query_string = request.META.get('QUERY_STRING', '') params = '?%s=%s&%s' % (REDIRECT_FIELD_NAME, urlquote(next_page), query_string)