mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-08 03:33:56 +00:00
Remove CORS headers from pod proxy responses
The API server sends its own CORS headers in its response, and if the proxied pod response also includes its own headers, it confuses clients.
This commit is contained in:
parent
521446503a
commit
e8af67c180
@ -214,9 +214,37 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun
|
|||||||
suffix += "/"
|
suffix += "/"
|
||||||
}
|
}
|
||||||
pathPrepend := strings.TrimSuffix(url.Path, suffix)
|
pathPrepend := strings.TrimSuffix(url.Path, suffix)
|
||||||
return &proxy.Transport{
|
internalTransport := &proxy.Transport{
|
||||||
Scheme: scheme,
|
Scheme: scheme,
|
||||||
Host: host,
|
Host: host,
|
||||||
PathPrepend: pathPrepend,
|
PathPrepend: pathPrepend,
|
||||||
}
|
}
|
||||||
|
return &corsRemovingTransport{
|
||||||
|
RoundTripper: internalTransport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// corsRemovingTransport is a wrapper for an internal transport. It removes CORS headers
|
||||||
|
// from the internal response.
|
||||||
|
type corsRemovingTransport struct {
|
||||||
|
http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := p.RoundTripper.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
removeCORSHeaders(resp)
|
||||||
|
return resp, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCORSHeaders strip CORS headers sent from the backend
|
||||||
|
// This should be called on all responses before returning
|
||||||
|
func removeCORSHeaders(resp *http.Response) {
|
||||||
|
resp.Header.Del("Access-Control-Allow-Credentials")
|
||||||
|
resp.Header.Del("Access-Control-Allow-Headers")
|
||||||
|
resp.Header.Del("Access-Control-Allow-Methods")
|
||||||
|
resp.Header.Del("Access-Control-Allow-Origin")
|
||||||
}
|
}
|
||||||
|
@ -51,8 +51,10 @@ func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range s.responseHeader {
|
if s.responseHeader != nil {
|
||||||
w.Header().Add(k, v)
|
for k, v := range s.responseHeader {
|
||||||
|
w.Header().Add(k, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w.Write([]byte(s.responseBody))
|
w.Write([]byte(s.responseBody))
|
||||||
}
|
}
|
||||||
@ -71,7 +73,7 @@ func validateParameters(t *testing.T, name string, actual url.Values, expected m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string) {
|
func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
|
||||||
for k, v := range expected {
|
for k, v := range expected {
|
||||||
actualValue, ok := actual[k]
|
actualValue, ok := actual[k]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -83,17 +85,28 @@ func validateHeaders(t *testing.T, name string, actual http.Header, expected map
|
|||||||
name, k, actualValue, v)
|
name, k, actualValue, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if notExpected == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, h := range notExpected {
|
||||||
|
if _, present := actual[h]; present {
|
||||||
|
t.Errorf("%s: unexpected header: %s", name, h)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
func TestServeHTTP(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
requestPath string
|
requestPath string
|
||||||
expectedPath string
|
expectedPath string
|
||||||
requestBody string
|
requestBody string
|
||||||
requestParams map[string]string
|
requestParams map[string]string
|
||||||
requestHeader map[string]string
|
requestHeader map[string]string
|
||||||
|
responseHeader map[string]string
|
||||||
|
expectedRespHeader map[string]string
|
||||||
|
notExpectedRespHeader []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "root path, simple get",
|
name: "root path, simple get",
|
||||||
@ -128,14 +141,37 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
requestPath: "",
|
requestPath: "",
|
||||||
expectedPath: "/",
|
expectedPath: "/",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "remove CORS headers",
|
||||||
|
method: "GET",
|
||||||
|
requestPath: "/some/path",
|
||||||
|
expectedPath: "/some/path",
|
||||||
|
responseHeader: map[string]string{
|
||||||
|
"Header1": "value1",
|
||||||
|
"Access-Control-Allow-Origin": "some.server",
|
||||||
|
"Access-Control-Allow-Methods": "GET"},
|
||||||
|
expectedRespHeader: map[string]string{
|
||||||
|
"Header1": "value1",
|
||||||
|
},
|
||||||
|
notExpectedRespHeader: []string{
|
||||||
|
"Access-Control-Allow-Origin",
|
||||||
|
"Access-Control-Allow-Methods",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
func() {
|
func() {
|
||||||
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
|
backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
|
||||||
|
backendResponseHeader := test.responseHeader
|
||||||
|
// Test a simple header if not specified in the test
|
||||||
|
if backendResponseHeader == nil && test.expectedRespHeader == nil {
|
||||||
|
backendResponseHeader = map[string]string{"Content-Type": "text/html"}
|
||||||
|
test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
|
||||||
|
}
|
||||||
backendHandler := &SimpleBackendHandler{
|
backendHandler := &SimpleBackendHandler{
|
||||||
responseBody: backendResponse,
|
responseBody: backendResponse,
|
||||||
responseHeader: map[string]string{"Content-Type": "text/html"},
|
responseHeader: backendResponseHeader,
|
||||||
}
|
}
|
||||||
backendServer := httptest.NewServer(backendHandler)
|
backendServer := httptest.NewServer(backendHandler)
|
||||||
defer backendServer.Close()
|
defer backendServer.Close()
|
||||||
@ -197,9 +233,13 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
|
validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
|
||||||
test.requestHeader)
|
test.requestHeader, nil)
|
||||||
|
|
||||||
// Validate proxy response
|
// Validate proxy response
|
||||||
|
|
||||||
|
// Response Headers
|
||||||
|
validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
|
||||||
|
|
||||||
// Validate Body
|
// Validate Body
|
||||||
responseBody, err := ioutil.ReadAll(res.Body)
|
responseBody, err := ioutil.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -297,7 +337,7 @@ func TestDefaultProxyTransport(t *testing.T) {
|
|||||||
Location: locURL,
|
Location: locURL,
|
||||||
}
|
}
|
||||||
result := h.defaultProxyTransport(URL)
|
result := h.defaultProxyTransport(URL)
|
||||||
transport := result.(*proxy.Transport)
|
transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport)
|
||||||
if transport.Scheme != test.expectedScheme {
|
if transport.Scheme != test.expectedScheme {
|
||||||
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
|
t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user