diff --git a/apps/authentication/api/sso.py b/apps/authentication/api/sso.py index 5740c80d8..b953b8d37 100644 --- a/apps/authentication/api/sso.py +++ b/apps/authentication/api/sso.py @@ -21,16 +21,19 @@ from ..filters import AuthKeyQueryDeclaration from ..mixins import AuthMixin from ..errors import SSOAuthClosed +NEXT_URL = 'next' +AUTH_KEY = 'authkey' + class SSOViewSet(AuthMixin, JmsGenericViewSet): queryset = SSOToken.objects.all() serializer_classes = { - 'get_login_url': SSOTokenSerializer, + 'login_url': SSOTokenSerializer, 'login': EmptySerializer } - @action(methods=[POST], detail=False, permission_classes=[IsSuperUser]) - def get_login_url(self, request, *args, **kwargs): + @action(methods=[POST], detail=False, permission_classes=[IsSuperUser], url_path='login-url') + def login_url(self, request, *args, **kwargs): if not settings.AUTH_SSO: raise SSOAuthClosed() @@ -39,12 +42,14 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet): username = serializer.validated_data['username'] user = User.objects.get(username=username) + next_url = serializer.validated_data.get(NEXT_URL) operator = request.user.username # TODO `created_by` 和 `created_by` 可以通过 `ThreadLocal` 统一处理 token = SSOToken.objects.create(user=user, created_by=operator, updated_by=operator) query = { - 'authkey': token.authkey + AUTH_KEY: token.authkey, + NEXT_URL: next_url or '' } login_url = '%s?%s' % (reverse('api-auth:sso-login', external=True), urlencode(query)) return Response(data={'login_url': login_url}) @@ -55,7 +60,11 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet): 此接口违反了 `Restful` 的规范 `GET` 应该是安全的方法,但此接口是不安全的 """ - authkey = request.query_params.get('authkey') + authkey = request.query_params.get(AUTH_KEY) + next_url = request.query_params.get(NEXT_URL) + if not next_url or not next_url.startswith('/'): + next_url = reverse('index') + try: authkey = UUID(authkey) token = SSOToken.objects.get(authkey=authkey, expired=False) @@ -63,15 +72,15 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet): token.expired = True token.save() except (ValueError, SSOToken.DoesNotExist): - self.send_auth_signal(success=False, reason=f'authkey invalid: {authkey}') + self.send_auth_signal(success=False, reason='authkey_invalid') return HttpResponseRedirect(reverse('authentication:login')) # 判断是否过期 if (utcnow().timestamp() - token.date_created.timestamp()) > settings.AUTH_SSO_AUTHKEY_TTL: - self.send_auth_signal(success=False, reason=f'authkey timeout: {authkey}') + self.send_auth_signal(success=False, reason='authkey_timeout') return HttpResponseRedirect(reverse('authentication:login')) user = token.user login(self.request, user, 'authentication.backends.api.SSOAuthentication') self.send_auth_signal(success=True, user=user) - return HttpResponseRedirect(reverse('index')) + return HttpResponseRedirect(next_url) diff --git a/apps/authentication/serializers.py b/apps/authentication/serializers.py index f04b847b4..7d666db4c 100644 --- a/apps/authentication/serializers.py +++ b/apps/authentication/serializers.py @@ -81,3 +81,4 @@ class LoginConfirmSettingSerializer(serializers.ModelSerializer): class SSOTokenSerializer(serializers.Serializer): username = serializers.CharField(write_only=True) login_url = serializers.CharField(read_only=True) + next = serializers.CharField(write_only=True, allow_blank=True, required=False, allow_null=True)