pkg/util/iptables/testing: Fix FakeIPTables

FakeIPTables barely implemented any of the iptables interface, and the
main part that it did implement, it implemented incorrectly. Fix it:

- Implement EnsureChain, DeleteChain, EnsureRule, and DeleteRule, not
  just SaveInto/Restore/RestoreAll.

- Restore/RestoreAll now correctly merge the provided state with the
  existing state, rather than simply overwriting it.

- SaveInto now returns the table that was requested, rather than just
  echoing back the Restore/RestoreAll.
This commit is contained in:
Dan Winship 2022-04-07 16:13:34 -04:00
parent 10a72a9e03
commit 913f4bc0ba
3 changed files with 508 additions and 27 deletions

View File

@ -3835,12 +3835,8 @@ func getRules(ipt *iptablestest.FakeIPTables, chain utiliptables.Chain) []*iptab
var rules []*iptablestest.Rule
buf := bytes.NewBuffer(nil)
// FIXME: FakeIPTables.SaveInto is currently broken and ignores the "table"
// argument and just echoes whatever was last passed to RestoreAll(), so even
// though we want to see the rules from both "nat" and "filter", we have to
// only request one of them, or else we'll get all the rules twice...
_ = ipt.SaveInto(utiliptables.TableNAT, buf)
// _ = ipt.SaveInto(utiliptable.TableFilter, buf)
_ = ipt.SaveInto(utiliptables.TableFilter, buf)
lines := strings.Split(string(buf.Bytes()), "\n")
for _, l := range lines {
if !strings.HasPrefix(l, "-A ") {

View File

@ -18,6 +18,8 @@ package testing
import (
"bytes"
"fmt"
"strings"
"time"
"k8s.io/kubernetes/pkg/util/iptables"
@ -26,53 +28,146 @@ import (
// FakeIPTables is no-op implementation of iptables Interface.
type FakeIPTables struct {
hasRandomFully bool
Lines []byte
protocol iptables.Protocol
Dump *IPTablesDump
}
// NewFake returns a no-op iptables.Interface
func NewFake() *FakeIPTables {
return &FakeIPTables{protocol: iptables.ProtocolIPv4}
f := &FakeIPTables{
protocol: iptables.ProtocolIPv4,
Dump: &IPTablesDump{
Tables: []Table{
{
Name: iptables.TableNAT,
Chains: []Chain{
{Name: iptables.ChainPrerouting},
{Name: iptables.ChainInput},
{Name: iptables.ChainOutput},
{Name: iptables.ChainPostrouting},
},
},
{
Name: iptables.TableFilter,
Chains: []Chain{
{Name: iptables.ChainInput},
{Name: iptables.ChainForward},
{Name: iptables.ChainOutput},
},
},
{
Name: iptables.TableMangle,
Chains: []Chain{},
},
},
},
}
return f
}
// NewIPv6Fake returns a no-op iptables.Interface with IsIPv6() == true
func NewIPv6Fake() *FakeIPTables {
return &FakeIPTables{protocol: iptables.ProtocolIPv6}
f := NewFake()
f.protocol = iptables.ProtocolIPv6
return f
}
// SetHasRandomFully is part of iptables.Interface
// SetHasRandomFully sets f's return value for HasRandomFully()
func (f *FakeIPTables) SetHasRandomFully(can bool) *FakeIPTables {
f.hasRandomFully = can
return f
}
// EnsureChain is part of iptables.Interface
func (*FakeIPTables) EnsureChain(table iptables.Table, chain iptables.Chain) (bool, error) {
return true, nil
func (f *FakeIPTables) EnsureChain(table iptables.Table, chain iptables.Chain) (bool, error) {
t, err := f.Dump.GetTable(table)
if err != nil {
return false, err
}
if c, _ := f.Dump.GetChain(table, chain); c != nil {
return true, nil
}
t.Chains = append(t.Chains, Chain{Name: chain})
return false, nil
}
// FlushChain is part of iptables.Interface
func (*FakeIPTables) FlushChain(table iptables.Table, chain iptables.Chain) error {
func (f *FakeIPTables) FlushChain(table iptables.Table, chain iptables.Chain) error {
if c, _ := f.Dump.GetChain(table, chain); c != nil {
c.Rules = nil
}
return nil
}
// DeleteChain is part of iptables.Interface
func (*FakeIPTables) DeleteChain(table iptables.Table, chain iptables.Chain) error {
func (f *FakeIPTables) DeleteChain(table iptables.Table, chain iptables.Chain) error {
t, err := f.Dump.GetTable(table)
if err != nil {
return err
}
for i := range t.Chains {
if t.Chains[i].Name == chain {
t.Chains = append(t.Chains[:i], t.Chains[i+1:]...)
return nil
}
}
return nil
}
// ChainExists is part of iptables.Interface
func (*FakeIPTables) ChainExists(table iptables.Table, chain iptables.Chain) (bool, error) {
return true, nil
func (f *FakeIPTables) ChainExists(table iptables.Table, chain iptables.Chain) (bool, error) {
if _, err := f.Dump.GetTable(table); err != nil {
return false, err
}
if c, _ := f.Dump.GetChain(table, chain); c != nil {
return true, nil
}
return false, nil
}
// EnsureRule is part of iptables.Interface
func (*FakeIPTables) EnsureRule(position iptables.RulePosition, table iptables.Table, chain iptables.Chain, args ...string) (bool, error) {
return true, nil
func (f *FakeIPTables) EnsureRule(position iptables.RulePosition, table iptables.Table, chain iptables.Chain, args ...string) (bool, error) {
c, err := f.Dump.GetChain(table, chain)
if err != nil {
return false, err
}
rule := "-A " + string(chain) + " " + strings.Join(args, " ")
for _, r := range c.Rules {
if r.Raw == rule {
return true, nil
}
}
parsed, err := ParseRule(rule, false)
if err != nil {
return false, err
}
if position == iptables.Append {
c.Rules = append(c.Rules, parsed)
} else {
c.Rules = append([]*Rule{parsed}, c.Rules...)
}
return false, nil
}
// DeleteRule is part of iptables.Interface
func (*FakeIPTables) DeleteRule(table iptables.Table, chain iptables.Chain, args ...string) error {
func (f *FakeIPTables) DeleteRule(table iptables.Table, chain iptables.Chain, args ...string) error {
c, err := f.Dump.GetChain(table, chain)
if err != nil {
return err
}
rule := "-A " + string(chain) + " " + strings.Join(args, " ")
for i, r := range c.Rules {
if r.Raw == rule {
c.Rules = append(c.Rules[:i], c.Rules[i+1:]...)
break
}
}
return nil
}
@ -86,27 +181,102 @@ func (f *FakeIPTables) Protocol() iptables.Protocol {
return f.protocol
}
// Save is part of iptables.Interface
func (f *FakeIPTables) Save(table iptables.Table) ([]byte, error) {
lines := make([]byte, len(f.Lines))
copy(lines, f.Lines)
return lines, nil
func (f *FakeIPTables) saveTable(table iptables.Table, buffer *bytes.Buffer) error {
t, err := f.Dump.GetTable(table)
if err != nil {
return err
}
fmt.Fprintf(buffer, "*%s\n", table)
for _, c := range t.Chains {
fmt.Fprintf(buffer, ":%s - [%d:%d]\n", c.Name, c.Packets, c.Bytes)
}
for _, c := range t.Chains {
for _, r := range c.Rules {
fmt.Fprintf(buffer, "%s\n", r.Raw)
}
}
fmt.Fprintf(buffer, "COMMIT\n")
return nil
}
// SaveInto is part of iptables.Interface
func (f *FakeIPTables) SaveInto(table iptables.Table, buffer *bytes.Buffer) error {
buffer.Write(f.Lines)
if table == "" {
// As a secret extension to the API, FakeIPTables treats table="" as
// meaning "all tables"
for i := range f.Dump.Tables {
err := f.saveTable(f.Dump.Tables[i].Name, buffer)
if err != nil {
return err
}
}
return nil
}
return f.saveTable(table, buffer)
}
func (f *FakeIPTables) restoreTable(newTable *Table, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
oldTable, err := f.Dump.GetTable(newTable.Name)
if err != nil {
return err
}
if flush == iptables.FlushTables {
oldTable.Chains = make([]Chain, 0, len(newTable.Chains))
}
for _, newChain := range newTable.Chains {
oldChain, _ := f.Dump.GetChain(newTable.Name, newChain.Name)
switch {
case oldChain == nil && newChain.Deleted:
// no-op
case oldChain == nil && !newChain.Deleted:
oldTable.Chains = append(oldTable.Chains, newChain)
case oldChain != nil && newChain.Deleted:
// FIXME: should make sure chain is not referenced from other jumps
_ = f.DeleteChain(newTable.Name, newChain.Name)
case oldChain != nil && !newChain.Deleted:
// replace old data with new
oldChain.Rules = newChain.Rules
if counters == iptables.RestoreCounters {
oldChain.Packets = newChain.Packets
oldChain.Bytes = newChain.Bytes
}
}
}
return nil
}
// Restore is part of iptables.Interface
func (*FakeIPTables) Restore(table iptables.Table, data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
return nil
func (f *FakeIPTables) Restore(table iptables.Table, data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
dump, err := ParseIPTablesDump(string(data))
if err != nil {
return err
}
newTable, err := dump.GetTable(table)
if err != nil {
return err
}
return f.restoreTable(newTable, flush, counters)
}
// RestoreAll is part of iptables.Interface
func (f *FakeIPTables) RestoreAll(data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
f.Lines = data
dump, err := ParseIPTablesDump(string(data))
if err != nil {
return err
}
for i := range dump.Tables {
err = f.restoreTable(&dump.Tables[i], flush, counters)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,315 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testing
import (
"bytes"
"strings"
"testing"
"github.com/lithammer/dedent"
"k8s.io/kubernetes/pkg/util/iptables"
)
func TestFakeIPTables(t *testing.T) {
fake := NewFake()
buf := bytes.NewBuffer(nil)
err := fake.SaveInto("", buf)
if err != nil {
t.Fatalf("unexpected error from SaveInto: %v", err)
}
expected := dedent.Dedent(strings.Trim(`
*nat
:PREROUTING - [0:0]
:INPUT - [0:0]
:OUTPUT - [0:0]
:POSTROUTING - [0:0]
COMMIT
*filter
:INPUT - [0:0]
:FORWARD - [0:0]
:OUTPUT - [0:0]
COMMIT
*mangle
COMMIT
`, "\n"))
if string(buf.Bytes()) != expected {
t.Fatalf("bad initial dump. expected:\n%s\n\ngot:\n%s\n", expected, buf.Bytes())
}
// EnsureChain
existed, err := fake.EnsureChain(iptables.Table("blah"), iptables.Chain("KUBE-TEST"))
if err == nil {
t.Errorf("did not get expected error creating chain in non-existent table")
} else if existed {
t.Errorf("wrong return value from EnsureChain with non-existent table")
}
existed, err = fake.EnsureChain(iptables.TableNAT, iptables.Chain("KUBE-TEST"))
if err != nil {
t.Errorf("unexpected error creating chain: %v", err)
} else if existed {
t.Errorf("wrong return value from EnsureChain with non-existent chain")
}
existed, err = fake.EnsureChain(iptables.TableNAT, iptables.Chain("KUBE-TEST"))
if err != nil {
t.Errorf("unexpected error creating chain: %v", err)
} else if !existed {
t.Errorf("wrong return value from EnsureChain with existing chain")
}
// ChainExists
exists, err := fake.ChainExists(iptables.TableNAT, iptables.Chain("KUBE-TEST"))
if err != nil {
t.Errorf("unexpected error checking chain: %v", err)
} else if !exists {
t.Errorf("wrong return value from ChainExists with existing chain")
}
exists, err = fake.ChainExists(iptables.TableNAT, iptables.Chain("KUBE-TEST-NOT"))
if err != nil {
t.Errorf("unexpected error checking chain: %v", err)
} else if exists {
t.Errorf("wrong return value from ChainExists with non-existent chain")
}
// EnsureRule
existed, err = fake.EnsureRule(iptables.Append, iptables.Table("blah"), iptables.Chain("KUBE-TEST"), "-j", "ACCEPT")
if err == nil {
t.Errorf("did not get expected error creating rule in non-existent table")
} else if existed {
t.Errorf("wrong return value from EnsureRule with non-existent table")
}
existed, err = fake.EnsureRule(iptables.Append, iptables.TableNAT, iptables.Chain("KUBE-TEST-NOT"), "-j", "ACCEPT")
if err == nil {
t.Errorf("did not get expected error creating rule in non-existent chain")
} else if existed {
t.Errorf("wrong return value from EnsureRule with non-existent chain")
}
existed, err = fake.EnsureRule(iptables.Append, iptables.TableNAT, iptables.Chain("KUBE-TEST"), "-j", "ACCEPT")
if err != nil {
t.Errorf("unexpected error creating rule: %v", err)
} else if existed {
t.Errorf("wrong return value from EnsureRule with non-existent rule")
}
existed, err = fake.EnsureRule(iptables.Prepend, iptables.TableNAT, iptables.Chain("KUBE-TEST"), "-j", "DROP")
if err != nil {
t.Errorf("unexpected error creating rule: %v", err)
} else if existed {
t.Errorf("wrong return value from EnsureRule with non-existent rule")
}
existed, err = fake.EnsureRule(iptables.Append, iptables.TableNAT, iptables.Chain("KUBE-TEST"), "-j", "DROP")
if err != nil {
t.Errorf("unexpected error creating rule: %v", err)
} else if !existed {
t.Errorf("wrong return value from EnsureRule with already-existing rule")
}
// Sanity-check...
buf.Reset()
err = fake.SaveInto("", buf)
if err != nil {
t.Fatalf("unexpected error from SaveInto: %v", err)
}
expected = dedent.Dedent(strings.Trim(`
*nat
:PREROUTING - [0:0]
:INPUT - [0:0]
:OUTPUT - [0:0]
:POSTROUTING - [0:0]
:KUBE-TEST - [0:0]
-A KUBE-TEST -j DROP
-A KUBE-TEST -j ACCEPT
COMMIT
*filter
:INPUT - [0:0]
:FORWARD - [0:0]
:OUTPUT - [0:0]
COMMIT
*mangle
COMMIT
`, "\n"))
if string(buf.Bytes()) != expected {
t.Fatalf("bad sanity-check dump. expected:\n%s\n\ngot:\n%s\n", expected, buf.Bytes())
}
// DeleteRule
err = fake.DeleteRule(iptables.Table("blah"), iptables.Chain("KUBE-TEST"), "-j", "DROP")
if err == nil {
t.Errorf("did not get expected error deleting rule in non-existent table")
}
err = fake.DeleteRule(iptables.TableNAT, iptables.Chain("KUBE-TEST-NOT"), "-j", "DROP")
if err == nil {
t.Errorf("did not get expected error deleting rule in non-existent chain")
}
err = fake.DeleteRule(iptables.TableNAT, iptables.Chain("KUBE-TEST"), "-j", "DROPLET")
if err != nil {
t.Errorf("unexpected error deleting non-existent rule: %v", err)
}
err = fake.DeleteRule(iptables.TableNAT, iptables.Chain("KUBE-TEST"), "-j", "DROP")
if err != nil {
t.Errorf("unexpected error deleting rule: %v", err)
}
// Restore
rules := dedent.Dedent(strings.Trim(`
*nat
:KUBE-RESTORED - [0:0]
:KUBE-MISC-CHAIN - [0:0]
:KUBE-EMPTY - [0:0]
-A KUBE-RESTORED -m comment --comment "restored chain" -j ACCEPT
-A KUBE-MISC-CHAIN -s 1.2.3.4 -j DROP
-A KUBE-MISC-CHAIN -d 5.6.7.8 -j MASQUERADE
COMMIT
`, "\n"))
err = fake.Restore(iptables.TableNAT, []byte(rules), iptables.NoFlushTables, iptables.NoRestoreCounters)
if err != nil {
t.Fatalf("unexpected error from Restore: %v", err)
}
// We used NoFlushTables, so this should leave KUBE-TEST unchanged
buf.Reset()
err = fake.SaveInto("", buf)
if err != nil {
t.Fatalf("unexpected error from SaveInto: %v", err)
}
expected = dedent.Dedent(strings.Trim(`
*nat
:PREROUTING - [0:0]
:INPUT - [0:0]
:OUTPUT - [0:0]
:POSTROUTING - [0:0]
:KUBE-TEST - [0:0]
:KUBE-RESTORED - [0:0]
:KUBE-MISC-CHAIN - [0:0]
:KUBE-EMPTY - [0:0]
-A KUBE-TEST -j ACCEPT
-A KUBE-RESTORED -m comment --comment "restored chain" -j ACCEPT
-A KUBE-MISC-CHAIN -s 1.2.3.4 -j DROP
-A KUBE-MISC-CHAIN -d 5.6.7.8 -j MASQUERADE
COMMIT
*filter
:INPUT - [0:0]
:FORWARD - [0:0]
:OUTPUT - [0:0]
COMMIT
*mangle
COMMIT
`, "\n"))
if string(buf.Bytes()) != expected {
t.Fatalf("bad post-restore dump. expected:\n%s\n\ngot:\n%s\n", expected, buf.Bytes())
}
// more Restore; empty out one chain and delete another, but also update its counters
rules = dedent.Dedent(strings.Trim(`
*nat
:KUBE-RESTORED - [0:0]
:KUBE-TEST - [99:9999]
-X KUBE-RESTORED
COMMIT
`, "\n"))
err = fake.Restore(iptables.TableNAT, []byte(rules), iptables.NoFlushTables, iptables.RestoreCounters)
if err != nil {
t.Fatalf("unexpected error from Restore: %v", err)
}
buf.Reset()
err = fake.SaveInto("", buf)
if err != nil {
t.Fatalf("unexpected error from SaveInto: %v", err)
}
expected = dedent.Dedent(strings.Trim(`
*nat
:PREROUTING - [0:0]
:INPUT - [0:0]
:OUTPUT - [0:0]
:POSTROUTING - [0:0]
:KUBE-TEST - [99:9999]
:KUBE-MISC-CHAIN - [0:0]
:KUBE-EMPTY - [0:0]
-A KUBE-MISC-CHAIN -s 1.2.3.4 -j DROP
-A KUBE-MISC-CHAIN -d 5.6.7.8 -j MASQUERADE
COMMIT
*filter
:INPUT - [0:0]
:FORWARD - [0:0]
:OUTPUT - [0:0]
COMMIT
*mangle
COMMIT
`, "\n"))
if string(buf.Bytes()) != expected {
t.Fatalf("bad post-second-restore dump. expected:\n%s\n\ngot:\n%s\n", expected, buf.Bytes())
}
// RestoreAll, FlushTables
rules = dedent.Dedent(strings.Trim(`
*filter
:INPUT - [0:0]
:FORWARD - [0:0]
:OUTPUT - [0:0]
:KUBE-TEST - [0:0]
-A KUBE-TEST -m comment --comment "filter table KUBE-TEST" -j ACCEPT
COMMIT
*nat
:PREROUTING - [0:0]
:INPUT - [0:0]
:OUTPUT - [0:0]
:POSTROUTING - [0:0]
:KUBE-TEST - [88:8888]
:KUBE-NEW-CHAIN - [0:0]
-A KUBE-NEW-CHAIN -d 172.30.0.1 -j DNAT --to-destination 10.0.0.1
-A KUBE-NEW-CHAIN -d 172.30.0.2 -j DNAT --to-destination 10.0.0.2
-A KUBE-NEW-CHAIN -d 172.30.0.3 -j DNAT --to-destination 10.0.0.3
COMMIT
`, "\n"))
err = fake.RestoreAll([]byte(rules), iptables.FlushTables, iptables.NoRestoreCounters)
if err != nil {
t.Fatalf("unexpected error from RestoreAll: %v", err)
}
buf.Reset()
err = fake.SaveInto("", buf)
if err != nil {
t.Fatalf("unexpected error from SaveInto: %v", err)
}
expected = dedent.Dedent(strings.Trim(`
*nat
:PREROUTING - [0:0]
:INPUT - [0:0]
:OUTPUT - [0:0]
:POSTROUTING - [0:0]
:KUBE-TEST - [88:8888]
:KUBE-NEW-CHAIN - [0:0]
-A KUBE-NEW-CHAIN -d 172.30.0.1 -j DNAT --to-destination 10.0.0.1
-A KUBE-NEW-CHAIN -d 172.30.0.2 -j DNAT --to-destination 10.0.0.2
-A KUBE-NEW-CHAIN -d 172.30.0.3 -j DNAT --to-destination 10.0.0.3
COMMIT
*filter
:INPUT - [0:0]
:FORWARD - [0:0]
:OUTPUT - [0:0]
:KUBE-TEST - [0:0]
-A KUBE-TEST -m comment --comment "filter table KUBE-TEST" -j ACCEPT
COMMIT
*mangle
COMMIT
`, "\n"))
if string(buf.Bytes()) != expected {
t.Fatalf("bad post-restore-all dump. expected:\n%s\n\ngot:\n%s\n", expected, buf.Bytes())
}
}