diff --git a/seahub/api2/endpoints/notifications.py b/seahub/api2/endpoints/notifications.py index 420dd7319e..45f5656c82 100644 --- a/seahub/api2/endpoints/notifications.py +++ b/seahub/api2/endpoints/notifications.py @@ -5,6 +5,7 @@ from rest_framework.authentication import SessionAuthentication from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework import status from django.core.cache import cache @@ -123,17 +124,32 @@ class NotificationView(APIView): notice_id = request.data.get('notice_id') + # argument check + try: + int(notice_id) + except Exception as e: + error_msg = 'notice_id invalid.' + logger.error(e) + return api_error(status.HTTP_400_BAD_REQUEST, error_msg) + + # resource check try: notice = UserNotification.objects.get(id=notice_id) except UserNotification.DoesNotExist as e: logger.error(e) - pass + error_msg = 'Notification %s not found.' % notice_id + return api_error(status.HTTP_404_NOT_FOUND, error_msg) + + # permission check + username = request.user.username + if notice.to_user != username: + error_msg = 'Permission denied.' + return api_error(status.HTTP_403_FORBIDDEN, error_msg) if not notice.seen: notice.seen = True notice.save() - username = request.user.username cache_key = get_cache_key_of_unseen_notifications(username) cache.delete(cache_key) diff --git a/tests/api/endpoints/test_notifications.py b/tests/api/endpoints/test_notifications.py index 0433cf1d83..418e763216 100644 --- a/tests/api/endpoints/test_notifications.py +++ b/tests/api/endpoints/test_notifications.py @@ -1,6 +1,7 @@ import json from seahub.test_utils import BaseTestCase from seahub.notifications.models import UserNotification +from seahub.base.accounts import UserManager class NotificationsTest(BaseTestCase): def setUp(self): @@ -58,3 +59,29 @@ class NotificationTest(BaseTestCase): self.assertEqual(200, resp.status_code) assert UserNotification.objects.count_unseen_user_notifications(self.username) == 0 + + def test_argument_check_notice_id_invalid(self): + self.login_as(self.user) + data = 'notice_id=%s' % 'a' + + resp = self.client.put(self.endpoint, data, 'application/x-www-form-urlencoded') + self.assertEqual(400, resp.status_code) + + def test_resource_check_notification_not_found(self): + self.login_as(self.user) + notice1 = UserNotification.objects.add_user_message(self.username, 'test1') + notice2 = UserNotification.objects.add_user_message(self.username, 'test2') + data = 'notice_id=%s' % str(notice2.id + 1) + + resp = self.client.put(self.endpoint, data, 'application/x-www-form-urlencoded') + self.assertEqual(404, resp.status_code) + + def test_permission_check_permission_denied(self): + self.login_as(self.user) + new_user = UserManager().create_user(email='new@new.com', password='root') + notice_to_new_user = UserNotification.objects.add_user_message(new_user.name, 'test for new user') + data = 'notice_id=%s' % notice_to_new_user.id + + resp = self.client.put(self.endpoint, data, 'application/x-www-form-urlencoded') + self.assertEqual(403, resp.status_code) +