Add a unit test for utiliptables.New()

This commit is contained in:
Dan Winship
2025-01-25 12:06:36 -05:00
parent 9c98d29795
commit f1d0eb4fe4

View File

@@ -42,44 +42,174 @@ func getLockPaths() (string, string) {
return lock14x, lock16x
}
func testIPTablesVersionCmds(t *testing.T, protocol Protocol) {
version := " v1.4.22"
iptablesCmd := iptablesCommand(protocol)
iptablesRestoreCmd := iptablesRestoreCommand(protocol)
type testCommand struct {
command string
action fakeexec.FakeAction
}
fcmd := fakeexec.FakeCmd{
CombinedOutputScript: []fakeexec.FakeAction{
// iptables version response (for runner instantiation)
func() ([]byte, []byte, error) { return []byte(iptablesCmd + version), nil, nil },
// iptables-restore version response (for runner instantiation)
func() ([]byte, []byte, error) { return []byte(iptablesRestoreCmd + version), nil, nil },
},
}
// Creates a FakeExec that expects exactly commands to be run (and will fail otherwise).
func fakeExecForCommands(commands []testCommand) *fakeexec.FakeExec {
fexec := &fakeexec.FakeExec{
CommandScript: []fakeexec.FakeCommandAction{
func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, cmd, args...) },
CommandScript: make([]fakeexec.FakeCommandAction, len(commands)),
ExactOrder: true,
}
for i := range commands {
fcmd := fakeexec.FakeCmd{
CombinedOutputScript: []fakeexec.FakeAction{commands[i].action},
}
argv := strings.Fields(commands[i].command)
fexec.CommandScript[i] = func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(&fcmd, argv[0], argv[1:]...) }
}
return fexec
}
func TestFakeExecForCommands(t *testing.T) {
var panicresult interface{}
defer func() {
panicresult = recover()
}()
fake1 := fakeExecForCommands([]testCommand{{
command: "foo bar baz",
action: func() ([]byte, []byte, error) { return []byte("output"), nil, nil },
}})
cmd := fake1.Command("foo", "bar", "baz")
out, err := cmd.CombinedOutput()
if string(out) != "output" {
t.Errorf("fake1: wrong output: expected %q, got %q", "output", out)
}
if err != nil {
t.Errorf("fake1: expected no error, got %v", err)
}
if panicresult != nil {
t.Errorf("fake1: expected no panic, got %q", panicresult)
}
fake2 := fakeExecForCommands([]testCommand{{
command: "foo bar baz",
action: func() ([]byte, []byte, error) { return []byte("output"), nil, nil },
}})
_ = fake2.Command("foo", "baz")
if panicresult == nil {
t.Errorf("fake2: expected panic from FakeExec, got none")
}
}
func TestNew(t *testing.T) {
testCases := []struct {
name string
commands []testCommand
expected *runner
}{
{
name: "ancient",
commands: []testCommand{
{
command: "iptables --version",
action: func() ([]byte, []byte, error) { return []byte("iptables v1.4.0"), nil, nil },
},
{
// iptables-restore version check: ignores --version and just no-ops
command: "iptables-restore --version",
action: func() ([]byte, []byte, error) { return nil, nil, nil },
},
},
expected: &runner{
hasCheck: false,
hasRandomFully: false,
waitFlag: nil,
restoreWaitFlag: nil,
},
},
{
name: "RHEL/CentOS 7",
commands: []testCommand{
{
command: "iptables --version",
action: func() ([]byte, []byte, error) { return []byte("iptables v1.4.21"), nil, nil },
},
{
command: "iptables-restore --version",
action: func() ([]byte, []byte, error) { return []byte("iptables-restore v1.4.21"), nil, nil },
},
},
expected: &runner{
hasCheck: true,
hasRandomFully: false,
waitFlag: []string{"-w"},
restoreWaitFlag: []string{"-w"},
},
},
{
name: "1.6",
commands: []testCommand{
{
command: "iptables --version",
action: func() ([]byte, []byte, error) { return []byte("iptables v1.6.2"), nil, nil },
},
},
expected: &runner{
hasCheck: true,
hasRandomFully: true,
waitFlag: []string{"-w", "5", "-W", "100000"},
restoreWaitFlag: []string{"-w", "5", "-W", "100000"},
},
},
{
name: "1.8",
commands: []testCommand{
{
command: "iptables --version",
action: func() ([]byte, []byte, error) { return []byte("iptables v1.8.11"), nil, nil },
},
},
expected: &runner{
hasCheck: true,
hasRandomFully: true,
waitFlag: []string{"-w", "5", "-W", "100000"},
restoreWaitFlag: []string{"-w", "5", "-W", "100000"},
},
},
{
name: "no iptables",
commands: []testCommand{
{
command: "iptables --version",
action: func() ([]byte, []byte, error) { return nil, nil, fmt.Errorf("no such file or directory") },
},
{
command: "iptables-restore --version",
action: func() ([]byte, []byte, error) { return nil, nil, fmt.Errorf("no such file or directory") },
},
},
expected: &runner{
hasCheck: true,
hasRandomFully: false,
waitFlag: nil,
restoreWaitFlag: nil,
},
},
}
_ = newInternal(fexec, protocol, "", "")
// Check that proper iptables version command was used during runner instantiation
if !sets.New(fcmd.CombinedOutputLog[0]...).HasAll(iptablesCmd, "--version") {
t.Errorf("%s runner instantiate: Expected cmd '%s --version', Got '%s'", protocol, iptablesCmd, fcmd.CombinedOutputLog[0])
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fexec := fakeExecForCommands(tc.commands)
runner := newInternal(fexec, ProtocolIPv4, "", "").(*runner)
if runner.hasCheck != tc.expected.hasCheck {
t.Errorf("Expected hasCheck=%v, got %v", tc.expected.hasCheck, runner.hasCheck)
}
if runner.hasRandomFully != tc.expected.hasRandomFully {
t.Errorf("Expected hasRandomFully=%v, got %v", tc.expected.hasRandomFully, runner.hasRandomFully)
}
if !reflect.DeepEqual(runner.waitFlag, tc.expected.waitFlag) {
t.Errorf("Expected waitFlag=%v, got %v", tc.expected.waitFlag, runner.waitFlag)
}
if !reflect.DeepEqual(runner.restoreWaitFlag, tc.expected.restoreWaitFlag) {
t.Errorf("Expected restoreWaitFlag=%v, got %v", tc.expected.restoreWaitFlag, runner.restoreWaitFlag)
}
})
}
// Check that proper iptables restore version command was used during runner instantiation
if !sets.New(fcmd.CombinedOutputLog[1]...).HasAll(iptablesRestoreCmd, "--version") {
t.Errorf("%s runner instantiate: Expected cmd '%s --version', Got '%s'", protocol, iptablesRestoreCmd, fcmd.CombinedOutputLog[1])
}
}
func TestIPTablesVersionCmdsIPv4(t *testing.T) {
testIPTablesVersionCmds(t, ProtocolIPv4)
}
func TestIPTablesVersionCmdsIPv6(t *testing.T) {
testIPTablesVersionCmds(t, ProtocolIPv6)
}
func testEnsureChain(t *testing.T, protocol Protocol) {