diff --git a/pkg/api/rest/rest.go b/pkg/api/rest/rest.go index d4db3d0a887..57be322a7ed 100644 --- a/pkg/api/rest/rest.go +++ b/pkg/api/rest/rest.go @@ -202,20 +202,24 @@ type Redirector interface { ResourceLocation(ctx api.Context, id string) (remoteLocation *url.URL, transport http.RoundTripper, err error) } -// ConnectHandler is a handler for HTTP connection requests. It extends the standard -// http.Handler interface by adding a method that returns an error object if an error -// occurred during the handling of the request. -type ConnectHandler interface { - http.Handler - - // RequestError returns an error if one occurred during handling of an HTTP request - RequestError() error +// Responder abstracts the normal response behavior for a REST method and is passed to callers that +// may wish to handle the response directly in some cases, but delegate to the normal error or object +// behavior in other cases. +type Responder interface { + // Object writes the provided object to the response. Invoking this method multiple times is undefined. + Object(statusCode int, obj runtime.Object) + // Error writes the provided error to the response. This method may only be invoked once. + Error(err error) } -// Connecter is a storage object that responds to a connection request +// Connecter is a storage object that responds to a connection request. type Connecter interface { - // Connect returns a ConnectHandler that will handle the request/response for a request - Connect(ctx api.Context, id string, options runtime.Object) (ConnectHandler, error) + // Connect returns an http.Handler that will handle the request/response for a given API invocation. + // The provided responder may be used for common API responses. The responder will write both status + // code and body, so the ServeHTTP method should exit after invoking the responder. The Handler will + // be used for a single API request and then discarded. The Responder is guaranteed to write to the + // same http.ResponseWriter passed to ServeHTTP. + Connect(ctx api.Context, id string, options runtime.Object, r Responder) (http.Handler, error) // NewConnectOptions returns an empty options object that will be used to pass // options to the Connect method. If nil, then a nil options object is passed to diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 7b35219e2a9..86d454c4297 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -338,19 +338,14 @@ func (s *SimpleStream) InputStream(version, accept string) (io.ReadCloser, bool, return s, false, s.contentType, s.err } -type SimpleConnectHandler struct { +type OutputConnect struct { response string - err error } -func (h *SimpleConnectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (h *OutputConnect) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.Write([]byte(h.response)) } -func (h *SimpleConnectHandler) RequestError() error { - return h.err -} - func (storage *SimpleRESTStorage) Get(ctx api.Context, id string) (runtime.Object, error) { storage.checkContext(ctx) if id == "binary" { @@ -448,10 +443,13 @@ func (storage *SimpleRESTStorage) ResourceLocation(ctx api.Context, id string) ( // Implement Connecter type ConnecterRESTStorage struct { - connectHandler rest.ConnectHandler + connectHandler http.Handler + handlerFunc func() http.Handler + emptyConnectOptions runtime.Object receivedConnectOptions runtime.Object receivedID string + receivedResponder rest.Responder takesPath string } @@ -462,9 +460,13 @@ func (s *ConnecterRESTStorage) New() runtime.Object { return &apiservertesting.Simple{} } -func (s *ConnecterRESTStorage) Connect(ctx api.Context, id string, options runtime.Object) (rest.ConnectHandler, error) { +func (s *ConnecterRESTStorage) Connect(ctx api.Context, id string, options runtime.Object, responder rest.Responder) (http.Handler, error) { s.receivedConnectOptions = options s.receivedID = id + s.receivedResponder = responder + if s.handlerFunc != nil { + return s.handlerFunc(), nil + } return s.connectHandler, nil } @@ -1287,7 +1289,7 @@ func TestConnect(t *testing.T) { responseText := "Hello World" itemID := "theID" connectStorage := &ConnecterRESTStorage{ - connectHandler: &SimpleConnectHandler{ + connectHandler: &OutputConnect{ response: responseText, }, } @@ -1320,9 +1322,92 @@ func TestConnect(t *testing.T) { } } +func TestConnectResponderObject(t *testing.T) { + itemID := "theID" + simple := &apiservertesting.Simple{Other: "foo"} + connectStorage := &ConnecterRESTStorage{} + connectStorage.handlerFunc = func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + connectStorage.receivedResponder.Object(http.StatusCreated, simple) + }) + } + storage := map[string]rest.Storage{ + "simple": &SimpleRESTStorage{}, + "simple/connect": connectStorage, + } + handler := handle(storage) + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL + "/api/version/namespaces/default/simple/" + itemID + "/connect") + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusCreated { + t.Errorf("unexpected response: %#v", resp) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if connectStorage.receivedID != itemID { + t.Errorf("Unexpected item id. Expected: %s. Actual: %s.", itemID, connectStorage.receivedID) + } + obj, err := codec.Decode(body) + if err != nil { + t.Fatal(err) + } + if !api.Semantic.DeepEqual(obj, simple) { + t.Errorf("Unexpected response: %#v", obj) + } +} + +func TestConnectResponderError(t *testing.T) { + itemID := "theID" + connectStorage := &ConnecterRESTStorage{} + connectStorage.handlerFunc = func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + connectStorage.receivedResponder.Error(apierrs.NewForbidden("simple", itemID, errors.New("you are terminated"))) + }) + } + storage := map[string]rest.Storage{ + "simple": &SimpleRESTStorage{}, + "simple/connect": connectStorage, + } + handler := handle(storage) + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL + "/api/version/namespaces/default/simple/" + itemID + "/connect") + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Errorf("unexpected response: %#v", resp) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if connectStorage.receivedID != itemID { + t.Errorf("Unexpected item id. Expected: %s. Actual: %s.", itemID, connectStorage.receivedID) + } + obj, err := codec.Decode(body) + if err != nil { + t.Fatal(err) + } + if obj.(*unversioned.Status).Code != http.StatusForbidden { + t.Errorf("Unexpected response: %#v", obj) + } +} + func TestConnectWithOptionsRouteParams(t *testing.T) { connectStorage := &ConnecterRESTStorage{ - connectHandler: &SimpleConnectHandler{}, + connectHandler: &OutputConnect{}, emptyConnectOptions: &apiservertesting.SimpleGetOptions{}, } storage := map[string]rest.Storage{ @@ -1351,7 +1436,7 @@ func TestConnectWithOptions(t *testing.T) { responseText := "Hello World" itemID := "theID" connectStorage := &ConnecterRESTStorage{ - connectHandler: &SimpleConnectHandler{ + connectHandler: &OutputConnect{ response: responseText, }, emptyConnectOptions: &apiservertesting.SimpleGetOptions{}, @@ -1383,6 +1468,9 @@ func TestConnectWithOptions(t *testing.T) { if string(body) != responseText { t.Errorf("Unexpected response. Expected: %s. Actual: %s.", responseText, string(body)) } + if connectStorage.receivedResponder == nil { + t.Errorf("Unexpected responder") + } opts, ok := connectStorage.receivedConnectOptions.(*apiservertesting.SimpleGetOptions) if !ok { t.Errorf("Unexpected options type: %#v", connectStorage.receivedConnectOptions) @@ -1397,7 +1485,7 @@ func TestConnectWithOptionsAndPath(t *testing.T) { itemID := "theID" testPath := "a/b/c/def" connectStorage := &ConnecterRESTStorage{ - connectHandler: &SimpleConnectHandler{ + connectHandler: &OutputConnect{ response: responseText, }, emptyConnectOptions: &apiservertesting.SimpleGetOptions{}, diff --git a/pkg/apiserver/resthandler.go b/pkg/apiserver/resthandler.go index 200150e4cc2..eeec1163840 100644 --- a/pkg/apiserver/resthandler.go +++ b/pkg/apiserver/resthandler.go @@ -179,20 +179,30 @@ func ConnectResource(connecter rest.Connecter, scope RequestScope, admit admissi return } } - handler, err := connecter.Connect(ctx, name, opts) + handler, err := connecter.Connect(ctx, name, opts, &responder{scope: scope, req: req.Request, w: w}) if err != nil { errorJSON(err, scope.Codec, w) return } handler.ServeHTTP(w, req.Request) - err = handler.RequestError() - if err != nil { - errorJSON(err, scope.Codec, w) - return - } } } +// responder implements rest.Responder for assisting a connector in writing objects or errors. +type responder struct { + scope RequestScope + req *http.Request + w http.ResponseWriter +} + +func (r *responder) Object(statusCode int, obj runtime.Object) { + write(statusCode, r.scope.APIVersion, r.scope.Codec, obj, r.w, r.req) +} + +func (r *responder) Error(err error) { + errorJSON(err, r.scope.Codec, r.w) +} + // ListResource returns a function that handles retrieving a list of resources from a rest.Storage object. func ListResource(r rest.Lister, rw rest.Watcher, scope RequestScope, forceWatch bool, minRequestTimeout time.Duration) restful.RouteFunction { return func(req *restful.Request, res *restful.Response) { diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index cc4d74440f7..b53d3bef4b6 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -44,30 +44,32 @@ type UpgradeAwareProxyHandler struct { WrapTransport bool FlushInterval time.Duration MaxBytesPerSec int64 - err error + Responder ErrorResponder } const defaultFlushInterval = 200 * time.Millisecond -// NewUpgradeAwareProxyHandler creates a new proxy handler with a default flush interval -func NewUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool) *UpgradeAwareProxyHandler { +// ErrorResponder abstracts error reporting to the proxy handler to remove the need to hardcode a particular +// error format. +type ErrorResponder interface { + Error(err error) +} + +// NewUpgradeAwareProxyHandler creates a new proxy handler with a default flush interval. Responder is required for returning +// errors to the caller. +func NewUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder ErrorResponder) *UpgradeAwareProxyHandler { return &UpgradeAwareProxyHandler{ Location: location, Transport: transport, WrapTransport: wrapTransport, UpgradeRequired: upgradeRequired, FlushInterval: defaultFlushInterval, + Responder: responder, } } -// RequestError returns an error that occurred while handling request -func (h *UpgradeAwareProxyHandler) RequestError() error { - return h.err -} - // ServeHTTP handles the proxy request func (h *UpgradeAwareProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - h.err = nil if len(h.Location.Scheme) == 0 { h.Location.Scheme = "http" } @@ -75,7 +77,7 @@ func (h *UpgradeAwareProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Re return } if h.UpgradeRequired { - h.err = errors.NewBadRequest("Upgrade request required") + h.Responder.Error(errors.NewBadRequest("Upgrade request required")) return } @@ -108,7 +110,7 @@ func (h *UpgradeAwareProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Re newReq, err := http.NewRequest(req.Method, loc.String(), req.Body) if err != nil { - h.err = err + h.Responder.Error(err) return } newReq.Header = req.Header @@ -127,27 +129,27 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R backendConn, err := proxy.DialURL(h.Location, h.Transport) if err != nil { - h.err = err + h.Responder.Error(err) return true } defer backendConn.Close() requestHijackedConn, _, err := w.(http.Hijacker).Hijack() if err != nil { - h.err = err + h.Responder.Error(err) return true } defer requestHijackedConn.Close() newReq, err := http.NewRequest(req.Method, h.Location.String(), req.Body) if err != nil { - h.err = err + h.Responder.Error(err) return true } newReq.Header = req.Header if err = newReq.Write(backendConn); err != nil { - h.err = err + h.Responder.Error(err) return true } diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index 829b1a6d5ef..9497d7454d6 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -26,6 +26,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "golang.org/x/net/websocket" @@ -33,6 +34,19 @@ import ( "k8s.io/kubernetes/pkg/util/proxy" ) +type fakeResponder struct { + called bool + err error +} + +func (r *fakeResponder) Error(err error) { + if r.called { + panic("called twice") + } + r.called = true + r.err = err +} + type SimpleBackendHandler struct { requestURL url.URL requestHeader http.Header @@ -110,6 +124,8 @@ func TestServeHTTP(t *testing.T) { responseHeader map[string]string expectedRespHeader map[string]string notExpectedRespHeader []string + upgradeRequired bool + expectError func(err error) bool }{ { name: "root path, simple get", @@ -117,6 +133,15 @@ func TestServeHTTP(t *testing.T) { requestPath: "/", expectedPath: "/", }, + { + name: "no upgrade header sent", + method: "GET", + requestPath: "/", + upgradeRequired: true, + expectError: func(err error) bool { + return err != nil && strings.Contains(err.Error(), "Upgrade request required") + }, + }, { name: "simple path, get", method: "GET", @@ -163,7 +188,7 @@ func TestServeHTTP(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { func() { backendResponse := "
Hello" backendResponseHeader := test.responseHeader @@ -179,10 +204,13 @@ func TestServeHTTP(t *testing.T) { backendServer := httptest.NewServer(backendHandler) defer backendServer.Close() + responder := &fakeResponder{} backendURL, _ := url.Parse(backendServer.URL) backendURL.Path = test.requestPath proxyHandler := &UpgradeAwareProxyHandler{ - Location: backendURL, + Location: backendURL, + Responder: responder, + UpgradeRequired: test.upgradeRequired, } proxyServer := httptest.NewServer(proxyHandler) defer proxyServer.Close() @@ -214,6 +242,17 @@ func TestServeHTTP(t *testing.T) { t.Errorf("Error from proxy request: %v", err) } + if test.expectError != nil { + if !responder.called { + t.Errorf("%d: responder was not invoked", i) + return + } + if !test.expectError(responder.err) { + t.Errorf("%d: unexpected error: %v", i, responder.err) + } + return + } + // Validate backend request // Method if backendHandler.requestMethod != test.method { @@ -253,9 +292,8 @@ func TestServeHTTP(t *testing.T) { } // Error - err = proxyHandler.RequestError() - if err != nil { - t.Errorf("Unexpected proxy handler error: %v", err) + if responder.called { + t.Errorf("Unexpected proxy handler error: %v", responder.err) } }() } diff --git a/pkg/registry/pod/etcd/etcd.go b/pkg/registry/pod/etcd/etcd.go index df2496746e5..f54ad192891 100644 --- a/pkg/registry/pod/etcd/etcd.go +++ b/pkg/registry/pod/etcd/etcd.go @@ -283,7 +283,7 @@ func (r *ProxyREST) NewConnectOptions() (runtime.Object, bool, string) { } // Connect returns a handler for the pod proxy -func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object) (rest.ConnectHandler, error) { +func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object, responder rest.Responder) (http.Handler, error) { proxyOpts, ok := opts.(*api.PodProxyOptions) if !ok { return nil, fmt.Errorf("Invalid options object: %#v", opts) @@ -294,7 +294,7 @@ func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object) (re } location.Path = path.Join(location.Path, proxyOpts.Path) // Return a proxy handler that uses the desired transport, wrapped with additional proxy handling (to get URL rewriting, X-Forwarded-* headers, etc) - return newThrottledUpgradeAwareProxyHandler(location, transport, true, false), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, true, false, responder), nil } // Support both GET and POST methods. We must support GET for browsers that want to use WebSockets. @@ -316,7 +316,7 @@ func (r *AttachREST) New() runtime.Object { } // Connect returns a handler for the pod exec proxy -func (r *AttachREST) Connect(ctx api.Context, name string, opts runtime.Object) (rest.ConnectHandler, error) { +func (r *AttachREST) Connect(ctx api.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) { attachOpts, ok := opts.(*api.PodAttachOptions) if !ok { return nil, fmt.Errorf("Invalid options object: %#v", opts) @@ -325,7 +325,7 @@ func (r *AttachREST) Connect(ctx api.Context, name string, opts runtime.Object) if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -354,16 +354,16 @@ func (r *ExecREST) New() runtime.Object { } // Connect returns a handler for the pod exec proxy -func (r *ExecREST) Connect(ctx api.Context, name string, opts runtime.Object) (rest.ConnectHandler, error) { +func (r *ExecREST) Connect(ctx api.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) { execOpts, ok := opts.(*api.PodExecOptions) if !ok { - return nil, fmt.Errorf("Invalid options object: %#v", opts) + return nil, fmt.Errorf("invalid options object: %#v", opts) } location, transport, err := pod.ExecLocation(r.store, r.kubeletConn, ctx, name, execOpts) if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -402,16 +402,16 @@ func (r *PortForwardREST) ConnectMethods() []string { } // Connect returns a handler for the pod portforward proxy -func (r *PortForwardREST) Connect(ctx api.Context, name string, opts runtime.Object) (rest.ConnectHandler, error) { +func (r *PortForwardREST) Connect(ctx api.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) { location, transport, err := pod.PortForwardLocation(r.store, r.kubeletConn, ctx, name) if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil } -func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool) *genericrest.UpgradeAwareProxyHandler { - handler := genericrest.NewUpgradeAwareProxyHandler(location, transport, wrapTransport, upgradeRequired) +func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) *genericrest.UpgradeAwareProxyHandler { + handler := genericrest.NewUpgradeAwareProxyHandler(location, transport, wrapTransport, upgradeRequired, responder) handler.MaxBytesPerSec = capabilities.Get().PerConnectionBandwidthLimitBytesPerSec return handler }