From 43a410b9dd12f4ec320668682c3830afa5992f84 Mon Sep 17 00:00:00 2001 From: "M. Mert Yildiran" <me@mertyildiran.com> Date: Wed, 16 Apr 2025 20:28:21 +0300 Subject: [PATCH] Add `--config-path` flag to root command (#1744) * Add `--config-path` flag to root command * Use `filepath.Abs` --------- Co-authored-by: Alon Girmonsky <1990761+alongir@users.noreply.github.com> --- cmd/config.go | 3 +-- cmd/root.go | 1 + config/config.go | 49 ++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/cmd/config.go b/cmd/config.go index 74292d97b..715dea371 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -2,7 +2,6 @@ package cmd import ( "fmt" - "path" "github.com/creasty/defaults" "github.com/kubeshark/kubeshark/config" @@ -52,5 +51,5 @@ func init() { log.Debug().Err(err).Send() } - configCmd.Flags().BoolP(configStructs.RegenerateConfigName, "r", defaultConfig.Config.Regenerate, fmt.Sprintf("Regenerate the config file with default values to path %s", path.Join(misc.GetDotFolderPath(), "config.yaml"))) + configCmd.Flags().BoolP(configStructs.RegenerateConfigName, "r", defaultConfig.Config.Regenerate, fmt.Sprintf("Regenerate the config file with default values to path %s", config.GetConfigFilePath(nil))) } diff --git a/cmd/root.go b/cmd/root.go index 0f9e2c646..178cad30c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -33,6 +33,7 @@ func init() { rootCmd.PersistentFlags().StringSlice(config.SetCommandName, []string{}, fmt.Sprintf("Override values using --%s", config.SetCommandName)) rootCmd.PersistentFlags().BoolP(config.DebugFlag, "d", false, "Enable debug mode") + rootCmd.PersistentFlags().String(config.ConfigPathFlag, "", fmt.Sprintf("Set the config path, default: %s", config.GetConfigFilePath(nil))) } // Execute adds all child commands to the root command and sets flags appropriately. diff --git a/config/config.go b/config/config.go index 953b5e7a0..0d6d6f058 100644 --- a/config/config.go +++ b/config/config.go @@ -28,6 +28,7 @@ const ( FieldNameTag = "yaml" ReadonlyTag = "readonly" DebugFlag = "debug" + ConfigPathFlag = "config-path" ) var ( @@ -82,7 +83,7 @@ func InitConfig(cmd *cobra.Command) error { return err } - ConfigFilePath = path.Join(misc.GetDotFolderPath(), "config.yaml") + ConfigFilePath = GetConfigFilePath(cmd) if err := loadConfigFile(&Config, utils.Contains([]string{ "manifests", "license", @@ -134,21 +135,44 @@ func WriteConfig(config *ConfigStruct) error { return nil } -func loadConfigFile(config *ConfigStruct, silent bool) error { +func GetConfigFilePath(cmd *cobra.Command) string { + defaultConfigPath := path.Join(misc.GetDotFolderPath(), "config.yaml") + cwd, err := os.Getwd() if err != nil { - return err + return defaultConfigPath + } + + if cmd != nil { + configPathOverride, err := cmd.Flags().GetString(ConfigPathFlag) + if err == nil { + if configPathOverride != "" { + resolvedConfigPath, err := filepath.Abs(configPathOverride) + if err != nil { + log.Error().Err(err).Msg("--config-path flag path cannot be resolved") + } else { + return resolvedConfigPath + } + } + } else { + log.Error().Err(err).Msg("--config-path flag parser error") + } } cwdConfig := filepath.Join(cwd, fmt.Sprintf("%s.yaml", misc.Program)) reader, err := os.Open(cwdConfig) if err != nil { - reader, err = os.Open(ConfigFilePath) - if err != nil { - return err - } + return defaultConfigPath } else { - ConfigFilePath = cwdConfig + reader.Close() + return cwdConfig + } +} + +func loadConfigFile(config *ConfigStruct, silent bool) error { + reader, err := os.Open(ConfigFilePath) + if err != nil { + return err } defer reader.Close() @@ -176,9 +200,14 @@ func initFlag(f *pflag.Flag) { flagPath = append(flagPath, strings.Split(f.Name, "-")...) + flagPathJoined := strings.Join(flagPath, ".") + if strings.HasSuffix(flagPathJoined, ".config.path") { + return + } + sliceValue, isSliceValue := f.Value.(pflag.SliceValue) if !isSliceValue { - if err := mergeFlagValue(configElemValue, flagPath, strings.Join(flagPath, "."), f.Value.String()); err != nil { + if err := mergeFlagValue(configElemValue, flagPath, flagPathJoined, f.Value.String()); err != nil { log.Warn().Err(err).Send() } return @@ -191,7 +220,7 @@ func initFlag(f *pflag.Flag) { return } - if err := mergeFlagValues(configElemValue, flagPath, strings.Join(flagPath, "."), sliceValue.GetSlice()); err != nil { + if err := mergeFlagValues(configElemValue, flagPath, flagPathJoined, sliceValue.GetSlice()); err != nil { log.Warn().Err(err).Send() } }