diff --git a/apps/assets/api/asset/asset.py b/apps/assets/api/asset/asset.py index 1d8bafc20..c076723ec 100644 --- a/apps/assets/api/asset/asset.py +++ b/apps/assets/api/asset/asset.py @@ -16,7 +16,7 @@ from assets.tasks import ( ) from common.api import SuggestionMixin from common.drf.filters import BaseFilterSet -from common.utils import get_logger +from common.utils import get_logger, is_uuid from orgs.mixins import generics from orgs.mixins.api import OrgBulkModelViewSet from ..mixin import NodeFilterMixin @@ -31,6 +31,7 @@ __all__ = [ class AssetFilterSet(BaseFilterSet): labels = django_filters.CharFilter(method='filter_labels') platform = django_filters.CharFilter(method='filter_platform') + domain = django_filters.CharFilter(method='filter_domain') type = django_filters.CharFilter(field_name="platform__type", lookup_expr="exact") category = django_filters.CharFilter(field_name="platform__category", lookup_expr="exact") domain_enabled = django_filters.BooleanFilter( @@ -59,7 +60,7 @@ class AssetFilterSet(BaseFilterSet): model = Asset fields = [ "id", "name", "address", "is_active", "labels", - "type", "category", "platform" + "type", "category", "platform", ] @staticmethod @@ -69,13 +70,21 @@ class AssetFilterSet(BaseFilterSet): else: return queryset.filter(platform__name=value) + @staticmethod + def filter_domain(queryset, name, value): + if is_uuid(value): + return queryset.filter(domain_id=value) + else: + return queryset.filter(domain__name__contains=value) + @staticmethod def filter_labels(queryset, name, value): if ':' in value: n, v = value.split(':', 1) queryset = queryset.filter(labels__name=n, labels__value=v) else: - queryset = queryset.filter(Q(labels__name=value) | Q(labels__value=value)) + q = Q(labels__name__contains=value) | Q(labels__value__contains=value) + queryset = queryset.filter(q) return queryset diff --git a/apps/assets/api/domain.py b/apps/assets/api/domain.py index 7ccaa6dcb..041de210d 100644 --- a/apps/assets/api/domain.py +++ b/apps/assets/api/domain.py @@ -6,7 +6,7 @@ from rest_framework.views import APIView, Response from common.utils import get_logger from assets.tasks import test_gateways_connectivity_manual from orgs.mixins.api import OrgBulkModelViewSet -from .asset import AssetViewSet +from .asset import HostViewSet from .. import serializers from ..models import Domain, Gateway @@ -28,11 +28,15 @@ class DomainViewSet(OrgBulkModelViewSet): return super().get_serializer_class() -class GatewayViewSet(AssetViewSet): +class GatewayViewSet(HostViewSet): perm_model = Gateway filterset_fields = ("domain__name", "name", "domain") search_fields = ("domain__name",) - serializer_class = serializers.GatewaySerializer + + def get_serializer_classes(self): + serializer_classes = super().get_serializer_classes() + serializer_classes['default'] = serializers.GatewaySerializer + return serializer_classes def get_queryset(self): queryset = Domain.get_gateway_queryset()