diff --git a/main.go b/main.go index 34dd97f..0b13df7 100644 --- a/main.go +++ b/main.go @@ -855,11 +855,16 @@ The validate command expects a configuration file as its only argument. Local fi } var bootState string - if c.Bool("active") { - bootState = "active" - } - if c.Bool("passive") { - bootState = "passive" + for k, v := range map[string]bool{ + "active": c.Bool("active"), + "passive": c.Bool("passive"), + "recovery": c.Bool("recovery"), + "common": c.Bool("common"), + } { + if v { + bootState = k + break + } } out, err := action.ListSystemExtensions(cfg, bootState) if err != nil { @@ -926,12 +931,18 @@ The validate command expects a configuration file as its only argument. Local fi return err } var bootState string - if c.Bool("active") { - bootState = "active" - } - if c.Bool("passive") { - bootState = "passive" + for k, v := range map[string]bool{ + "active": c.Bool("active"), + "passive": c.Bool("passive"), + "recovery": c.Bool("recovery"), + "common": c.Bool("common"), + } { + if v { + bootState = k + break + } } + ext := c.Args().First() if err := action.EnableSystemExtension(cfg, ext, bootState, c.Bool("now")); err != nil { cfg.Logger.Logger.Error().Err(err).Msg("failed enabling system extension") @@ -990,11 +1001,16 @@ The validate command expects a configuration file as its only argument. Local fi return err } var bootState string - if c.Bool("active") { - bootState = "active" - } - if c.Bool("passive") { - bootState = "passive" + for k, v := range map[string]bool{ + "active": c.Bool("active"), + "passive": c.Bool("passive"), + "recovery": c.Bool("recovery"), + "common": c.Bool("common"), + } { + if v { + bootState = k + break + } } ext := c.Args().First() if err := action.DisableSystemExtension(cfg, ext, bootState, c.Bool("now")); err != nil { diff --git a/pkg/action/sysext.go b/pkg/action/sysext.go index 4151c2d..7828270 100644 --- a/pkg/action/sysext.go +++ b/pkg/action/sysext.go @@ -46,26 +46,25 @@ func (s *SysExtension) String() string { return s.Name } + +func dirFromBootState(bootState string) string { + switch bootState { + case "active": + return sysextDirActive + case "passive": + return sysextDirPassive + case "recovery": + return sysextDirRecovery + case "common": + return sysextDirCommon + default: + return sysextDir + } +} // ListSystemExtensions lists the system extensions in the given directory // If none is passed then it shows the generic ones func ListSystemExtensions(cfg *config.Config, bootState string) ([]SysExtension, error) { - switch bootState { - case "active": - cfg.Logger.Debug("Listing active system extensions") - return getDirExtensions(cfg, sysextDirActive) - case "passive": - cfg.Logger.Debug("Listing passive system extensions") - return getDirExtensions(cfg, sysextDirPassive) - case "recovery": - cfg.Logger.Debug("Listing recovery system extensions") - return getDirExtensions(cfg, sysextDirRecovery) - case "common": - cfg.Logger.Debug("Listing common system extensions") - return getDirExtensions(cfg, sysextDirCommon) - default: - cfg.Logger.Debug("Listing all system extensions (Enabled or not)") - return getDirExtensions(cfg, sysextDir) - } + return getDirExtensions(cfg, dirFromBootState(bootState)) } // getDirExtensions lists the system extensions in the given directory @@ -126,19 +125,7 @@ func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) return err } - var targetDir string - switch bootState { - case "active": - targetDir = sysextDirActive - case "passive": - targetDir = sysextDirPassive - case "recovery": - targetDir = sysextDirRecovery - case "common": - targetDir = sysextDirCommon - default: - return fmt.Errorf("boot state %s not supported", bootState) - } + targetDir := dirFromBootState(bootState) // Check if the target dir exists and create it if it doesn't if _, err := cfg.Fs.Stat(targetDir); os.IsNotExist(err) { @@ -194,19 +181,7 @@ func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) // DisableSystemExtension disables a system extension that is already in the system for a given bootstate // It removes the symlink from the target dir according to the bootstate given func DisableSystemExtension(cfg *config.Config, ext string, bootState string, now bool) error { - var targetDir string - switch bootState { - case "active": - targetDir = sysextDirActive - case "passive": - targetDir = sysextDirPassive - case "recovery": - targetDir = sysextDirRecovery - case "common": - targetDir = sysextDirCommon - default: - return fmt.Errorf("boot state %s not supported", bootState) - } + targetDir := dirFromBootState(bootState) // Check if the target dir exists if _, err := cfg.Fs.Stat(targetDir); os.IsNotExist(err) { diff --git a/pkg/action/sysext_test.go b/pkg/action/sysext_test.go index 9f2f4be..ef2dc98 100644 --- a/pkg/action/sysext_test.go +++ b/pkg/action/sysext_test.go @@ -13,7 +13,7 @@ import ( "github.com/twpayne/go-vfs/v5/vfst" ) -var _ = Describe("Sysext Actions test", func() { +var _ = Describe("Sysext Actions test", Label("sysext"), func() { var config *agentConfig.Config var runner *v1mock.FakeRunner var fs vfs.FS @@ -173,12 +173,6 @@ var _ = Describe("Sysext Actions test", func() { }) }) Describe("Enabling extensions", func() { - It("should fail to enable a extension if bootState is not valid", func() { - err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) - Expect(err).ToNot(HaveOccurred()) - err = action.EnableSystemExtension(config, "valid", "invalid", false) - Expect(err).To(HaveOccurred()) - }) It("should enable an installed extension", func() { err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) Expect(err).ToNot(HaveOccurred()) @@ -399,10 +393,6 @@ var _ = Describe("Sysext Actions test", func() { }) Describe("Disabling extensions", func() { - It("should fail if bootState is not valid", func() { - err := action.DisableSystemExtension(config, "whatever", "invalid", false) - Expect(err).To(HaveOccurred()) - }) It("should disable an enabled extension", func() { err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) Expect(err).ToNot(HaveOccurred())