From 5ba69b1c5f828577ce6c0e683ea7576201ff2669 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Fri, 7 Feb 2020 14:20:45 -0700 Subject: [PATCH] Fix acme listener --- redirect.go | 4 +++- server/server.go | 37 +++++++++++++++++++------------------ 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/redirect.go b/redirect.go index 3987172..76ff144 100644 --- a/redirect.go +++ b/redirect.go @@ -11,8 +11,10 @@ import ( func HTTPRedirect(next http.Handler) http.Handler { return http.HandlerFunc( func(rw http.ResponseWriter, r *http.Request) { - if r.Header.Get("x-Forwarded-Proto") == "https" || + if r.TLS != nil || + r.Header.Get("x-Forwarded-Proto") == "https" || r.Header.Get("x-Forwarded-Proto") == "wss" || + strings.HasPrefix(r.URL.Path, "/.well-known/") || strings.HasPrefix(r.URL.Path, "/ping") || strings.HasPrefix(r.URL.Path, "/health") { next.ServeHTTP(rw, r) diff --git a/server/server.go b/server/server.go index 57c2780..66521dd 100644 --- a/server/server.go +++ b/server/server.go @@ -32,15 +32,11 @@ type ListenOpts struct { CertBackup string AcmeDomains []string BindHost string + NoRedirect bool TLSListenerConfig dynamiclistener.Config } func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.Handler, opts *ListenOpts) error { - var ( - // https listener will change this if http is enabled - targetHandler = handler - ) - if opts == nil { opts = &ListenOpts{} } @@ -58,23 +54,23 @@ func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.H return err } - dynListener, dynHandler, err := getTLSListener(ctx, tlsTCPListener, *opts) + tlsTCPListener, handler, err = getTLSListener(ctx, tlsTCPListener, handler, *opts) if err != nil { return err } - if dynHandler != nil { - targetHandler = wrapHandler(dynHandler, handler) + if !opts.NoRedirect { + handler = dynamiclistener.HTTPRedirect(handler) } + tlsServer := http.Server{ - Handler: targetHandler, + Handler: handler, ErrorLog: errorLog, } - targetHandler = dynamiclistener.HTTPRedirect(targetHandler) go func() { logrus.Infof("Listening on %s:%d", opts.BindHost, httpsPort) - err := tlsServer.Serve(dynListener) + err := tlsServer.Serve(tlsTCPListener) if err != http.ErrServerClosed && err != nil { logrus.Fatalf("https server failed: %v", err) } @@ -88,7 +84,7 @@ func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.H if httpPort > 0 { httpServer := http.Server{ Addr: fmt.Sprintf("%s:%d", opts.BindHost, httpPort), - Handler: targetHandler, + Handler: handler, ErrorLog: errorLog, } go func() { @@ -107,17 +103,17 @@ func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.H return nil } -func getTLSListener(ctx context.Context, tcp net.Listener, opts ListenOpts) (net.Listener, http.Handler, error) { +func getTLSListener(ctx context.Context, tcp net.Listener, handler http.Handler, opts ListenOpts) (net.Listener, http.Handler, error) { if len(opts.TLSListenerConfig.TLSConfig.NextProtos) == 0 { opts.TLSListenerConfig.TLSConfig.NextProtos = []string{"h2", "http/1.1"} } if len(opts.TLSListenerConfig.TLSConfig.Certificates) > 0 { - return tls.NewListener(tcp, opts.TLSListenerConfig.TLSConfig), nil, nil + return tls.NewListener(tcp, opts.TLSListenerConfig.TLSConfig), handler, nil } if len(opts.AcmeDomains) > 0 { - return acmeListener(tcp, opts), nil, nil + return acmeListener(tcp, handler, opts) } storage := opts.Storage @@ -130,7 +126,12 @@ func getTLSListener(ctx context.Context, tcp net.Listener, opts ListenOpts) (net return nil, nil, err } - return dynamiclistener.NewListener(tcp, storage, caCert, caKey, opts.TLSListenerConfig) + listener, dynHandler, err := dynamiclistener.NewListener(tcp, storage, caCert, caKey, opts.TLSListenerConfig) + if err != nil { + return nil, nil, err + } + + return listener, wrapHandler(dynHandler, handler), nil } func getCA(opts ListenOpts) (*x509.Certificate, crypto.Signer, error) { @@ -187,7 +188,7 @@ func wrapHandler(handler http.Handler, next http.Handler) http.Handler { }) } -func acmeListener(tcp net.Listener, opts ListenOpts) net.Listener { +func acmeListener(tcp net.Listener, handler http.Handler, opts ListenOpts) (net.Listener, http.Handler, error) { hosts := map[string]bool{} for _, domain := range opts.AcmeDomains { hosts[domain] = true @@ -215,5 +216,5 @@ func acmeListener(tcp net.Listener, opts ListenOpts) net.Listener { return manager.GetCertificate(hello) } - return tls.NewListener(tcp, opts.TLSListenerConfig.TLSConfig) + return tls.NewListener(tcp, opts.TLSListenerConfig.TLSConfig), manager.HTTPHandler(handler), nil }