From 94af8c8dd323fa071689d40e5fd4664dd587fe41 Mon Sep 17 00:00:00 2001
From: Itxaka <itxaka@kairos.io>
Date: Thu, 30 Nov 2023 11:39:21 +0100
Subject: [PATCH] Also unlock with TPM

so we can use the same functions everywhere just with a flag

Signed-off-by: Itxaka <itxaka@kairos.io>
---
 main.go           | 10 ++++++++--
 pkg/lib/unlock.go | 16 ++++++++++++----
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/main.go b/main.go
index 99d40c1..f2bdec1 100644
--- a/main.go
+++ b/main.go
@@ -55,9 +55,15 @@ func main() {
 				UsageText:   "unlock-all",
 				Usage:       "Try to unlock all LUKS partitions",
 				Description: "Typically run during initrd to unlock all the LUKS partitions found",
-				ArgsUsage:   "kcrypt unlock-all",
+				ArgsUsage:   "kcrypt [--tpm] unlock-all",
+				Flags: []cli.Flag{
+					&cli.BoolFlag{
+						Name:  "tpm",
+						Usage: "Use TPM to unlock the partition",
+					},
+				},
 				Action: func(c *cli.Context) error {
-					return lib.UnlockAll()
+					return lib.UnlockAll(c.Bool("tpm"))
 				},
 			},
 			{
diff --git a/pkg/lib/unlock.go b/pkg/lib/unlock.go
index 6913c4c..f9b612d 100644
--- a/pkg/lib/unlock.go
+++ b/pkg/lib/unlock.go
@@ -15,7 +15,7 @@ import (
 )
 
 // UnlockAll Unlocks all encrypted devices found in the system
-func UnlockAll() error {
+func UnlockAll(tpm bool) error {
 	bus.Manager.Initialize()
 
 	config, err := configpkg.GetConfiguration(configpkg.ConfigScanDirs)
@@ -52,9 +52,17 @@ func UnlockAll() error {
 				// We mount it under /dev/mapper/DEVICE, so It's pretty easy to check
 				if !utils.Exists(filepath.Join("/dev", "mapper", p.Name)) {
 					fmt.Printf("Unmounted Luks found at '%s' LABEL '%s' \n", filepath.Join("/dev", p.Name), p.FilesystemLabel)
-					err = UnlockDisk(p)
-					if err != nil {
-						fmt.Printf("Unlocking failed: '%s'\n", err.Error())
+					if tpm {
+						out, err := utils.SH(fmt.Sprintf("/usr/lib/systemd/systemd-cryptsetup attach %s %s - tpm2-device=auto", p.Name, filepath.Join("/dev", p.Name)))
+						if err != nil {
+							fmt.Printf("Unlocking failed: '%s'\n", err.Error())
+							fmt.Printf("Unlocking failed, command output: '%s'\n", out)
+						}
+					} else {
+						err = UnlockDisk(p)
+						if err != nil {
+							fmt.Printf("Unlocking failed: '%s'\n", err.Error())
+						}
 					}
 				} else {
 					fmt.Printf("Device %s seems to be mounted at %s, skipping\n", filepath.Join("/dev", p.Name), filepath.Join("/dev", "mapper", p.Name))