[Update] 修改swagger

This commit is contained in:
ibuler
2019-09-18 22:06:46 +08:00
parent 0db3e41bde
commit 5464c884db
44 changed files with 979 additions and 633 deletions

View File

@@ -1,11 +1,15 @@
# -*- coding: utf-8 -*-
#
import coreapi
from rest_framework import filters
from rest_framework.fields import DateTimeField
from rest_framework.serializers import ValidationError
from django.core.cache import cache
import logging
__all__ = ["DatetimeRangeFilter"]
from . import const
__all__ = ["DatetimeRangeFilter", "IDSpmFilter", "CustomFilter"]
class DatetimeRangeFilter(filters.BaseFilterBackend):
@@ -40,3 +44,50 @@ class DatetimeRangeFilter(filters.BaseFilterBackend):
if kwargs:
queryset = queryset.filter(**kwargs)
return queryset
class IDSpmFilter(filters.BaseFilterBackend):
def get_schema_fields(self, view):
return [
coreapi.Field(
name='spm', location='query', required=False,
type='string', example='',
description='Pre post objects id get spm id, then using filter'
)
]
def filter_queryset(self, request, queryset, view):
spm = request.query_params.get('spm')
if not spm:
return queryset
cache_key = const.KEY_CACHE_RESOURCES_ID.format(spm)
resources_id = cache.get(cache_key)
if not resources_id or not isinstance(resources_id, list):
queryset = queryset.none()
return queryset
queryset = queryset.filter(id__in=resources_id)
return queryset
class CustomFilter(filters.BaseFilterBackend):
custom_filter_fields = [] # ["node", "asset"]
def get_schema_fields(self, view):
fields = []
defaults = dict(
location='query', required=False,
type='string', example='',
description=''
)
for field in self.custom_filter_fields:
if isinstance(field, str):
defaults['name'] = field
elif isinstance(field, dict):
defaults.update(field)
else:
continue
fields.append(coreapi.Field(**defaults))
return fields
def filter_queryset(self, request, queryset, view):
return queryset

View File

@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
#
from django.http import JsonResponse
from django.core.cache import cache
from django.utils.translation import ugettext_lazy as _
from django.contrib import messages
from rest_framework.settings import api_settings
from ..const import KEY_CACHE_RESOURCES_ID
from ..filters import IDSpmFilter, CustomFilter
__all__ = [
"JSONResponseMixin", "IDInCacheFilterMixin", "IDExportFilterMixin",
"IDInFilterMixin", "ApiMessageMixin"
"JSONResponseMixin", "CommonApiMixin",
"IDSpmFilterMixin", "CommonApiMixin",
]
@@ -20,69 +18,31 @@ class JSONResponseMixin(object):
return JsonResponse(context)
class IDInFilterMixin(object):
class IDSpmFilterMixin:
def get_filter_backends(self):
backends = super().get_filter_backends()
backends.append(IDSpmFilter)
return backends
class ExtraFilterFieldsMixin:
default_added_filters = [CustomFilter, IDSpmFilter]
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
extra_filter_fields = []
extra_filter_backends = []
def get_filter_backends(self):
if self.filter_backends != self.__class__.filter_backends:
return self.filter_backends
return list(self.filter_backends) + \
self.default_added_filters + \
list(self.extra_filter_backends)
def filter_queryset(self, queryset):
queryset = super(IDInFilterMixin, self).filter_queryset(queryset)
id_list = self.request.query_params.get('id__in')
if id_list:
import json
try:
ids = json.loads(id_list)
except Exception as e:
return queryset
if isinstance(ids, list):
queryset = queryset.filter(id__in=ids)
for backend in self.get_filter_backends():
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
class IDInCacheFilterMixin(object):
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
spm = self.request.query_params.get('spm')
if not spm:
return queryset
cache_key = KEY_CACHE_RESOURCES_ID.format(spm)
resources_id = cache.get(cache_key)
if not resources_id or not isinstance(resources_id, list):
queryset = queryset.none()
return queryset
queryset = queryset.filter(id__in=resources_id)
return queryset
class IDExportFilterMixin(object):
def filter_queryset(self, queryset):
# 下载导入模版
if self.request.query_params.get('template') == 'import':
return []
else:
return super(IDExportFilterMixin, self).filter_queryset(queryset)
class ApiMessageMixin:
success_message = _("%(name)s was %(action)s successfully")
_action_map = {"create": _("create"), "update": _("update")}
def get_success_message(self, cleaned_data):
if not isinstance(cleaned_data, dict):
return ''
data = {k: v for k, v in cleaned_data.items()}
action = getattr(self, "action", "create")
data["action"] = self._action_map.get(action)
try:
message = self.success_message % data
except:
message = ''
return message
def dispatch(self, request, *args, **kwargs):
resp = super().dispatch(request, *args, **kwargs)
if request.method.lower() in ("get", "delete", "patch"):
return resp
if resp.status_code >= 400:
return resp
message = self.get_success_message(resp.data)
if message:
messages.success(request, message)
return resp
class CommonApiMixin(ExtraFilterFieldsMixin):
pass

View File

@@ -8,7 +8,6 @@ import datetime
import uuid
from functools import wraps
import time
import copy
import ipaddress
@@ -199,3 +198,18 @@ def timeit(func):
logger.debug(msg)
return result
return wrapper
def group_obj_by_count(objs, count=50):
objs_grouped = [
objs[i:i + count] for i in range(0, len(objs), count)
]
return objs_grouped
def dict_get_any(d, keys):
for key in keys:
value = d.get(key)
if value:
return value
return None