Avoid global vars

Signed-off-by: Dimitris Karakasilis <dimitris@karakasilis.me>
This commit is contained in:
Dimitris Karakasilis
2025-09-24 13:04:13 +03:00
parent 55a0d62231
commit caedb1ef7f

View File

@@ -14,16 +14,20 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// GetFlags holds all flags specific to the get command
type GetFlags struct {
PartitionName string
PartitionUUID string
PartitionLabel string
Attempts int
ChallengerServer string
EnableMDNS bool
ServerCertificate string
}
var ( var (
// Global flags for the get subcommand (passphrase retrieval) // Global/persistent flags
partitionName string debug bool
partitionUUID string
partitionLabel string
attempts int
challengerServer string
enableMDNS bool
serverCertificate string
debug bool
) )
// rootCmd represents the base command (TPM hash generation) // rootCmd represents the base command (TPM hash generation)
@@ -63,11 +67,14 @@ Configuration:
}, },
} }
// getCmd represents the get command (passphrase retrieval) // newGetCmd creates the get command with its flags
var getCmd = &cobra.Command{ func newGetCmd() *cobra.Command {
Use: "get", flags := &GetFlags{}
Short: "Get passphrase for encrypted partition",
Long: `Get passphrase for encrypted partition using TPM attestation. cmd := &cobra.Command{
Use: "get",
Short: "Get passphrase for encrypted partition",
Long: `Get passphrase for encrypted partition using TPM attestation.
This command retrieves passphrases for encrypted partitions by communicating This command retrieves passphrases for encrypted partitions by communicating
with a challenger server using TPM-based attestation. At least one partition with a challenger server using TPM-based attestation. At least one partition
@@ -78,7 +85,7 @@ can override specific settings:
--challenger-server Override kcrypt.challenger.challenger_server --challenger-server Override kcrypt.challenger.challenger_server
--mdns Override kcrypt.challenger.mdns --mdns Override kcrypt.challenger.mdns
--certificate Override kcrypt.challenger.certificate`, --certificate Override kcrypt.challenger.certificate`,
Example: ` # Get passphrase using partition name Example: ` # Get passphrase using partition name
kcrypt-discovery-challenger get --partition-name=/dev/sda2 kcrypt-discovery-challenger get --partition-name=/dev/sda2
# Get passphrase using UUID # Get passphrase using UUID
@@ -92,16 +99,28 @@ can override specific settings:
# Get passphrase with custom server # Get passphrase with custom server
kcrypt-discovery-challenger get --partition-label=encrypted-data --challenger-server=https://my-server.com:8082`, kcrypt-discovery-challenger get --partition-label=encrypted-data --challenger-server=https://my-server.com:8082`,
PreRunE: func(cmd *cobra.Command, args []string) error { PreRunE: func(cmd *cobra.Command, args []string) error {
// Validate that at least one partition identifier is provided // Validate that at least one partition identifier is provided
if partitionName == "" && partitionUUID == "" && partitionLabel == "" { if flags.PartitionName == "" && flags.PartitionUUID == "" && flags.PartitionLabel == "" {
return fmt.Errorf("at least one of --partition-name, --partition-uuid, or --partition-label must be provided") return fmt.Errorf("at least one of --partition-name, --partition-uuid, or --partition-label must be provided")
} }
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return runGetPassphrase() return runGetPassphrase(flags)
}, },
}
// Register flags
cmd.Flags().StringVar(&flags.PartitionName, "partition-name", "", "Name of the partition (at least one identifier required)")
cmd.Flags().StringVar(&flags.PartitionUUID, "partition-uuid", "", "UUID of the partition (at least one identifier required)")
cmd.Flags().StringVar(&flags.PartitionLabel, "partition-label", "", "Filesystem label of the partition (at least one identifier required)")
cmd.Flags().IntVar(&flags.Attempts, "attempts", 30, "Number of attempts to get the passphrase")
cmd.Flags().StringVar(&flags.ChallengerServer, "challenger-server", "", "URL of the challenger server (overrides config)")
cmd.Flags().BoolVar(&flags.EnableMDNS, "mdns", false, "Enable mDNS discovery (overrides config)")
cmd.Flags().StringVar(&flags.ServerCertificate, "certificate", "", "Server certificate for verification (overrides config)")
return cmd
} }
// pluginCmd represents the plugin event commands // pluginCmd represents the plugin event commands
@@ -121,20 +140,11 @@ with kcrypt and other tools.`, bus.EventDiscoveryPassword),
} }
func init() { func init() {
// Global flags (available to all commands) // Global/persistent flags (available to all commands)
rootCmd.PersistentFlags().BoolVar(&debug, "debug", false, "Enable debug logging") rootCmd.PersistentFlags().BoolVar(&debug, "debug", false, "Enable debug logging")
// Get command flags (for passphrase retrieval)
getCmd.Flags().StringVar(&partitionName, "partition-name", "", "Name of the partition (at least one identifier required)")
getCmd.Flags().StringVar(&partitionUUID, "partition-uuid", "", "UUID of the partition (at least one identifier required)")
getCmd.Flags().StringVar(&partitionLabel, "partition-label", "", "Filesystem label of the partition (at least one identifier required)")
getCmd.Flags().IntVar(&attempts, "attempts", 30, "Number of attempts to get the passphrase")
getCmd.Flags().StringVar(&challengerServer, "challenger-server", "", "URL of the challenger server (overrides config)")
getCmd.Flags().BoolVar(&enableMDNS, "mdns", false, "Enable mDNS discovery (overrides config)")
getCmd.Flags().StringVar(&serverCertificate, "certificate", "", "Server certificate for verification (overrides config)")
// Add subcommands // Add subcommands
rootCmd.AddCommand(getCmd) rootCmd.AddCommand(newGetCmd())
rootCmd.AddCommand(pluginCmd) rootCmd.AddCommand(pluginCmd)
} }
@@ -193,7 +203,7 @@ func runTPMHash() error {
} }
// runGetPassphrase handles the get subcommand - passphrase retrieval // runGetPassphrase handles the get subcommand - passphrase retrieval
func runGetPassphrase() error { func runGetPassphrase(flags *GetFlags) error {
// Create logger based on debug flag // Create logger based on debug flag
var logger types.KairosLogger var logger types.KairosLogger
if debug { if debug {
@@ -203,16 +213,16 @@ func runGetPassphrase() error {
} }
// Create client with potential CLI overrides // Create client with potential CLI overrides
c, err := createClientWithOverrides(challengerServer, enableMDNS, serverCertificate, logger) c, err := createClientWithOverrides(flags.ChallengerServer, flags.EnableMDNS, flags.ServerCertificate, logger)
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
// Create partition object // Create partition object
partition := &block.Partition{ partition := &block.Partition{
Name: partitionName, Name: flags.PartitionName,
UUID: partitionUUID, UUID: flags.PartitionUUID,
FilesystemLabel: partitionLabel, FilesystemLabel: flags.PartitionLabel,
} }
// Log partition information // Log partition information
@@ -220,13 +230,13 @@ func runGetPassphrase() error {
logger.Debugf(" Name: %s", partition.Name) logger.Debugf(" Name: %s", partition.Name)
logger.Debugf(" UUID: %s", partition.UUID) logger.Debugf(" UUID: %s", partition.UUID)
logger.Debugf(" Label: %s", partition.FilesystemLabel) logger.Debugf(" Label: %s", partition.FilesystemLabel)
logger.Debugf(" Attempts: %d", attempts) logger.Debugf(" Attempts: %d", flags.Attempts)
// Get the passphrase using the same backend logic as the plugin // Get the passphrase using the same backend logic as the plugin
fmt.Fprintf(os.Stderr, "Requesting passphrase for partition %s (UUID: %s, Label: %s)...\n", fmt.Fprintf(os.Stderr, "Requesting passphrase for partition %s (UUID: %s, Label: %s)...\n",
partitionName, partitionUUID, partitionLabel) flags.PartitionName, flags.PartitionUUID, flags.PartitionLabel)
passphrase, err := c.GetPassphrase(partition, attempts) passphrase, err := c.GetPassphrase(partition, flags.Attempts)
if err != nil { if err != nil {
return fmt.Errorf("getting passphrase: %w", err) return fmt.Errorf("getting passphrase: %w", err)
} }