From caedb1ef7f813ab84de1d041b85cf5e105b66f7e Mon Sep 17 00:00:00 2001 From: Dimitris Karakasilis Date: Wed, 24 Sep 2025 13:04:13 +0300 Subject: [PATCH] Avoid global vars Signed-off-by: Dimitris Karakasilis --- cmd/discovery/main.go | 98 ++++++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 44 deletions(-) diff --git a/cmd/discovery/main.go b/cmd/discovery/main.go index 023fbd6..63bd2a5 100644 --- a/cmd/discovery/main.go +++ b/cmd/discovery/main.go @@ -14,16 +14,20 @@ import ( "github.com/spf13/cobra" ) +// GetFlags holds all flags specific to the get command +type GetFlags struct { + PartitionName string + PartitionUUID string + PartitionLabel string + Attempts int + ChallengerServer string + EnableMDNS bool + ServerCertificate string +} + var ( - // Global flags for the get subcommand (passphrase retrieval) - partitionName string - partitionUUID string - partitionLabel string - attempts int - challengerServer string - enableMDNS bool - serverCertificate string - debug bool + // Global/persistent flags + debug bool ) // rootCmd represents the base command (TPM hash generation) @@ -63,11 +67,14 @@ Configuration: }, } -// getCmd represents the get command (passphrase retrieval) -var getCmd = &cobra.Command{ - Use: "get", - Short: "Get passphrase for encrypted partition", - Long: `Get passphrase for encrypted partition using TPM attestation. +// newGetCmd creates the get command with its flags +func newGetCmd() *cobra.Command { + flags := &GetFlags{} + + cmd := &cobra.Command{ + Use: "get", + Short: "Get passphrase for encrypted partition", + Long: `Get passphrase for encrypted partition using TPM attestation. This command retrieves passphrases for encrypted partitions by communicating with a challenger server using TPM-based attestation. At least one partition @@ -78,7 +85,7 @@ can override specific settings: --challenger-server Override kcrypt.challenger.challenger_server --mdns Override kcrypt.challenger.mdns --certificate Override kcrypt.challenger.certificate`, - Example: ` # Get passphrase using partition name + Example: ` # Get passphrase using partition name kcrypt-discovery-challenger get --partition-name=/dev/sda2 # Get passphrase using UUID @@ -92,16 +99,28 @@ can override specific settings: # Get passphrase with custom server kcrypt-discovery-challenger get --partition-label=encrypted-data --challenger-server=https://my-server.com:8082`, - PreRunE: func(cmd *cobra.Command, args []string) error { - // Validate that at least one partition identifier is provided - if partitionName == "" && partitionUUID == "" && partitionLabel == "" { - return fmt.Errorf("at least one of --partition-name, --partition-uuid, or --partition-label must be provided") - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - return runGetPassphrase() - }, + PreRunE: func(cmd *cobra.Command, args []string) error { + // Validate that at least one partition identifier is provided + if flags.PartitionName == "" && flags.PartitionUUID == "" && flags.PartitionLabel == "" { + return fmt.Errorf("at least one of --partition-name, --partition-uuid, or --partition-label must be provided") + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + return runGetPassphrase(flags) + }, + } + + // Register flags + cmd.Flags().StringVar(&flags.PartitionName, "partition-name", "", "Name of the partition (at least one identifier required)") + cmd.Flags().StringVar(&flags.PartitionUUID, "partition-uuid", "", "UUID of the partition (at least one identifier required)") + cmd.Flags().StringVar(&flags.PartitionLabel, "partition-label", "", "Filesystem label of the partition (at least one identifier required)") + cmd.Flags().IntVar(&flags.Attempts, "attempts", 30, "Number of attempts to get the passphrase") + cmd.Flags().StringVar(&flags.ChallengerServer, "challenger-server", "", "URL of the challenger server (overrides config)") + cmd.Flags().BoolVar(&flags.EnableMDNS, "mdns", false, "Enable mDNS discovery (overrides config)") + cmd.Flags().StringVar(&flags.ServerCertificate, "certificate", "", "Server certificate for verification (overrides config)") + + return cmd } // pluginCmd represents the plugin event commands @@ -121,20 +140,11 @@ with kcrypt and other tools.`, bus.EventDiscoveryPassword), } func init() { - // Global flags (available to all commands) + // Global/persistent flags (available to all commands) rootCmd.PersistentFlags().BoolVar(&debug, "debug", false, "Enable debug logging") - // Get command flags (for passphrase retrieval) - getCmd.Flags().StringVar(&partitionName, "partition-name", "", "Name of the partition (at least one identifier required)") - getCmd.Flags().StringVar(&partitionUUID, "partition-uuid", "", "UUID of the partition (at least one identifier required)") - getCmd.Flags().StringVar(&partitionLabel, "partition-label", "", "Filesystem label of the partition (at least one identifier required)") - getCmd.Flags().IntVar(&attempts, "attempts", 30, "Number of attempts to get the passphrase") - getCmd.Flags().StringVar(&challengerServer, "challenger-server", "", "URL of the challenger server (overrides config)") - getCmd.Flags().BoolVar(&enableMDNS, "mdns", false, "Enable mDNS discovery (overrides config)") - getCmd.Flags().StringVar(&serverCertificate, "certificate", "", "Server certificate for verification (overrides config)") - // Add subcommands - rootCmd.AddCommand(getCmd) + rootCmd.AddCommand(newGetCmd()) rootCmd.AddCommand(pluginCmd) } @@ -193,7 +203,7 @@ func runTPMHash() error { } // runGetPassphrase handles the get subcommand - passphrase retrieval -func runGetPassphrase() error { +func runGetPassphrase(flags *GetFlags) error { // Create logger based on debug flag var logger types.KairosLogger if debug { @@ -203,16 +213,16 @@ func runGetPassphrase() error { } // Create client with potential CLI overrides - c, err := createClientWithOverrides(challengerServer, enableMDNS, serverCertificate, logger) + c, err := createClientWithOverrides(flags.ChallengerServer, flags.EnableMDNS, flags.ServerCertificate, logger) if err != nil { return fmt.Errorf("creating client: %w", err) } // Create partition object partition := &block.Partition{ - Name: partitionName, - UUID: partitionUUID, - FilesystemLabel: partitionLabel, + Name: flags.PartitionName, + UUID: flags.PartitionUUID, + FilesystemLabel: flags.PartitionLabel, } // Log partition information @@ -220,13 +230,13 @@ func runGetPassphrase() error { logger.Debugf(" Name: %s", partition.Name) logger.Debugf(" UUID: %s", partition.UUID) logger.Debugf(" Label: %s", partition.FilesystemLabel) - logger.Debugf(" Attempts: %d", attempts) + logger.Debugf(" Attempts: %d", flags.Attempts) // Get the passphrase using the same backend logic as the plugin fmt.Fprintf(os.Stderr, "Requesting passphrase for partition %s (UUID: %s, Label: %s)...\n", - partitionName, partitionUUID, partitionLabel) + flags.PartitionName, flags.PartitionUUID, flags.PartitionLabel) - passphrase, err := c.GetPassphrase(partition, attempts) + passphrase, err := c.GetPassphrase(partition, flags.Attempts) if err != nil { return fmt.Errorf("getting passphrase: %w", err) }