diff --git a/pkg/probe/http/http_test.go b/pkg/probe/http/http_test.go index 096cb430439..9f9afc30d6f 100644 --- a/pkg/probe/http/http_test.go +++ b/pkg/probe/http/http_test.go @@ -314,3 +314,53 @@ func TestHTTPProbeChecker_NonLocalRedirects(t *testing.T) { }) } } + +func TestHTTPProbeChecker_HostHeaderPreservedAfterRedirect(t *testing.T) { + successHostHeader := "www.success.com" + failHostHeader := "www.fail.com" + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/redirect": + http.Redirect(w, r, "/success", http.StatusFound) + case "/success": + if r.Host == successHostHeader { + w.WriteHeader(http.StatusOK) + } else { + http.Error(w, "", http.StatusBadRequest) + } + default: + http.Error(w, "", http.StatusInternalServerError) + } + }) + server := httptest.NewServer(handler) + defer server.Close() + + testCases := map[string]struct { + hostHeader string + expectedResult probe.Result + }{ + "success": {successHostHeader, probe.Success}, + "fail": {failHostHeader, probe.Failure}, + } + for desc, test := range testCases { + headers := http.Header{} + headers.Add("Host", test.hostHeader) + t.Run(desc+"local", func(t *testing.T) { + followNonLocalRedirects := false + prober := New(followNonLocalRedirects) + target, err := url.Parse(server.URL + "/redirect") + require.NoError(t, err) + result, _, _ := prober.Probe(target, headers, wait.ForeverTestTimeout) + assert.Equal(t, test.expectedResult, result) + }) + t.Run(desc+"nonlocal", func(t *testing.T) { + followNonLocalRedirects := true + prober := New(followNonLocalRedirects) + target, err := url.Parse(server.URL + "/redirect") + require.NoError(t, err) + result, _, _ := prober.Probe(target, headers, wait.ForeverTestTimeout) + assert.Equal(t, test.expectedResult, result) + }) + } +}