perf: Update IP group validation to include address validation

This commit is contained in:
wangruidong 2025-07-03 19:22:23 +08:00 committed by 老广
parent 99c4622ccb
commit ab06ac1f1f
2 changed files with 22 additions and 6 deletions

View File

@ -1,5 +1,7 @@
# coding: utf-8 # coding: utf-8
# #
from urllib.parse import urlparse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
@ -8,7 +10,7 @@ from common.utils.ip import is_ip_address, is_ip_network, is_ip_segment
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = ['RuleSerializer', 'ip_group_child_validator', 'ip_group_help_text'] __all__ = ['RuleSerializer', 'ip_group_child_validator', 'ip_group_help_text', 'address_validator']
def ip_group_child_validator(ip_group_child): def ip_group_child_validator(ip_group_child):
@ -21,6 +23,19 @@ def ip_group_child_validator(ip_group_child):
raise serializers.ValidationError(error) raise serializers.ValidationError(error)
def address_validator(value):
parsed = urlparse(value)
is_basic_url = parsed.scheme in ('http', 'https') and parsed.netloc
is_valid = value == '*' \
or is_ip_address(value) \
or is_ip_network(value) \
or is_ip_segment(value) \
or is_basic_url
if not is_valid:
error = _('address invalid: `{}`').format(value)
raise serializers.ValidationError(error)
ip_group_help_text = _( ip_group_help_text = _(
'With * indicating a match all. ' 'With * indicating a match all. '
'Such as: ' 'Such as: '

View File

@ -1,7 +1,7 @@
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text from acls.serializers.rules import address_validator, ip_group_help_text
from common.serializers import BulkModelSerializer from common.serializers import BulkModelSerializer
from common.serializers.fields import ObjectRelatedField from common.serializers.fields import ObjectRelatedField
from ..models import Endpoint, EndpointRule from ..models import Endpoint, EndpointRule
@ -16,7 +16,7 @@ class EndpointSerializer(BulkModelSerializer):
fields_small = [ fields_small = [
'host', 'https_port', 'http_port', 'ssh_port', 'rdp_port', 'host', 'https_port', 'http_port', 'ssh_port', 'rdp_port',
'mysql_port', 'mariadb_port', 'postgresql_port', 'redis_port', 'vnc_port', 'mysql_port', 'mariadb_port', 'postgresql_port', 'redis_port', 'vnc_port',
'oracle_port', 'sqlserver_port', 'mongodb_port','is_active' 'oracle_port', 'sqlserver_port', 'mongodb_port', 'is_active'
] ]
fields = fields_mini + fields_small + [ fields = fields_mini + fields_small + [
'comment', 'date_created', 'date_updated', 'created_by' 'comment', 'date_created', 'date_updated', 'created_by'
@ -29,6 +29,7 @@ class EndpointSerializer(BulkModelSerializer):
) )
}, },
} }
def get_extra_kwargs(self): def get_extra_kwargs(self):
extra_kwargs = super().get_extra_kwargs() extra_kwargs = super().get_extra_kwargs()
model_fields = self.Meta.model._meta.fields model_fields = self.Meta.model._meta.fields
@ -49,13 +50,13 @@ class EndpointSerializer(BulkModelSerializer):
class EndpointRuleSerializer(BulkModelSerializer): class EndpointRuleSerializer(BulkModelSerializer):
_ip_group_help_text = '{}, {} <br>{}'.format( _ip_group_help_text = '{}, {} <br>{}'.format(
_('The assets within this IP range, the following endpoint will be used for the connection'), _('The assets within this IP range or Host, the following endpoint will be used for the connection'),
_('If asset IP addresses under different endpoints conflict, use asset labels'), _('If asset IP addresses under different endpoints conflict, use asset labels'),
ip_group_help_text, ip_group_help_text,
) )
ip_group = serializers.ListField( ip_group = serializers.ListField(
default=['*'], label=_('Asset IP'), help_text=_ip_group_help_text, default=['*'], label=_('Address'), help_text=_ip_group_help_text,
child=serializers.CharField(max_length=1024, validators=[ip_group_child_validator]) child=serializers.CharField(max_length=1024, validators=[address_validator]),
) )
endpoint = ObjectRelatedField( endpoint = ObjectRelatedField(
allow_null=True, required=False, queryset=Endpoint.objects, label=_('Endpoint') allow_null=True, required=False, queryset=Endpoint.objects, label=_('Endpoint')