diff --git a/virtcontainers/factory/factory.go b/virtcontainers/factory/factory.go index b0d308f4e8..cb023a11c2 100644 --- a/virtcontainers/factory/factory.go +++ b/virtcontainers/factory/factory.go @@ -100,6 +100,71 @@ func resetHypervisorConfig(config *vc.VMConfig) { config.ProxyConfig = vc.ProxyConfig{} } +func compareStruct(foo, bar reflect.Value) bool { + for i := 0; i < foo.NumField(); i++ { + if !deepCompareValue(foo.Field(i), bar.Field(i)) { + return false + } + } + + return true +} + +func compareMap(foo, bar reflect.Value) bool { + if foo.Len() != bar.Len() { + return false + } + + for _, k := range foo.MapKeys() { + if !deepCompareValue(foo.MapIndex(k), bar.MapIndex(k)) { + return false + } + } + + return true +} + +func compareSlice(foo, bar reflect.Value) bool { + if foo.Len() != bar.Len() { + return false + } + for j := 0; j < foo.Len(); j++ { + if !deepCompareValue(foo.Index(j), bar.Index(j)) { + return false + } + } + return true +} + +func deepCompareValue(foo, bar reflect.Value) bool { + if !foo.IsValid() || !bar.IsValid() { + return foo.IsValid() == bar.IsValid() + } + + if foo.Type() != bar.Type() { + return false + } + switch foo.Kind() { + case reflect.Map: + return compareMap(foo, bar) + case reflect.Array: + fallthrough + case reflect.Slice: + return compareSlice(foo, bar) + case reflect.Struct: + return compareStruct(foo, bar) + default: + return foo.Interface() == bar.Interface() + } +} + +func deepCompare(foo, bar interface{}) bool { + v1 := reflect.ValueOf(foo) + v2 := reflect.ValueOf(bar) + + return deepCompareValue(v1, v2) +} + // It's important that baseConfig and newConfig are passed by value! func checkVMConfig(config1, config2 vc.VMConfig) error { if config1.HypervisorType != config2.HypervisorType { @@ -114,7 +179,7 @@ func checkVMConfig(config1, config2 vc.VMConfig) error { resetHypervisorConfig(&config1) resetHypervisorConfig(&config2) - if !reflect.DeepEqual(config1, config2) { + if !deepCompare(config1, config2) { return fmt.Errorf("hypervisor config does not match, base: %+v. new: %+v", config1, config2) } diff --git a/virtcontainers/factory/factory_test.go b/virtcontainers/factory/factory_test.go index 452889d5cb..6893c4c5f1 100644 --- a/virtcontainers/factory/factory_test.go +++ b/virtcontainers/factory/factory_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" vc "github.com/kata-containers/runtime/virtcontainers" + "github.com/kata-containers/runtime/virtcontainers/factory/base" ) func TestNewFactory(t *testing.T) { @@ -248,3 +249,66 @@ func TestFactoryGetVM(t *testing.T) { f.CloseFactory(ctx) } + +func TestDeepCompare(t *testing.T) { + assert := assert.New(t) + + foo := vc.VMConfig{} + bar := vc.VMConfig{} + assert.True(deepCompare(foo, bar)) + + foo.HypervisorConfig.NumVCPUs = 1 + assert.False(deepCompare(foo, bar)) + bar.HypervisorConfig.NumVCPUs = 1 + assert.True(deepCompare(foo, bar)) + + // slice + foo.HypervisorConfig.KernelParams = []vc.Param{} + assert.True(deepCompare(foo, bar)) + foo.HypervisorConfig.KernelParams = append(foo.HypervisorConfig.KernelParams, vc.Param{Key: "key", Value: "value"}) + assert.False(deepCompare(foo, bar)) + bar.HypervisorConfig.KernelParams = append(bar.HypervisorConfig.KernelParams, vc.Param{Key: "key", Value: "value"}) + assert.True(deepCompare(foo, bar)) + + // map + var fooMap map[string]vc.VMConfig + var barMap map[string]vc.VMConfig + assert.False(deepCompare(foo, fooMap)) + assert.True(deepCompare(fooMap, barMap)) + fooMap = make(map[string]vc.VMConfig) + assert.True(deepCompare(fooMap, barMap)) + fooMap["foo"] = foo + assert.False(deepCompare(fooMap, barMap)) + barMap = make(map[string]vc.VMConfig) + assert.False(deepCompare(fooMap, barMap)) + barMap["foo"] = bar + assert.True(deepCompare(fooMap, barMap)) + + // invalid interface + var f1 vc.Factory + var f2 vc.Factory + var f3 base.FactoryBase + assert.True(deepCompare(f1, f2)) + assert.True(deepCompare(f1, f3)) + + // valid interface + var config Config + var err error + ctx := context.Background() + config.VMConfig = vc.VMConfig{ + HypervisorType: vc.MockHypervisor, + AgentType: vc.NoopAgentType, + ProxyType: vc.NoopProxyType, + } + testDir, _ := ioutil.TempDir("", "vmfactory-tmp-") + config.VMConfig.HypervisorConfig = vc.HypervisorConfig{ + KernelPath: testDir, + ImagePath: testDir, + } + f1, err = NewFactory(ctx, config, false) + assert.Nil(err) + assert.True(deepCompare(f1, f1)) + f2, err = NewFactory(ctx, config, false) + assert.Nil(err) + assert.False(deepCompare(f1, f2)) +}