Extend iptables packet tracer to check the protocol

This commit is contained in:
Dan Winship 2023-06-25 15:20:16 -04:00
parent a25fb03c00
commit 0910fe4b98

View File

@ -1395,9 +1395,8 @@ func newIPTablesTracer(t *testing.T, ipt *iptablestest.FakeIPTables, nodeIP stri
}
// ruleMatches checks if the given iptables rule matches (at least probabilistically) a
// packet with the given sourceIP, destIP, and destPort. (Note that protocol is currently
// ignored.)
func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, destIP, destPort string) bool {
// packet with the given sourceIP, destIP, and destPort.
func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, protocol, destIP, destPort string) bool {
// The sub-rules within an iptables rule are ANDed together, so the rule only
// matches if all of them match. So go through the subrules, and if any of them
// DON'T match, then fail.
@ -1415,6 +1414,10 @@ func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, des
}
}
if rule.Protocol != nil && !rule.Protocol.Matches(protocol) {
return false
}
if rule.DestinationAddress != nil && !addressMatches(tracer.t, rule.DestinationAddress, destIP) {
return false
}
@ -1442,7 +1445,7 @@ func (tracer *iptablesTracer) ruleMatches(rule *iptablestest.Rule, sourceIP, des
// runChain runs the given packet through the rules in the given table and chain, updating
// tracer's internal state accordingly. It returns true if it hits a terminal action.
func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptables.Chain, sourceIP, destIP, destPort string) bool {
func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptables.Chain, sourceIP, protocol, destIP, destPort string) bool {
c, _ := tracer.ipt.Dump.GetChain(table, chain)
if c == nil {
return false
@ -1453,7 +1456,7 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab
continue
}
if !tracer.ruleMatches(rule, sourceIP, destIP, destPort) {
if !tracer.ruleMatches(rule, sourceIP, protocol, destIP, destPort) {
continue
}
// record the matched rule for debugging purposes
@ -1476,7 +1479,7 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab
default:
// We got a "-j KUBE-SOMETHING", so process that chain
terminated := tracer.runChain(table, utiliptables.Chain(rule.Jump.Value), sourceIP, destIP, destPort)
terminated := tracer.runChain(table, utiliptables.Chain(rule.Jump.Value), sourceIP, protocol, destIP, destPort)
// If the subchain hit a terminal rule AND the rule that sent us
// to that chain was non-probabilistic, then this chain terminates
@ -1492,18 +1495,19 @@ func (tracer *iptablesTracer) runChain(table utiliptables.Table, chain utiliptab
return false
}
// tracePacket determines what would happen to a packet with the given sourceIP, destIP,
// and destPort, given the indicated iptables ruleData. nodeIP is the local node IP (for
// rules matching "LOCAL").
// tracePacket determines what would happen to a packet with the given sourceIP, protocol,
// destIP, and destPort, given the indicated iptables ruleData. nodeIP is the local node
// IP (for rules matching "LOCAL"). (The protocol value should be lowercase as in iptables
// rules, not uppercase as in corev1.)
//
// The return values are: an array of matched rules (for debugging), the final packet
// destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT",
// "DROP", or "REJECT"), and whether the packet would be masqueraded.
func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP, destPort, nodeIP string) ([]string, string, bool) {
func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, protocol, destIP, destPort, nodeIP string) ([]string, string, bool) {
tracer := newIPTablesTracer(t, ipt, nodeIP)
// nat:PREROUTING goes first
tracer.runChain(utiliptables.TableNAT, utiliptables.ChainPrerouting, sourceIP, destIP, destPort)
tracer.runChain(utiliptables.TableNAT, utiliptables.ChainPrerouting, sourceIP, protocol, destIP, destPort)
// After the PREROUTING rules run, pending DNATs are processed (which would affect
// the destination IP that later rules match against).
@ -1515,10 +1519,10 @@ func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP,
// inbound, outbound, or intra-host packet, which we don't know. So we just run
// the interesting tables manually. (Theoretically this could cause conflicts in
// the future in which case we'd have to do something more complicated.)
tracer.runChain(utiliptables.TableFilter, kubeServicesChain, sourceIP, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeExternalServicesChain, sourceIP, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeNodePortsChain, sourceIP, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeProxyFirewallChain, sourceIP, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeServicesChain, sourceIP, protocol, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeExternalServicesChain, sourceIP, protocol, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeNodePortsChain, sourceIP, protocol, destIP, destPort)
tracer.runChain(utiliptables.TableFilter, kubeProxyFirewallChain, sourceIP, protocol, destIP, destPort)
// Finally, the nat:POSTROUTING rules run, but the only interesting thing that
// happens there is that the masquerade mark gets turned into actual masquerading.
@ -1529,6 +1533,7 @@ func tracePacket(t *testing.T, ipt *iptablestest.FakeIPTables, sourceIP, destIP,
type packetFlowTest struct {
name string
sourceIP string
protocol v1.Protocol
destIP string
destPort int
output string
@ -1542,7 +1547,11 @@ func runPacketFlowTests(t *testing.T, line int, ipt *iptablestest.FakeIPTables,
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
matches, output, masq := tracePacket(t, ipt, tc.sourceIP, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIP)
protocol := strings.ToLower(string(tc.protocol))
if protocol == "" {
protocol = "tcp"
}
matches, output, masq := tracePacket(t, ipt, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIP)
var errors []string
if output != tc.output {
errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output))
@ -1551,8 +1560,8 @@ func runPacketFlowTests(t *testing.T, line int, ipt *iptablestest.FakeIPTables,
errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq))
}
if errors != nil {
t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n",
tc.name, tc.sourceIP, tc.destIP, tc.destPort, lineStr, strings.Join(errors, "\n"), strings.Join(matches, "\n"))
t.Errorf("Test %q of a %s packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n",
tc.name, protocol, tc.sourceIP, tc.destIP, tc.destPort, lineStr, strings.Join(errors, "\n"), strings.Join(matches, "\n"))
}
})
}
@ -2139,6 +2148,7 @@ func TestClusterIPEndpointsMore(t *testing.T) {
{
name: "cluster IP accepted",
sourceIP: "10.180.0.2",
protocol: v1.ProtocolSCTP,
destIP: "172.30.0.41",
destPort: 80,
output: "10.180.0.1:80",
@ -2147,6 +2157,7 @@ func TestClusterIPEndpointsMore(t *testing.T) {
{
name: "hairpin to cluster IP",
sourceIP: "10.180.0.1",
protocol: v1.ProtocolSCTP,
destIP: "172.30.0.41",
destPort: 80,
output: "10.180.0.1:80",