From e844c12a6184433ea510b39004d579f86e7411b2 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Tue, 21 Jun 2022 23:00:38 +0800 Subject: [PATCH] Cleanup: defer to close server in tests (#110367) * Cleanup: defer to close server in tests Signed-off-by: kerthcet * address comments Signed-off-by: kerthcet * address comments Signed-off-by: kerthcet --- pkg/client/tests/portfoward_test.go | 139 ++++++++++---------- pkg/client/tests/remotecommand_test.go | 172 ++++++++++++------------- 2 files changed, 153 insertions(+), 158 deletions(-) diff --git a/pkg/client/tests/portfoward_test.go b/pkg/client/tests/portfoward_test.go index 27afe391ee5..79f84a4d7fc 100644 --- a/pkg/client/tests/portfoward_test.go +++ b/pkg/client/tests/portfoward_test.go @@ -129,81 +129,80 @@ func TestForwardPorts(t *testing.T) { } for testName, test := range tests { - server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends)) + t.Run(testName, func(t *testing.T) { + server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends)) + defer server.Close() - transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) - if err != nil { - t.Fatal(err) - } - url, _ := url.Parse(server.URL) - dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) - - stopChan := make(chan struct{}, 1) - readyChan := make(chan struct{}) - - pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr) - if err != nil { - t.Fatalf("%s: unexpected error calling New: %v", testName, err) - } - - doneChan := make(chan error) - go func() { - doneChan <- pf.ForwardPorts() - }() - <-pf.Ready - - forwardedPorts, err := pf.GetPorts() - if err != nil { - t.Fatal(err) - } - - remoteToLocalMap := map[int32]int32{} - for _, forwardedPort := range forwardedPorts { - remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local) - } - - for port, data := range test.clientSends { - clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port])) + transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) if err != nil { - t.Errorf("%s: error dialing %d: %s", testName, port, err) - server.Close() - continue + t.Fatal(err) } - defer clientConn.Close() + url, _ := url.Parse(server.URL) + dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) - n, err := clientConn.Write([]byte(data)) - if err != nil && err != io.EOF { - t.Errorf("%s: Error sending data '%s': %s", testName, data, err) - server.Close() - continue - } - if n == 0 { - t.Errorf("%s: unexpected write of 0 bytes", testName) - server.Close() - continue - } - b := make([]byte, 4) - _, err = clientConn.Read(b) - if err != nil && err != io.EOF { - t.Errorf("%s: Error reading data: %s", testName, err) - server.Close() - continue - } - if !bytes.Equal([]byte(test.serverSends[port]), b) { - t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b) - server.Close() - continue - } - } - // tell r.ForwardPorts to stop - close(stopChan) + stopChan := make(chan struct{}, 1) + readyChan := make(chan struct{}) - // wait for r.ForwardPorts to actually return - err = <-doneChan - if err != nil { - t.Errorf("%s: unexpected error: %s", testName, err) - } - server.Close() + pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr) + if err != nil { + t.Fatalf("%s: unexpected error calling New: %v", testName, err) + } + + doneChan := make(chan error) + go func() { + doneChan <- pf.ForwardPorts() + }() + <-pf.Ready + + forwardedPorts, err := pf.GetPorts() + if err != nil { + t.Fatal(err) + } + + remoteToLocalMap := map[int32]int32{} + for _, forwardedPort := range forwardedPorts { + remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local) + } + + clientSend := func(port int32, data string) error { + clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port])) + if err != nil { + return fmt.Errorf("%s: error dialing %d: %s", testName, port, err) + + } + defer clientConn.Close() + + n, err := clientConn.Write([]byte(data)) + if err != nil && err != io.EOF { + return fmt.Errorf("%s: Error sending data '%s': %s", testName, data, err) + } + if n == 0 { + return fmt.Errorf("%s: unexpected write of 0 bytes", testName) + } + b := make([]byte, 4) + _, err = clientConn.Read(b) + if err != nil && err != io.EOF { + return fmt.Errorf("%s: Error reading data: %s", testName, err) + } + if !bytes.Equal([]byte(test.serverSends[port]), b) { + return fmt.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b) + } + return nil + } + for port, data := range test.clientSends { + if err := clientSend(port, data); err != nil { + t.Error(err) + } + } + // tell r.ForwardPorts to stop + close(stopChan) + + // wait for r.ForwardPorts to actually return + err = <-doneChan + if err != nil { + t.Errorf("%s: unexpected error: %s", testName, err) + } + }) } } diff --git a/pkg/client/tests/remotecommand_test.go b/pkg/client/tests/remotecommand_test.go index 7c6a1124029..910cfce38f8 100644 --- a/pkg/client/tests/remotecommand_test.go +++ b/pkg/client/tests/remotecommand_test.go @@ -195,108 +195,104 @@ func TestStream(t *testing.T) { } else { name = testCase.TestName + " (attach)" } - var ( - streamIn io.Reader - streamOut, streamErr io.Writer - ) - localOut := &bytes.Buffer{} - localErr := &bytes.Buffer{} - requestReceived := make(chan struct{}) - server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols)) + t.Run(name, func(t *testing.T) { + var ( + streamIn io.Reader + streamOut, streamErr io.Writer + ) + localOut := &bytes.Buffer{} + localErr := &bytes.Buffer{} - url, _ := url.ParseRequestURI(server.URL) - config := restclient.ClientContentConfig{ - GroupVersion: schema.GroupVersion{Group: "x"}, - Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}), - } - c, err := restclient.NewRESTClient(url, "", config, nil, nil) - if err != nil { - t.Fatalf("failed to create a client: %v", err) - } - req := c.Post().Resource("testing") + requestReceived := make(chan struct{}) + server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols)) + defer server.Close() - if exec { - req.Param("command", "ls") - req.Param("command", "/") - } + url, _ := url.ParseRequestURI(server.URL) + config := restclient.ClientContentConfig{ + GroupVersion: schema.GroupVersion{Group: "x"}, + Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}), + } + c, err := restclient.NewRESTClient(url, "", config, nil, nil) + if err != nil { + t.Fatalf("failed to create a client: %v", err) + } + req := c.Post().Resource("testing") - if len(testCase.Stdin) > 0 { - req.Param(api.ExecStdinParam, "1") - streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)) - } + if exec { + req.Param("command", "ls") + req.Param("command", "/") + } - if len(testCase.Stdout) > 0 { - req.Param(api.ExecStdoutParam, "1") - streamOut = localOut - } + if len(testCase.Stdin) > 0 { + req.Param(api.ExecStdinParam, "1") + streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)) + } - if testCase.Tty { - req.Param(api.ExecTTYParam, "1") - } else if len(testCase.Stderr) > 0 { - req.Param(api.ExecStderrParam, "1") - streamErr = localErr - } + if len(testCase.Stdout) > 0 { + req.Param(api.ExecStdoutParam, "1") + streamOut = localOut + } - conf := &restclient.Config{ - Host: server.URL, - } - transport, upgradeTransport, err := spdy.RoundTripperFor(conf) - if err != nil { - t.Errorf("%s: unexpected error: %v", name, err) - continue - } - e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...) - if err != nil { - t.Errorf("%s: unexpected error: %v", name, err) - continue - } - err = e.Stream(remoteclient.StreamOptions{ - Stdin: streamIn, - Stdout: streamOut, - Stderr: streamErr, - Tty: testCase.Tty, - }) - hasErr := err != nil + if testCase.Tty { + req.Param(api.ExecTTYParam, "1") + } else if len(testCase.Stderr) > 0 { + req.Param(api.ExecStderrParam, "1") + streamErr = localErr + } - if len(testCase.Error) > 0 { - if !hasErr { - t.Errorf("%s: expected an error", name) - } else { - if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) { - t.Errorf("%s: expected error stream read %q, got %q", name, e, a) + conf := &restclient.Config{ + Host: server.URL, + } + transport, upgradeTransport, err := spdy.RoundTripperFor(conf) + if err != nil { + t.Fatalf("%s: unexpected error: %v", name, err) + } + e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...) + if err != nil { + t.Fatalf("%s: unexpected error: %v", name, err) + } + err = e.Stream(remoteclient.StreamOptions{ + Stdin: streamIn, + Stdout: streamOut, + Stderr: streamErr, + Tty: testCase.Tty, + }) + hasErr := err != nil + + if len(testCase.Error) > 0 { + if !hasErr { + t.Errorf("%s: expected an error", name) + } else { + if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) { + t.Errorf("%s: expected error stream read %q, got %q", name, e, a) + } + } + return + } + + if hasErr { + t.Fatalf("%s: unexpected error: %v", name, err) + } + + if len(testCase.Stdout) > 0 { + if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() { + t.Fatalf("%s: expected stdout data %q, got %q", name, e, a) } } - server.Close() - continue - } - - if hasErr { - t.Errorf("%s: unexpected error: %v", name, err) - server.Close() - continue - } - - if len(testCase.Stdout) > 0 { - if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() { - t.Errorf("%s: expected stdout data %q, got %q", name, e, a) + if testCase.Stderr != "" { + if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() { + t.Fatalf("%s: expected stderr data %q, got %q", name, e, a) + } } - } - if testCase.Stderr != "" { - if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() { - t.Errorf("%s: expected stderr data %q, got %q", name, e, a) + select { + case <-requestReceived: + case <-time.After(time.Minute): + t.Errorf("%s: expected fakeServerInstance to receive request", name) } - } - - select { - case <-requestReceived: - case <-time.After(time.Minute): - t.Errorf("%s: expected fakeServerInstance to receive request", name) - } - - server.Close() + }) } } }