Fix proxy rewriting

This commit is contained in:
Daniel Smith 2015-06-18 17:52:36 -07:00
parent 32114f6256
commit ddbe4c914f
3 changed files with 41 additions and 15 deletions

View File

@ -20,6 +20,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@ -55,6 +56,8 @@ type ProxyHandler struct {
} }
func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
proxyHandlerTraceID := rand.Int63()
var verb string var verb string
var apiResource string var apiResource string
var httpCode int var httpCode int
@ -108,7 +111,7 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
location, transport, err := redirector.ResourceLocation(ctx, id) location, roundTripper, err := redirector.ResourceLocation(ctx, id)
if err != nil { if err != nil {
httplog.LogOf(req, w).Addf("Error getting ResourceLocation: %v", err) httplog.LogOf(req, w).Addf("Error getting ResourceLocation: %v", err)
status := errToAPIStatus(err) status := errToAPIStatus(err)
@ -123,8 +126,11 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
// If we have a custom dialer, and no pre-existing transport, initialize it to use the dialer. // If we have a custom dialer, and no pre-existing transport, initialize it to use the dialer.
if transport == nil && r.dial != nil { if roundTripper == nil && r.dial != nil {
transport = &http.Transport{Dial: r.dial} glog.V(5).Infof("[%x: %v] making a dial-only transport...", proxyHandlerTraceID, req.URL)
roundTripper = &http.Transport{Dial: r.dial}
} else if roundTripper != nil {
glog.V(5).Infof("[%x: %v] using transport %T...", proxyHandlerTraceID, req.URL, roundTripper)
} }
// Default to http // Default to http
@ -158,7 +164,7 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// TODO convert this entire proxy to an UpgradeAwareProxy similar to // TODO convert this entire proxy to an UpgradeAwareProxy similar to
// https://github.com/openshift/origin/blob/master/pkg/util/httpproxy/upgradeawareproxy.go. // https://github.com/openshift/origin/blob/master/pkg/util/httpproxy/upgradeawareproxy.go.
// That proxy needs to be modified to support multiple backends, not just 1. // That proxy needs to be modified to support multiple backends, not just 1.
if r.tryUpgrade(w, req, newReq, location, transport) { if r.tryUpgrade(w, req, newReq, location, roundTripper) {
return return
} }
@ -175,19 +181,33 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
start := time.Now()
glog.V(4).Infof("[%x] Beginning proxy %s...", proxyHandlerTraceID, req.URL)
defer func() {
glog.V(4).Infof("[%x] Proxy %v finished %v.", proxyHandlerTraceID, req.URL, time.Now().Sub(start))
}()
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: location.Scheme, Host: location.Host}) proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: location.Scheme, Host: location.Host})
if transport == nil { alreadyRewriting := false
if roundTripper != nil {
_, alreadyRewriting = roundTripper.(*proxyutil.Transport)
glog.V(5).Infof("[%x] Not making a reriting transport for proxy %s...", proxyHandlerTraceID, req.URL)
}
if !alreadyRewriting {
glog.V(5).Infof("[%x] making a transport for proxy %s...", proxyHandlerTraceID, req.URL)
prepend := path.Join(r.prefix, resource, id) prepend := path.Join(r.prefix, resource, id)
if len(namespace) > 0 { if len(namespace) > 0 {
prepend = path.Join(r.prefix, "namespaces", namespace, resource, id) prepend = path.Join(r.prefix, "namespaces", namespace, resource, id)
} }
transport = &proxyutil.Transport{ pTransport := &proxyutil.Transport{
Scheme: req.URL.Scheme, Scheme: req.URL.Scheme,
Host: req.URL.Host, Host: req.URL.Host,
PathPrepend: prepend, PathPrepend: prepend,
RoundTripper: roundTripper,
} }
roundTripper = pTransport
} }
proxy.Transport = transport proxy.Transport = roundTripper
proxy.FlushInterval = 200 * time.Millisecond proxy.FlushInterval = 200 * time.Millisecond
proxy.ServeHTTP(w, newReq) proxy.ServeHTTP(w, newReq)
} }

View File

@ -98,10 +98,10 @@ func TestAccept(t *testing.T) {
acceptPaths: DefaultPathAcceptRE, acceptPaths: DefaultPathAcceptRE,
rejectPaths: DefaultPathRejectRE, rejectPaths: DefaultPathRejectRE,
acceptHosts: DefaultHostAcceptRE, acceptHosts: DefaultHostAcceptRE,
path: "/foo/v1/pods", path: "/ui",
host: "localhost", host: "localhost",
method: "GET", method: "GET",
expectAccept: false, expectAccept: true,
}, },
{ {
acceptPaths: DefaultPathAcceptRE, acceptPaths: DefaultPathAcceptRE,
@ -230,7 +230,7 @@ func TestAPIRequests(t *testing.T) {
// httptest.NewServer should always generate a valid URL. // httptest.NewServer should always generate a valid URL.
target, _ := url.Parse(ts.URL) target, _ := url.Parse(ts.URL)
proxy := newProxyServer(target) proxy := newProxy(target)
tests := []struct{ method, body string }{ tests := []struct{ method, body string }{
{"GET", ""}, {"GET", ""},
@ -291,7 +291,7 @@ func TestPathHandling(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("%#v: %v", item, err) t.Fatalf("%#v: %v", item, err)
} }
pts := httptest.NewServer(p.mux) pts := httptest.NewServer(p.handler)
defer pts.Close() defer pts.Close()
r, err := http.Get(pts.URL + item.reqPath) r, err := http.Get(pts.URL + item.reqPath)

View File

@ -73,6 +73,8 @@ type Transport struct {
Scheme string Scheme string
Host string Host string
PathPrepend string PathPrepend string
http.RoundTripper
} }
// RoundTrip implements the http.RoundTripper interface // RoundTrip implements the http.RoundTripper interface
@ -86,7 +88,11 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Forwarded-Host", t.Host) req.Header.Set("X-Forwarded-Host", t.Host)
req.Header.Set("X-Forwarded-Proto", t.Scheme) req.Header.Set("X-Forwarded-Proto", t.Scheme)
resp, err := http.DefaultTransport.RoundTrip(req) rt := t.RoundTripper
if rt == nil {
rt = http.DefaultTransport
}
resp, err := rt.RoundTrip(req)
if err != nil { if err != nil {
message := fmt.Sprintf("Error: '%s'\nTrying to reach: '%v'", err.Error(), req.URL.String()) message := fmt.Sprintf("Error: '%s'\nTrying to reach: '%v'", err.Error(), req.URL.String())