diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index e82b74f6497..2d22da370c8 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -207,6 +207,8 @@ func TestGetListener(t *testing.T) { // kubelet.PortForwarder. type fakePortForwarder struct { lock sync.Mutex + // stores data expected from the stream per port + expected map[uint16]string // stores data received from the stream per port received map[uint16]string // data to be sent to the stream per port @@ -218,33 +220,23 @@ var _ kubelet.PortForwarder = &fakePortForwarder{} func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { defer stream.Close() - var wg sync.WaitGroup + // read from the client + received := make([]byte, len(pf.expected[port])) + n, err := stream.Read(received) + if err != nil { + return fmt.Errorf("error reading from client for port %d: %v", port, err) + } + if n != len(pf.expected[port]) { + return fmt.Errorf("unexpected length read from client for port %d: got %d, expected %d. data=%q", port, n, len(pf.expected[port]), string(received)) + } - // client -> server - wg.Add(1) - go func() { - defer wg.Done() + // store the received content + pf.lock.Lock() + pf.received[port] = string(received) + pf.lock.Unlock() - // copy from stream into a buffer - received := new(bytes.Buffer) - io.Copy(received, stream) - - // store the received content - pf.lock.Lock() - pf.received[port] = received.String() - pf.lock.Unlock() - }() - - // server -> client - wg.Add(1) - go func() { - defer wg.Done() - - // send the hardcoded data to the stream - io.Copy(stream, strings.NewReader(pf.send[port])) - }() - - wg.Wait() + // send the hardcoded data to the client + io.Copy(stream, strings.NewReader(pf.send[port])) return nil } @@ -254,6 +246,7 @@ func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16 func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { pf := &fakePortForwarder{ + expected: expectedFromClient, received: make(map[uint16]string), send: serverSends, }