Add CustomRecovery builtin middleware (#2322)
* Add CustomRecovery and CustomRecoveryWithWriter methods * add CustomRecovery example to README * add test for CustomRecovery * support RecoveryWithWriter(io.Writer, ...RecoveryFunc)
This commit is contained in:
		
							
								
								
									
										33
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								README.md
									
									
									
									
									
								
							@ -496,6 +496,39 @@ func main() {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Custom Recovery behavior
 | 
				
			||||||
 | 
					```go
 | 
				
			||||||
 | 
					func main() {
 | 
				
			||||||
 | 
						// Creates a router without any middleware by default
 | 
				
			||||||
 | 
						r := gin.New()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Global middleware
 | 
				
			||||||
 | 
						// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
 | 
				
			||||||
 | 
						// By default gin.DefaultWriter = os.Stdout
 | 
				
			||||||
 | 
						r.Use(gin.Logger())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Recovery middleware recovers from any panics and writes a 500 if there was one.
 | 
				
			||||||
 | 
						r.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
 | 
				
			||||||
 | 
							if err, ok := recovered.(string); ok {
 | 
				
			||||||
 | 
								c.String(http.StatusInternalServerError, fmt.Sprintf("error: %s", err))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							c.AbortWithStatus(http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r.GET("/panic", func(c *gin.Context) {
 | 
				
			||||||
 | 
							// panic with a string -- the custom middleware could save this to a database or report it to the user
 | 
				
			||||||
 | 
							panic("foo")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r.GET("/", func(c *gin.Context) {
 | 
				
			||||||
 | 
							c.String(http.StatusOK, "ohai")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Listen and serve on 0.0.0.0:8080
 | 
				
			||||||
 | 
						r.Run(":8080")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### How to write log file
 | 
					### How to write log file
 | 
				
			||||||
```go
 | 
					```go
 | 
				
			||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										27
									
								
								recovery.go
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								recovery.go
									
									
									
									
									
								
							@ -26,13 +26,29 @@ var (
 | 
				
			|||||||
	slash     = []byte("/")
 | 
						slash     = []byte("/")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RecoveryFunc defines the function passable to CustomRecovery.
 | 
				
			||||||
 | 
					type RecoveryFunc func(c *Context, err interface{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
 | 
					// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
 | 
				
			||||||
func Recovery() HandlerFunc {
 | 
					func Recovery() HandlerFunc {
 | 
				
			||||||
	return RecoveryWithWriter(DefaultErrorWriter)
 | 
						return RecoveryWithWriter(DefaultErrorWriter)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
 | 
				
			||||||
 | 
					func CustomRecovery(handle RecoveryFunc) HandlerFunc {
 | 
				
			||||||
 | 
						return RecoveryWithWriter(DefaultErrorWriter, handle)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
 | 
					// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
 | 
				
			||||||
func RecoveryWithWriter(out io.Writer) HandlerFunc {
 | 
					func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
 | 
				
			||||||
 | 
						if len(recovery) > 0 {
 | 
				
			||||||
 | 
							return CustomRecoveryWithWriter(out, recovery[0])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return CustomRecoveryWithWriter(out, defaultHandleRecovery)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
 | 
				
			||||||
 | 
					func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
 | 
				
			||||||
	var logger *log.Logger
 | 
						var logger *log.Logger
 | 
				
			||||||
	if out != nil {
 | 
						if out != nil {
 | 
				
			||||||
		logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
 | 
							logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
 | 
				
			||||||
@ -70,13 +86,12 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc {
 | 
				
			|||||||
							timeFormat(time.Now()), err, stack, reset)
 | 
												timeFormat(time.Now()), err, stack, reset)
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					 | 
				
			||||||
				// If the connection is dead, we can't write a status to it.
 | 
					 | 
				
			||||||
				if brokenPipe {
 | 
									if brokenPipe {
 | 
				
			||||||
 | 
										// If the connection is dead, we can't write a status to it.
 | 
				
			||||||
					c.Error(err.(error)) // nolint: errcheck
 | 
										c.Error(err.(error)) // nolint: errcheck
 | 
				
			||||||
					c.Abort()
 | 
										c.Abort()
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					c.AbortWithStatus(http.StatusInternalServerError)
 | 
										handle(c, err)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
@ -84,6 +99,10 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func defaultHandleRecovery(c *Context, err interface{}) {
 | 
				
			||||||
 | 
						c.AbortWithStatus(http.StatusInternalServerError)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// stack returns a nicely formatted stack frame, skipping skip frames.
 | 
					// stack returns a nicely formatted stack frame, skipping skip frames.
 | 
				
			||||||
func stack(skip int) []byte {
 | 
					func stack(skip int) []byte {
 | 
				
			||||||
	buf := new(bytes.Buffer) // the returned data
 | 
						buf := new(bytes.Buffer) // the returned data
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										106
									
								
								recovery_test.go
									
									
									
									
									
								
							
							
						
						
									
										106
									
								
								recovery_test.go
									
									
									
									
									
								
							@ -62,7 +62,7 @@ func TestPanicInHandler(t *testing.T) {
 | 
				
			|||||||
	assert.Equal(t, http.StatusInternalServerError, w.Code)
 | 
						assert.Equal(t, http.StatusInternalServerError, w.Code)
 | 
				
			||||||
	assert.Contains(t, buffer.String(), "panic recovered")
 | 
						assert.Contains(t, buffer.String(), "panic recovered")
 | 
				
			||||||
	assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
 | 
						assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
 | 
				
			||||||
	assert.Contains(t, buffer.String(), "TestPanicInHandler")
 | 
						assert.Contains(t, buffer.String(), t.Name())
 | 
				
			||||||
	assert.NotContains(t, buffer.String(), "GET /recovery")
 | 
						assert.NotContains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Debug mode prints the request
 | 
						// Debug mode prints the request
 | 
				
			||||||
@ -144,3 +144,107 @@ func TestPanicWithBrokenPipe(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCustomRecoveryWithWriter(t *testing.T) {
 | 
				
			||||||
 | 
						errBuffer := new(bytes.Buffer)
 | 
				
			||||||
 | 
						buffer := new(bytes.Buffer)
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						handleRecovery := func(c *Context, err interface{}) {
 | 
				
			||||||
 | 
							errBuffer.WriteString(err.(string))
 | 
				
			||||||
 | 
							c.AbortWithStatus(http.StatusBadRequest)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						router.Use(CustomRecoveryWithWriter(buffer, handleRecovery))
 | 
				
			||||||
 | 
						router.GET("/recovery", func(_ *Context) {
 | 
				
			||||||
 | 
							panic("Oupps, Houston, we have a problem")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						// RUN
 | 
				
			||||||
 | 
						w := performRequest(router, "GET", "/recovery")
 | 
				
			||||||
 | 
						// TEST
 | 
				
			||||||
 | 
						assert.Equal(t, http.StatusBadRequest, w.Code)
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "panic recovered")
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), t.Name())
 | 
				
			||||||
 | 
						assert.NotContains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Debug mode prints the request
 | 
				
			||||||
 | 
						SetMode(DebugMode)
 | 
				
			||||||
 | 
						// RUN
 | 
				
			||||||
 | 
						w = performRequest(router, "GET", "/recovery")
 | 
				
			||||||
 | 
						// TEST
 | 
				
			||||||
 | 
						assert.Equal(t, http.StatusBadRequest, w.Code)
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						SetMode(TestMode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCustomRecovery(t *testing.T) {
 | 
				
			||||||
 | 
						errBuffer := new(bytes.Buffer)
 | 
				
			||||||
 | 
						buffer := new(bytes.Buffer)
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						DefaultErrorWriter = buffer
 | 
				
			||||||
 | 
						handleRecovery := func(c *Context, err interface{}) {
 | 
				
			||||||
 | 
							errBuffer.WriteString(err.(string))
 | 
				
			||||||
 | 
							c.AbortWithStatus(http.StatusBadRequest)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						router.Use(CustomRecovery(handleRecovery))
 | 
				
			||||||
 | 
						router.GET("/recovery", func(_ *Context) {
 | 
				
			||||||
 | 
							panic("Oupps, Houston, we have a problem")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						// RUN
 | 
				
			||||||
 | 
						w := performRequest(router, "GET", "/recovery")
 | 
				
			||||||
 | 
						// TEST
 | 
				
			||||||
 | 
						assert.Equal(t, http.StatusBadRequest, w.Code)
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "panic recovered")
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), t.Name())
 | 
				
			||||||
 | 
						assert.NotContains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Debug mode prints the request
 | 
				
			||||||
 | 
						SetMode(DebugMode)
 | 
				
			||||||
 | 
						// RUN
 | 
				
			||||||
 | 
						w = performRequest(router, "GET", "/recovery")
 | 
				
			||||||
 | 
						// TEST
 | 
				
			||||||
 | 
						assert.Equal(t, http.StatusBadRequest, w.Code)
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						SetMode(TestMode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) {
 | 
				
			||||||
 | 
						errBuffer := new(bytes.Buffer)
 | 
				
			||||||
 | 
						buffer := new(bytes.Buffer)
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						DefaultErrorWriter = buffer
 | 
				
			||||||
 | 
						handleRecovery := func(c *Context, err interface{}) {
 | 
				
			||||||
 | 
							errBuffer.WriteString(err.(string))
 | 
				
			||||||
 | 
							c.AbortWithStatus(http.StatusBadRequest)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						router.Use(RecoveryWithWriter(DefaultErrorWriter, handleRecovery))
 | 
				
			||||||
 | 
						router.GET("/recovery", func(_ *Context) {
 | 
				
			||||||
 | 
							panic("Oupps, Houston, we have a problem")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						// RUN
 | 
				
			||||||
 | 
						w := performRequest(router, "GET", "/recovery")
 | 
				
			||||||
 | 
						// TEST
 | 
				
			||||||
 | 
						assert.Equal(t, http.StatusBadRequest, w.Code)
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "panic recovered")
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), t.Name())
 | 
				
			||||||
 | 
						assert.NotContains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Debug mode prints the request
 | 
				
			||||||
 | 
						SetMode(DebugMode)
 | 
				
			||||||
 | 
						// RUN
 | 
				
			||||||
 | 
						w = performRequest(router, "GET", "/recovery")
 | 
				
			||||||
 | 
						// TEST
 | 
				
			||||||
 | 
						assert.Equal(t, http.StatusBadRequest, w.Code)
 | 
				
			||||||
 | 
						assert.Contains(t, buffer.String(), "GET /recovery")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						SetMode(TestMode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user