diff --git a/recovery.go b/recovery.go index 6c28b4f..e788cc4 100644 --- a/recovery.go +++ b/recovery.go @@ -10,9 +10,12 @@ import ( "io" "io/ioutil" "log" + "net" "net/http" "net/http/httputil" + "os" "runtime" + "syscall" "time" ) @@ -37,16 +40,37 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc { return func(c *Context) { defer func() { if err := recover(); err != nil { - if logger != nil { - stack := stack(3) - if IsDebugging() { - httprequest, _ := httputil.DumpRequest(c.Request, false) - logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s", timeFormat(time.Now()), string(httprequest), err, stack, reset) - } else { - logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s", timeFormat(time.Now()), err, stack, reset) + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + var brokenPipe bool + if ne, ok := err.(*net.OpError); ok { + if se, ok := ne.Err.(*os.SyscallError); ok { + if se.Err == syscall.EPIPE || se.Err == syscall.ECONNRESET { + brokenPipe = true + } } } - c.AbortWithStatus(http.StatusInternalServerError) + if logger != nil { + stack := stack(3) + httprequest, _ := httputil.DumpRequest(c.Request, false) + if brokenPipe { + logger.Printf("%s\n%s%s", err, string(httprequest), reset) + } else if IsDebugging() { + logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s", + timeFormat(time.Now()), string(httprequest), err, stack, reset) + } else { + logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s", + timeFormat(time.Now()), err, stack, reset) + } + } + + // If the connection is dead, we can't write a status to it. + if brokenPipe { + c.Error(err.(error)) + c.Abort() + } else { + c.AbortWithStatus(http.StatusInternalServerError) + } } }() c.Next() diff --git a/recovery_test.go b/recovery_test.go index 7d422b7..cafaee9 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -2,11 +2,16 @@ // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. +// +build go1.7 + package gin import ( "bytes" + "net" "net/http" + "os" + "syscall" "testing" "github.com/stretchr/testify/assert" @@ -72,3 +77,38 @@ func TestFunction(t *testing.T) { bs := function(1) assert.Equal(t, []byte("???"), bs) } + +// TestPanicWithBrokenPipe asserts that recovery specifically handles +// writing responses to broken pipes +func TestPanicWithBrokenPipe(t *testing.T) { + const expectCode = 204 + + expectMsgs := map[syscall.Errno]string{ + syscall.EPIPE: "broken pipe", + syscall.ECONNRESET: "connection reset", + } + + for errno, expectMsg := range expectMsgs { + t.Run(expectMsg, func(t *testing.T) { + + var buf bytes.Buffer + + router := New() + router.Use(RecoveryWithWriter(&buf)) + router.GET("/recovery", func(c *Context) { + // Start writing response + c.Header("X-Test", "Value") + c.Status(expectCode) + + // Oops. Client connection closed + e := &net.OpError{Err: &os.SyscallError{Err: errno}} + panic(e) + }) + // RUN + w := performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, expectCode, w.Code) + assert.Contains(t, buf.String(), expectMsg) + }) + } +}