Check iptables for -C flag at construct time

Also, reorganize the code a bit in preparation for checking for
another flag as well. And, if semver.NewVersion() returns an error, it
means there's a bug in the code somewhere (we should only ever be
passing it valid version strings), so just log that error rather than
returning it to the caller.
This commit is contained in:
Dan Winship 2015-08-26 10:08:37 -04:00
parent a41e422600
commit 6bab1adfd1
2 changed files with 46 additions and 31 deletions

View File

@ -115,11 +115,21 @@ type runner struct {
mu sync.Mutex mu sync.Mutex
exec utilexec.Interface exec utilexec.Interface
protocol Protocol protocol Protocol
hasCheck bool
} }
// New returns a new Interface which will exec iptables. // New returns a new Interface which will exec iptables.
func New(exec utilexec.Interface, protocol Protocol) Interface { func New(exec utilexec.Interface, protocol Protocol) Interface {
return &runner{exec: exec, protocol: protocol} vstring, err := GetIptablesVersionString(exec)
if err != nil {
glog.Warningf("Error checking iptables version, assuming version at least %s: %v", MinCheckVersion, err)
vstring = MinCheckVersion
}
return &runner{
exec: exec,
protocol: protocol,
hasCheck: getIptablesHasCheckCommand(vstring),
}
} }
// EnsureChain is part of Interface. // EnsureChain is part of Interface.
@ -308,12 +318,7 @@ func (runner *runner) run(op operation, args []string) ([]byte, error) {
// Returns (bool, nil) if it was able to check the existence of the rule, or // Returns (bool, nil) if it was able to check the existence of the rule, or
// (<undefined>, error) if the process of checking failed. // (<undefined>, error) if the process of checking failed.
func (runner *runner) checkRule(table Table, chain Chain, args ...string) (bool, error) { func (runner *runner) checkRule(table Table, chain Chain, args ...string) (bool, error) {
checkPresent, err := getIptablesHasCheckCommand(runner.exec) if runner.hasCheck {
if err != nil {
glog.Warningf("Error checking iptables version, assuming version at least 1.4.11: %v", err)
checkPresent = true
}
if checkPresent {
return runner.checkRuleUsingCheck(makeFullArgs(table, chain, args...)) return runner.checkRuleUsingCheck(makeFullArgs(table, chain, args...))
} else { } else {
return runner.checkRuleWithoutCheck(table, chain, args...) return runner.checkRuleWithoutCheck(table, chain, args...)
@ -399,23 +404,21 @@ func makeFullArgs(table Table, chain Chain, args ...string) []string {
} }
// Checks if iptables has the "-C" flag // Checks if iptables has the "-C" flag
func getIptablesHasCheckCommand(exec utilexec.Interface) (bool, error) { func getIptablesHasCheckCommand(vstring string) bool {
minVersion, err := semver.NewVersion(MinCheckVersion) minVersion, err := semver.NewVersion(MinCheckVersion)
if err != nil { if err != nil {
return false, err glog.Errorf("MinCheckVersion (%s) is not a valid version string: %v", MinCheckVersion, err)
} return true
vstring, err := GetIptablesVersionString(exec)
if err != nil {
return false, err
} }
version, err := semver.NewVersion(vstring) version, err := semver.NewVersion(vstring)
if err != nil { if err != nil {
return false, err glog.Errorf("vstring (%s) is not a valid version string: %v", vstring, err)
return true
} }
if version.LessThan(*minVersion) { if version.LessThan(*minVersion) {
return false, nil return false
} }
return true, nil return true
} }
// GetIptablesVersionString runs "iptables --version" to get the version string // GetIptablesVersionString runs "iptables --version" to get the version string

View File

@ -36,6 +36,8 @@ func getIptablesCommand(protocol Protocol) string {
func testEnsureChain(t *testing.T, protocol Protocol) { func testEnsureChain(t *testing.T, protocol Protocol) {
fcmd := exec.FakeCmd{ fcmd := exec.FakeCmd{
CombinedOutputScript: []exec.FakeCombinedOutputAction{ CombinedOutputScript: []exec.FakeCombinedOutputAction{
// iptables version check
func() ([]byte, error) { return []byte("iptables v1.9.22"), nil },
// Success. // Success.
func() ([]byte, error) { return []byte{}, nil }, func() ([]byte, error) { return []byte{}, nil },
// Exists. // Exists.
@ -49,6 +51,7 @@ func testEnsureChain(t *testing.T, protocol Protocol) {
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
}, },
} }
runner := New(&fexec, protocol) runner := New(&fexec, protocol)
@ -60,12 +63,12 @@ func testEnsureChain(t *testing.T, protocol Protocol) {
if exists { if exists {
t.Errorf("expected exists = false") t.Errorf("expected exists = false")
} }
if fcmd.CombinedOutputCalls != 1 { if fcmd.CombinedOutputCalls != 2 {
t.Errorf("expected 1 CombinedOutput() call, got %d", fcmd.CombinedOutputCalls) t.Errorf("expected 2 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls)
} }
cmd := getIptablesCommand(protocol) cmd := getIptablesCommand(protocol)
if !util.NewStringSet(fcmd.CombinedOutputLog[0]...).HasAll(cmd, "-t", "nat", "-N", "FOOBAR") { if !util.NewStringSet(fcmd.CombinedOutputLog[1]...).HasAll(cmd, "-t", "nat", "-N", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[0]) t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[1])
} }
// Exists. // Exists.
exists, err = runner.EnsureChain(TableNAT, Chain("FOOBAR")) exists, err = runner.EnsureChain(TableNAT, Chain("FOOBAR"))
@ -93,6 +96,8 @@ func TestEnsureChainIpv6(t *testing.T) {
func TestFlushChain(t *testing.T) { func TestFlushChain(t *testing.T) {
fcmd := exec.FakeCmd{ fcmd := exec.FakeCmd{
CombinedOutputScript: []exec.FakeCombinedOutputAction{ CombinedOutputScript: []exec.FakeCombinedOutputAction{
// iptables version check
func() ([]byte, error) { return []byte("iptables v1.9.22"), nil },
// Success. // Success.
func() ([]byte, error) { return []byte{}, nil }, func() ([]byte, error) { return []byte{}, nil },
// Failure. // Failure.
@ -103,6 +108,7 @@ func TestFlushChain(t *testing.T) {
CommandScript: []exec.FakeCommandAction{ CommandScript: []exec.FakeCommandAction{
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
}, },
} }
runner := New(&fexec, ProtocolIpv4) runner := New(&fexec, ProtocolIpv4)
@ -111,11 +117,11 @@ func TestFlushChain(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected success, got %v", err) t.Errorf("expected success, got %v", err)
} }
if fcmd.CombinedOutputCalls != 1 { if fcmd.CombinedOutputCalls != 2 {
t.Errorf("expected 1 CombinedOutput() call, got %d", fcmd.CombinedOutputCalls) t.Errorf("expected 2 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls)
} }
if !util.NewStringSet(fcmd.CombinedOutputLog[0]...).HasAll("iptables", "-t", "nat", "-F", "FOOBAR") { if !util.NewStringSet(fcmd.CombinedOutputLog[1]...).HasAll("iptables", "-t", "nat", "-F", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[0]) t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[1])
} }
// Failure. // Failure.
err = runner.FlushChain(TableNAT, Chain("FOOBAR")) err = runner.FlushChain(TableNAT, Chain("FOOBAR"))
@ -127,6 +133,8 @@ func TestFlushChain(t *testing.T) {
func TestDeleteChain(t *testing.T) { func TestDeleteChain(t *testing.T) {
fcmd := exec.FakeCmd{ fcmd := exec.FakeCmd{
CombinedOutputScript: []exec.FakeCombinedOutputAction{ CombinedOutputScript: []exec.FakeCombinedOutputAction{
// iptables version check
func() ([]byte, error) { return []byte("iptables v1.9.22"), nil },
// Success. // Success.
func() ([]byte, error) { return []byte{}, nil }, func() ([]byte, error) { return []byte{}, nil },
// Failure. // Failure.
@ -137,6 +145,7 @@ func TestDeleteChain(t *testing.T) {
CommandScript: []exec.FakeCommandAction{ CommandScript: []exec.FakeCommandAction{
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
}, },
} }
runner := New(&fexec, ProtocolIpv4) runner := New(&fexec, ProtocolIpv4)
@ -145,11 +154,11 @@ func TestDeleteChain(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected success, got %v", err) t.Errorf("expected success, got %v", err)
} }
if fcmd.CombinedOutputCalls != 1 { if fcmd.CombinedOutputCalls != 2 {
t.Errorf("expected 1 CombinedOutput() call, got %d", fcmd.CombinedOutputCalls) t.Errorf("expected 2 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls)
} }
if !util.NewStringSet(fcmd.CombinedOutputLog[0]...).HasAll("iptables", "-t", "nat", "-X", "FOOBAR") { if !util.NewStringSet(fcmd.CombinedOutputLog[1]...).HasAll("iptables", "-t", "nat", "-X", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[0]) t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[1])
} }
// Failure. // Failure.
err = runner.DeleteChain(TableNAT, Chain("FOOBAR")) err = runner.DeleteChain(TableNAT, Chain("FOOBAR"))
@ -428,12 +437,15 @@ func TestGetIptablesHasCheckCommand(t *testing.T) {
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
}, },
} }
check, err := getIptablesHasCheckCommand(&fexec) version, err := GetIptablesVersionString(&fexec)
if (err != nil) != testCase.Err { if (err != nil) != testCase.Err {
t.Errorf("Expected error: %v, Got error: %v", testCase.Err, err) t.Errorf("Expected error: %v, Got error: %v", testCase.Err, err)
} }
if err == nil && testCase.Expected != check { if err == nil {
t.Errorf("Expected result: %v, Got result: %v", testCase.Expected, check) check := getIptablesHasCheckCommand(version)
if testCase.Expected != check {
t.Errorf("Expected result: %v, Got result: %v", testCase.Expected, check)
}
} }
} }
} }