diff --git a/seahub/api2/throttling.py b/seahub/api2/throttling.py index 20855eaee6..8ac724d6de 100644 --- a/seahub/api2/throttling.py +++ b/seahub/api2/throttling.py @@ -71,6 +71,9 @@ class SimpleRateThrottle(BaseThrottle): def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() + print '000000000', self.THROTTLE_RATES +# print '-------', self.rate +# assert False self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request, view): @@ -116,14 +119,14 @@ class SimpleRateThrottle(BaseThrottle): On success calls `throttle_success`. On failure calls `throttle_failure`. """ + print '-------', self.num_requests if self.rate is None: return True - if get_remote_ip(request) in \ - settings.REST_FRAMEWORK_THROTTING_WHITELIST: + if get_remote_ip(request) in settings.REST_FRAMEWORK_THROTTING_WHITELIST: return True - else: - self.key = self.get_cache_key(request, view) + + self.key = self.get_cache_key(request, view) if self.key is None: return True diff --git a/tests/api/test_throttings.py b/tests/api/test_throttings.py index e49b626c53..8bebeae7a4 100644 --- a/tests/api/test_throttings.py +++ b/tests/api/test_throttings.py @@ -1,30 +1,41 @@ +from mock import patch import time from django.core.urlresolvers import reverse -from django.conf import settings from django.test import override_settings +from seahub.api2.throttling import SimpleRateThrottle from seahub.test_utils import BaseTestCase -@override_settings(REST_FRAMEWORK = {'DEFAULT_THROTTLE_RATES': - {'ping': '600/minute', 'anon': '5000/minute', 'user': '10/minute',},}) class ThrottingsTest(BaseTestCase): def setUp(self): + # clear cache between every test case to avoid cache issue in throtting + self.clear_cache() + self.login_as(self.user) - def test_whitelist(self): - WHITELIST = settings.REST_FRAMEWORK_THROTTING_WHITELIST + @patch.object(SimpleRateThrottle, 'get_rate') + def test_default(self, mock_get_rate): + mock_get_rate.return_value = '10/minute' + for i in range(12): - time.sleep(0.1) res = self.client.get(reverse('api2-pub-repos')) - if i > 10: + if i >= 10: assert res.status_code == 429 - WHITELIST.append('127.0.0.1') - count = 0 - for i in range(12): - time.sleep(0.1) - res = self.client.get(reverse('api2-pub-repos')) - if i > 10: + else: assert res.status_code == 200 + + time.sleep(0.1) + + @override_settings(REST_FRAMEWORK_THROTTING_WHITELIST=['127.0.0.1']) + @patch.object(SimpleRateThrottle, 'get_rate') + def test_whitelist(self, mock_get_rate): + mock_get_rate.return_value = '10/minute' + + for i in range(12): + res = self.client.get(reverse('api2-pub-repos')) + assert res.status_code == 200 + + time.sleep(0.1)