From 43a410b9dd12f4ec320668682c3830afa5992f84 Mon Sep 17 00:00:00 2001
From: "M. Mert Yildiran" <me@mertyildiran.com>
Date: Wed, 16 Apr 2025 20:28:21 +0300
Subject: [PATCH] Add `--config-path` flag to root command (#1744)

* Add `--config-path` flag to root command

* Use `filepath.Abs`

---------

Co-authored-by: Alon Girmonsky <1990761+alongir@users.noreply.github.com>
---
 cmd/config.go    |  3 +--
 cmd/root.go      |  1 +
 config/config.go | 49 ++++++++++++++++++++++++++++++++++++++----------
 3 files changed, 41 insertions(+), 12 deletions(-)

diff --git a/cmd/config.go b/cmd/config.go
index 74292d97b..715dea371 100644
--- a/cmd/config.go
+++ b/cmd/config.go
@@ -2,7 +2,6 @@ package cmd
 
 import (
 	"fmt"
-	"path"
 
 	"github.com/creasty/defaults"
 	"github.com/kubeshark/kubeshark/config"
@@ -52,5 +51,5 @@ func init() {
 		log.Debug().Err(err).Send()
 	}
 
-	configCmd.Flags().BoolP(configStructs.RegenerateConfigName, "r", defaultConfig.Config.Regenerate, fmt.Sprintf("Regenerate the config file with default values to path %s", path.Join(misc.GetDotFolderPath(), "config.yaml")))
+	configCmd.Flags().BoolP(configStructs.RegenerateConfigName, "r", defaultConfig.Config.Regenerate, fmt.Sprintf("Regenerate the config file with default values to path %s", config.GetConfigFilePath(nil)))
 }
diff --git a/cmd/root.go b/cmd/root.go
index 0f9e2c646..178cad30c 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -33,6 +33,7 @@ func init() {
 
 	rootCmd.PersistentFlags().StringSlice(config.SetCommandName, []string{}, fmt.Sprintf("Override values using --%s", config.SetCommandName))
 	rootCmd.PersistentFlags().BoolP(config.DebugFlag, "d", false, "Enable debug mode")
+	rootCmd.PersistentFlags().String(config.ConfigPathFlag, "", fmt.Sprintf("Set the config path, default: %s", config.GetConfigFilePath(nil)))
 }
 
 // Execute adds all child commands to the root command and sets flags appropriately.
diff --git a/config/config.go b/config/config.go
index 953b5e7a0..0d6d6f058 100644
--- a/config/config.go
+++ b/config/config.go
@@ -28,6 +28,7 @@ const (
 	FieldNameTag   = "yaml"
 	ReadonlyTag    = "readonly"
 	DebugFlag      = "debug"
+	ConfigPathFlag = "config-path"
 )
 
 var (
@@ -82,7 +83,7 @@ func InitConfig(cmd *cobra.Command) error {
 		return err
 	}
 
-	ConfigFilePath = path.Join(misc.GetDotFolderPath(), "config.yaml")
+	ConfigFilePath = GetConfigFilePath(cmd)
 	if err := loadConfigFile(&Config, utils.Contains([]string{
 		"manifests",
 		"license",
@@ -134,21 +135,44 @@ func WriteConfig(config *ConfigStruct) error {
 	return nil
 }
 
-func loadConfigFile(config *ConfigStruct, silent bool) error {
+func GetConfigFilePath(cmd *cobra.Command) string {
+	defaultConfigPath := path.Join(misc.GetDotFolderPath(), "config.yaml")
+
 	cwd, err := os.Getwd()
 	if err != nil {
-		return err
+		return defaultConfigPath
+	}
+
+	if cmd != nil {
+		configPathOverride, err := cmd.Flags().GetString(ConfigPathFlag)
+		if err == nil {
+			if configPathOverride != "" {
+				resolvedConfigPath, err := filepath.Abs(configPathOverride)
+				if err != nil {
+					log.Error().Err(err).Msg("--config-path flag path cannot be resolved")
+				} else {
+					return resolvedConfigPath
+				}
+			}
+		} else {
+			log.Error().Err(err).Msg("--config-path flag parser error")
+		}
 	}
 
 	cwdConfig := filepath.Join(cwd, fmt.Sprintf("%s.yaml", misc.Program))
 	reader, err := os.Open(cwdConfig)
 	if err != nil {
-		reader, err = os.Open(ConfigFilePath)
-		if err != nil {
-			return err
-		}
+		return defaultConfigPath
 	} else {
-		ConfigFilePath = cwdConfig
+		reader.Close()
+		return cwdConfig
+	}
+}
+
+func loadConfigFile(config *ConfigStruct, silent bool) error {
+	reader, err := os.Open(ConfigFilePath)
+	if err != nil {
+		return err
 	}
 	defer reader.Close()
 
@@ -176,9 +200,14 @@ func initFlag(f *pflag.Flag) {
 
 	flagPath = append(flagPath, strings.Split(f.Name, "-")...)
 
+	flagPathJoined := strings.Join(flagPath, ".")
+	if strings.HasSuffix(flagPathJoined, ".config.path") {
+		return
+	}
+
 	sliceValue, isSliceValue := f.Value.(pflag.SliceValue)
 	if !isSliceValue {
-		if err := mergeFlagValue(configElemValue, flagPath, strings.Join(flagPath, "."), f.Value.String()); err != nil {
+		if err := mergeFlagValue(configElemValue, flagPath, flagPathJoined, f.Value.String()); err != nil {
 			log.Warn().Err(err).Send()
 		}
 		return
@@ -191,7 +220,7 @@ func initFlag(f *pflag.Flag) {
 		return
 	}
 
-	if err := mergeFlagValues(configElemValue, flagPath, strings.Join(flagPath, "."), sliceValue.GetSlice()); err != nil {
+	if err := mergeFlagValues(configElemValue, flagPath, flagPathJoined, sliceValue.GetSlice()); err != nil {
 		log.Warn().Err(err).Send()
 	}
 }