Extend iptables packet tracer to check the protocol

This commit is contained in:
Dan Winship 2023-06-25 15:20:16 -04:00
parent a25fb03c00
commit 0910fe4b98

View File

@ -1395,9 +1395,8 @@ func newIPTablesTracer(t *testing.T, ipt *iptablestest.FakeIPTables, nodeIP stri
} }
// ruleMatches checks if the given iptables rule matches (at least probabilistically) a // ruleMatches checks if the given iptables rule matches (at least probabilistically) a
// packet with the given sourceIP, destIP, and destPort. (Note that protocol is currently // packet with the given sourceIP, destIP, and destPort.
// ignored.) func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, protocol, destIP, destPort string) bool {
func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, destIP, destPort string) bool {
// The sub-rules within an iptables rule are ANDed together, so the rule only // The sub-rules within an iptables rule are ANDed together, so the rule only
// matches if all of them match. So go through the subrules, and if any of them // matches if all of them match. So go through the subrules, and if any of them
// DON'T match, then fail. // DON'T match, then fail.
@ -1415,6 +1414,10 @@ func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, des
} }
} }
if rule.Protocol != nil && !rule.Protocol.Matches(protocol) {
return false
}
if rule.DestinationAddress != nil && !addressMatches(tracer.t, rule.DestinationAddress, destIP) { if rule.DestinationAddress != nil && !addressMatches(tracer.t, rule.DestinationAddress, destIP) {
return false return false
} }
@ -1442,7 +1445,7 @@ func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, des
// runChain runs the given packet through the rules in the given table and chain, updating // runChain runs the given packet through the rules in the given table and chain, updating
// tracer's internal state accordingly. It returns true if it hits a terminal action. // tracer's internal state accordingly. It returns true if it hits a terminal action.
func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptables.Chain, sourceIP, destIP, destPort string) bool { func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptables.Chain, sourceIP, protocol, destIP, destPort string) bool {
c, _ := tracer.ipt.Dump.GetChain(table, chain) c, _ := tracer.ipt.Dump.GetChain(table, chain)
if c == nil { if c == nil {
return false return false
@ -1453,7 +1456,7 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab
continue continue
} }
if !tracer.ruleMatches(rule, sourceIP, destIP, destPort) { if !tracer.ruleMatches(rule, sourceIP, protocol, destIP, destPort) {
continue continue
} }
// record the matched rule for debugging purposes // record the matched rule for debugging purposes
@ -1476,7 +1479,7 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab
default: default:
// We got a "-j KUBE-SOMETHING", so process that chain // We got a "-j KUBE-SOMETHING", so process that chain
terminated := tracer.runChain(table, utiliptables.Chain(rule.Jump.Value), sourceIP, destIP, destPort) terminated := tracer.runChain(table, utiliptables.Chain(rule.Jump.Value), sourceIP, protocol, destIP, destPort)
// If the subchain hit a terminal rule AND the rule that sent us // If the subchain hit a terminal rule AND the rule that sent us
// to that chain was non-probabilistic, then this chain terminates // to that chain was non-probabilistic, then this chain terminates
@ -1492,18 +1495,19 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab
return false return false
} }
// tracePacket determines what would happen to a packet with the given sourceIP, destIP, // tracePacket determines what would happen to a packet with the given sourceIP, protocol,
// and destPort, given the indicated iptables ruleData. nodeIP is the local node IP (for // destIP, and destPort, given the indicated iptables ruleData. nodeIP is the local node
// rules matching "LOCAL"). // IP (for rules matching "LOCAL"). (The protocol value should be lowercase as in iptables
// rules, not uppercase as in corev1.)
// //
// The return values are: an array of matched rules (for debugging), the final packet // The return values are: an array of matched rules (for debugging), the final packet
// destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT", // destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT",
// "DROP", or "REJECT"), and whether the packet would be masqueraded. // "DROP", or "REJECT"), and whether the packet would be masqueraded.
func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP, destPort, nodeIP string) ([]string, string, bool) { func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, protocol, destIP, destPort, nodeIP string) ([]string, string, bool) {
tracer := newIPTablesTracer(t, ipt, nodeIP) tracer := newIPTablesTracer(t, ipt, nodeIP)
// nat:PREROUTING goes first // nat:PREROUTING goes first
tracer.runChain(utiliptables.TableNAT, utiliptables.ChainPrerouting, sourceIP, destIP, destPort) tracer.runChain(utiliptables.TableNAT, utiliptables.ChainPrerouting, sourceIP, protocol, destIP, destPort)
// After the PREROUTING rules run, pending DNATs are processed (which would affect // After the PREROUTING rules run, pending DNATs are processed (which would affect
// the destination IP that later rules match against). // the destination IP that later rules match against).
@ -1515,10 +1519,10 @@ func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP,
// inbound, outbound, or intra-host packet, which we don't know. So we just run // inbound, outbound, or intra-host packet, which we don't know. So we just run
// the interesting tables manually. (Theoretically this could cause conflicts in // the interesting tables manually. (Theoretically this could cause conflicts in
// the future in which case we'd have to do something more complicated.) // the future in which case we'd have to do something more complicated.)
tracer.runChain(utiliptables.TableFilter, kubeServicesChain, sourceIP, destIP, destPort) tracer.runChain(utiliptables.TableFilter, kubeServicesChain, sourceIP, protocol, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeExternalServicesChain, sourceIP, destIP, destPort) tracer.runChain(utiliptables.TableFilter, kubeExternalServicesChain, sourceIP, protocol, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeNodePortsChain, sourceIP, destIP, destPort) tracer.runChain(utiliptables.TableFilter, kubeNodePortsChain, sourceIP, protocol, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeProxyFirewallChain, sourceIP, destIP, destPort) tracer.runChain(utiliptables.TableFilter, kubeProxyFirewallChain, sourceIP, protocol, destIP, destPort)
// Finally, the nat:POSTROUTING rules run, but the only interesting thing that // Finally, the nat:POSTROUTING rules run, but the only interesting thing that
// happens there is that the masquerade mark gets turned into actual masquerading. // happens there is that the masquerade mark gets turned into actual masquerading.
@ -1529,6 +1533,7 @@ func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP,
type packetFlowTest struct { type packetFlowTest struct {
name string name string
sourceIP string sourceIP string
protocol v1.Protocol
destIP string destIP string
destPort int destPort int
output string output string
@ -1542,7 +1547,11 @@ func runPacketFlowTests(t *testing.T, line int, ipt *iptablestest.FakeIPTables,
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
matches, output, masq := tracePacket(t, ipt, tc.sourceIP, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIP) protocol := strings.ToLower(string(tc.protocol))
if protocol == "" {
protocol = "tcp"
}
matches, output, masq := tracePacket(t, ipt, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIP)
var errors []string var errors []string
if output != tc.output { if output != tc.output {
errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output)) errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output))
@ -1551,8 +1560,8 @@ func runPacketFlowTests(t *testing.T, line int, ipt *iptablestest.FakeIPTables,
errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq)) errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq))
} }
if errors != nil { if errors != nil {
t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n", t.Errorf("Test %q of a %s packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n",
tc.name, tc.sourceIP, tc.destIP, tc.destPort, lineStr, strings.Join(errors, "\n"), strings.Join(matches, "\n")) tc.name, protocol, tc.sourceIP, tc.destIP, tc.destPort, lineStr, strings.Join(errors, "\n"), strings.Join(matches, "\n"))
} }
}) })
} }
@ -2139,6 +2148,7 @@ func TestClusterIPEndpointsMore(t *testing.T) {
{ {
name: "cluster IP accepted", name: "cluster IP accepted",
sourceIP: "10.180.0.2", sourceIP: "10.180.0.2",
protocol: v1.ProtocolSCTP,
destIP: "172.30.0.41", destIP: "172.30.0.41",
destPort: 80, destPort: 80,
output: "10.180.0.1:80", output: "10.180.0.1:80",
@ -2147,6 +2157,7 @@ func TestClusterIPEndpointsMore(t *testing.T) {
{ {
name: "hairpin to cluster IP", name: "hairpin to cluster IP",
sourceIP: "10.180.0.1", sourceIP: "10.180.0.1",
protocol: v1.ProtocolSCTP,
destIP: "172.30.0.41", destIP: "172.30.0.41",
destPort: 80, destPort: 80,
output: "10.180.0.1:80", output: "10.180.0.1:80",