diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index f6808b0ac..9984c8614 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -273,3 +273,7 @@ class Time: for timestamp, msg in zip(timestamps, self._msgs): logger.debug(f'TIME_IT: {msg} {timestamp-last}') last = timestamp + + +def isinstance_method(attr): + return isinstance(attr, type(Time().time)) diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py index 0b92da278..3a1d49016 100644 --- a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py +++ b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py @@ -34,12 +34,12 @@ class UserAllGrantedAssetsQuerysetMixin: pagination_class = AllGrantedAssetPagination user: User - def get_union_queryset(self, qs_stage: QuerySetStage): + def get_queryset(self): if getattr(self, 'swagger_fake_view', False): return Asset.objects.none() - qs_stage.prefetch_related('platform').only(*self.only_fields) queryset = UserGrantedAssetsQueryUtils(self.user) \ - .get_all_granted_assets(qs_stage) + .get_all_granted_assets() + queryset = queryset.prefetch_related('platform').only(*self.only_fields) return queryset @@ -47,13 +47,13 @@ class UserFavoriteGrantedAssetsMixin: only_fields = serializers.AssetGrantedSerializer.Meta.only_fields user: User - def get_union_queryset(self, qs_stage: QuerySetStage): + def get_queryset(self): if getattr(self, 'swagger_fake_view', False): return Asset.objects.none() user = self.user - qs_stage.prefetch_related('platform').only(*self.only_fields) utils = UserGrantedAssetsQueryUtils(user) - assets = utils.get_favorite_assets(qs_stage=qs_stage) + assets = utils.get_favorite_assets() + assets = assets.prefetch_related('platform').only(*self.only_fields) return assets @@ -63,58 +63,35 @@ class UserGrantedNodeAssetsMixin: pagination_node: Node user: User - def get_union_queryset(self, qs_stage: QuerySetStage): + def get_queryset(self): if getattr(self, 'swagger_fake_view', False): return Asset.objects.none() node_id = self.kwargs.get("node_id") - qs_stage.prefetch_related('platform').only(*self.only_fields) + node, assets = UserGrantedAssetsQueryUtils(self.user).get_node_all_assets( - node_id, qs_stage=qs_stage + node_id ) + assets = assets.prefetch_related('platform').only(*self.only_fields) self.pagination_node = node return assets # 控制格式的 ---------------------------------------------------- -class AssetsUnionQuerysetMixin: - def get_queryset_union_prefer(self): - if hasattr(self, 'get_union_queryset'): - # 为了支持 union 查询 - queryset = Asset.objects.all().distinct() - queryset = self.filter_queryset(queryset) - qs_stage = QuerySetStage() - qs_stage.and_with_queryset(queryset) - queryset = self.get_union_queryset(qs_stage) - else: - queryset = self.filter_queryset(self.get_queryset()) - return queryset - -class AssetsSerializerFormatMixin(AssetsUnionQuerysetMixin): +class AssetsSerializerFormatMixin: serializer_class = serializers.AssetGrantedSerializer filterset_fields = ['hostname', 'ip', 'id', 'comment'] search_fields = ['hostname', 'ip', 'comment'] - def list(self, request, *args, **kwargs): - queryset = self.get_queryset_union_prefer() - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - return self.get_paginated_response(serializer.data) - - serializer = self.get_serializer(queryset, many=True) - return Response(serializer.data) - - -class AssetsTreeFormatMixin(AssetsUnionQuerysetMixin, SerializeToTreeNodeMixin): +class AssetsTreeFormatMixin(SerializeToTreeNodeMixin): """ 将 资产 序列化成树的结构返回 """ def list(self, request: Request, *args, **kwargs): - queryset = self.get_queryset_union_prefer() + queryset = self.filter_queryset(self.get_queryset()) if request.query_params.get('search'): # 如果用户搜索的条件不精准,会导致返回大量的无意义数据。 diff --git a/apps/perms/utils/asset/user_permission.py b/apps/perms/utils/asset/user_permission.py index 6319f7f4c..6bd2e8b2c 100644 --- a/apps/perms/utils/asset/user_permission.py +++ b/apps/perms/utils/asset/user_permission.py @@ -1,5 +1,7 @@ from collections import defaultdict from typing import List, Tuple +from functools import reduce, partial +from common.utils import isinstance_method from django.core.cache import cache from django.conf import settings @@ -51,6 +53,81 @@ def get_user_all_asset_perm_ids(user) -> set: return asset_perm_ids +class UnionQuerySet(QuerySet): + after_union = ['order_by'] + not_return_qs = [ + 'query', 'get', 'create', 'get_or_create', + 'update_or_create', 'bulk_create', 'count', + 'latest', 'earliest', 'first', 'last', 'aggregate', + 'exists', 'update', 'delete', 'as_manager', 'explain', + ] + + def __init__(self, *queryset_list): + self.queryset_list = queryset_list + self.after_union_items = [] + self.before_union_items = [] + + def __execute(self): + queryset_list = [] + for qs in self.queryset_list: + for attr, args, kwargs in self.before_union_items: + qs = getattr(qs, attr)(*args, **kwargs) + queryset_list.append(qs) + union_qs = reduce(lambda x, y: x.union(y), queryset_list) + for attr, args, kwargs in self.after_union_items: + union_qs = getattr(union_qs, attr)(*args, **kwargs) + return union_qs + + def __before_union_perform(self, item, *args, **kwargs): + self.before_union_items.append((item, args, kwargs)) + return self.__clone(*self.queryset_list) + + def __after_union_perform(self, item, *args, **kwargs): + self.after_union_items.append((item, args, kwargs)) + return self.__clone(*self.queryset_list) + + def __clone(self, *queryset_list): + uqs = UnionQuerySet(*queryset_list) + uqs.after_union_items = self.after_union_items + uqs.before_union_items = self.before_union_items + return uqs + + def __getattribute__(self, item): + if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [ + 'queryset_list', 'after_union_items', 'before_union_items' + ]: + return object.__getattribute__(self, item) + + if item in UnionQuerySet.not_return_qs: + return getattr(self.__execute(), item) + + origin_item = object.__getattribute__(self, 'queryset_list')[0] + origin_attr = getattr(origin_item, item, None) + if not isinstance_method(origin_attr): + return getattr(self.__execute(), item) + + if item in UnionQuerySet.after_union: + attr = partial(self.__after_union_perform, item) + else: + attr = partial(self.__before_union_perform, item) + return attr + + def __getitem__(self, item): + return self.__execute()[item] + + def __next__(self): + return next(self.__execute()) + + @classmethod + def test_it(cls): + from assets.models import Asset + assets1 = Asset.objects.filter(hostname__startswith='a') + assets2 = Asset.objects.filter(hostname__startswith='b') + + qs = cls(assets1, assets2) + return qs + + class QuerySetStage: def __init__(self): self._prefetch_related = set() @@ -541,14 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase): class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): - def get_favorite_assets(self, qs_stage: QuerySetStage = None, only=('id', )) -> AssetQuerySet: + def get_favorite_assets(self, only=('id', )) -> QuerySet: favorite_asset_ids = FavoriteAsset.objects.filter( user=self.user ).values_list('asset_id', flat=True) favorite_asset_ids = list(favorite_asset_ids) - qs_stage = qs_stage or QuerySetStage() - qs_stage.filter(id__in=favorite_asset_ids).only(*only) - assets = self.get_all_granted_assets(qs_stage) + assets = self.get_all_granted_assets() + assets = assets.filter(id__in=favorite_asset_ids).only(*only) return assets def get_ungroup_assets(self) -> AssetQuerySet: @@ -560,39 +636,30 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): ).distinct() return queryset - def get_direct_granted_nodes_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet: + def get_direct_granted_nodes_assets(self) -> AssetQuerySet: granted_node_ids = AssetPermission.nodes.through.objects.filter( assetpermission_id__in=self.asset_perm_ids ).values_list('node_id', flat=True).distinct() granted_node_ids = list(granted_node_ids) granted_nodes = PermNode.objects.filter(id__in=granted_node_ids).only('id', 'key') queryset = PermNode.get_nodes_all_assets(*granted_nodes) - if qs_stage: - queryset = qs_stage.merge(queryset) return queryset - def get_all_granted_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet: + def get_all_granted_assets(self) -> QuerySet: nodes_assets = self.get_direct_granted_nodes_assets() assets = self.get_direct_granted_assets() - - if qs_stage: - nodes_assets, assets = qs_stage.merge_multi_before_union(nodes_assets, assets) - queryset = nodes_assets.union(assets) - if qs_stage: - queryset = qs_stage.merge_after_union(queryset) + queryset = UnionQuerySet(nodes_assets, assets) return queryset - def get_node_all_assets(self, id, qs_stage: QuerySetStage = None) -> Tuple[PermNode, QuerySet]: + def get_node_all_assets(self, id) -> Tuple[PermNode, QuerySet]: node = PermNode.objects.get(id=id) granted_status = node.get_granted_status(self.user) if granted_status == NodeFrom.granted: assets = PermNode.get_nodes_all_assets(node) - if qs_stage: - assets = qs_stage.merge(assets) return node, assets elif granted_status in (NodeFrom.asset, NodeFrom.child): node.use_granted_assets_amount() - assets = self._get_indirect_granted_node_all_assets(node, qs_stage=qs_stage) + assets = self._get_indirect_granted_node_all_assets(node) return node, assets else: node.assets_amount = 0 @@ -614,7 +681,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets() return assets - def _get_indirect_granted_node_all_assets(self, node, qs_stage: QuerySetStage = None) -> QuerySet: + def _get_indirect_granted_node_all_assets(self, node) -> QuerySet: """ 此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询 1. 查询该节点下的直接授权节点 @@ -645,10 +712,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): nodes__id__in=only_asset_granted_node_ids, granted_by_permissions__id__in=self.asset_perm_ids ).distinct().order_by() - if qs_stage: - node_assets, assets = qs_stage.merge_multi_before_union(node_assets, assets) - granted_assets = node_assets.union(assets) - granted_assets = qs_stage.merge_after_union(granted_assets) + granted_assets = UnionQuerySet(node_assets, assets) return granted_assets