diff --git a/apps/acls/api/login_acl.py b/apps/acls/api/login_acl.py index 806ffb343..07aa7f46e 100644 --- a/apps/acls/api/login_acl.py +++ b/apps/acls/api/login_acl.py @@ -1,7 +1,7 @@ from common.api import JMSBulkModelViewSet -from ..models import LoginACL from .. import serializers from ..filters import LoginAclFilter +from ..models import LoginACL __all__ = ['LoginACLViewSet'] @@ -11,4 +11,3 @@ class LoginACLViewSet(JMSBulkModelViewSet): filterset_class = LoginAclFilter search_fields = ('name',) serializer_class = serializers.LoginACLSerializer - diff --git a/apps/acls/api/login_asset_check.py b/apps/acls/api/login_asset_check.py index 3c157a1cc..bb61a9043 100644 --- a/apps/acls/api/login_asset_check.py +++ b/apps/acls/api/login_asset_check.py @@ -1,7 +1,6 @@ from rest_framework.generics import CreateAPIView from rest_framework.response import Response -from common.db.fields import JSONManyToManyField from common.utils import reverse, lazyproperty from orgs.utils import tmp_to_org from .. import serializers @@ -36,9 +35,9 @@ class LoginAssetCheckAPI(CreateAPIView): # 用户满足的 acls queryset = LoginAssetACL.objects.all() - q = JSONManyToManyField.get_filter_q(LoginAssetACL, 'users', user) + q = LoginAssetACL.users.get_filter_q(LoginAssetACL, 'users', user) queryset = queryset.filter(q) - q = JSONManyToManyField.get_filter_q(LoginAssetACL, 'assets', asset) + q = LoginAssetACL.assets.get_filter_q(LoginAssetACL, 'assets', asset) queryset = queryset.filter(q) account_username = self.serializer.validated_data.get('account_username') queryset = queryset.filter(accounts__contains=account_username) diff --git a/apps/common/db/fields.py b/apps/common/db/fields.py index 70c14c4f2..d2087f5d4 100644 --- a/apps/common/db/fields.py +++ b/apps/common/db/fields.py @@ -3,19 +3,20 @@ import ipaddress import json +import logging import re from django.apps import apps from django.core.exceptions import ValidationError from django.core.validators import MinValueValidator, MaxValueValidator from django.db import models -from django.db.models import Q +from django.db.models import Q, Manager from django.utils.encoding import force_text from django.utils.translation import ugettext_lazy as _ from rest_framework.utils.encoders import JSONEncoder from common.local import add_encrypted_field_set -from common.utils import signer, crypto +from common.utils import signer, crypto, contains_ip from .validators import PortRangeValidator __all__ = [ @@ -321,58 +322,82 @@ class RelatedManager: continue return q + def _get_filter_attrs_q(self, value, to_model): + filters = Q() + # 特殊情况有这几种, + # 1. 像 资产中的 type 和 category,集成自 Platform。所以不能直接查询 + # 2. 像 资产中的 nodes,不是简单的 m2m,是树 的关系 + # 3. 像 用户中的 orgs 也不是简单的 m2m,也是计算出来的 + # get_filter_{}_attr_q 处理复杂的 + custom_attr_filter = getattr(to_model, "get_json_filter_attr_q", None) + for attr in value["attrs"]: + if not isinstance(attr, dict): + continue + + name = attr.get('name') + val = attr.get('value') + match = attr.get('match', 'exact') + if name is None or val is None: + continue + + print("Has custom filter: {}".format(custom_attr_filter)) + if custom_attr_filter: + custom_filter_q = custom_attr_filter(name, val, match) + print("Custom filter: {}".format(custom_filter_q)) + if custom_filter_q: + filters &= custom_filter_q + continue + + if match == 'ip_in': + q = self.get_ip_in_q(name, val) + elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"): + lookup = "{}__{}".format(name, match) + q = Q(**{lookup: val}) + elif match == "not": + q = ~Q(**{name: val}) + elif match == "m2m": + if not isinstance(val, list): + val = [val] + q = Q(**{"{}__in".format(name): val}) + elif match == "in" and isinstance(val, list): + if '*' not in val: + lookup = "{}__in".format(name) + q = Q(**{lookup: val}) + else: + q = Q() + else: + if val == '*': + q = Q() + else: + q = Q(**{name: val}) + + filters &= q + return filters + def _get_queryset(self): - model = apps.get_model(self.field.to) + to_model = apps.get_model(self.field.to) value = self.value + if hasattr(to_model, "get_queryset"): + queryset = to_model.get_queryset() + else: + queryset = to_model.objects.all() + if not value or not isinstance(value, dict): - return model.objects.none() + return queryset.none() if value["type"] == "all": - return model.objects.all() + return queryset elif value["type"] == "ids" and isinstance(value.get("ids"), list): - return model.objects.filter(id__in=value["ids"]) + return queryset.filter(id__in=value["ids"]) elif value["type"] == "attrs" and isinstance(value.get("attrs"), list): - filters = Q() - excludes = Q() - for attr in value["attrs"]: - if not isinstance(attr, dict): - continue - - name = attr.get('name') - val = attr.get('value') - match = attr.get('match', 'exact') - rel = attr.get('rel', 'and') - if name is None or val is None: - continue - - if match == 'ip_in': - q = self.get_ip_in_q(name, val) - elif match in ("exact", "contains", "startswith", "endswith", "regex"): - lookup = "{}__{}".format(name, match) - q = Q(**{lookup: val}) - elif match == "not": - q = ~Q(**{name: val}) - elif match == "in" and isinstance(val, list): - if '*' not in val: - lookup = "{}__in".format(name) - q = Q(**{lookup: val}) - else: - q = Q() - else: - if val == '*': - q = Q() - else: - q = Q(**{name: val}) - - if rel == 'or': - filters |= q - elif rel == 'not': - excludes |= q - else: - filters &= q - return model.objects.filter(filters).exclude(excludes) + q = self._get_filter_attrs_q(value, to_model) + return queryset.filter(q) else: - return model.objects.none() + return queryset.none() + + def get_attr_q(self): + q = self._get_filter_attrs_q(self.value) + return q def all(self): return self._get_queryset() @@ -415,40 +440,68 @@ class JSONManyToManyDescriptor: value = value.value manager.set(value) - def test_is(self): - print("Self.field is", self.field) - print("Self.field to", self.field.to) - print("Self.field model", self.field.model) - print("Self.field column", self.field.column) - print("Self.field to", self.field.__dict__) + def is_match(self, obj, attr_rules): + # m2m 的情况 + # 自定义的情况:比如 nodes, category + res = True + to_model = apps.get_model(self.field.to) + src_model = self.field.model + field_name = self.field.name + custom_attr_filter = getattr(src_model, "get_filter_{}_attr_q".format(field_name), None) - @staticmethod - def attr_to_regex(attr): - """将属性规则转换为正则表达式""" - name, value, match = attr['name'], attr['value'], attr['match'] - if match == 'contains': - return r'.*{}.*'.format(escape_regex(value)) - elif match == 'startswith': - return r'^{}.*'.format(escape_regex(value)) - elif match == 'endswith': - return r'.*{}$'.format(escape_regex(value)) - elif match == 'regex': - return value - elif match == 'not': - return r'^(?!^{}$)'.format(escape_regex(value)) - elif match == 'in': - values = '|'.join(map(escape_regex, value)) - return r'^(?:{})$'.format(values) - else: - return r'^{}$'.format(escape_regex(value)) - - def is_match(self, attr_dict, attr_rules): + custom_q = Q() for rule in attr_rules: - value = attr_dict.get(rule['name'], '') - regex = self.attr_to_regex(rule) - if not re.match(regex, value): - return False - return True + value = getattr(obj, rule['name'], '') + rule_value = rule.get('value', '') + rule_match = rule.get('match', 'exact') + + if custom_attr_filter: + q = custom_attr_filter(rule['name'], rule_value, rule_match) + if q: + custom_q &= q + continue + + if rule_match == 'in': + res &= value in rule_value + elif rule_match == 'exact': + res &= value == rule_value + elif rule_match == 'contains': + res &= rule_value in value + elif rule_match == 'startswith': + res &= str(value).startswith(str(rule_value)) + elif rule_match == 'endswith': + res &= str(value).endswith(str(rule_value)) + elif rule_match == 'regex': + res &= re.match(rule_value, value) + elif rule_match == 'not': + res &= value != rule_value + elif rule['match'] == 'gte': + res &= value >= rule_value + elif rule['match'] == 'lte': + res &= value <= rule_value + elif rule['match'] == 'gt': + res &= value > rule_value + elif rule['match'] == 'lt': + res &= value < rule_value + elif rule['match'] == 'ip_in': + if isinstance(rule_value, str): + rule_value = [rule_value] + res &= contains_ip(value, rule_value) + elif rule['match'] == 'm2m': + if isinstance(value, Manager): + value = value.values_list('id', flat=True) + value = set(map(str, value)) + rule_value = set(map(str, rule_value)) + res &= rule_value.issubset(value) + else: + logging.error("unknown match: {}".format(rule['match'])) + res &= False + + if not res: + return res + if custom_q: + res &= to_model.objects.filter(custom_q).filter(id=obj.id).exists() + return res def get_filter_q(self, instance): model_cls = self.field.model @@ -457,18 +510,12 @@ class JSONManyToManyDescriptor: queryset_id_attrs = model_cls.objects \ .filter(**{'{}__type'.format(field_name): 'attrs'}) \ .values_list('id', '{}__attrs'.format(field_name)) - instance_attr = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')} - ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance_attr, attr_rules)] + ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance, attr_rules)] if ids: q |= Q(id__in=ids) return q -def escape_regex(s): - """转义字符串中的正则表达式特殊字符""" - return re.sub('[.*+?^${}()|[\\]]', r'\\\g<0>', s) - - class JSONManyToManyField(models.JSONField): def __init__(self, to, *args, **kwargs): self.to = to @@ -490,7 +537,7 @@ class JSONManyToManyField(models.JSONField): e = ValueError(_( "Invalid JSON data for JSONManyToManyField, should be like " "{'type': 'all'} or {'type': 'ids', 'ids': []} " - "or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': 'value', 'rel': 'and|or|not'}}" + "or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': '1.1.1.1'}}" )) if not isinstance(val, dict): raise e diff --git a/apps/common/utils/ip/utils.py b/apps/common/utils/ip/utils.py index e5d43911e..14851ff27 100644 --- a/apps/common/utils/ip/utils.py +++ b/apps/common/utils/ip/utils.py @@ -1,3 +1,4 @@ +import ipaddress import socket from ipaddress import ip_network, ip_address @@ -75,6 +76,23 @@ def contains_ip(ip, ip_group): return False +def is_ip(self, ip, rule_value): + if rule_value == '*': + return True + elif '/' in rule_value: + network = ipaddress.ip_network(rule_value) + return ip in network.hosts() + elif '-' in rule_value: + start_ip, end_ip = rule_value.split('-') + start_ip = ipaddress.ip_address(start_ip) + end_ip = ipaddress.ip_address(end_ip) + return start_ip <= ip <= end_ip + elif len(rule_value.split('.')) == 4: + return ip == rule_value + else: + return ip.startswith(rule_value) + + def get_ip_city(ip): if not ip or not isinstance(ip, str): return _("Invalid address") diff --git a/apps/users/models/user.py b/apps/users/models/user.py index b487365e5..5d1edb036 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -668,7 +668,33 @@ class MFAMixin: return backend -class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser): +class JSONFilterMixin: + """ + users = JSONManyToManyField('users.User', blank=True, null=True) + """ + + @staticmethod + def get_json_filter_attr_q(name, value, match): + from rbac.models import RoleBinding + from orgs.utils import current_org + + if name == 'system_roles': + user_id = RoleBinding.objects \ + .filter(role__in=value, scope='system') \ + .values_list('user_id', flat=True) + return models.Q(id__in=user_id) + elif name == 'org_roles': + kwargs = dict(role__in=value, scope='org') + if not current_org.is_root(): + kwargs['org_id'] = current_org.id + + user_id = RoleBinding.objects.filter(**kwargs) \ + .values_list('user_id', flat=True) + return models.Q(id__in=user_id) + return None + + +class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, JSONFilterMixin, AbstractUser): class Source(models.TextChoices): local = 'local', _('Local') ldap = 'ldap', 'LDAP/AD'