diff --git a/apps/audits/api.py b/apps/audits/api.py index 207e590cf..dcd22c33c 100644 --- a/apps/audits/api.py +++ b/apps/audits/api.py @@ -18,6 +18,7 @@ from rest_framework.response import Response from common.api import CommonApiMixin from common.const.http import GET, POST from common.drf.filters import DatetimeRangeFilterBackend +from common.drf.throttling import FileTransferThrottle from common.permissions import IsServiceAccount from common.plugins.es import QuerySet as ESQuerySet from common.sessions.cache import user_session_manager @@ -111,6 +112,7 @@ class FTPLogViewSet(OrgModelViewSet): @action( methods=[GET], detail=True, permission_classes=[RBACPermission, ], + throttle_classes=[FileTransferThrottle], url_path='file/download' ) def download(self, request, *args, **kwargs): @@ -133,7 +135,9 @@ class FTPLogViewSet(OrgModelViewSet): ) return response - @action(methods=[POST], detail=True, permission_classes=[IsServiceAccount, ], serializer_class=FileSerializer) + @action(methods=[POST], detail=True, permission_classes=[IsServiceAccount, ], + throttle_classes=[FileTransferThrottle], + serializer_class=FileSerializer) def upload(self, request, *args, **kwargs): ftp_log = self.get_object() serializer = self.get_serializer(data=request.data) diff --git a/apps/common/drf/throttling.py b/apps/common/drf/throttling.py index 3483c6de4..9df9652c9 100644 --- a/apps/common/drf/throttling.py +++ b/apps/common/drf/throttling.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from rest_framework.throttling import SimpleRateThrottle +__all__ = ['RateThrottle', 'FileTransferThrottle'] + class RateThrottle(SimpleRateThrottle): @@ -32,3 +34,17 @@ class RateThrottle(SimpleRateThrottle): 'scope': self.scope, 'ident': ident } + + +class FileTransferThrottle(SimpleRateThrottle): + """ + 文件上传下载限流,防止DOS攻击 + """ + scope = 'file_transfer' + + def get_cache_key(self, request, view): + if request.user and request.user.is_authenticated: + ident = request.user.pk + else: + ident = self.get_ident(request) + return self.cache_format % {'scope': self.scope, 'ident': ident} diff --git a/apps/jumpserver/conf.py b/apps/jumpserver/conf.py index bca77db57..e011fb3ab 100644 --- a/apps/jumpserver/conf.py +++ b/apps/jumpserver/conf.py @@ -226,6 +226,9 @@ class Config(dict): 'THROTTLE_RATES_USER': '180/min', 'THROTTLE_RATES_SERVICE_ACCOUNT': '300/min', + # 文件上传下载限流 (防止DOS攻击) + 'THROTTLE_FILE_TRANSFER': '50/hour', + # Security 'X_FRAME_OPTIONS': 'SAMEORIGIN', 'VERIFY_EXTERNAL_SSL': True, diff --git a/apps/jumpserver/settings/libs.py b/apps/jumpserver/settings/libs.py index 6512321cd..75ad6121e 100644 --- a/apps/jumpserver/settings/libs.py +++ b/apps/jumpserver/settings/libs.py @@ -45,6 +45,7 @@ REST_FRAMEWORK = { 'anon': CONFIG.THROTTLE_RATES_ANON, 'user': CONFIG.THROTTLE_RATES_USER, 'service_account': CONFIG.THROTTLE_RATES_SERVICE_ACCOUNT, + 'file_transfer': CONFIG.THROTTLE_FILE_TRANSFER, }, 'DEFAULT_FILTER_BACKENDS': ( 'django_filters.rest_framework.DjangoFilterBackend', diff --git a/apps/ops/api/job.py b/apps/ops/api/job.py index 3eccabc19..68cfaefc9 100644 --- a/apps/ops/api/job.py +++ b/apps/ops/api/job.py @@ -18,6 +18,7 @@ from rest_framework.views import APIView from acls.models import LoginAssetACL from assets.models import Asset from common.const.http import POST +from common.drf.throttling import FileTransferThrottle from common.permissions import IsValidUser from common.utils import get_request_ip_or_data from ops.celery import app @@ -171,7 +172,9 @@ class JobViewSet(LoginAssetACLCheckMixin, OrgBulkModelViewSet): return exceeds_limit_files @action(methods=[POST], detail=False, serializer_class=FileSerializer, - permission_classes=[IsValidUser, ], url_path='upload') + permission_classes=[IsValidUser, ], + throttle_classes=[FileTransferThrottle], + url_path='upload') def upload(self, request, *args, **kwargs): uploaded_files = request.FILES.getlist('files') serializer = self.get_serializer(data=request.data) diff --git a/apps/terminal/api/session/session.py b/apps/terminal/api/session/session.py index 4efedbb4a..6dccb3c45 100644 --- a/apps/terminal/api/session/session.py +++ b/apps/terminal/api/session/session.py @@ -25,6 +25,7 @@ from common.const.http import GET, POST from common.drf.filters import BaseFilterSet from common.drf.filters import DatetimeRangeFilterBackend from common.drf.renders import PassthroughRenderer +from common.drf.throttling import FileTransferThrottle from common.permissions import IsServiceAccount from common.storage.replay import ReplayStorageHandler, SessionPartReplayStorageHandler from common.utils import data_to_json, is_uuid, i18n_fmt @@ -127,6 +128,7 @@ class SessionViewSet(OrgBulkModelViewSet): return file @action(methods=[GET], detail=True, renderer_classes=(PassthroughRenderer,), url_path='replay/download', + throttle_classes=[FileTransferThrottle], url_name='replay-download') def download(self, request, *args, **kwargs): session = self.get_object()