diff --git a/pkg/probe/http/http_test.go b/pkg/probe/http/http_test.go index 7ae17f046b9..894d4783b12 100644 --- a/pkg/probe/http/http_test.go +++ b/pkg/probe/http/http_test.go @@ -38,26 +38,11 @@ import ( const FailureCode int = -1 -func setEnv(key, value string) func() { - originalValue := os.Getenv(key) - os.Setenv(key, value) - if len(originalValue) > 0 { - return func() { - os.Setenv(key, originalValue) - } +func unsetEnv(t testing.TB, key string) { + if originalValue, ok := os.LookupEnv(key); ok { + t.Cleanup(func() { os.Setenv(key, originalValue) }) + os.Unsetenv(key) } - return func() {} -} - -func unsetEnv(key string) func() { - originalValue := os.Getenv(key) - os.Unsetenv(key) - if len(originalValue) > 0 { - return func() { - os.Setenv(key, originalValue) - } - } - return func() {} } func TestHTTPProbeProxy(t *testing.T) { @@ -70,10 +55,10 @@ func TestHTTPProbeProxy(t *testing.T) { localProxy := server.URL - defer setEnv("http_proxy", localProxy)() - defer setEnv("HTTP_PROXY", localProxy)() - defer unsetEnv("no_proxy")() - defer unsetEnv("NO_PROXY")() + t.Setenv("http_proxy", localProxy) + t.Setenv("HTTP_PROXY", localProxy) + unsetEnv(t, "no_proxy") + unsetEnv(t, "NO_PROXY") followNonLocalRedirects := true prober := New(followNonLocalRedirects) diff --git a/pkg/scheduler/framework/plugins/nodevolumelimits/non_csi_test.go b/pkg/scheduler/framework/plugins/nodevolumelimits/non_csi_test.go index ccab287182d..5b0b71d562b 100644 --- a/pkg/scheduler/framework/plugins/nodevolumelimits/non_csi_test.go +++ b/pkg/scheduler/framework/plugins/nodevolumelimits/non_csi_test.go @@ -20,7 +20,6 @@ import ( "context" "errors" "fmt" - "os" "reflect" "strings" "testing" @@ -942,8 +941,6 @@ func TestGCEPDLimits(t *testing.T) { } func TestGetMaxVols(t *testing.T) { - previousValue := os.Getenv(KubeMaxPDVols) - tests := []struct { rawMaxVols string expected int @@ -968,18 +965,13 @@ func TestGetMaxVols(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - os.Setenv(KubeMaxPDVols, test.rawMaxVols) + t.Setenv(KubeMaxPDVols, test.rawMaxVols) result := getMaxVolLimitFromEnv() if result != test.expected { t.Errorf("expected %v got %v", test.expected, result) } }) } - - os.Unsetenv(KubeMaxPDVols) - if previousValue != "" { - os.Setenv(KubeMaxPDVols, previousValue) - } } func getFakePVCLister(filterName string) fakeframework.PersistentVolumeClaimLister { diff --git a/pkg/util/env/env_test.go b/pkg/util/env/env_test.go index 883f3422d46..23f98176b2e 100644 --- a/pkg/util/env/env_test.go +++ b/pkg/util/env/env_test.go @@ -17,7 +17,6 @@ limitations under the License. package env import ( - "os" "strconv" "testing" @@ -30,7 +29,7 @@ func TestGetEnvAsStringOrFallback(t *testing.T) { assert := assert.New(t) key := "FLOCKER_SET_VAR" - os.Setenv(key, expected) + t.Setenv(key, expected) assert.Equal(expected, GetEnvAsStringOrFallback(key, "~"+expected)) key = "FLOCKER_UNSET_VAR" @@ -43,7 +42,7 @@ func TestGetEnvAsIntOrFallback(t *testing.T) { assert := assert.New(t) key := "FLOCKER_SET_VAR" - os.Setenv(key, strconv.Itoa(expected)) + t.Setenv(key, strconv.Itoa(expected)) returnVal, _ := GetEnvAsIntOrFallback(key, 1) assert.Equal(expected, returnVal) @@ -52,7 +51,7 @@ func TestGetEnvAsIntOrFallback(t *testing.T) { assert.Equal(expected, returnVal) key = "FLOCKER_SET_VAR" - os.Setenv(key, "not-an-int") + t.Setenv(key, "not-an-int") returnVal, err := GetEnvAsIntOrFallback(key, 1) assert.Equal(expected, returnVal) if err == nil { @@ -66,7 +65,7 @@ func TestGetEnvAsFloat64OrFallback(t *testing.T) { assert := assert.New(t) key := "FLOCKER_SET_VAR" - os.Setenv(key, "1.0") + t.Setenv(key, "1.0") returnVal, _ := GetEnvAsFloat64OrFallback(key, 2.0) assert.Equal(expected, returnVal) @@ -75,7 +74,7 @@ func TestGetEnvAsFloat64OrFallback(t *testing.T) { assert.Equal(expected, returnVal) key = "FLOCKER_SET_VAR" - os.Setenv(key, "not-a-float") + t.Setenv(key, "not-a-float") returnVal, err := GetEnvAsFloat64OrFallback(key, 1.0) assert.Equal(expected, returnVal) assert.EqualError(err, "strconv.ParseFloat: parsing \"not-a-float\": invalid syntax")