kcrypt/pkg/lib/initrd.go

105 lines
2.4 KiB
Go
Raw Normal View History

package lib
import (
"fmt"
"os"
"path/filepath"
"strings"
cp "github.com/otiai10/copy"
)
const (
GZType = "gz"
XZType = "xz"
LZMAType = "lzma"
)
func createInitrd(initrd string, src string, format string) error {
fmt.Printf("Creating '%s' from '%s' as '%s'\n", initrd, src, format)
if _, err := os.Stat(src); err != nil {
return err
}
var err error
var out string
if format == XZType {
out, err = SH(fmt.Sprintf("cd %s && find . 2>/dev/null | cpio -H newc --quiet --null -o -R root:root | xz -0 --check=crc32 > %s", src, initrd))
} else if format == GZType {
out, err = SH(fmt.Sprintf("cd %s && find . | cpio -H newc -o -R root:root | gzip -9 > %s", src, initrd))
} else if format == LZMAType {
out, err = SH(fmt.Sprintf("cd %s && find . 2>/dev/null | cpio -H newc -o -R root:root | xz -9 --format=lzma > %s", src, initrd))
}
fmt.Println(out)
return err
}
func InjectInitrd(initrd string, file, dst string) error {
fmt.Printf("Injecting '%s' as '%s' into '%s'\n", file, dst, initrd)
format, err := detect(initrd)
if err != nil {
return err
}
tmp, err := os.MkdirTemp("", "kcrypt")
if err != nil {
return fmt.Errorf("cannot create tempdir, %s", err)
}
defer os.RemoveAll(tmp)
fmt.Printf("Extracting '%s' in '%s' ...\n", initrd, tmp)
if err := ExtractInitrd(initrd, tmp); err != nil {
return fmt.Errorf("cannot extract initrd, %s", err)
}
d := filepath.Join(tmp, dst)
fmt.Printf("Copying '%s' in '%s' ...\n", file, d)
if err := cp.Copy(file, d); err != nil {
return fmt.Errorf("cannot copy file, %s", err)
}
return createInitrd(initrd, tmp, format)
}
func ExtractInitrd(initrd string, dst string) error {
var out string
var err error
err = os.MkdirAll(dst, os.ModePerm)
if err != nil {
return err
}
format, err := detect(initrd)
if err != nil {
return err
}
if format == XZType || format == LZMAType {
out, err = SH(fmt.Sprintf("cd %s && xz -dc < %s | cpio -idmv", dst, initrd))
} else if format == GZType {
out, err = SH(fmt.Sprintf("cd %s && zcat %s | cpio -idmv", dst, initrd))
}
fmt.Println(out)
return err
}
func detect(archive string) (string, error) {
out, err := SH(fmt.Sprintf("file %s", archive))
if err != nil {
return "", err
}
out = strings.ToLower(out)
if strings.Contains(out, "xz") {
return XZType, nil
} else if strings.Contains(out, "lzma") {
return LZMAType, nil
} else if strings.Contains(out, "gz") {
return GZType, nil
}
return "", fmt.Errorf("Unknown")
}