diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/goaway_test.go b/staging/src/k8s.io/apiserver/pkg/server/filters/goaway_test.go index 330093e2aae..3b4aba993be 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/goaway_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/goaway_test.go @@ -17,12 +17,18 @@ limitations under the License. package filters import ( + "bytes" + "context" "crypto/tls" + "fmt" "io" + "io/ioutil" "math/rand" "net" "net/http" "net/http/httptest" + "reflect" + "sync" "testing" "time" @@ -85,25 +91,34 @@ func TestProbabilisticGoawayDecider(t *testing.T) { } } -// TestClientReceivedGOAWAY tests the in-flight watch requests will not be affected and new requests use a -// connection after client received GOAWAY, and server response watch request with GOAWAY will not break client -// watching body read. -func TestClientReceivedGOAWAY(t *testing.T) { - const ( - urlNormal = "/normal" - urlWatch = "/watch" - urlGoaway = "/goaway" - urlWatchWithGoaway = "/watch-with-goaway" - ) +const ( + urlGet = "/get" + urlPost = "/post" + urlWatch = "/watch" + urlGetWithGoaway = "/get-with-goaway" + urlPostWithGoaway = "/post-with-goaway" + urlWatchWithGoaway = "/watch-with-goaway" +) - const ( - // indicate the bytes watch request will be sent - // used to check if watch request was broke by GOAWAY - watchExpectSendBytes = 5 - ) +var ( + // responseBody is the response body which test GOAWAY server sent for each request, + // for watch request, test GOAWAY server push 1 byte in every second. + responseBody = []byte("hello") + // responseBodySize is the size of response body which test GOAWAY server sent for watch request, + // used to check if watch request was broken by GOAWAY frame. + responseBodySize = len(responseBody) + + // requestPostBody is the request body which client must send to test GOAWAY server for POST method, + // otherwise, test GOAWAY server will respond 400 HTTP status code. + requestPostBody = responseBody +) + +// newTestGOAWAYServer return a test GOAWAY server instance. +func newTestGOAWAYServer() (*httptest.Server, error) { watchHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { timer := time.NewTicker(time.Second) + defer timer.Stop() w.Header().Set("Transfer-Encoding", "chunked") w.WriteHeader(200) @@ -114,30 +129,44 @@ func TestClientReceivedGOAWAY(t *testing.T) { count := 0 for { <-timer.C - n, err := w.Write([]byte("w")) + n, err := w.Write(responseBody[count : count+1]) if err != nil { return } flusher.Flush() count += n - if count == watchExpectSendBytes { + if count == len(responseBody) { return } } }) + getHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(responseBody) + return + }) + postHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !reflect.DeepEqual(requestPostBody, reqBody) { + http.Error(w, fmt.Sprintf("expect request body: %s, got: %s", requestPostBody, reqBody), http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(responseBody) + return + }) mux := http.NewServeMux() - mux.Handle(urlNormal, WithProbabilisticGoaway(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("hello")) - return - }), 0)) + mux.Handle(urlGet, WithProbabilisticGoaway(getHandler, 0)) + mux.Handle(urlPost, WithProbabilisticGoaway(postHandler, 0)) mux.Handle(urlWatch, WithProbabilisticGoaway(watchHandler, 0)) - mux.Handle(urlGoaway, WithProbabilisticGoaway(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("hello")) - return - }), 1)) + mux.Handle(urlGetWithGoaway, WithProbabilisticGoaway(getHandler, 1)) + mux.Handle(urlPostWithGoaway, WithProbabilisticGoaway(postHandler, 1)) mux.Handle(urlWatchWithGoaway, WithProbabilisticGoaway(watchHandler, 1)) s := httptest.NewUnstartedServer(mux) @@ -145,18 +174,139 @@ func TestClientReceivedGOAWAY(t *testing.T) { http2Options := &http2.Server{} if err := http2.ConfigureServer(s.Config, http2Options); err != nil { - t.Fatalf("failed to configure test server to be HTTP2 server, err: %v", err) + return nil, fmt.Errorf("failed to configure test server to be HTTP2 server, err: %v", err) } s.TLS = s.Config.TLSConfig - s.StartTLS() - defer s.Close() + return s, nil +} + +// watchResponse wraps watch response with data which server send and an error may occur. +type watchResponse struct { + // body is the response data which test GOAWAY server sent to client + body []byte + // err will be set to be a non-nil value if watch request is not end with EOF nor http2.GoAwayError + err error +} + +// newGOAWAYClient return a configured http client which used to request test GOAWAY server, a dial may specified +// to encounter the TCP connection. +func newGOAWAYClient(dial func(network, addr string, cfg *tls.Config) (net.Conn, error)) (*http.Client, error) { tlsConfig := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{http2.NextProtoTLS}, } + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: tlsConfig, + MaxIdleConnsPerHost: 25, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if dial != nil { + return dial(network, addr, tlsConfig) + } + + return tls.Dial(network, addr, tlsConfig) + }, + } + if err := http2.ConfigureTransport(tr); err != nil { + return nil, err + } + + client := &http.Client{ + Transport: tr, + } + + return client, nil +} + +// requestGOAWAYServer request test GOAWAY server using specified method and data according to the given url. +// A non-nil channel will be returned if the request is watch, and a watchResponse can be got from the channel when watch done. +func requestGOAWAYServer(client *http.Client, serverBaseURL, url string) (<-chan watchResponse, error) { + method := http.MethodGet + var reqBody io.Reader + + if url == urlPost || url == urlPostWithGoaway { + method = http.MethodPost + reqBody = bytes.NewReader(requestPostBody) + } + + req, err := http.NewRequest(method, serverBaseURL+url, reqBody) + if err != nil { + return nil, fmt.Errorf("unexpect new request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed request test server, err: %v", err) + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body and status code is %d, error: %v", resp.StatusCode, err) + } + + return nil, fmt.Errorf("expect response status code: %d, but got: %d. response body: %s", http.StatusOK, resp.StatusCode, body) + } + + // encounter watch bytes received, does not expect to be broken + if url == urlWatch || url == urlWatchWithGoaway { + ch := make(chan watchResponse) + go func() { + defer resp.Body.Close() + + body := make([]byte, 0) + buffer := make([]byte, 1) + for { + n, err := resp.Body.Read(buffer) + if err != nil { + // urlWatch will receive io.EOF, + // urlWatchWithGoaway will receive http2.GoAwayError + if err == io.EOF { + err = nil + } else if _, ok := err.(http2.GoAwayError); ok { + err = nil + } + + ch <- watchResponse{ + body: body, + err: err, + } + return + } + body = append(body, buffer[0:n]...) + } + }() + return ch, nil + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body, error: %v", err) + } + + if !reflect.DeepEqual(responseBody, body) { + return nil, fmt.Errorf("expect response body: %s, got: %s", string(responseBody), string(body)) + } + + return nil, nil +} + +// TestClientReceivedGOAWAY tests the in-flight watch requests will not be affected and new requests use a new +// connection after client received GOAWAY. +func TestClientReceivedGOAWAY(t *testing.T) { + s, err := newTestGOAWAYServer() + if err != nil { + t.Fatalf("failed to set-up test GOAWAY http server, err: %v", err) + } + + s.StartTLS() + defer s.Close() + cases := []struct { name string reqs []string @@ -165,27 +315,27 @@ func TestClientReceivedGOAWAY(t *testing.T) { }{ { name: "all normal requests use only one connection", - reqs: []string{urlNormal, urlNormal, urlNormal}, + reqs: []string{urlGet, urlPost, urlGet}, expectConnections: 1, }, { name: "got GOAWAY after set-up watch", - reqs: []string{urlNormal, urlWatch, urlGoaway, urlNormal, urlNormal}, + reqs: []string{urlPost, urlWatch, urlGetWithGoaway, urlGet, urlPost}, expectConnections: 2, }, { name: "got GOAWAY after set-up watch, and set-up a new watch", - reqs: []string{urlNormal, urlWatch, urlGoaway, urlWatch, urlNormal, urlNormal}, + reqs: []string{urlGet, urlWatch, urlGetWithGoaway, urlWatch, urlGet, urlPost}, expectConnections: 2, }, { name: "got 2 GOAWAY after set-up watch", - reqs: []string{urlNormal, urlWatch, urlGoaway, urlGoaway, urlNormal, urlNormal}, + reqs: []string{urlPost, urlWatch, urlGetWithGoaway, urlGetWithGoaway, urlGet, urlPost}, expectConnections: 3, }, { name: "combine with watch-with-goaway", - reqs: []string{urlNormal, urlWatchWithGoaway, urlNormal, urlWatch, urlGoaway, urlNormal, urlNormal}, + reqs: []string{urlGet, urlWatchWithGoaway, urlGet, urlWatch, urlGetWithGoaway, urlGet, urlPost}, expectConnections: 3, }, } @@ -195,55 +345,26 @@ func TestClientReceivedGOAWAY(t *testing.T) { // localAddr indicates how many TCP connection set up localAddr := make([]string, 0) - // init HTTP2 client - client := http.Client{ - Transport: &http2.Transport{ - TLSClientConfig: tlsConfig, - DialTLS: func(network, addr string, cfg *tls.Config) (conn net.Conn, err error) { - conn, err = tls.Dial(network, addr, cfg) - if err != nil { - t.Fatalf("unexpect connection err: %v", err) - } - localAddr = append(localAddr, conn.LocalAddr().String()) - return - }, - }, + client, err := newGOAWAYClient(func(network, addr string, cfg *tls.Config) (conn net.Conn, err error) { + conn, err = tls.Dial(network, addr, cfg) + if err != nil { + t.Fatalf("unexpect connection err: %v", err) + } + localAddr = append(localAddr, conn.LocalAddr().String()) + return + }) + if err != nil { + t.Fatalf("failed to set-up client, err: %v", err) } - watchChs := make([]chan int, 0) + watchChs := make([]<-chan watchResponse, 0) for _, url := range tc.reqs { - req, err := http.NewRequest(http.MethodGet, s.URL+url, nil) + w, err := requestGOAWAYServer(client, s.URL, url) if err != nil { - t.Fatalf("unexpect new request error: %v", err) + t.Fatalf("failed to request server, err: %v", err) } - resp, err := client.Do(req) - if err != nil { - t.Fatalf("failed request test server, err: %v", err) - } - - // encounter watch bytes received, does not expect to be broken - if url == urlWatch || url == urlWatchWithGoaway { - ch := make(chan int) - watchChs = append(watchChs, ch) - go func() { - count := 0 - for { - buffer := make([]byte, 1) - n, err := resp.Body.Read(buffer) - if err != nil { - // urlWatch will receive io.EOF, - // urlWatchWithGoaway will receive http2.GoAwayError - if err != io.EOF { - if _, ok := err.(http2.GoAwayError); !ok { - t.Errorf("watch received not EOF err: %v", err) - } - } - ch <- count - return - } - count += n - } - }() + if w != nil { + watchChs = append(watchChs, w) } } @@ -252,13 +373,17 @@ func TestClientReceivedGOAWAY(t *testing.T) { t.Fatalf("expect TCP connection: %d, actual: %d", tc.expectConnections, len(localAddr)) } - // check if watch request is broken by GOAWAY response + // check if watch request is broken by GOAWAY frame watchTimeout := time.NewTimer(time.Second * 10) + defer watchTimeout.Stop() for _, watchCh := range watchChs { select { - case n := <-watchCh: - if n != watchExpectSendBytes { - t.Fatalf("in-flight watch was broken by GOAWAY response, expect go bytes: %d, actual got: %d", watchExpectSendBytes, n) + case watchResp := <-watchCh: + if watchResp.err != nil { + t.Fatalf("watch response got an unexepct error: %v", watchResp.err) + } + if !reflect.DeepEqual(responseBody, watchResp.body) { + t.Fatalf("in-flight watch was broken by GOAWAY frame, expect response body: %s, got: %s", responseBody, watchResp.body) } case <-watchTimeout.C: t.Error("watch receive timeout") @@ -268,7 +393,8 @@ func TestClientReceivedGOAWAY(t *testing.T) { } } -func TestHTTP1Requests(t *testing.T) { +// TestGOAWAYHTTP1Requests tests GOAWAY filter will not affect HTTP1.1 requests. +func TestGOAWAYHTTP1Requests(t *testing.T) { s := httptest.NewUnstartedServer(WithProbabilisticGoaway(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("hello")) @@ -305,3 +431,94 @@ func TestHTTP1Requests(t *testing.T) { t.Errorf("expect response HTTP header Connection to be empty, but got: %s", v) } } + +// TestGOAWAYConcurrency tests GOAWAY frame will not affect concurrency requests in a single http client instance. +// Known issues in history: https://github.com/kubernetes/kubernetes/issues/91131. +func TestGOAWAYConcurrency(t *testing.T) { + s, err := newTestGOAWAYServer() + if err != nil { + t.Fatalf("failed to set-up test GOAWAY http server, err: %v", err) + } + + s.StartTLS() + defer s.Close() + + client, err := newGOAWAYClient(nil) + if err != nil { + t.Fatalf("failed to set-up client, err: %v", err) + } + + const ( + requestCount = 300 + workers = 10 + ) + + expectWatchers := 0 + + urlsForTest := []string{urlGet, urlPost, urlWatch, urlGetWithGoaway, urlPostWithGoaway, urlWatchWithGoaway} + urls := make(chan string, requestCount) + for i := 0; i < requestCount; i++ { + index := rand.Intn(len(urlsForTest)) + url := urlsForTest[index] + + if url == urlWatch || url == urlWatchWithGoaway { + expectWatchers++ + } + + urls <- url + } + close(urls) + + wg := &sync.WaitGroup{} + wg.Add(workers) + + watchers := make(chan (<-chan watchResponse), expectWatchers) + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + + for { + url, ok := <-urls + if !ok { + return + } + + w, err := requestGOAWAYServer(client, s.URL, url) + if err != nil { + t.Errorf("failed to request %q, err: %v", url, err) + } + + if w != nil { + watchers <- w + } + } + }() + } + + wg.Wait() + + // check if watch request is broken by GOAWAY frame + watchTimeout := time.NewTimer(time.Second * 10) + defer watchTimeout.Stop() + for i := 0; i < expectWatchers; i++ { + var watcher <-chan watchResponse + + select { + case watcher = <-watchers: + default: + t.Fatalf("expect watcher count: %d, but got: %d", expectWatchers, i) + } + + select { + case watchResp := <-watcher: + if watchResp.err != nil { + t.Fatalf("watch response got an unexepct error: %v", watchResp.err) + } + if !reflect.DeepEqual(responseBody, watchResp.body) { + t.Fatalf("in-flight watch was broken by GOAWAY frame, expect response body: %s, got: %s", responseBody, watchResp.body) + } + case <-watchTimeout.C: + t.Error("watch receive timeout") + } + } +}