diff --git a/apps/common/drf/filters.py b/apps/common/drf/filters.py index 803849909..8c75dea8a 100644 --- a/apps/common/drf/filters.py +++ b/apps/common/drf/filters.py @@ -3,6 +3,7 @@ import base64 import json import logging +from collections import defaultdict from django.core.cache import cache from django.core.exceptions import ImproperlyConfigured @@ -180,7 +181,7 @@ class LabelFilterBackend(filters.BaseFilterBackend): ] @staticmethod - def parse_label_ids(labels_id): + def parse_labels(labels_id): from labels.models import Label label_ids = [i.strip() for i in labels_id.split(',')] cleaned = [] @@ -201,8 +202,8 @@ class LabelFilterBackend(filters.BaseFilterBackend): q = Q() for kwarg in args: q |= Q(**kwarg) - ids = Label.objects.filter(q).values_list('id', flat=True) - cleaned.extend(list(ids)) + labels = Label.objects.filter(q) + cleaned.extend(list(labels)) return cleaned def filter_queryset(self, request, queryset, view): @@ -221,13 +222,23 @@ class LabelFilterBackend(filters.BaseFilterBackend): app_label = model._meta.app_label model_name = model._meta.model_name - resources = labeled_resource_cls.objects.filter( + full_resources = labeled_resource_cls.objects.filter( res_type__app_label=app_label, res_type__model=model_name, ) - label_ids = self.parse_label_ids(labels_id) - resources = model.filter_resources_by_labels(resources, label_ids) - res_ids = resources.values_list('res_id', flat=True) - queryset = queryset.filter(id__in=set(res_ids)) + labels = self.parse_labels(labels_id) + grouped = defaultdict(set) + for label in labels: + grouped[label.name].add(label.id) + + matched_ids = set() + for name, label_ids in grouped.items(): + resources = model.filter_resources_by_labels(full_resources, label_ids, rel='any') + res_ids = resources.values_list('res_id', flat=True) + if not matched_ids: + matched_ids = set(res_ids) + else: + matched_ids &= set(res_ids) + queryset = queryset.filter(id__in=matched_ids) return queryset diff --git a/apps/labels/mixins.py b/apps/labels/mixins.py index 33e73b60b..bb059721d 100644 --- a/apps/labels/mixins.py +++ b/apps/labels/mixins.py @@ -1,4 +1,5 @@ from django.contrib.contenttypes.fields import GenericRelation +from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models import OneToOneField, Count @@ -38,8 +39,11 @@ class LabeledMixin(models.Model): self.real.labels.set(value, bulk=False) @classmethod - def filter_resources_by_labels(cls, resources, label_ids): - return cls._get_filter_res_by_labels_m2m_all(resources, label_ids) + def filter_resources_by_labels(cls, resources, label_ids, rel='all'): + if rel == 'all': + return cls._get_filter_res_by_labels_m2m_all(resources, label_ids) + else: + return cls._get_filter_res_by_labels_m2m_in(resources, label_ids) @classmethod def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids): @@ -60,7 +64,8 @@ class LabeledMixin(models.Model): @classmethod def get_labels_filter_attr_q(cls, value, match): - resources = LabeledResource.objects.all() + res_type = ContentType.objects.get_for_model(cls.label_model()) + resources = LabeledResource.objects.all().filter(res_type=res_type) if not value: return None