diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index 9c088da13ff..d705c74a22d 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -1395,9 +1395,8 @@ func newIPTablesTracer(t *testing.T, ipt *iptablestest.FakeIPTables, nodeIP stri } // ruleMatches checks if the given iptables rule matches (at least probabilistically) a -// packet with the given sourceIP, destIP, and destPort. (Note that protocol is currently -// ignored.) -func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, destIP, destPort string) bool { +// packet with the given sourceIP, destIP, and destPort. +func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, protocol, destIP, destPort string) bool { // The sub-rules within an iptables rule are ANDed together, so the rule only // matches if all of them match. So go through the subrules, and if any of them // DON'T match, then fail. @@ -1415,6 +1414,10 @@ func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, des } } + if rule.Protocol != nil && !rule.Protocol.Matches(protocol) { + return false + } + if rule.DestinationAddress != nil && !addressMatches(tracer.t, rule.DestinationAddress, destIP) { return false } @@ -1442,7 +1445,7 @@ func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, des // runChain runs the given packet through the rules in the given table and chain, updating // tracer's internal state accordingly. It returns true if it hits a terminal action. -func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptables.Chain, sourceIP, destIP, destPort string) bool { +func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptables.Chain, sourceIP, protocol, destIP, destPort string) bool { c, _ := tracer.ipt.Dump.GetChain(table, chain) if c == nil { return false @@ -1453,7 +1456,7 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab continue } - if !tracer.ruleMatches(rule, sourceIP, destIP, destPort) { + if !tracer.ruleMatches(rule, sourceIP, protocol, destIP, destPort) { continue } // record the matched rule for debugging purposes @@ -1476,7 +1479,7 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab default: // We got a "-j KUBE-SOMETHING", so process that chain - terminated := tracer.runChain(table, utiliptables.Chain(rule.Jump.Value), sourceIP, destIP, destPort) + terminated := tracer.runChain(table, utiliptables.Chain(rule.Jump.Value), sourceIP, protocol, destIP, destPort) // If the subchain hit a terminal rule AND the rule that sent us // to that chain was non-probabilistic, then this chain terminates @@ -1492,18 +1495,19 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab return false } -// tracePacket determines what would happen to a packet with the given sourceIP, destIP, -// and destPort, given the indicated iptables ruleData. nodeIP is the local node IP (for -// rules matching "LOCAL"). +// tracePacket determines what would happen to a packet with the given sourceIP, protocol, +// destIP, and destPort, given the indicated iptables ruleData. nodeIP is the local node +// IP (for rules matching "LOCAL"). (The protocol value should be lowercase as in iptables +// rules, not uppercase as in corev1.) // // The return values are: an array of matched rules (for debugging), the final packet // destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT", // "DROP", or "REJECT"), and whether the packet would be masqueraded. -func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP, destPort, nodeIP string) ([]string, string, bool) { +func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, protocol, destIP, destPort, nodeIP string) ([]string, string, bool) { tracer := newIPTablesTracer(t, ipt, nodeIP) // nat:PREROUTING goes first - tracer.runChain(utiliptables.TableNAT, utiliptables.ChainPrerouting, sourceIP, destIP, destPort) + tracer.runChain(utiliptables.TableNAT, utiliptables.ChainPrerouting, sourceIP, protocol, destIP, destPort) // After the PREROUTING rules run, pending DNATs are processed (which would affect // the destination IP that later rules match against). @@ -1515,10 +1519,10 @@ func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP, // inbound, outbound, or intra-host packet, which we don't know. So we just run // the interesting tables manually. (Theoretically this could cause conflicts in // the future in which case we'd have to do something more complicated.) - tracer.runChain(utiliptables.TableFilter, kubeServicesChain, sourceIP, destIP, destPort) - tracer.runChain(utiliptables.TableFilter, kubeExternalServicesChain, sourceIP, destIP, destPort) - tracer.runChain(utiliptables.TableFilter, kubeNodePortsChain, sourceIP, destIP, destPort) - tracer.runChain(utiliptables.TableFilter, kubeProxyFirewallChain, sourceIP, destIP, destPort) + tracer.runChain(utiliptables.TableFilter, kubeServicesChain, sourceIP, protocol, destIP, destPort) + tracer.runChain(utiliptables.TableFilter, kubeExternalServicesChain, sourceIP, protocol, destIP, destPort) + tracer.runChain(utiliptables.TableFilter, kubeNodePortsChain, sourceIP, protocol, destIP, destPort) + tracer.runChain(utiliptables.TableFilter, kubeProxyFirewallChain, sourceIP, protocol, destIP, destPort) // Finally, the nat:POSTROUTING rules run, but the only interesting thing that // happens there is that the masquerade mark gets turned into actual masquerading. @@ -1529,6 +1533,7 @@ func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP, type packetFlowTest struct { name string sourceIP string + protocol v1.Protocol destIP string destPort int output string @@ -1542,7 +1547,11 @@ func runPacketFlowTests(t *testing.T, line int, ipt *iptablestest.FakeIPTables, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - matches, output, masq := tracePacket(t, ipt, tc.sourceIP, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIP) + protocol := strings.ToLower(string(tc.protocol)) + if protocol == "" { + protocol = "tcp" + } + matches, output, masq := tracePacket(t, ipt, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIP) var errors []string if output != tc.output { errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output)) @@ -1551,8 +1560,8 @@ func runPacketFlowTests(t *testing.T, line int, ipt *iptablestest.FakeIPTables, errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq)) } if errors != nil { - t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n", - tc.name, tc.sourceIP, tc.destIP, tc.destPort, lineStr, strings.Join(errors, "\n"), strings.Join(matches, "\n")) + t.Errorf("Test %q of a %s packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n", + tc.name, protocol, tc.sourceIP, tc.destIP, tc.destPort, lineStr, strings.Join(errors, "\n"), strings.Join(matches, "\n")) } }) } @@ -2139,6 +2148,7 @@ func TestClusterIPEndpointsMore(t *testing.T) { { name: "cluster IP accepted", sourceIP: "10.180.0.2", + protocol: v1.ProtocolSCTP, destIP: "172.30.0.41", destPort: 80, output: "10.180.0.1:80", @@ -2147,6 +2157,7 @@ func TestClusterIPEndpointsMore(t *testing.T) { { name: "hairpin to cluster IP", sourceIP: "10.180.0.1", + protocol: v1.ProtocolSCTP, destIP: "172.30.0.41", destPort: 80, output: "10.180.0.1:80",