diff --git a/apps/common/drf/api.py b/apps/common/drf/api.py index fddd37939..073e9fd7e 100644 --- a/apps/common/drf/api.py +++ b/apps/common/drf/api.py @@ -2,19 +2,10 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet, ReadOnlyModelV from rest_framework_bulk import BulkModelViewSet from ..mixins.api import ( - SerializerMixin, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin, - RelationMixin, AllowBulkDestroyMixin, RenderToJsonMixin, + RelationMixin, AllowBulkDestroyMixin, CommonMixin ) -class CommonMixin(SerializerMixin, - QuerySetMixin, - ExtraFilterFieldsMixin, - PaginatedResponseMixin, - RenderToJsonMixin): - pass - - class JMSGenericViewSet(CommonMixin, GenericViewSet): pass diff --git a/apps/common/mixins/api/common.py b/apps/common/mixins/api/common.py index ba3895356..8dbf4fb1e 100644 --- a/apps/common/mixins/api/common.py +++ b/apps/common/mixins/api/common.py @@ -1,18 +1,28 @@ # -*- coding: utf-8 -*- # +from typing import Callable from rest_framework.response import Response +from collections import defaultdict + +from django.db.models.signals import m2m_changed from .serializer import SerializerMixin from .filter import ExtraFilterFieldsMixin from .action import RenderToJsonMixin +from .queryset import QuerySetMixin + __all__ = [ - 'CommonApiMixin', 'PaginatedResponseMixin', + 'CommonApiMixin', 'PaginatedResponseMixin', 'RelationMixin', 'CommonMixin' ] class PaginatedResponseMixin: - def get_paginated_response_with_query_set(self, queryset): + paginate_queryset: Callable + get_serializer: Callable + get_paginated_response: Callable + + def get_paginated_response_from_queryset(self, queryset): page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) @@ -22,9 +32,65 @@ class PaginatedResponseMixin: return Response(serializer.data) +class RelationMixin: + m2m_field = None + from_field = None + to_field = None + to_model = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.m2m_field is not None, ''' + `m2m_field` should not be `None` + ''' + + self.from_field = self.m2m_field.m2m_field_name() + self.to_field = self.m2m_field.m2m_reverse_field_name() + self.to_model = self.m2m_field.related_model + self.through = getattr(self.m2m_field.model, self.m2m_field.attname).through + + def get_queryset(self): + # 注意,此处拦截了 `get_queryset` 没有 `super` + queryset = self.through.objects.all() + return queryset + + def send_m2m_changed_signal(self, instances, action): + if not isinstance(instances, list): + instances = [instances] + + from_to_mapper = defaultdict(list) + + for i in instances: + to_id = getattr(i, self.to_field).id + # TODO 优化,不应该每次都查询数据库 + from_obj = getattr(i, self.from_field) + from_to_mapper[from_obj].append(to_id) + + for from_obj, to_ids in from_to_mapper.items(): + m2m_changed.send( + sender=self.through, instance=from_obj, action=action, + reverse=False, model=self.to_model, pk_set=to_ids + ) + + def perform_create(self, serializer): + instance = serializer.save() + self.send_m2m_changed_signal(instance, 'post_add') + + def perform_destroy(self, instance): + instance.delete() + self.send_m2m_changed_signal(instance, 'post_remove') + + class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin, RenderToJsonMixin): pass +class CommonMixin(SerializerMixin, + QuerySetMixin, + ExtraFilterFieldsMixin, + RenderToJsonMixin): + pass + diff --git a/apps/common/mixins/api/serializer.py b/apps/common/mixins/api/serializer.py index c5c9b4737..52b0637df 100644 --- a/apps/common/mixins/api/serializer.py +++ b/apps/common/mixins/api/serializer.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- # -from collections import defaultdict -from django.db.models.signals import m2m_changed from rest_framework.request import Request -__all__ = ['SerializerMixin', 'RelationMixin'] +__all__ = ['SerializerMixin'] class SerializerMixin: @@ -43,53 +41,3 @@ class SerializerMixin: if serializer_class is None: serializer_class = super().get_serializer_class() return serializer_class - - -class RelationMixin: - m2m_field = None - from_field = None - to_field = None - to_model = None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - assert self.m2m_field is not None, ''' - `m2m_field` should not be `None` - ''' - - self.from_field = self.m2m_field.m2m_field_name() - self.to_field = self.m2m_field.m2m_reverse_field_name() - self.to_model = self.m2m_field.related_model - self.through = getattr(self.m2m_field.model, self.m2m_field.attname).through - - def get_queryset(self): - # 注意,此处拦截了 `get_queryset` 没有 `super` - queryset = self.through.objects.all() - return queryset - - def send_m2m_changed_signal(self, instances, action): - if not isinstance(instances, list): - instances = [instances] - - from_to_mapper = defaultdict(list) - - for i in instances: - to_id = getattr(i, self.to_field).id - # TODO 优化,不应该每次都查询数据库 - from_obj = getattr(i, self.from_field) - from_to_mapper[from_obj].append(to_id) - - for from_obj, to_ids in from_to_mapper.items(): - m2m_changed.send( - sender=self.through, instance=from_obj, action=action, - reverse=False, model=self.to_model, pk_set=to_ids - ) - - def perform_create(self, serializer): - instance = serializer.save() - self.send_m2m_changed_signal(instance, 'post_add') - - def perform_destroy(self, instance): - instance.delete() - self.send_m2m_changed_signal(instance, 'post_remove')