diff --git a/main.go b/main.go index 7f828ce..5be94a1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "os/exec" "path/filepath" "regexp" "runtime" @@ -805,6 +806,231 @@ The validate command expects a configuration file as its only argument. Local fi return action.ListBootEntries(cfg) }, }, + { + Name: "sysext", + Usage: "sysext subcommands", + Description: "sysext subcommands", + Before: func(c *cli.Context) error { + _, err := exec.LookPath("systemd-sysext") + if err != nil { + return fmt.Errorf("systemd-sysext not found in PATH") + } + return nil + }, + Subcommands: []*cli.Command{ + { + Name: "list", + Usage: "List all the installed system extensions", + Description: "List all the installed system extensions", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "active", + Usage: "List the system extensions for the active boot entry", + }, + &cli.BoolFlag{ + Name: "passive", + Usage: "List the system extensions for the passive boot entry", + }, + }, + Before: func(c *cli.Context) error { + if c.Bool("active") && c.Bool("passive") { + return fmt.Errorf("only one of --active or --passive can be set") + } + if err := checkRoot(); err != nil { + return err + } + return nil + }, + Action: func(c *cli.Context) error { + cfg, err := agentConfig.Scan(collector.Directories(constants.GetUserConfigDirs()...), collector.NoLogs) + if err != nil { + return err + } + var bootState string + + if c.Bool("active") { + bootState = "active" + } + if c.Bool("passive") { + bootState = "passive" + } + out, err := action.ListSystemExtensions(cfg, bootState) + if err != nil { + return err + } + if len(out) == 0 { + cfg.Logger.Logger.Info().Msg("No system extensions found") + return nil + } + for _, ext := range out { + cfg.Logger.Info(litter.Sdump(ext)) + } + return nil + }, + }, + { + Name: "enable", + Usage: "Enable a installed system extension for a give entry", + UsageText: "enable [--active|--passive] EXTENSION", + Description: "Enable a system extension for a given boot entry (active or passive)", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "active", + Usage: "Enable the system extension for the active boot entry", + }, + &cli.BoolFlag{ + Name: "passive", + Usage: "Enable the system extension for the passive boot entry", + }, + &cli.BoolFlag{ + Name: "now", + Usage: "Enable the system extension now and reload systemd-sysext", + }, + }, + Before: func(c *cli.Context) error { + if c.Bool("active") && c.Bool("passive") { + return fmt.Errorf("only one of --active or --passive can be set") + } + if c.Args().Len() != 1 { + return fmt.Errorf("extension name required") + } + if c.Bool("active") == false && c.Bool("passive") == false { + return fmt.Errorf("either --active or --passive must be set") + } + if err := checkRoot(); err != nil { + return err + } + return nil + }, + Action: func(c *cli.Context) error { + cfg, err := agentConfig.Scan(collector.Directories(constants.GetUserConfigDirs()...), collector.NoLogs) + if err != nil { + return err + } + var bootState string + if c.Bool("active") { + bootState = "active" + } + if c.Bool("passive") { + bootState = "passive" + } + ext := c.Args().First() + if err := action.EnableSystemExtension(cfg, ext, bootState, c.Bool("now")); err != nil { + cfg.Logger.Logger.Error().Err(err).Msg("failed enabling system extension") + return err + } + return nil + }, + }, + { + Name: "disable", + Usage: "Disable a installed system extension for a give entry", + UsageText: "disable [--active|--passive] EXTENSION", + Description: "Disable a system extension for a given boot entry (active or passive)", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "active", + Usage: "Disable the system extension for the active boot entry", + }, + &cli.BoolFlag{ + Name: "passive", + Usage: "Disable the system extension for the passive boot entry", + }, + &cli.BoolFlag{ + Name: "now", + Usage: "Disable the system extension now and reload systemd-sysext", + }, + }, + Before: func(c *cli.Context) error { + if c.Bool("active") && c.Bool("passive") { + return fmt.Errorf("only one of --active or --passive can be set") + } + if c.Args().Len() != 1 { + return fmt.Errorf("extension name required") + } + if c.Bool("active") == false && c.Bool("passive") == false { + return fmt.Errorf("either --active or --passive must be set") + } + if err := checkRoot(); err != nil { + return err + } + return nil + }, + Action: func(c *cli.Context) error { + cfg, err := agentConfig.Scan(collector.Directories(constants.GetUserConfigDirs()...), collector.NoLogs) + if err != nil { + return err + } + var bootState string + if c.Bool("active") { + bootState = "active" + } + if c.Bool("passive") { + bootState = "passive" + } + ext := c.Args().First() + if err := action.DisableSystemExtension(cfg, ext, bootState, c.Bool("now")); err != nil { + cfg.Logger.Logger.Error().Err(err).Msg("failed disabling system extension") + return err + } + return nil + }, + }, + { + Name: "install", + Usage: "Install a system extension", + UsageText: "install URI", + Description: "Install a system extension from a given URI", + Action: func(c *cli.Context) error { + if c.Args().Len() != 1 { + return fmt.Errorf("extension URI required") + } + uri := c.Args().First() + if err := validateSourceSysext(uri); err != nil { + return err + } + cfg, err := agentConfig.Scan(collector.Directories(constants.GetUserConfigDirs()...), collector.NoLogs) + if err != nil { + return err + } + if err := action.InstallSystemExtension(cfg, uri); err != nil { + cfg.Logger.Logger.Error().Err(err).Msg("failed installing system extension") + return err + } + cfg.Logger.Logger.Info().Msgf("System extension %s installed", uri) + return nil + }, + }, + { + Name: "remove", + Usage: "Remove a system extension", + UsageText: "remove EXTENSION", + Description: "Remove a installed system extension", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "now", + Usage: "Remove the system extension now and reload systemd-sysext", + }, + }, + Action: func(c *cli.Context) error { + if c.Args().Len() != 1 { + return fmt.Errorf("extension required") + } + extension := c.Args().First() + cfg, err := agentConfig.Scan(collector.Directories(constants.GetUserConfigDirs()...), collector.NoLogs) + if err != nil { + return err + } + if err := action.RemoveSystemExtension(cfg, extension, c.Bool("now")); err != nil { + cfg.Logger.Logger.Error().Err(err).Msg("failed removing system extension") + return err + } + cfg.Logger.Logger.Info().Msgf("System extension %s removed", extension) + return nil + }, + }, + }, + }, } func main() { @@ -896,6 +1122,22 @@ func validateSource(source string) error { return nil } +func validateSourceSysext(source string) error { + if source == "" { + return nil + } + + r, err := regexp.Compile(`^oci:|^file:|^http:|^https:`) + if err != nil { + return err + } + if !r.MatchString(source) { + return fmt.Errorf("source %s does not match any of oci:, file: or http(s): ", source) + } + + return nil +} + // Check func bootFromLiveMedia() bool { // Check if the system is booted from a LIVE media by checking if the file /run/cos/livecd is present diff --git a/pkg/action/sysext.go b/pkg/action/sysext.go new file mode 100644 index 0000000..197316d --- /dev/null +++ b/pkg/action/sysext.go @@ -0,0 +1,423 @@ +package action + +import ( + "fmt" + "github.com/distribution/reference" + "github.com/kairos-io/kairos-agent/v2/pkg/config" + "github.com/kairos-io/kairos-sdk/types" + "github.com/twpayne/go-vfs/v5" + "net/url" + "os" + "path/filepath" + "regexp" +) + +// Implementation details for not trusted boot +// sysext are stored under +// /var/lib/kairos/extensions/ +// we link them to /var/lib/kairos/extensions/{active,passive} depending on where we want it to be enabled +// Immucore on boot after mounting the persistent dir, will check those dirs\ +// it will then create the proper links to them under /run/extensions +// This means they are enabled on boot and they are ephemeral, nothing is left behind in the actual sysext dirs +// This prevents us from having to clean up in different dirs, we can just do cleaning in our dirs (remove links) +// and on reboot they will not be enabled on boot +// So all the actions (list, upgrade, download, remove) will be done on the persistent dir +// And on boot we dinamycally link and enable them based on the boot type (active,passive) via immucore + +// TODO: Check which extensions are running? is that possible? +// TODO: On disable we should check if the extension is running and refresh systemd-sysext? YES +// TODO: On remove we should check if the extension is running and refresh systemd-sysext? YES + +const ( + sysextDir = "/var/lib/kairos/extensions/" + sysextDirActive = "/var/lib/kairos/extensions/active" + sysextDirPassive = "/var/lib/kairos/extensions/passive" +) + +// SysExtension represents a system extension +type SysExtension struct { + Name string + Location string +} + +func (s *SysExtension) String() string { + return s.Name +} + +// ListSystemExtensions lists the system extensions in the given directory +// If none is passed then it shows the generic ones +func ListSystemExtensions(cfg *config.Config, bootState string) ([]SysExtension, error) { + switch bootState { + case "active": + cfg.Logger.Debug("Listing active system extensions") + return getDirExtensions(cfg, sysextDirActive) + case "passive": + cfg.Logger.Debug("Listing passive system extensions") + return getDirExtensions(cfg, sysextDirPassive) + default: + cfg.Logger.Debug("Listing all system extensions (Enabled or not)") + return getDirExtensions(cfg, sysextDir) + } +} + +// getDirExtensions lists the system extensions in the given directory +func getDirExtensions(cfg *config.Config, dir string) ([]SysExtension, error) { + var out []SysExtension + // get all the extensions in the sysextDir + // Try to create the dir if it does not exist + if _, err := cfg.Fs.Stat(dir); os.IsNotExist(err) { + if err := vfs.MkdirAll(cfg.Fs, dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create target dir %s: %w", dir, err) + } + } + entries, err := cfg.Fs.ReadDir(dir) + // We don't care if the dir does not exist, we just return an empty list + if err != nil && !os.IsNotExist(err) { + return nil, err + } + for _, entry := range entries { + if !entry.IsDir() && filepath.Ext(entry.Name()) == ".raw" { + out = append(out, SysExtension{Name: entry.Name(), Location: filepath.Join(dir, entry.Name())}) + } + } + return out, nil +} + +// GetSystemExtension returns the system extension for a given name +func GetSystemExtension(cfg *config.Config, name, bootState string) (SysExtension, error) { + // Get a list of all installed system extensions + installed, err := ListSystemExtensions(cfg, bootState) + if err != nil { + return SysExtension{}, err + } + // Check if the extension is installed + // regex against the name + re, err := regexp.Compile(name) + if err != nil { + return SysExtension{}, err + } + for _, ext := range installed { + if re.MatchString(ext.Name) { + return ext, nil + } + } + // If not, return an error + return SysExtension{}, fmt.Errorf("system extension %s not found", name) +} + +// EnableSystemExtension enables a system extension that is already in the system for a given bootstate +// It creates a symlink to the extension in the target dir according to the bootstate given +// It will create the target dir if it does not exist +// It will check if the extension is already enabled but not fail if it is +// It will check if the extension is installed +// If now is true, it will enable the extension immediately by linking it to /run/extensions and refreshing systemd-sysext +func EnableSystemExtension(cfg *config.Config, ext, bootState string, now bool) error { + // first check if the extension is installed + extension, err := GetSystemExtension(cfg, ext, "") + if err != nil { + return err + } + + var targetDir string + switch bootState { + case "active": + targetDir = sysextDirActive + case "passive": + targetDir = sysextDirPassive + default: + return fmt.Errorf("boot state %s not supported", bootState) + } + + // Check if the target dir exists and create it if it doesn't + if _, err := cfg.Fs.Stat(targetDir); os.IsNotExist(err) { + if err := vfs.MkdirAll(cfg.Fs, targetDir, 0755); err != nil { + return fmt.Errorf("failed to create target dir %s: %w", targetDir, err) + } + } + + // Check if the extension is already enabled + enabled, err := GetSystemExtension(cfg, ext, bootState) + // This doesnt fail if we have it already enabled + if err == nil { + if enabled.Name == extension.Name { + cfg.Logger.Infof("System extension %s is already enabled in %s", extension.Name, bootState) + return nil + } + } + + // Create a symlink to the extension in the target dir + if err := cfg.Fs.Symlink(extension.Location, filepath.Join(targetDir, extension.Name)); err != nil { + return fmt.Errorf("failed to create symlink for %s: %w", extension.Name, err) + } + cfg.Logger.Infof("System extension %s enabled in %s", extension.Name, bootState) + + if now { + // Check if the boot state is the same as the one we are enabling + // This is to avoid enabling the extension in the wrong boot state + _, stateMatches := cfg.Fs.Stat(fmt.Sprintf("/run/cos/%s_mode", bootState)) + // TODO: Check in UKI? + cfg.Logger.Logger.Debug().Str("boot_state", bootState).Str("filecheck", fmt.Sprintf("/run/cos/%s_state", bootState)).Msg("Checking boot state") + if stateMatches == nil { + err = cfg.Fs.Symlink(filepath.Join(targetDir, extension.Name), filepath.Join("/run/extensions", extension.Name)) + if err != nil { + return fmt.Errorf("failed to create symlink for %s: %w", extension.Name, err) + } + cfg.Logger.Infof("System extension %s enabled in /run/extensions", extension.Name) + // It makes the sysext check the extension for a valid signature + // Refresh systemd-sysext by restarting the service. As the config is set via the service overrides to nice things + output, err := cfg.Runner.Run("systemctl", "restart", "systemd-sysext") + if err != nil { + cfg.Logger.Logger.Err(err).Str("output", string(output)).Msg("Failed to refresh systemd-sysext") + return err + } + cfg.Logger.Infof("System extension %s merged by systemd-sysext", extension.Name) + } else { + cfg.Logger.Infof("System extension %s enabled in %s but not merged by systemd-sysext as we are currently not booted in %s", extension.Name, bootState, bootState) + } + + } + return nil +} + +// DisableSystemExtension disables a system extension that is already in the system for a given bootstate +// It removes the symlink from the target dir according to the bootstate given +func DisableSystemExtension(cfg *config.Config, ext string, bootState string, now bool) error { + var targetDir string + switch bootState { + case "active": + targetDir = sysextDirActive + case "passive": + targetDir = sysextDirPassive + default: + return fmt.Errorf("boot state %s not supported", bootState) + } + + // Check if the target dir exists + if _, err := cfg.Fs.Stat(targetDir); os.IsNotExist(err) { + return fmt.Errorf("target dir %s does not exist", targetDir) + } + + // Check if the extension is enabled, do not fail if it is not + extension, err := GetSystemExtension(cfg, ext, bootState) + if err != nil { + cfg.Logger.Infof("system extension %s is not enabled in %s", ext, bootState) + return nil + } + + // Remove the symlink + if err := cfg.Fs.Remove(extension.Location); err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", ext, err) + } + if now { + // Check if the boot state is the same as the one we are disabling + // This is to avoid disabling the extension in the wrong boot state + _, stateMatches := cfg.Fs.Stat(fmt.Sprintf("/run/cos/%s_mode", bootState)) + cfg.Logger.Logger.Debug().Str("boot_state", bootState).Str("filecheck", fmt.Sprintf("/run/cos/%s_mode", bootState)).Msg("Checking boot state") + if stateMatches == nil { + // Remove the symlink from /run/extensions if is in there + cfg.Logger.Logger.Debug().Str("stat", filepath.Join("/run/extensions", extension.Name)).Msg("Checking if symlink exists") + _, stat := cfg.Fs.Readlink(filepath.Join("/run/extensions", extension.Name)) + if stat == nil { + err = cfg.Fs.Remove(filepath.Join("/run/extensions", extension.Name)) + if err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", extension.Name, err) + } + cfg.Logger.Infof("System extension %s disabled from /run/extensions", extension.Name) + // Now that its removed we refresh systemd-sysext + output, err := cfg.Runner.Run("systemctl", "restart", "systemd-sysext") + if err != nil { + cfg.Logger.Logger.Err(err).Str("output", string(output)).Msg("Failed to refresh systemd-sysext") + return err + } + cfg.Logger.Infof("System extension %s refreshed by systemd-sysext", extension.Name) + } else { + cfg.Logger.Logger.Info().Msg("Extension not in /run/extensions, not refreshing") + } + } else { + cfg.Logger.Infof("System extension %s disabled in %s but not refreshed by systemd-sysext as we are currently not booted in %s", extension.Name, bootState, bootState) + } + } + cfg.Logger.Infof("System extension %s disabled in %s", ext, bootState) + return nil +} + +// InstallSystemExtension installs a system extension from a given URI +// It will download the extension and extract it to the target dir +// It will check if the extension is already installed before doing anything +func InstallSystemExtension(cfg *config.Config, uri string) error { + // Parse the URI + download, err := parseURI(cfg, uri) + if err != nil { + return fmt.Errorf("failed to parse URI %s: %w", uri, err) + } + // Check if directory exists or create it + if _, err := cfg.Fs.Stat(sysextDir); os.IsNotExist(err) { + if err := vfs.MkdirAll(cfg.Fs, sysextDir, 0755); err != nil { + return fmt.Errorf("failed to create target dir %s: %w", sysextDir, err) + } + } + // Download the extension + if err := download.Download(sysextDir); err != nil { + return err + } + + return nil +} + +// RemoveSystemExtension removes a system extension from the system +// It will remove any symlinks to the extension +// Then it will remove the extension +// It will check if the extension is installed before doing anything +func RemoveSystemExtension(cfg *config.Config, extension string, now bool) error { + // Check if the extension is installed + installed, err := GetSystemExtension(cfg, extension, "") + if err != nil { + return err + } + if installed.Name == "" && installed.Location == "" { + cfg.Logger.Infof("System extension %s is not installed", extension) + return nil + } + // Check if the extension is enabled in active or passive + enabledActive, err := GetSystemExtension(cfg, extension, "active") + if err == nil { + // Remove the symlink + if err := cfg.Fs.Remove(enabledActive.Location); err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", enabledActive.Name, err) + } + cfg.Logger.Infof("System extension %s disabled from active", enabledActive.Name) + } + enabledPassive, err := GetSystemExtension(cfg, extension, "passive") + if err == nil { + // Remove the symlink + if err := cfg.Fs.Remove(enabledPassive.Location); err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", enabledPassive.Name, err) + } + cfg.Logger.Infof("System extension %s disabled from passive", enabledPassive.Name) + } + // Remove the extension + if err := cfg.Fs.RemoveAll(installed.Location); err != nil { + return fmt.Errorf("failed to remove extension %s: %w", installed.Name, err) + } + + if now { + // Here as we are removing the extension we need to check if its in /run/extensions + // We dont care about the bootState because we are removing it from all + _, stat := cfg.Fs.Readlink(filepath.Join("/run/extensions", installed.Name)) + if stat == nil { + err = cfg.Fs.Remove(filepath.Join("/run/extensions", installed.Name)) + if err != nil { + return fmt.Errorf("failed to remove symlink for %s: %w", installed.Name, err) + } + cfg.Logger.Infof("System extension %s removed from /run/extensions", installed.Name) + // Now that its removed we refresh systemd-sysext + output, err := cfg.Runner.Run("systemctl", "restart", "systemd-sysext") + if err != nil { + cfg.Logger.Logger.Err(err).Str("output", string(output)).Msg("Failed to refresh systemd-sysext") + return err + } + cfg.Logger.Infof("System extension %s refreshed by systemd-sysext", installed.Name) + } else { + cfg.Logger.Logger.Info().Msg("Extension not in /run/extensions, not refreshing") + } + } + + cfg.Logger.Infof("System extension %s removed", installed.Name) + return nil +} + +// ParseURI parses a URI and returns a SourceDownload +// implementation based on the scheme of the URI +func parseURI(cfg *config.Config, uri string) (SourceDownload, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + scheme := u.Scheme + value := u.Opaque + if value == "" { + value = filepath.Join(u.Host, u.Path) + } + switch scheme { + case "oci", "docker", "container": + n, err := reference.ParseNormalizedNamed(value) + if err != nil { + return nil, fmt.Errorf("invalid image reference %s", value) + } else if reference.IsNameOnly(n) { + value += ":latest" + } + return &dockerSource{value, cfg}, nil + case "file": + return &fileSource{value, cfg}, nil + case "http", "https": + // Pass the full uri including the protocol + return &httpSource{uri, cfg}, nil + default: + return nil, fmt.Errorf("invalid URI reference %s", uri) + } +} + +// SourceDownload is an interface for downloading system extensions +// from different sources. It allows for different implementations +// for different sources of system extensions, such as files, directories, +// or docker images. The interface defines a single method, Download, +// which takes a destination path as an argument and returns an error +type SourceDownload interface { + Download(string) error +} + +// fileSource is a struct that implements the SourceDownload interface +// for downloading system extensions from a file. It has two fields, +// uri, which is the URI of the file to be downloaded and cfg which points to the Config +// The Download method takes a destination path as an argument and returns an error if the +// download fails. +type fileSource struct { + uri string + cfg *config.Config +} + +// Download just copies the file to the destination +// As this is a file source, we just copy the file to the destination, not much to it +func (f *fileSource) Download(dst string) error { + src, err := f.cfg.Fs.ReadFile(f.uri) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", f.uri, err) + } + + stat, _ := f.cfg.Fs.Stat(f.uri) + dstFile := filepath.Join(dst, filepath.Base(f.uri)) + f.cfg.Logger.Logger.Debug().Str("uri", f.uri).Str("target", dstFile).Msg("Copying system extension") + // Keep original permissions + if err = f.cfg.Fs.WriteFile(dstFile, src, stat.Mode()); err != nil { + return fmt.Errorf("failed to copy file %s to %s: %w", f.uri, dstFile, err) + } + + return nil +} + +type httpSource struct { + uri string + cfg *config.Config +} + +func (h httpSource) Download(s string) error { + // Download the file from the URI + // and save it to the destination path + h.cfg.Logger.Logger.Debug().Str("uri", h.uri).Str("target", filepath.Join(s, filepath.Base(h.uri))).Msg("Downloading system extension") + return h.cfg.Client.GetURL(types.NewNullLogger(), h.uri, filepath.Join(s, filepath.Base(h.uri))) +} + +type dockerSource struct { + uri string + cfg *config.Config +} + +func (d dockerSource) Download(s string) error { + // Download the file from the URI + // and save it to the destination path + err := d.cfg.ImageExtractor.ExtractImage(d.uri, s, "") + if err != nil { + return err + } + return nil +} diff --git a/pkg/action/sysext_test.go b/pkg/action/sysext_test.go new file mode 100644 index 0000000..94989c2 --- /dev/null +++ b/pkg/action/sysext_test.go @@ -0,0 +1,511 @@ +package action_test + +import ( + "bytes" + "fmt" + "github.com/kairos-io/kairos-agent/v2/pkg/action" + agentConfig "github.com/kairos-io/kairos-agent/v2/pkg/config" + v1mock "github.com/kairos-io/kairos-agent/v2/tests/mocks" + sdkTypes "github.com/kairos-io/kairos-sdk/types" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/twpayne/go-vfs/v5" + "github.com/twpayne/go-vfs/v5/vfst" +) + +var _ = Describe("Sysext Actions test", func() { + var config *agentConfig.Config + var runner *v1mock.FakeRunner + var fs vfs.FS + var logger sdkTypes.KairosLogger + var mounter *v1mock.ErrorMounter + var syscall *v1mock.FakeSyscall + var httpClient *v1mock.FakeHTTPClient + var cloudInit *v1mock.FakeCloudInitRunner + var cleanup func() + var memLog *bytes.Buffer + var extractor *v1mock.FakeImageExtractor + var err error + + BeforeEach(func() { + runner = v1mock.NewFakeRunner() + syscall = &v1mock.FakeSyscall{} + mounter = v1mock.NewErrorMounter() + httpClient = &v1mock.FakeHTTPClient{} + memLog = &bytes.Buffer{} + logger = sdkTypes.NewBufferLogger(memLog) + logger.SetLevel("debug") + extractor = v1mock.NewFakeImageExtractor(logger) + cloudInit = &v1mock.FakeCloudInitRunner{} + fs, cleanup, err = vfst.NewTestFS(map[string]interface{}{}) + Expect(err).ToNot(HaveOccurred()) + + err := vfs.MkdirAll(fs, "/var/lib/kairos/extensions", 0755) + Expect(err).ToNot(HaveOccurred()) + err = vfs.MkdirAll(fs, "/run/extensions", 0755) + Expect(err).ToNot(HaveOccurred()) + + // Config object with all of our fakes on it + config = agentConfig.NewConfig( + agentConfig.WithFs(fs), + agentConfig.WithRunner(runner), + agentConfig.WithLogger(logger), + agentConfig.WithMounter(mounter), + agentConfig.WithSyscall(syscall), + agentConfig.WithClient(httpClient), + agentConfig.WithCloudInitRunner(cloudInit), + agentConfig.WithImageExtractor(extractor), + agentConfig.WithPlatform("linux/amd64"), + ) + }) + + AfterEach(func() { + cleanup() + }) + + Describe("Listing extensions", func() { + It("should NOT fail if the bootstate is not valid", func() { + extensions, err := action.ListSystemExtensions(config, "invalid") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + Describe("With no dir", func() { + BeforeEach(func() { + err = config.Fs.RemoveAll("/var/lib/kairos/extensions") + Expect(err).ToNot(HaveOccurred()) + }) + AfterEach(func() { + cleanup() + }) + It("should return no extensions for installed extensions", func() { + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return no extensions for active enabled extensions", func() { + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return no extensions for passive enabled extensions", func() { + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + }) + Describe("With empty dir", func() { + It("should return no extensions", func() { + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return no extensions for active enabled extensions", func() { + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return no extensions for passive enabled extensions", func() { + extensions, err := action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + }) + Describe("With dir with files", func() { + It("should not return files that are not valid extensions", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/invalid", []byte("invalid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should return files that are valid extensions", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/valid.raw", + }, + })) + }) + It("should ONLY return files that are valid extensions", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/invalid", []byte("invalid"), 0644) + Expect(err).ToNot(HaveOccurred()) + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(len(extensions)).To(Equal(1)) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/valid.raw", + }, + })) + }) + }) + }) + Describe("Enabling extensions", func() { + It("should fail to enable a extension if bootState is not valid", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + err = action.EnableSystemExtension(config, "valid", "invalid", false) + Expect(err).To(HaveOccurred()) + }) + It("should enable an installed extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + // Passive should be empty + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for passive + err = action.EnableSystemExtension(config, "valid.raw", "passive", false) + Expect(err).ToNot(HaveOccurred()) + // Passive should have the extension + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/passive/valid.raw", + }, + })) + // Check active again to see if it is still there + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + + }) + It("should enable an installed extension and reload the system with it", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + // Fake the boot state + Expect(config.Fs.Mkdir("/run/cos", 0755)).ToNot(HaveOccurred()) + Expect(config.Fs.WriteFile("/run/cos/active_mode", []byte("true"), 0644)).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + // This basically returns an error if the command is not executed + Expect(runner.IncludesCmds([][]string{ + {"systemctl", "restart", "systemd-sysext"}, + })).To(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", true) + Expect(err).ToNot(HaveOccurred()) + Expect(runner.IncludesCmds([][]string{ + {"systemctl", "restart", "systemd-sysext"}, + })).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + // Passive should be empty + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Symlink should be created in /run/extensions + _, err = config.Fs.Stat("/run/extensions/valid.raw") + Expect(err).ToNot(HaveOccurred()) + readlink, err := config.Fs.Readlink("/run/extensions/valid.raw") + Expect(err).ToNot(HaveOccurred()) + // Get the raw path as the readlink will return the real path, not the one in our fake fs + realPath, err := config.Fs.RawPath("/var/lib/kairos/extensions/active/valid.raw") + Expect(err).ToNot(HaveOccurred()) + Expect(readlink).To(Equal(realPath)) + }) + It("should enable an installed extension and NOT reload the system with it if we are on the wrong boot state", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + // This basically returns an error if the command is not executed + Expect(runner.IncludesCmds([][]string{ + {"systemctl", "restart", "systemd-sysext"}, + })).To(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", true) + Expect(err).ToNot(HaveOccurred()) + Expect(runner.IncludesCmds([][]string{ + {"systemctl", "restart", "systemd-sysext"}, + })).To(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + // Passive should be empty + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Symlink should be created in /run/extensions + _, err = config.Fs.Stat("/run/extensions/valid.raw") + Expect(err).To(HaveOccurred()) + }) + It("should fail to enable a missing extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "invalid.raw", "active", false) + Expect(err).To(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Passive should be empty + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should not fail if the extension is already enabled", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + err = action.EnableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + }) + + }) + Describe("Disabling extensions", func() { + It("should fail if bootState is not valid", func() { + err := action.DisableSystemExtension(config, "whatever", "invalid", false) + Expect(err).To(HaveOccurred()) + }) + It("should disable an enabled extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + // Disable it + err = action.DisableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should not fail to disable a not enabled extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + // Disable a non enabled extension + err = action.DisableSystemExtension(config, "invalid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + }) + }) + Describe("Installing extensions", func() { + Describe("With a file source", func() { + It("should install a extension", func() { + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + err = config.Fs.WriteFile("/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + err = action.InstallSystemExtension(config, "file:///valid.raw") + Expect(err).ToNot(HaveOccurred(), memLog.String()) + // Check if the extension is installed + extensions, err = action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/valid.raw", + }, + })) + }) + It("should fail to install a missing extension", func() { + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + err = action.InstallSystemExtension(config, "file:///invalid.raw") + Expect(err).To(HaveOccurred(), memLog.String()) + // Check if the extension is installed + extensions, err = action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + }) + Describe("with a docker source", func() { + It("should install a extension", func() { + err = action.InstallSystemExtension(config, "docker://quay.io/valid:v1.0.0") + Expect(err).ToNot(HaveOccurred(), memLog.String()) + expectedCall := v1mock.ExtractCall{ImageRef: "quay.io/valid:v1.0.0", Destination: "/var/lib/kairos/extensions/", PlatformRef: ""} + Expect(extractor.WasCalledWithExtractCall(expectedCall)).To(BeTrue()) + }) + It("should fail to install a missing extension", func() { + extractor.SideEffect = func(imageRef, destination, platformRef string) error { + return fmt.Errorf("error") + } + err = action.InstallSystemExtension(config, "docker://quay.io/invalid:v1.0.0") + Expect(err).To(HaveOccurred(), memLog.String()) + expectedCall := v1mock.ExtractCall{ImageRef: "quay.io/invalid:v1.0.0", Destination: "/var/lib/kairos/extensions/", PlatformRef: ""} + Expect(extractor.WasCalledWithExtractCall(expectedCall)).To(BeTrue()) + }) + }) + Describe("with a http source", func() { + It("should install a extension", func() { + err = action.InstallSystemExtension(config, "http://localhost:8080/valid.raw") + Expect(err).ToNot(HaveOccurred(), memLog.String()) + Expect(httpClient.WasGetCalledWith("http://localhost:8080/valid.raw")).To(BeTrue()) + }) + It("should fail to install a missing extension", func() { + httpClient.Error = true + err = action.InstallSystemExtension(config, "http://localhost:8080/invalid.raw") + Expect(err).To(HaveOccurred()) + Expect(httpClient.WasGetCalledWith("http://localhost:8080/invalid.raw")).To(BeTrue()) + }) + }) + }) + Describe("Removing extensions", func() { + It("should remove an installed extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/valid.raw", + }, + })) + err = action.RemoveSystemExtension(config, "valid.raw", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should disable and remove an enabled extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Enable it for active + err = action.EnableSystemExtension(config, "valid.raw", "active", false) + Expect(err).ToNot(HaveOccurred()) + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/active/valid.raw", + }, + })) + err = action.RemoveSystemExtension(config, "valid.raw", false) + Expect(err).ToNot(HaveOccurred()) + // Check if it is removed from active + extensions, err = action.ListSystemExtensions(config, "active") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Check if it is removed from passive + extensions, err = action.ListSystemExtensions(config, "passive") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + // Check if it is removed from installed + extensions, err = action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(BeEmpty()) + }) + It("should fail to remove a missing extension", func() { + err = config.Fs.WriteFile("/var/lib/kairos/extensions/valid.raw", []byte("valid"), 0644) + Expect(err).ToNot(HaveOccurred()) + extensions, err := action.ListSystemExtensions(config, "") + Expect(err).ToNot(HaveOccurred()) + Expect(extensions).To(Equal([]action.SysExtension{ + { + Name: "valid.raw", + Location: "/var/lib/kairos/extensions/valid.raw", + }, + })) + err = action.RemoveSystemExtension(config, "invalid.raw", false) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/pkg/types/v1/fs.go b/pkg/types/v1/fs.go index d94dfe8..ee1476a 100644 --- a/pkg/types/v1/fs.go +++ b/pkg/types/v1/fs.go @@ -31,6 +31,7 @@ type FS interface { RemoveAll(path string) error ReadFile(filename string) ([]byte, error) Readlink(name string) (string, error) + Symlink(oldname, newname string) error RawPath(name string) (string, error) ReadDir(dirname string) ([]fs.DirEntry, error) Remove(name string) error diff --git a/tests/mocks/extractor_mock.go b/tests/mocks/extractor_mock.go index fb8b264..86d8a98 100644 --- a/tests/mocks/extractor_mock.go +++ b/tests/mocks/extractor_mock.go @@ -22,15 +22,22 @@ import ( ) type FakeImageExtractor struct { - Logger sdkTypes.KairosLogger - SideEffect func(imageRef, destination, platformRef string) error + Logger sdkTypes.KairosLogger + SideEffect func(imageRef, destination, platformRef string) error + ClientCalls []ExtractCall } -func (f FakeImageExtractor) GetOCIImageSize(imageRef, platformRef string) (int64, error) { +type ExtractCall struct { + ImageRef string + Destination string + PlatformRef string +} + +func (f *FakeImageExtractor) GetOCIImageSize(imageRef, platformRef string) (int64, error) { return 0, nil } -var _ v1.ImageExtractor = FakeImageExtractor{} +var _ v1.ImageExtractor = &FakeImageExtractor{} func NewFakeImageExtractor(logger sdkTypes.KairosLogger) *FakeImageExtractor { l := logger @@ -42,8 +49,9 @@ func NewFakeImageExtractor(logger sdkTypes.KairosLogger) *FakeImageExtractor { } } -func (f FakeImageExtractor) ExtractImage(imageRef, destination, platformRef string) error { +func (f *FakeImageExtractor) ExtractImage(imageRef, destination, platformRef string) error { f.Logger.Debugf("extracting %s to %s in platform %s", imageRef, destination, platformRef) + f.ClientCalls = append(f.ClientCalls, ExtractCall{ImageRef: imageRef, Destination: destination, PlatformRef: platformRef}) if f.SideEffect != nil { f.Logger.Debugf("running side effect") return f.SideEffect(imageRef, destination, platformRef) @@ -51,3 +59,24 @@ func (f FakeImageExtractor) ExtractImage(imageRef, destination, platformRef stri return nil } + +// WasCalledWithImageRef is a helper method to confirm that the client was called with the given image ref +func (f *FakeImageExtractor) WasCalledWithImageRef(imageRef string) bool { + for _, c := range f.ClientCalls { + if c.ImageRef == imageRef { + return true + } + } + return false +} + +// WasCalledWithExtractCall is a helper method to confirm that the client was called with the given extract call +// This matches exactly the calls made to the client in all fields +func (f *FakeImageExtractor) WasCalledWithExtractCall(call ExtractCall) bool { + for _, c := range f.ClientCalls { + if c == call { + return true + } + } + return false +}