diff --git a/server/server.go b/server/server.go index 9ad31b1..787ad05 100644 --- a/server/server.go +++ b/server/server.go @@ -36,20 +36,30 @@ type ListenOpts struct { BindHost string NoRedirect bool TLSListenerConfig dynamiclistener.Config + + // Override legacy behavior where server logs written to the application's logrus object + // were dropped unless logrus was set to debug-level (such as by launching steve with '--debug'). + // Setting this to true results in server logs appearing at an ERROR level. + DisplayServerLogs bool } func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.Handler, opts *ListenOpts) error { + logger := logrus.StandardLogger() + writer := logger.WriterLevel(logrus.DebugLevel) if opts == nil { opts = &ListenOpts{} } + if opts.DisplayServerLogs { + writer = logger.WriterLevel(logrus.ErrorLevel) + } + // Otherwise preserve legacy behaviour of displaying server logs only in debug mode. + + errorLog := log.New(writer, "", log.LstdFlags) if opts.TLSListenerConfig.TLSConfig == nil { opts.TLSListenerConfig.TLSConfig = &tls.Config{} } - logger := logrus.StandardLogger() - errorLog := log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags) - if httpsPort > 0 { tlsTCPListener, err := dynamiclistener.NewTCPListener(opts.BindHost, httpsPort) if err != nil { diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..28a17b1 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,134 @@ +package server + +import ( + "bytes" + "context" + "fmt" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/sirupsen/logrus" + assertPkg "github.com/stretchr/testify/assert" +) + +type alwaysPanicHandler struct { + msg string +} + +func (z *alwaysPanicHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { + panic(z.msg) +} + +// safeWriter is used to allow writing to a buffer-based log in a web server +// and safely read from it in the client (i.e. this test code) +type safeWriter struct { + writer *bytes.Buffer + mutex *sync.Mutex +} + +func newSafeWriter(writer *bytes.Buffer, mutex *sync.Mutex) *safeWriter { + return &safeWriter{writer: writer, mutex: mutex} +} + +func (s *safeWriter) Write(p []byte) (n int, err error) { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.writer.Write(p) +} + +func TestHttpServerLogWithLogrus(t *testing.T) { + assert := assertPkg.New(t) + message := "debug-level writer" + msg := fmt.Sprintf("panicking context: %s", message) + var buf bytes.Buffer + var mutex sync.Mutex + safeWriter := newSafeWriter(&buf, &mutex) + err := doRequest(safeWriter, message, logrus.ErrorLevel) + assert.Nil(err) + + mutex.Lock() + s := buf.String() + assert.Greater(len(s), 0) + assert.Contains(s, msg) + assert.Contains(s, "panic serving 127.0.0.1") + mutex.Unlock() +} + +func TestHttpNoServerLogsWithLogrus(t *testing.T) { + assert := assertPkg.New(t) + + message := "error-level writer" + var buf bytes.Buffer + var mutex sync.Mutex + safeWriter := newSafeWriter(&buf, &mutex) + err := doRequest(safeWriter, message, logrus.DebugLevel) + assert.Nil(err) + + mutex.Lock() + s := buf.String() + if len(s) > 0 { + assert.NotContains(s, message) + } + mutex.Unlock() +} + +func doRequest(safeWriter *safeWriter, message string, logLevel logrus.Level) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + host := "127.0.0.1" + httpPort := 9012 + httpsPort := 0 + msg := fmt.Sprintf("panicking context: %s", message) + handler := alwaysPanicHandler{msg: msg} + listenOpts := &ListenOpts{ + BindHost: host, + DisplayServerLogs: logLevel == logrus.ErrorLevel, + } + + logrus.StandardLogger().SetOutput(safeWriter) + if err := ListenAndServe(ctx, httpsPort, httpPort, &handler, listenOpts); err != nil { + return err + } + addr := fmt.Sprintf("%s:%d", host, httpPort) + return makeTheHttpRequest(addr) +} + +func makeTheHttpRequest(addr string) error { + url := fmt.Sprintf("%s://%s/", "http", addr) + + waitTime := 10 * time.Millisecond + totalTime := 0 * time.Millisecond + const maxWaitTime = 10 * time.Second + // Waiting for server to be ready..., max of maxWaitTime + for { + conn, err := net.Dial("tcp", addr) + if err == nil { + conn.Close() + break + } else if totalTime > maxWaitTime { + return fmt.Errorf("timed out waiting for the server to start after %d msec", totalTime/1e6) + } + time.Sleep(waitTime) + totalTime += waitTime + waitTime += 10 * time.Millisecond + } + + client := &http.Client{ + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + resp, err := client.Do(req) + if err == nil { + return fmt.Errorf("server should have panicked on request") + } + if resp != nil { + defer resp.Body.Close() + } + return nil +}