mirror of
https://github.com/kairos-io/kcrypt.git
synced 2025-04-28 03:30:07 +00:00
105 lines
2.4 KiB
Go
105 lines
2.4 KiB
Go
package lib
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
cp "github.com/otiai10/copy"
|
|
)
|
|
|
|
const (
|
|
GZType = "gz"
|
|
XZType = "xz"
|
|
LZMAType = "lzma"
|
|
)
|
|
|
|
func createInitrd(initrd string, src string, format string) error {
|
|
fmt.Printf("Creating '%s' from '%s' as '%s'\n", initrd, src, format)
|
|
|
|
if _, err := os.Stat(src); err != nil {
|
|
return err
|
|
}
|
|
var err error
|
|
var out string
|
|
if format == XZType {
|
|
out, err = SH(fmt.Sprintf("cd %s && find . 2>/dev/null | cpio -H newc --quiet --null -o -R root:root | xz -0 --check=crc32 > %s", src, initrd))
|
|
} else if format == GZType {
|
|
out, err = SH(fmt.Sprintf("cd %s && find . | cpio -H newc -o -R root:root | gzip -9 > %s", src, initrd))
|
|
} else if format == LZMAType {
|
|
out, err = SH(fmt.Sprintf("cd %s && find . 2>/dev/null | cpio -H newc -o -R root:root | xz -9 --format=lzma > %s", src, initrd))
|
|
}
|
|
fmt.Println(out)
|
|
|
|
return err
|
|
}
|
|
|
|
func InjectInitrd(initrd string, file, dst string) error {
|
|
fmt.Printf("Injecting '%s' as '%s' into '%s'\n", file, dst, initrd)
|
|
format, err := detect(initrd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tmp, err := os.MkdirTemp("", "kcrypt")
|
|
if err != nil {
|
|
return fmt.Errorf("cannot create tempdir, %s", err)
|
|
}
|
|
defer os.RemoveAll(tmp)
|
|
|
|
fmt.Printf("Extracting '%s' in '%s' ...\n", initrd, tmp)
|
|
if err := ExtractInitrd(initrd, tmp); err != nil {
|
|
return fmt.Errorf("cannot extract initrd, %s", err)
|
|
}
|
|
|
|
d := filepath.Join(tmp, dst)
|
|
fmt.Printf("Copying '%s' in '%s' ...\n", file, d)
|
|
if err := cp.Copy(file, d); err != nil {
|
|
return fmt.Errorf("cannot copy file, %s", err)
|
|
}
|
|
|
|
return createInitrd(initrd, tmp, format)
|
|
}
|
|
|
|
func ExtractInitrd(initrd string, dst string) error {
|
|
var out string
|
|
var err error
|
|
err = os.MkdirAll(dst, os.ModePerm)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
format, err := detect(initrd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if format == XZType || format == LZMAType {
|
|
out, err = SH(fmt.Sprintf("cd %s && xz -dc < %s | cpio -idmv", dst, initrd))
|
|
} else if format == GZType {
|
|
out, err = SH(fmt.Sprintf("cd %s && zcat %s | cpio -idmv", dst, initrd))
|
|
}
|
|
fmt.Println(out)
|
|
|
|
return err
|
|
}
|
|
|
|
func detect(archive string) (string, error) {
|
|
out, err := SH(fmt.Sprintf("file %s", archive))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
out = strings.ToLower(out)
|
|
if strings.Contains(out, "xz") {
|
|
return XZType, nil
|
|
|
|
} else if strings.Contains(out, "lzma") {
|
|
return LZMAType, nil
|
|
|
|
} else if strings.Contains(out, "gz") {
|
|
return GZType, nil
|
|
|
|
}
|
|
|
|
return "", fmt.Errorf("Unknown")
|
|
}
|