package lib

import (
	"fmt"
	"os"
	"path/filepath"
	"strings"

	cp "github.com/otiai10/copy"
)

const (
	GZType   = "gz"
	XZType   = "xz"
	LZMAType = "lzma"
)

func createInitrd(initrd string, src string, format string) error {
	fmt.Printf("Creating '%s' from '%s' as '%s'\n", initrd, src, format)

	if _, err := os.Stat(src); err != nil {
		return err
	}
	var err error
	var out string
	if format == XZType {
		out, err = SH(fmt.Sprintf("cd %s && find . 2>/dev/null | cpio -H newc --quiet --null -o -R root:root | xz -0 --check=crc32 > %s", src, initrd))
	} else if format == GZType {
		out, err = SH(fmt.Sprintf("cd %s && find . | cpio -H newc -o -R root:root | gzip -9 > %s", src, initrd))
	} else if format == LZMAType {
		out, err = SH(fmt.Sprintf("cd %s && find . 2>/dev/null | cpio -H newc -o -R root:root | xz -9 --format=lzma > %s", src, initrd))
	}
	fmt.Println(out)

	return err
}

func InjectInitrd(initrd string, file, dst string) error {
	fmt.Printf("Injecting '%s' as '%s' into '%s'\n", file, dst, initrd)
	format, err := detect(initrd)
	if err != nil {
		return err
	}
	tmp, err := os.MkdirTemp("", "kcrypt")
	if err != nil {
		return fmt.Errorf("cannot create tempdir, %s", err)
	}
	defer os.RemoveAll(tmp)

	fmt.Printf("Extracting '%s' in '%s' ...\n", initrd, tmp)
	if err := ExtractInitrd(initrd, tmp); err != nil {
		return fmt.Errorf("cannot extract initrd, %s", err)
	}

	d := filepath.Join(tmp, dst)
	fmt.Printf("Copying '%s' in '%s' ...\n", file, d)
	if err := cp.Copy(file, d); err != nil {
		return fmt.Errorf("cannot copy file, %s", err)
	}

	return createInitrd(initrd, tmp, format)
}

func ExtractInitrd(initrd string, dst string) error {
	var out string
	var err error
	err = os.MkdirAll(dst, os.ModePerm)
	if err != nil {
		return err
	}

	format, err := detect(initrd)
	if err != nil {
		return err
	}
	if format == XZType || format == LZMAType {
		out, err = SH(fmt.Sprintf("cd %s && xz -dc <  %s | cpio -idmv", dst, initrd))
	} else if format == GZType {
		out, err = SH(fmt.Sprintf("cd %s && zcat %s | cpio -idmv", dst, initrd))
	}
	fmt.Println(out)

	return err
}

func detect(archive string) (string, error) {
	out, err := SH(fmt.Sprintf("file %s", archive))
	if err != nil {
		return "", err
	}
	out = strings.ToLower(out)
	if strings.Contains(out, "xz") {
		return XZType, nil

	} else if strings.Contains(out, "lzma") {
		return LZMAType, nil

	} else if strings.Contains(out, "gz") {
		return GZType, nil

	}

	return "", fmt.Errorf("Unknown")
}