Check iptables for -C flag at construct time

Also, reorganize the code a bit in preparation for checking for
another flag as well. And, if semver.NewVersion() returns an error, it
means there's a bug in the code somewhere (we should only ever be
passing it valid version strings), so just log that error rather than
returning it to the caller.
This commit is contained in:
Dan Winship 2015-08-26 10:08:37 -04:00
parent a41e422600
commit 6bab1adfd1
2 changed files with 46 additions and 31 deletions

View File

@ -115,11 +115,21 @@ type runner struct {
mu sync.Mutex
exec utilexec.Interface
protocol Protocol
hasCheck bool
}
// New returns a new Interface which will exec iptables.
func New(exec utilexec.Interface, protocol Protocol) Interface {
return &runner{exec: exec, protocol: protocol}
vstring, err := GetIptablesVersionString(exec)
if err != nil {
glog.Warningf("Error checking iptables version, assuming version at least %s: %v", MinCheckVersion, err)
vstring = MinCheckVersion
}
return &runner{
exec: exec,
protocol: protocol,
hasCheck: getIptablesHasCheckCommand(vstring),
}
}
// EnsureChain is part of Interface.
@ -308,12 +318,7 @@ func (runner *runner) run(op operation, args []string) ([]byte, error) {
// Returns (bool, nil) if it was able to check the existence of the rule, or
// (<undefined>, error) if the process of checking failed.
func (runner *runner) checkRule(table Table, chain Chain, args ...string) (bool, error) {
checkPresent, err := getIptablesHasCheckCommand(runner.exec)
if err != nil {
glog.Warningf("Error checking iptables version, assuming version at least 1.4.11: %v", err)
checkPresent = true
}
if checkPresent {
if runner.hasCheck {
return runner.checkRuleUsingCheck(makeFullArgs(table, chain, args...))
} else {
return runner.checkRuleWithoutCheck(table, chain, args...)
@ -399,23 +404,21 @@ func makeFullArgs(table Table, chain Chain, args ...string) []string {
}
// Checks if iptables has the "-C" flag
func getIptablesHasCheckCommand(exec utilexec.Interface) (bool, error) {
func getIptablesHasCheckCommand(vstring string) bool {
minVersion, err := semver.NewVersion(MinCheckVersion)
if err != nil {
return false, err
}
vstring, err := GetIptablesVersionString(exec)
if err != nil {
return false, err
glog.Errorf("MinCheckVersion (%s) is not a valid version string: %v", MinCheckVersion, err)
return true
}
version, err := semver.NewVersion(vstring)
if err != nil {
return false, err
glog.Errorf("vstring (%s) is not a valid version string: %v", vstring, err)
return true
}
if version.LessThan(*minVersion) {
return false, nil
return false
}
return true, nil
return true
}
// GetIptablesVersionString runs "iptables --version" to get the version string

View File

@ -36,6 +36,8 @@ func getIptablesCommand(protocol Protocol) string {
func testEnsureChain(t *testing.T, protocol Protocol) {
fcmd := exec.FakeCmd{
CombinedOutputScript: []exec.FakeCombinedOutputAction{
// iptables version check
func() ([]byte, error) { return []byte("iptables v1.9.22"), nil },
// Success.
func() ([]byte, error) { return []byte{}, nil },
// Exists.
@ -49,6 +51,7 @@ func testEnsureChain(t *testing.T, protocol Protocol) {
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
},
}
runner := New(&fexec, protocol)
@ -60,12 +63,12 @@ func testEnsureChain(t *testing.T, protocol Protocol) {
if exists {
t.Errorf("expected exists = false")
}
if fcmd.CombinedOutputCalls != 1 {
t.Errorf("expected 1 CombinedOutput() call, got %d", fcmd.CombinedOutputCalls)
if fcmd.CombinedOutputCalls != 2 {
t.Errorf("expected 2 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls)
}
cmd := getIptablesCommand(protocol)
if !util.NewStringSet(fcmd.CombinedOutputLog[0]...).HasAll(cmd, "-t", "nat", "-N", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[0])
if !util.NewStringSet(fcmd.CombinedOutputLog[1]...).HasAll(cmd, "-t", "nat", "-N", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[1])
}
// Exists.
exists, err = runner.EnsureChain(TableNAT, Chain("FOOBAR"))
@ -93,6 +96,8 @@ func TestEnsureChainIpv6(t *testing.T) {
func TestFlushChain(t *testing.T) {
fcmd := exec.FakeCmd{
CombinedOutputScript: []exec.FakeCombinedOutputAction{
// iptables version check
func() ([]byte, error) { return []byte("iptables v1.9.22"), nil },
// Success.
func() ([]byte, error) { return []byte{}, nil },
// Failure.
@ -103,6 +108,7 @@ func TestFlushChain(t *testing.T) {
CommandScript: []exec.FakeCommandAction{
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
},
}
runner := New(&fexec, ProtocolIpv4)
@ -111,11 +117,11 @@ func TestFlushChain(t *testing.T) {
if err != nil {
t.Errorf("expected success, got %v", err)
}
if fcmd.CombinedOutputCalls != 1 {
t.Errorf("expected 1 CombinedOutput() call, got %d", fcmd.CombinedOutputCalls)
if fcmd.CombinedOutputCalls != 2 {
t.Errorf("expected 2 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls)
}
if !util.NewStringSet(fcmd.CombinedOutputLog[0]...).HasAll("iptables", "-t", "nat", "-F", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[0])
if !util.NewStringSet(fcmd.CombinedOutputLog[1]...).HasAll("iptables", "-t", "nat", "-F", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[1])
}
// Failure.
err = runner.FlushChain(TableNAT, Chain("FOOBAR"))
@ -127,6 +133,8 @@ func TestFlushChain(t *testing.T) {
func TestDeleteChain(t *testing.T) {
fcmd := exec.FakeCmd{
CombinedOutputScript: []exec.FakeCombinedOutputAction{
// iptables version check
func() ([]byte, error) { return []byte("iptables v1.9.22"), nil },
// Success.
func() ([]byte, error) { return []byte{}, nil },
// Failure.
@ -137,6 +145,7 @@ func TestDeleteChain(t *testing.T) {
CommandScript: []exec.FakeCommandAction{
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
},
}
runner := New(&fexec, ProtocolIpv4)
@ -145,11 +154,11 @@ func TestDeleteChain(t *testing.T) {
if err != nil {
t.Errorf("expected success, got %v", err)
}
if fcmd.CombinedOutputCalls != 1 {
t.Errorf("expected 1 CombinedOutput() call, got %d", fcmd.CombinedOutputCalls)
if fcmd.CombinedOutputCalls != 2 {
t.Errorf("expected 2 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls)
}
if !util.NewStringSet(fcmd.CombinedOutputLog[0]...).HasAll("iptables", "-t", "nat", "-X", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[0])
if !util.NewStringSet(fcmd.CombinedOutputLog[1]...).HasAll("iptables", "-t", "nat", "-X", "FOOBAR") {
t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[1])
}
// Failure.
err = runner.DeleteChain(TableNAT, Chain("FOOBAR"))
@ -428,12 +437,15 @@ func TestGetIptablesHasCheckCommand(t *testing.T) {
func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) },
},
}
check, err := getIptablesHasCheckCommand(&fexec)
version, err := GetIptablesVersionString(&fexec)
if (err != nil) != testCase.Err {
t.Errorf("Expected error: %v, Got error: %v", testCase.Err, err)
}
if err == nil && testCase.Expected != check {
t.Errorf("Expected result: %v, Got result: %v", testCase.Expected, check)
if err == nil {
check := getIptablesHasCheckCommand(version)
if testCase.Expected != check {
t.Errorf("Expected result: %v, Got result: %v", testCase.Expected, check)
}
}
}
}