diff --git a/main.go b/main.go index 60da4cb..5be94a1 100644 --- a/main.go +++ b/main.go @@ -936,6 +936,10 @@ The validate command expects a configuration file as its only argument. Local fi Name: "passive", Usage: "Disable the system extension for the passive boot entry", }, + &cli.BoolFlag{ + Name: "now", + Usage: "Disable the system extension now and reload systemd-sysext", + }, }, Before: func(c *cli.Context) error { if c.Bool("active") && c.Bool("passive") { @@ -965,7 +969,7 @@ The validate command expects a configuration file as its only argument. Local fi bootState = "passive" } ext := c.Args().First() - if err := action.DisableSystemExtension(cfg, ext, bootState); err != nil { + if err := action.DisableSystemExtension(cfg, ext, bootState, c.Bool("now")); err != nil { cfg.Logger.Logger.Error().Err(err).Msg("failed disabling system extension") return err } @@ -1002,6 +1006,12 @@ The validate command expects a configuration file as its only argument. Local fi Usage: "Remove a system extension", UsageText: "remove EXTENSION", Description: "Remove a installed system extension", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "now", + Usage: "Remove the system extension now and reload systemd-sysext", + }, + }, Action: func(c *cli.Context) error { if c.Args().Len() != 1 { return fmt.Errorf("extension required") @@ -1011,7 +1021,7 @@ The validate command expects a configuration file as its only argument. Local fi if err != nil { return err } - if err := action.RemoveSystemExtension(cfg, extension); err != nil { + if err := action.RemoveSystemExtension(cfg, extension, c.Bool("now")); err != nil { cfg.Logger.Logger.Error().Err(err).Msg("failed removing system extension") return err } diff --git a/pkg/action/sysext.go b/pkg/action/sysext.go index b276317..dbfe323 100644 --- a/pkg/action/sysext.go +++ b/pkg/action/sysext.go @@ -152,7 +152,6 @@ func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) cfg.Logger.Infof("System extension %s enabled in %s", extension.Name, bootState) if now { - // Check if the boot state is the same as the one we are enabling // This is to avoid enabling the extension in the wrong boot state _, stateMatches := cfg.Fs.Stat(fmt.Sprintf("/run/cos/%s_mode", bootState)) @@ -182,7 +181,7 @@ func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) // DisableSystemExtension disables a system extension that is already in the system for a given bootstate // It removes the symlink from the target dir according to the bootstate given -func DisableSystemExtension(cfg *config.Config, ext string, bootState string) error { +func DisableSystemExtension(cfg *config.Config, ext string, bootState string, now bool) error { var targetDir string switch bootState { case "active": @@ -209,6 +208,32 @@ func DisableSystemExtension(cfg *config.Config, ext string, bootState string) er if err := cfg.Fs.Remove(extension.Location); err != nil { return fmt.Errorf("failed to remove symlink for %s: %w", ext, err) } + if now { + // Check if the boot state is the same as the one we are disabling + // This is to avoid disabling the extension in the wrong boot state + _, stateMatches := cfg.Fs.Stat(fmt.Sprintf("/run/cos/%s_mode", bootState)) + cfg.Logger.Logger.Debug().Str("boot_state", bootState).Str("filecheck", fmt.Sprintf("/run/cos/%s_state", bootState)).Msg("Checking boot state") + if stateMatches == nil { + // Remove the symlink from /run/extensions if is in there + _, stat := cfg.Fs.Stat(filepath.Join("/run/extensions", extension.Name)) + if stat == nil { + err = cfg.Fs.Remove(filepath.Join("/run/extensions", extension.Name)) + if err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", extension.Name, err) + } + cfg.Logger.Infof("System extension %s disabled from /run/extensions", extension.Name) + // Now that its removed we refresh systemd-sysext + output, err := cfg.Runner.Run("systemctl", "restart", "systemd-sysext") + if err != nil { + cfg.Logger.Logger.Err(err).Str("output", string(output)).Msg("Failed to refresh systemd-sysext") + return err + } + cfg.Logger.Infof("System extension %s refreshed by systemd-sysext", extension.Name) + } + } else { + cfg.Logger.Infof("System extension %s disabled in %s but not refreshed by systemd-sysext as we are currently not booted in %s", extension.Name, bootState, bootState) + } + } cfg.Logger.Infof("System extension %s disabled in %s", ext, bootState) return nil } @@ -240,7 +265,7 @@ func InstallSystemExtension(cfg *config.Config, uri string) error { // It will remove any symlinks to the extension // Then it will remove the extension // It will check if the extension is installed before doing anything -func RemoveSystemExtension(cfg *config.Config, extension string) error { +func RemoveSystemExtension(cfg *config.Config, extension string, now bool) error { // Check if the extension is installed installed, err := GetSystemExtension(cfg, extension, "") if err != nil { @@ -271,6 +296,28 @@ func RemoveSystemExtension(cfg *config.Config, extension string) error { if err := cfg.Fs.RemoveAll(installed.Location); err != nil { return fmt.Errorf("failed to remove extension %s: %w", installed.Name, err) } + + if now { + // Here as we are removing the extension we need to check if its in /run/extensions + // We dont care about the bootState because we are removing it from all + _, stat := cfg.Fs.Stat(filepath.Join("/run/extensions", installed.Name)) + if stat == nil { + err = cfg.Fs.Remove(filepath.Join("/run/extensions", installed.Name)) + if err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", installed.Name, err) + } + cfg.Logger.Infof("System extension %s removed from /run/extensions", installed.Name) + // Now that its removed we refresh systemd-sysext + output, err := cfg.Runner.Run("systemctl", "restart", "systemd-sysext") + if err != nil { + cfg.Logger.Logger.Err(err).Str("output", string(output)).Msg("Failed to refresh systemd-sysext") + return err + } + cfg.Logger.Infof("System extension %s refreshed by systemd-sysext", installed.Name) + } + + } + cfg.Logger.Infof("System extension %s removed", installed.Name) return nil } diff --git a/pkg/action/sysext_test.go b/pkg/action/sysext_test.go index 58eb2c0..94989c2 100644 --- a/pkg/action/sysext_test.go +++ b/pkg/action/sysext_test.go @@ -323,7 +323,7 @@ var _ = Describe("Sysext Actions test", func() { }) Describe("Disabling extensions", func() { It("should fail if bootState is not valid", func() { - err := action.DisableSystemExtension(config, "whatever", "invalid") + err := action.DisableSystemExtension(config, "whatever", "invalid", false) Expect(err).To(HaveOccurred()) }) It("should disable an enabled extension", func() { @@ -344,7 +344,7 @@ var _ = Describe("Sysext Actions test", func() { }, })) // Disable it - err = action.DisableSystemExtension(config, "valid.raw", "active") + err = action.DisableSystemExtension(config, "valid.raw", "active", false) Expect(err).ToNot(HaveOccurred()) extensions, err = action.ListSystemExtensions(config, "active") Expect(err).ToNot(HaveOccurred()) @@ -368,7 +368,7 @@ var _ = Describe("Sysext Actions test", func() { }, })) // Disable a non enabled extension - err = action.DisableSystemExtension(config, "invalid.raw", "active") + err = action.DisableSystemExtension(config, "invalid.raw", "active", false) Expect(err).ToNot(HaveOccurred()) extensions, err = action.ListSystemExtensions(config, "active") Expect(err).ToNot(HaveOccurred()) @@ -455,7 +455,7 @@ var _ = Describe("Sysext Actions test", func() { Location: "/var/lib/kairos/extensions/valid.raw", }, })) - err = action.RemoveSystemExtension(config, "valid.raw") + err = action.RemoveSystemExtension(config, "valid.raw", false) Expect(err).ToNot(HaveOccurred()) extensions, err = action.ListSystemExtensions(config, "") Expect(err).ToNot(HaveOccurred()) @@ -478,7 +478,7 @@ var _ = Describe("Sysext Actions test", func() { Location: "/var/lib/kairos/extensions/active/valid.raw", }, })) - err = action.RemoveSystemExtension(config, "valid.raw") + err = action.RemoveSystemExtension(config, "valid.raw", false) Expect(err).ToNot(HaveOccurred()) // Check if it is removed from active extensions, err = action.ListSystemExtensions(config, "active") @@ -504,7 +504,7 @@ var _ = Describe("Sysext Actions test", func() { Location: "/var/lib/kairos/extensions/valid.raw", }, })) - err = action.RemoveSystemExtension(config, "invalid.raw") + err = action.RemoveSystemExtension(config, "invalid.raw", false) Expect(err).To(HaveOccurred()) }) })