Cleanup: defer to close server in tests (#110367)

* Cleanup: defer to close server in tests

Signed-off-by: kerthcet <kerthcet@gmail.com>

* address comments

Signed-off-by: kerthcet <kerthcet@gmail.com>

* address comments

Signed-off-by: kerthcet <kerthcet@gmail.com>
This commit is contained in:
Kante Yin 2022-06-21 23:00:38 +08:00 committed by GitHub
parent 375fd32b9f
commit e844c12a61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 153 additions and 158 deletions

View File

@ -129,81 +129,80 @@ func TestForwardPorts(t *testing.T) {
}
for testName, test := range tests {
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
t.Run(testName, func(t *testing.T) {
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
defer server.Close()
transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
if err != nil {
t.Fatal(err)
}
url, _ := url.Parse(server.URL)
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
stopChan := make(chan struct{}, 1)
readyChan := make(chan struct{})
pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
if err != nil {
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
}
doneChan := make(chan error)
go func() {
doneChan <- pf.ForwardPorts()
}()
<-pf.Ready
forwardedPorts, err := pf.GetPorts()
if err != nil {
t.Fatal(err)
}
remoteToLocalMap := map[int32]int32{}
for _, forwardedPort := range forwardedPorts {
remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local)
}
for port, data := range test.clientSends {
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port]))
transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
if err != nil {
t.Errorf("%s: error dialing %d: %s", testName, port, err)
server.Close()
continue
t.Fatal(err)
}
defer clientConn.Close()
url, _ := url.Parse(server.URL)
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
t.Errorf("%s: Error sending data '%s': %s", testName, data, err)
server.Close()
continue
}
if n == 0 {
t.Errorf("%s: unexpected write of 0 bytes", testName)
server.Close()
continue
}
b := make([]byte, 4)
_, err = clientConn.Read(b)
if err != nil && err != io.EOF {
t.Errorf("%s: Error reading data: %s", testName, err)
server.Close()
continue
}
if !bytes.Equal([]byte(test.serverSends[port]), b) {
t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
server.Close()
continue
}
}
// tell r.ForwardPorts to stop
close(stopChan)
stopChan := make(chan struct{}, 1)
readyChan := make(chan struct{})
// wait for r.ForwardPorts to actually return
err = <-doneChan
if err != nil {
t.Errorf("%s: unexpected error: %s", testName, err)
}
server.Close()
pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
if err != nil {
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
}
doneChan := make(chan error)
go func() {
doneChan <- pf.ForwardPorts()
}()
<-pf.Ready
forwardedPorts, err := pf.GetPorts()
if err != nil {
t.Fatal(err)
}
remoteToLocalMap := map[int32]int32{}
for _, forwardedPort := range forwardedPorts {
remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local)
}
clientSend := func(port int32, data string) error {
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port]))
if err != nil {
return fmt.Errorf("%s: error dialing %d: %s", testName, port, err)
}
defer clientConn.Close()
n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
return fmt.Errorf("%s: Error sending data '%s': %s", testName, data, err)
}
if n == 0 {
return fmt.Errorf("%s: unexpected write of 0 bytes", testName)
}
b := make([]byte, 4)
_, err = clientConn.Read(b)
if err != nil && err != io.EOF {
return fmt.Errorf("%s: Error reading data: %s", testName, err)
}
if !bytes.Equal([]byte(test.serverSends[port]), b) {
return fmt.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
}
return nil
}
for port, data := range test.clientSends {
if err := clientSend(port, data); err != nil {
t.Error(err)
}
}
// tell r.ForwardPorts to stop
close(stopChan)
// wait for r.ForwardPorts to actually return
err = <-doneChan
if err != nil {
t.Errorf("%s: unexpected error: %s", testName, err)
}
})
}
}

View File

@ -195,108 +195,104 @@ func TestStream(t *testing.T) {
} else {
name = testCase.TestName + " (attach)"
}
var (
streamIn io.Reader
streamOut, streamErr io.Writer
)
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}
requestReceived := make(chan struct{})
server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
t.Run(name, func(t *testing.T) {
var (
streamIn io.Reader
streamOut, streamErr io.Writer
)
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}
url, _ := url.ParseRequestURI(server.URL)
config := restclient.ClientContentConfig{
GroupVersion: schema.GroupVersion{Group: "x"},
Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
}
c, err := restclient.NewRESTClient(url, "", config, nil, nil)
if err != nil {
t.Fatalf("failed to create a client: %v", err)
}
req := c.Post().Resource("testing")
requestReceived := make(chan struct{})
server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
defer server.Close()
if exec {
req.Param("command", "ls")
req.Param("command", "/")
}
url, _ := url.ParseRequestURI(server.URL)
config := restclient.ClientContentConfig{
GroupVersion: schema.GroupVersion{Group: "x"},
Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
}
c, err := restclient.NewRESTClient(url, "", config, nil, nil)
if err != nil {
t.Fatalf("failed to create a client: %v", err)
}
req := c.Post().Resource("testing")
if len(testCase.Stdin) > 0 {
req.Param(api.ExecStdinParam, "1")
streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
}
if exec {
req.Param("command", "ls")
req.Param("command", "/")
}
if len(testCase.Stdout) > 0 {
req.Param(api.ExecStdoutParam, "1")
streamOut = localOut
}
if len(testCase.Stdin) > 0 {
req.Param(api.ExecStdinParam, "1")
streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
}
if testCase.Tty {
req.Param(api.ExecTTYParam, "1")
} else if len(testCase.Stderr) > 0 {
req.Param(api.ExecStderrParam, "1")
streamErr = localErr
}
if len(testCase.Stdout) > 0 {
req.Param(api.ExecStdoutParam, "1")
streamOut = localOut
}
conf := &restclient.Config{
Host: server.URL,
}
transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
if err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
if err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
err = e.Stream(remoteclient.StreamOptions{
Stdin: streamIn,
Stdout: streamOut,
Stderr: streamErr,
Tty: testCase.Tty,
})
hasErr := err != nil
if testCase.Tty {
req.Param(api.ExecTTYParam, "1")
} else if len(testCase.Stderr) > 0 {
req.Param(api.ExecStderrParam, "1")
streamErr = localErr
}
if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%s: expected an error", name)
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
conf := &restclient.Config{
Host: server.URL,
}
transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
if err != nil {
t.Fatalf("%s: unexpected error: %v", name, err)
}
e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
if err != nil {
t.Fatalf("%s: unexpected error: %v", name, err)
}
err = e.Stream(remoteclient.StreamOptions{
Stdin: streamIn,
Stdout: streamOut,
Stderr: streamErr,
Tty: testCase.Tty,
})
hasErr := err != nil
if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%s: expected an error", name)
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
}
}
return
}
if hasErr {
t.Fatalf("%s: unexpected error: %v", name, err)
}
if len(testCase.Stdout) > 0 {
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
t.Fatalf("%s: expected stdout data %q, got %q", name, e, a)
}
}
server.Close()
continue
}
if hasErr {
t.Errorf("%s: unexpected error: %v", name, err)
server.Close()
continue
}
if len(testCase.Stdout) > 0 {
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
t.Errorf("%s: expected stdout data %q, got %q", name, e, a)
if testCase.Stderr != "" {
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
t.Fatalf("%s: expected stderr data %q, got %q", name, e, a)
}
}
}
if testCase.Stderr != "" {
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
t.Errorf("%s: expected stderr data %q, got %q", name, e, a)
select {
case <-requestReceived:
case <-time.After(time.Minute):
t.Errorf("%s: expected fakeServerInstance to receive request", name)
}
}
select {
case <-requestReceived:
case <-time.After(time.Minute):
t.Errorf("%s: expected fakeServerInstance to receive request", name)
}
server.Close()
})
}
}
}