mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-09-26 07:22:27 +00:00
[Update] 修改swagger
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import coreapi
|
||||
from rest_framework import filters
|
||||
from rest_framework.fields import DateTimeField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from django.core.cache import cache
|
||||
import logging
|
||||
|
||||
__all__ = ["DatetimeRangeFilter"]
|
||||
from . import const
|
||||
|
||||
__all__ = ["DatetimeRangeFilter", "IDSpmFilter", "CustomFilter"]
|
||||
|
||||
|
||||
class DatetimeRangeFilter(filters.BaseFilterBackend):
|
||||
@@ -40,3 +44,50 @@ class DatetimeRangeFilter(filters.BaseFilterBackend):
|
||||
if kwargs:
|
||||
queryset = queryset.filter(**kwargs)
|
||||
return queryset
|
||||
|
||||
|
||||
class IDSpmFilter(filters.BaseFilterBackend):
|
||||
def get_schema_fields(self, view):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='spm', location='query', required=False,
|
||||
type='string', example='',
|
||||
description='Pre post objects id get spm id, then using filter'
|
||||
)
|
||||
]
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
spm = request.query_params.get('spm')
|
||||
if not spm:
|
||||
return queryset
|
||||
cache_key = const.KEY_CACHE_RESOURCES_ID.format(spm)
|
||||
resources_id = cache.get(cache_key)
|
||||
if not resources_id or not isinstance(resources_id, list):
|
||||
queryset = queryset.none()
|
||||
return queryset
|
||||
queryset = queryset.filter(id__in=resources_id)
|
||||
return queryset
|
||||
|
||||
|
||||
class CustomFilter(filters.BaseFilterBackend):
|
||||
custom_filter_fields = [] # ["node", "asset"]
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
fields = []
|
||||
defaults = dict(
|
||||
location='query', required=False,
|
||||
type='string', example='',
|
||||
description=''
|
||||
)
|
||||
for field in self.custom_filter_fields:
|
||||
if isinstance(field, str):
|
||||
defaults['name'] = field
|
||||
elif isinstance(field, dict):
|
||||
defaults.update(field)
|
||||
else:
|
||||
continue
|
||||
fields.append(coreapi.Field(**defaults))
|
||||
return fields
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
return queryset
|
||||
|
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.http import JsonResponse
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.contrib import messages
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from ..const import KEY_CACHE_RESOURCES_ID
|
||||
from ..filters import IDSpmFilter, CustomFilter
|
||||
|
||||
__all__ = [
|
||||
"JSONResponseMixin", "IDInCacheFilterMixin", "IDExportFilterMixin",
|
||||
"IDInFilterMixin", "ApiMessageMixin"
|
||||
"JSONResponseMixin", "CommonApiMixin",
|
||||
"IDSpmFilterMixin", "CommonApiMixin",
|
||||
]
|
||||
|
||||
|
||||
@@ -20,69 +18,31 @@ class JSONResponseMixin(object):
|
||||
return JsonResponse(context)
|
||||
|
||||
|
||||
class IDInFilterMixin(object):
|
||||
class IDSpmFilterMixin:
|
||||
def get_filter_backends(self):
|
||||
backends = super().get_filter_backends()
|
||||
backends.append(IDSpmFilter)
|
||||
return backends
|
||||
|
||||
|
||||
class ExtraFilterFieldsMixin:
|
||||
default_added_filters = [CustomFilter, IDSpmFilter]
|
||||
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
|
||||
extra_filter_fields = []
|
||||
extra_filter_backends = []
|
||||
|
||||
def get_filter_backends(self):
|
||||
if self.filter_backends != self.__class__.filter_backends:
|
||||
return self.filter_backends
|
||||
return list(self.filter_backends) + \
|
||||
self.default_added_filters + \
|
||||
list(self.extra_filter_backends)
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super(IDInFilterMixin, self).filter_queryset(queryset)
|
||||
id_list = self.request.query_params.get('id__in')
|
||||
if id_list:
|
||||
import json
|
||||
try:
|
||||
ids = json.loads(id_list)
|
||||
except Exception as e:
|
||||
return queryset
|
||||
if isinstance(ids, list):
|
||||
queryset = queryset.filter(id__in=ids)
|
||||
for backend in self.get_filter_backends():
|
||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||
return queryset
|
||||
|
||||
|
||||
class IDInCacheFilterMixin(object):
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super().filter_queryset(queryset)
|
||||
spm = self.request.query_params.get('spm')
|
||||
if not spm:
|
||||
return queryset
|
||||
cache_key = KEY_CACHE_RESOURCES_ID.format(spm)
|
||||
resources_id = cache.get(cache_key)
|
||||
if not resources_id or not isinstance(resources_id, list):
|
||||
queryset = queryset.none()
|
||||
return queryset
|
||||
queryset = queryset.filter(id__in=resources_id)
|
||||
return queryset
|
||||
|
||||
|
||||
class IDExportFilterMixin(object):
|
||||
def filter_queryset(self, queryset):
|
||||
# 下载导入模版
|
||||
if self.request.query_params.get('template') == 'import':
|
||||
return []
|
||||
else:
|
||||
return super(IDExportFilterMixin, self).filter_queryset(queryset)
|
||||
|
||||
|
||||
class ApiMessageMixin:
|
||||
success_message = _("%(name)s was %(action)s successfully")
|
||||
_action_map = {"create": _("create"), "update": _("update")}
|
||||
|
||||
def get_success_message(self, cleaned_data):
|
||||
if not isinstance(cleaned_data, dict):
|
||||
return ''
|
||||
data = {k: v for k, v in cleaned_data.items()}
|
||||
action = getattr(self, "action", "create")
|
||||
data["action"] = self._action_map.get(action)
|
||||
try:
|
||||
message = self.success_message % data
|
||||
except:
|
||||
message = ''
|
||||
return message
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
resp = super().dispatch(request, *args, **kwargs)
|
||||
if request.method.lower() in ("get", "delete", "patch"):
|
||||
return resp
|
||||
if resp.status_code >= 400:
|
||||
return resp
|
||||
message = self.get_success_message(resp.data)
|
||||
if message:
|
||||
messages.success(request, message)
|
||||
return resp
|
||||
class CommonApiMixin(ExtraFilterFieldsMixin):
|
||||
pass
|
||||
|
@@ -8,7 +8,6 @@ import datetime
|
||||
import uuid
|
||||
from functools import wraps
|
||||
import time
|
||||
import copy
|
||||
import ipaddress
|
||||
|
||||
|
||||
@@ -199,3 +198,18 @@ def timeit(func):
|
||||
logger.debug(msg)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
def group_obj_by_count(objs, count=50):
|
||||
objs_grouped = [
|
||||
objs[i:i + count] for i in range(0, len(objs), count)
|
||||
]
|
||||
return objs_grouped
|
||||
|
||||
|
||||
def dict_get_any(d, keys):
|
||||
for key in keys:
|
||||
value = d.get(key)
|
||||
if value:
|
||||
return value
|
||||
return None
|
||||
|
Reference in New Issue
Block a user