diff --git a/apps/users/api/profile.py b/apps/users/api/profile.py index 321ef54e1..e9631b5b7 100644 --- a/apps/users/api/profile.py +++ b/apps/users/api/profile.py @@ -83,8 +83,3 @@ class UserPublicKeyApi(generics.RetrieveUpdateAPIView): def get_object(self): return self.request.user - - def perform_update(self, serializer): - user = self.get_object() - user.public_key = serializer.validated_data['public_key'] - user.save() diff --git a/apps/users/models/user.py b/apps/users/models/user.py index 431a9789b..3fb058396 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -48,8 +48,9 @@ class AuthMixin: super().set_password(raw_password) def set_public_key(self, public_key): - self.public_key = public_key - self.save() + if self.can_update_ssh_key(): + self.public_key = public_key + self.save() def can_update_password(self): return self.is_local @@ -58,7 +59,7 @@ class AuthMixin: return self.can_use_ssh_key_login() def can_use_ssh_key_login(self): - return settings.TERMINAL_PUBLIC_KEY_AUTH + return self.is_local and settings.TERMINAL_PUBLIC_KEY_AUTH def is_public_key_valid(self): """ diff --git a/apps/users/serializers/user.py b/apps/users/serializers/user.py index fceeed5f3..019524cae 100644 --- a/apps/users/serializers/user.py +++ b/apps/users/serializers/user.py @@ -234,10 +234,6 @@ class UserProfileSerializer(UserSerializer): fields.remove('password') extra_kwargs.pop('password', None) - if 'public_key' in fields: - fields.remove('public_key') - extra_kwargs.pop('public_key', None) - @staticmethod def get_guide_url(obj): return settings.USER_GUIDE_URL @@ -247,6 +243,13 @@ class UserProfileSerializer(UserSerializer): return 2 return mfa_level + def validate_public_key(self, public_key): + if self.instance and self.instance.can_update_ssh_key(): + if not validate_ssh_public_key(public_key): + raise serializers.ValidationError(_('Not a valid ssh public key')) + return public_key + return None + class UserUpdatePasswordSerializer(serializers.ModelSerializer): old_password = serializers.CharField(required=True, max_length=128, write_only=True) diff --git a/apps/users/views/profile/pubkey.py b/apps/users/views/profile/pubkey.py index 52c149084..4010fd996 100644 --- a/apps/users/views/profile/pubkey.py +++ b/apps/users/views/profile/pubkey.py @@ -46,8 +46,7 @@ class UserPublicKeyGenerateView(PermissionsMixin, View): def get(self, request, *args, **kwargs): username = request.user.username private, public = ssh_key_gen(username=username, hostname='jumpserver') - request.user.public_key = public - request.user.save() + request.user.set_public_key(public) response = HttpResponse(private, content_type='text/plain') filename = "{0}-jumpserver.pem".format(username) response['Content-Disposition'] = 'attachment; filename={}'.format(filename)