diff --git a/main.go b/main.go index debc97c..3cb8837 100644 --- a/main.go +++ b/main.go @@ -279,11 +279,17 @@ func injectInitrd(initrd string, file, dst string) error { func unlockAll() error { bus.Manager.Initialize() + partitionInfo, err := pi.NewPartitionInfoFromFile(pi.DefaultPartitionInfoFile) + if err != nil { + return err + } + block, err := ghw.Block() if err == nil { for _, disk := range block.Disks { for _, p := range disk.Partitions { if p.Type == "crypto_LUKS" { + p.Label = partitionInfo.LookupLabelForUUID(p.UUID) fmt.Printf("Unmounted Luks found at '%s' LABEL '%s' \n", p.Name, p.Label) err = multierror.Append(err, unlockDisk(p)) if err != nil { diff --git a/pkg/partition_info/partition_info.go b/pkg/partition_info/partition_info.go index 3bbf986..304f4c1 100644 --- a/pkg/partition_info/partition_info.go +++ b/pkg/partition_info/partition_info.go @@ -38,6 +38,16 @@ func (pi PartitionInfo) LookupUUIDForLabel(l string) string { return pi.mapping[l] } +func (pi PartitionInfo) LookupLabelForUUID(uuid string) string { + for k, v := range pi.mapping { + if v == uuid { + return k + } + } + + return "" +} + // UpdatePartitionLabelMapping takes partition information as a string argument // the the form: `label:name:uuid` (that's what the `kcrypt encrypt` command returns // on success. This function stores it in the PartitionInfoFile yaml file for diff --git a/pkg/partition_info/partition_info_test.go b/pkg/partition_info/partition_info_test.go index 1a1e22d..b77483b 100644 --- a/pkg/partition_info/partition_info_test.go +++ b/pkg/partition_info/partition_info_test.go @@ -121,9 +121,36 @@ TO_KEEP: old_uuid_1 Expect(err).ToNot(HaveOccurred()) }) - It("parses the file correctly", func() { + It("returns the correct UUID", func() { uuid := partitionInfo.LookupUUIDForLabel("COS_PERSISTENT") Expect(uuid).To(Equal("some_uuid_1")) }) + + It("returns an empty UUID when the label is not found", func() { + uuid := partitionInfo.LookupUUIDForLabel("DOESNT_EXIST") + Expect(uuid).To(Equal("")) + }) + }) + + Describe("LookupLabelForUUID", func() { + var file string + var partitionInfo *pi.PartitionInfo + var err error + + BeforeEach(func() { + file = "../../tests/assets/partition_info.yaml" + partitionInfo, err = pi.NewPartitionInfoFromFile(file) + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns the correct label", func() { + uuid := partitionInfo.LookupLabelForUUID("some_uuid_1") + Expect(uuid).To(Equal("COS_PERSISTENT")) + }) + + It("returns an empty label when UUID doesn't exist", func() { + uuid := partitionInfo.LookupLabelForUUID("doesnt_exist") + Expect(uuid).To(Equal("")) + }) }) })