diff --git a/pkg/kubelet/server_test.go b/pkg/kubelet/server_test.go index 241b80c950c..d1ba9c6feb4 100644 --- a/pkg/kubelet/server_test.go +++ b/pkg/kubelet/server_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -639,7 +640,7 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1" - upgradeRoundTripper := spdy.NewRoundTripper(nil) + upgradeRoundTripper := spdy.NewSpdyRoundTripper(nil) c := &http.Client{Transport: upgradeRoundTripper} resp, err := c.Get(url) @@ -648,6 +649,10 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { } defer resp.Body.Close() + upgradeRoundTripper.Dialer = &net.Dialer{ + Deadline: time.Now().Add(60 * time.Second), + Timeout: 60 * time.Second, + } conn, err := upgradeRoundTripper.NewConnection(resp) if err != nil { t.Fatalf("Unexpected error creating streaming connection: %s", err) diff --git a/pkg/util/httpstream/spdy/roundtripper.go b/pkg/util/httpstream/spdy/roundtripper.go index 3d96570f2f6..c2ee3ad3d87 100644 --- a/pkg/util/httpstream/spdy/roundtripper.go +++ b/pkg/util/httpstream/spdy/roundtripper.go @@ -45,11 +45,18 @@ type SpdyRoundTripper struct { */ // conn is the underlying network connection to the remote server. conn net.Conn + + // Dialer is the dialer used to connect. Used if non-nil. + Dialer *net.Dialer } // NewSpdyRoundTripper creates a new SpdyRoundTripper that will use // the specified tlsConfig. func NewRoundTripper(tlsConfig *tls.Config) httpstream.UpgradeRoundTripper { + return NewSpdyRoundTripper(tlsConfig) +} + +func NewSpdyRoundTripper(tlsConfig *tls.Config) *SpdyRoundTripper { return &SpdyRoundTripper{tlsConfig: tlsConfig} } @@ -58,11 +65,21 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) { dialAddr := netutil.CanonicalAddr(req.URL) if req.URL.Scheme == "http" { - return net.Dial("tcp", dialAddr) + if s.Dialer == nil { + return net.Dial("tcp", dialAddr) + } else { + return s.Dialer.Dial("tcp", dialAddr) + } } // TODO validate the TLSClientConfig is set up? - conn, err := tls.Dial("tcp", dialAddr, s.tlsConfig) + var conn *tls.Conn + var err error + if s.Dialer == nil { + conn, err = tls.Dial("tcp", dialAddr, s.tlsConfig) + } else { + conn, err = tls.DialWithDialer(s.Dialer, "tcp", dialAddr, s.tlsConfig) + } if err != nil { return nil, err }