diff --git a/apps/common/api/mixin.py b/apps/common/api/mixin.py index b939af1d5..0473bea0c 100644 --- a/apps/common/api/mixin.py +++ b/apps/common/api/mixin.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # from collections import defaultdict +from contextlib import nullcontext from itertools import chain from typing import Callable @@ -15,6 +16,7 @@ from common.drf.filters import ( IDNotFilterBackend, NotOrRelFilterBackend, LabelFilterBackend ) from common.utils import get_logger, lazyproperty +from orgs.utils import tmp_to_org, tmp_to_root_org from .action import RenderToJsonMixin from .serializer import SerializerMixin @@ -125,25 +127,39 @@ class QuerySetMixin: return queryset def paginate_queryset(self, queryset): - in_root_org = getattr(queryset, 'in_root_org') page = super().paginate_queryset(queryset) - - if in_root_org: - return page - model = getattr(queryset, 'model', None) if not model or hasattr(queryset, 'custom'): return page serializer_class = self.get_serializer_class() if page and serializer_class: - ids = [str(obj.id) for obj in page] - page = model.objects.filter(id__in=ids) - page = self.setup_eager_loading(page, is_paginated=True) + # 必须要返回 ids,用于排序 + queryset, ids = self._get_page_again(page, model) + page = self.setup_eager_loading(queryset, is_paginated=True) page_mapper = {str(obj.id): obj for obj in page} page = [page_mapper.get(_id) for _id in ids if _id in page_mapper] return page + def _get_page_again(self, page, model): + """ + 因为 setup_eager_loading 需要是 queryset 结构, 所以必须要重新构造 + """ + id_org_mapper = {str(obj.id): getattr(obj, 'org_id', None) for obj in page} + ids = list(id_org_mapper.keys()) + org_ids = list(set(id_org_mapper.values()) - {None}) + + if not org_ids: + context = nullcontext() + elif len(org_ids) == 1: + context = tmp_to_org(org_ids[0]) + else: + context = tmp_to_root_org() + + with context: + page = model.objects.filter(id__in=ids) + return page, ids + class ExtraFilterFieldsMixin: """ diff --git a/apps/orgs/utils.py b/apps/orgs/utils.py index d861d680b..1b9025b18 100644 --- a/apps/orgs/utils.py +++ b/apps/orgs/utils.py @@ -153,9 +153,6 @@ def filter_org_queryset(queryset): # print(line) # print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<") queryset = queryset.filter(**kwargs) - - if org and org.is_root(): - queryset.in_root_org = True return queryset