Improve ResourceLocation API, allow proxy to use authenticated transport

This commit is contained in:
Jordan Liggitt
2015-03-23 14:42:39 -04:00
parent 1dc7bcf53b
commit a75b501821
18 changed files with 247 additions and 72 deletions

View File

@@ -24,6 +24,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"sync"
@@ -226,7 +227,7 @@ type SimpleRESTStorage struct {
// The id requested, and location to return for ResourceLocation
requestedResourceLocationID string
resourceLocation string
resourceLocation *url.URL
expectedResourceNamespace string
// If non-nil, called inside the WorkFunc when answering update, delete, create.
@@ -315,18 +316,23 @@ func (storage *SimpleRESTStorage) Watch(ctx api.Context, label labels.Selector,
}
// Implement Redirector.
func (storage *SimpleRESTStorage) ResourceLocation(ctx api.Context, id string) (string, error) {
var _ = rest.Redirector(&SimpleRESTStorage{})
// Implement Redirector.
func (storage *SimpleRESTStorage) ResourceLocation(ctx api.Context, id string) (*url.URL, http.RoundTripper, error) {
storage.checkContext(ctx)
// validate that the namespace context on the request matches the expected input
storage.requestedResourceNamespace = api.NamespaceValue(ctx)
if storage.expectedResourceNamespace != storage.requestedResourceNamespace {
return "", fmt.Errorf("Expected request namespace %s, but got namespace %s", storage.expectedResourceNamespace, storage.requestedResourceNamespace)
return nil, nil, fmt.Errorf("Expected request namespace %s, but got namespace %s", storage.expectedResourceNamespace, storage.requestedResourceNamespace)
}
storage.requestedResourceLocationID = id
if err := storage.errors["resourceLocation"]; err != nil {
return "", err
return nil, nil, err
}
return storage.resourceLocation, nil
// Make a copy so the internal URL never gets mutated
locationCopy := *storage.resourceLocation
return &locationCopy, nil, nil
}
type LegacyRESTStorage struct {

View File

@@ -19,6 +19,7 @@ package apiserver
import (
"bytes"
"compress/gzip"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
@@ -140,7 +141,7 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
location, err := redirector.ResourceLocation(ctx, id)
location, transport, err := redirector.ResourceLocation(ctx, id)
if err != nil {
httplog.LogOf(req, w).Addf("Error getting ResourceLocation: %v", err)
status := errToAPIStatus(err)
@@ -148,22 +149,31 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
httpCode = status.Code
return
}
destURL, err := url.Parse(location)
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
httpCode = status.Code
if location == nil {
httplog.LogOf(req, w).Addf("ResourceLocation for %v returned nil", id)
notFound(w, req)
httpCode = http.StatusNotFound
return
}
if destURL.Scheme == "" {
// If no scheme was present in location, url.Parse sometimes mistakes
// hosts for paths.
destURL.Host = location
// Default to http
if location.Scheme == "" {
location.Scheme = "http"
}
destURL.Path = remainder
destURL.RawQuery = req.URL.RawQuery
newReq, err := http.NewRequest(req.Method, destURL.String(), req.Body)
// Add the subpath
if len(remainder) > 0 {
location.Path = singleJoiningSlash(location.Path, remainder)
}
// Start with anything returned from the storage, and add the original request's parameters
values := location.Query()
for k, vs := range req.URL.Query() {
for _, v := range vs {
values.Add(k, v)
}
}
location.RawQuery = values.Encode()
newReq, err := http.NewRequest(req.Method, location.String(), req.Body)
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
@@ -177,29 +187,34 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// TODO convert this entire proxy to an UpgradeAwareProxy similar to
// https://github.com/openshift/origin/blob/master/pkg/util/httpproxy/upgradeawareproxy.go.
// That proxy needs to be modified to support multiple backends, not just 1.
if r.tryUpgrade(w, req, newReq, destURL) {
if r.tryUpgrade(w, req, newReq, location, transport) {
return
}
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: destURL.Host})
proxy.Transport = &proxyTransport{
proxyScheme: req.URL.Scheme,
proxyHost: req.URL.Host,
proxyPathPrepend: requestInfo.URLPath(),
proxy := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: location.Scheme, Host: location.Host})
if transport == nil {
prepend := path.Join(r.prefix, resource, id)
if len(namespace) > 0 {
prepend = path.Join(r.prefix, "namespaces", namespace, resource, id)
}
transport = &proxyTransport{
proxyScheme: req.URL.Scheme,
proxyHost: req.URL.Host,
proxyPathPrepend: prepend,
}
}
proxy.Transport = transport
proxy.FlushInterval = 200 * time.Millisecond
proxy.ServeHTTP(w, newReq)
}
// tryUpgrade returns true if the request was handled.
func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Request, destURL *url.URL) bool {
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || len(req.Header.Get(httpstream.HeaderUpgrade)) == 0 {
func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Request, location *url.URL, transport http.RoundTripper) bool {
if !httpstream.IsUpgradeRequest(req) {
return false
}
//TODO support TLS? Doesn't look like proxyTransport does anything special ...
dialAddr := netutil.CanonicalAddr(destURL)
backendConn, err := net.Dial("tcp", dialAddr)
backendConn, err := dialURL(location, transport)
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
@@ -246,6 +261,54 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque
return true
}
func dialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
switch url.Scheme {
case "http":
return net.Dial("tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
if transport != nil {
httpTransport, ok := transport.(*http.Transport)
if ok {
tlsConfig = httpTransport.TLSClientConfig
}
}
// Dial
tlsConn, err := tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
}
// Verify
host, _, _ := net.SplitHostPort(dialAddr)
if err := tlsConn.VerifyHostname(host); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
default:
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
}
}
// borrowed from net/http/httputil/reverseproxy.go
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
type proxyTransport struct {
proxyScheme string
proxyHost string

View File

@@ -275,9 +275,10 @@ func TestProxy(t *testing.T) {
}))
defer proxyServer.Close()
serverURL, _ := url.Parse(proxyServer.URL)
simpleStorage := &SimpleRESTStorage{
errors: map[string]error{},
resourceLocation: proxyServer.URL,
resourceLocation: serverURL,
expectedResourceNamespace: item.reqNamespace,
}
@@ -335,9 +336,10 @@ func TestProxyUpgrade(t *testing.T) {
}))
defer backendServer.Close()
serverURL, _ := url.Parse(backendServer.URL)
simpleStorage := &SimpleRESTStorage{
errors: map[string]error{},
resourceLocation: backendServer.URL,
resourceLocation: serverURL,
expectedResourceNamespace: "myns",
}

View File

@@ -17,7 +17,6 @@ limitations under the License.
package apiserver
import (
"fmt"
"net/http"
"time"
@@ -79,15 +78,26 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
location, err := redirector.ResourceLocation(ctx, id)
location, _, err := redirector.ResourceLocation(ctx, id)
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w)
httpCode = status.Code
return
}
if location == nil {
httplog.LogOf(req, w).Addf("ResourceLocation for %v returned nil", id)
notFound(w, req)
httpCode = http.StatusNotFound
return
}
w.Header().Set("Location", fmt.Sprintf("http://%s", location))
// Default to http
if location.Scheme == "" {
location.Scheme = "http"
}
w.Header().Set("Location", location.String())
w.WriteHeader(http.StatusTemporaryRedirect)
httpCode = http.StatusTemporaryRedirect
}

View File

@@ -53,7 +53,7 @@ func TestRedirect(t *testing.T) {
for _, item := range table {
simpleStorage.errors["resourceLocation"] = item.err
simpleStorage.resourceLocation = item.id
simpleStorage.resourceLocation = &url.URL{Host: item.id}
resp, err := client.Get(server.URL + "/api/version/redirect/foo/" + item.id)
if resp == nil {
t.Fatalf("Unexpected nil resp")
@@ -104,7 +104,7 @@ func TestRedirectWithNamespaces(t *testing.T) {
for _, item := range table {
simpleStorage.errors["resourceLocation"] = item.err
simpleStorage.resourceLocation = item.id
simpleStorage.resourceLocation = &url.URL{Host: item.id}
resp, err := client.Get(server.URL + "/api/version/redirect/namespaces/other/foo/" + item.id)
if resp == nil {
t.Fatalf("Unexpected nil resp")