diff --git a/main.go b/main.go index 5be94a1..1401d1c 100644 --- a/main.go +++ b/main.go @@ -831,11 +831,20 @@ The validate command expects a configuration file as its only argument. Local fi Name: "passive", Usage: "List the system extensions for the passive boot entry", }, + &cli.BoolFlag{ + Name: "recovery", + Usage: "List the system extensions for the recovery boot entry", + }, + &cli.BoolFlag{ + Name: "common", + Usage: "List the system extensions for the common boot entry (applies to all boot states)", + }, }, Before: func(c *cli.Context) error { - if c.Bool("active") && c.Bool("passive") { - return fmt.Errorf("only one of --active or --passive can be set") + if moreThanOneEnabled(c.Bool("active"), c.Bool("passive"), c.Bool("recovery"), c.Bool("common")) { + return fmt.Errorf("only one of --active, --passive, --recovery or --common can be set") } + if err := checkRoot(); err != nil { return err } @@ -882,21 +891,32 @@ The validate command expects a configuration file as its only argument. Local fi Name: "passive", Usage: "Enable the system extension for the passive boot entry", }, + &cli.BoolFlag{ + Name: "recovery", + Usage: "List the system extensions for the recovery boot entry", + }, + &cli.BoolFlag{ + Name: "common", + Usage: "List the system extensions for the common boot entry (applies to all boot states)", + }, &cli.BoolFlag{ Name: "now", Usage: "Enable the system extension now and reload systemd-sysext", }, }, Before: func(c *cli.Context) error { - if c.Bool("active") && c.Bool("passive") { - return fmt.Errorf("only one of --active or --passive can be set") - } if c.Args().Len() != 1 { return fmt.Errorf("extension name required") } - if c.Bool("active") == false && c.Bool("passive") == false { - return fmt.Errorf("either --active or --passive must be set") + + if moreThanOneEnabled(c.Bool("active"), c.Bool("passive"), c.Bool("recovery"), c.Bool("common")) { + return fmt.Errorf("only one of --active, --passive, --recovery or --common can be set") } + + if noneOfEnabled(c.Bool("active"), c.Bool("passive"), c.Bool("recovery"), c.Bool("common")) { + return fmt.Errorf("either --active, --passive, --recovery or --common must be set") + } + if err := checkRoot(); err != nil { return err } @@ -936,20 +956,30 @@ The validate command expects a configuration file as its only argument. Local fi Name: "passive", Usage: "Disable the system extension for the passive boot entry", }, + &cli.BoolFlag{ + Name: "recovery", + Usage: "List the system extensions for the recovery boot entry", + }, + &cli.BoolFlag{ + Name: "common", + Usage: "List the system extensions for the common boot entry (applies to all boot states)", + }, &cli.BoolFlag{ Name: "now", Usage: "Disable the system extension now and reload systemd-sysext", }, }, Before: func(c *cli.Context) error { - if c.Bool("active") && c.Bool("passive") { - return fmt.Errorf("only one of --active or --passive can be set") - } if c.Args().Len() != 1 { return fmt.Errorf("extension name required") } - if c.Bool("active") == false && c.Bool("passive") == false { - return fmt.Errorf("either --active or --passive must be set") + + if moreThanOneEnabled(c.Bool("active"), c.Bool("passive"), c.Bool("recovery"), c.Bool("common")) { + return fmt.Errorf("only one of --active, --passive, --recovery or --common can be set") + } + + if noneOfEnabled(c.Bool("active"), c.Bool("passive"), c.Bool("recovery"), c.Bool("common")) { + return fmt.Errorf("either --active, --passive, --recovery or --common must be set") } if err := checkRoot(); err != nil { return err @@ -1171,3 +1201,26 @@ func getReleasesFromProvider(includePrereleases bool) ([]string, error) { return tags, nil } + +func moreThanOneEnabled(bools ...bool) bool { + count := 0 + for _, b := range bools { + if b { + count++ + } + if count > 1 { + return true + } + } + return false +} + +func noneOfEnabled(bools ...bool) bool { + count := 0 + for _, b := range bools { + if b { + count++ + } + } + return count == 0 +} diff --git a/pkg/action/sysext.go b/pkg/action/sysext.go index 197316d..4151c2d 100644 --- a/pkg/action/sysext.go +++ b/pkg/action/sysext.go @@ -29,9 +29,11 @@ import ( // TODO: On remove we should check if the extension is running and refresh systemd-sysext? YES const ( - sysextDir = "/var/lib/kairos/extensions/" - sysextDirActive = "/var/lib/kairos/extensions/active" - sysextDirPassive = "/var/lib/kairos/extensions/passive" + sysextDir = "/var/lib/kairos/extensions/" + sysextDirActive = "/var/lib/kairos/extensions/active" + sysextDirPassive = "/var/lib/kairos/extensions/passive" + sysextDirRecovery = "/var/lib/kairos/extensions/recovery" + sysextDirCommon = "/var/lib/kairos/extensions/common" ) // SysExtension represents a system extension @@ -54,6 +56,12 @@ func ListSystemExtensions(cfg *config.Config, bootState string) ([]SysExtension, case "passive": cfg.Logger.Debug("Listing passive system extensions") return getDirExtensions(cfg, sysextDirPassive) + case "recovery": + cfg.Logger.Debug("Listing recovery system extensions") + return getDirExtensions(cfg, sysextDirRecovery) + case "common": + cfg.Logger.Debug("Listing common system extensions") + return getDirExtensions(cfg, sysextDirCommon) default: cfg.Logger.Debug("Listing all system extensions (Enabled or not)") return getDirExtensions(cfg, sysextDir) @@ -124,6 +132,10 @@ func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) targetDir = sysextDirActive case "passive": targetDir = sysextDirPassive + case "recovery": + targetDir = sysextDirRecovery + case "common": + targetDir = sysextDirCommon default: return fmt.Errorf("boot state %s not supported", bootState) } @@ -157,7 +169,7 @@ func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) _, stateMatches := cfg.Fs.Stat(fmt.Sprintf("/run/cos/%s_mode", bootState)) // TODO: Check in UKI? cfg.Logger.Logger.Debug().Str("boot_state", bootState).Str("filecheck", fmt.Sprintf("/run/cos/%s_state", bootState)).Msg("Checking boot state") - if stateMatches == nil { + if stateMatches == nil || bootState == "common" { err = cfg.Fs.Symlink(filepath.Join(targetDir, extension.Name), filepath.Join("/run/extensions", extension.Name)) if err != nil { return fmt.Errorf("failed to create symlink for %s: %w", extension.Name, err) @@ -188,6 +200,10 @@ func DisableSystemExtension(cfg *config.Config, ext string, bootState string, no targetDir = sysextDirActive case "passive": targetDir = sysextDirPassive + case "recovery": + targetDir = sysextDirRecovery + case "common": + targetDir = sysextDirCommon default: return fmt.Errorf("boot state %s not supported", bootState) } @@ -213,7 +229,7 @@ func DisableSystemExtension(cfg *config.Config, ext string, bootState string, no // This is to avoid disabling the extension in the wrong boot state _, stateMatches := cfg.Fs.Stat(fmt.Sprintf("/run/cos/%s_mode", bootState)) cfg.Logger.Logger.Debug().Str("boot_state", bootState).Str("filecheck", fmt.Sprintf("/run/cos/%s_mode", bootState)).Msg("Checking boot state") - if stateMatches == nil { + if stateMatches == nil || bootState == "common" { // Remove the symlink from /run/extensions if is in there cfg.Logger.Logger.Debug().Str("stat", filepath.Join("/run/extensions", extension.Name)).Msg("Checking if symlink exists") _, stat := cfg.Fs.Readlink(filepath.Join("/run/extensions", extension.Name)) @@ -279,21 +295,15 @@ func RemoveSystemExtension(cfg *config.Config, extension string, now bool) error return nil } // Check if the extension is enabled in active or passive - enabledActive, err := GetSystemExtension(cfg, extension, "active") - if err == nil { - // Remove the symlink - if err := cfg.Fs.Remove(enabledActive.Location); err != nil { - return fmt.Errorf("failed to remove symlink for %s: %w", enabledActive.Name, err) + for _, state := range []string{"active", "passive", "recovery", "common"} { + enabled, err := GetSystemExtension(cfg, extension, state) + if err == nil { + // Remove the symlink + if err := cfg.Fs.Remove(enabled.Location); err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", enabled.Name, err) + } + cfg.Logger.Infof("System extension %s disabled from %s", enabled.Name, state) } - cfg.Logger.Infof("System extension %s disabled from active", enabledActive.Name) - } - enabledPassive, err := GetSystemExtension(cfg, extension, "passive") - if err == nil { - // Remove the symlink - if err := cfg.Fs.Remove(enabledPassive.Location); err != nil { - return fmt.Errorf("failed to remove symlink for %s: %w", enabledPassive.Name, err) - } - cfg.Logger.Infof("System extension %s disabled from passive", enabledPassive.Name) } // Remove the extension if err := cfg.Fs.RemoveAll(installed.Location); err != nil { diff --git a/pkg/action/sysext_test.go b/pkg/action/sysext_test.go index 94989c2..9f2f4be 100644 --- a/pkg/action/sysext_test.go +++ b/pkg/action/sysext_test.go @@ -95,6 +95,18 @@ var _ = Describe("Sysext Actions test", func() { Expect(err).ToNot(HaveOccurred()) Expect(extensions).To(BeEmpty()) }) + It("should return no extensions for recovery enabled extensions", func() { + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "recovery") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return no extensions for common enabled extensions", func() { + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "common") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) }) Describe("With empty dir", func() { It("should return no extensions", func() { @@ -112,6 +124,16 @@ var _ = Describe("Sysext Actions test", func() { Expect(err).ToNot(HaveOccurred()) Expect(extensions).To(BeEmpty()) }) + It("should return no extensions for recovery enabled extensions", func() { + extensions, err := action.ListSystemExtensions(config, "recovery") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return no extensions for common enabled extensions", func() { + extensions, err := action.ListSystemExtensions(config, "common") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) }) Describe("With dir with files", func() { It("should not return files that are not valid extensions", func() { @@ -178,6 +200,12 @@ var _ = Describe("Sysext Actions test", func() { extensions, err = action.ListSystemExtensions(config, "passive") Expect(err).ToNot(HaveOccurred()) Expect(extensions).To(BeEmpty()) + extensions, err = action.ListSystemExtensions(config, "recovery") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + extensions, err = action.ListSystemExtensions(config, "common") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) // Enable it for passive err = action.EnableSystemExtension(config, "valid.raw", "passive", false) Expect(err).ToNot(HaveOccurred()) @@ -199,6 +227,18 @@ var _ = Describe("Sysext Actions test", func() { Location: "/var/lib/kairos/extensions/active/valid.raw", }, })) + // Enable it for recovery + err = action.EnableSystemExtension(config, "valid.raw", "recovery", false) + Expect(err).ToNot(HaveOccurred()) + // Passive should have the extension + extensions, err = action.ListSystemExtensions(config, "recovery") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/recovery/valid.raw", + }, + })) }) It("should enable an installed extension and reload the system with it", func() { @@ -242,6 +282,43 @@ var _ = Describe("Sysext Actions test", func() { Expect(err).ToNot(HaveOccurred()) Expect(readlink).To(Equal(realPath)) }) + It("should enable an installed extension and reload the system with it if its a common one", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "common") + Expect(err).ToNot(HaveOccurred()) + // This basically returns an error if the command is not executed + Expect(runner.IncludesCmds([][]string{ + {"systemctl", "restart", "systemd-sysext"}, + })).To(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for common + err = action.EnableSystemExtension(config, "valid.raw", "common", true) + Expect(err).ToNot(HaveOccurred()) + // Should have refreshed the systemd-sysext + Expect(runner.IncludesCmds([][]string{ + {"systemctl", "restart", "systemd-sysext"}, + })).ToNot(HaveOccurred()) + // Should be enabled + extensions, err = action.ListSystemExtensions(config, "common") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/common/valid.raw", + }, + })) + // Active and Passive should be empty + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Symlink should be created in /run/extensions + _, err = config.Fs.Stat("/run/extensions/valid.raw") + Expect(err).ToNot(HaveOccurred()) + }) It("should enable an installed extension and NOT reload the system with it if we are on the wrong boot state", func() { err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) Expect(err).ToNot(HaveOccurred())