diff --git a/framework/middleware/timeout.go b/framework/middleware/timeout.go index e20d929..47183e9 100644 --- a/framework/middleware/timeout.go +++ b/framework/middleware/timeout.go @@ -41,6 +41,7 @@ func Timeout(d time.Duration) framework.ControllerHandler { log.Println("finish") case <-durationCtx.Done(): c.SetHasTimeout() + c.GetResponseWriter().WriteHeader(http.StatusRequestTimeout) c.GetResponseWriter().Write([]byte("time out")) } return nil diff --git a/framework/middleware/timeout_test.go b/framework/middleware/timeout_test.go index 42e1d67..f8a1a61 100644 --- a/framework/middleware/timeout_test.go +++ b/framework/middleware/timeout_test.go @@ -3,7 +3,6 @@ package middleware import ( "bytes" "io" - "log" "net/http" "net/http/httptest" "testing" @@ -16,50 +15,89 @@ func TestTimeout(t *testing.T) { t.Run("Test timeout handler", func(t *testing.T) { timeoutHandler := Timeout(1 * time.Millisecond) - request := httptest.NewRequest(http.MethodGet, "/", nil) - response := httptest.NewRecorder() - c := framework.NewContext(response, request) - longHandler := func(c *framework.Context) error { - log.Println("TEST") time.Sleep(2 * time.Millisecond) return nil } - c.SetHandlers([]framework.ControllerHandler{longHandler}) + res := prepareMiddlewareTest(t, timeoutHandler, longHandler) - err := timeoutHandler(c) - if err != nil { - t.Fatal(err) - } - res := response.Result() - buf, _ := io.ReadAll(res.Body) - if cmp := bytes.Compare(buf, []byte("time out")); cmp != 0 { - t.Errorf("got %q, want time out", string(buf)) + assertCode(t, res.StatusCode, http.StatusRequestTimeout) + assertBody(t, res.Body, "time out") + }) + t.Run("Test no timeout", func(t *testing.T) { + timeoutHandler := Timeout(2 * time.Millisecond) + + quickHandler := func(c *framework.Context) error { + // time.Sleep(1 * time.Millisecond) + c.WriteJSON(http.StatusOK, "ok") + return nil } + + res := prepareMiddlewareTest(t, timeoutHandler, quickHandler) + + assertCode(t, res.StatusCode, http.StatusOK) }) } func TestRecover(t *testing.T) { - recoverer := Recovery() + t.Run("Test panic", func(t *testing.T) { + recoverer := Recovery() + panicHandler := func(c *framework.Context) error { + panic("panic") + } + + res := prepareMiddlewareTest(t, recoverer, panicHandler) + + assertCode(t, res.StatusCode, http.StatusInternalServerError) + }) + t.Run("Test no panic", func(t *testing.T) { + recoverer := Recovery() + + normalHandler := func(c *framework.Context) error { + c.WriteJSON(http.StatusOK, "ok") + return nil + } + + res := prepareMiddlewareTest(t, recoverer, normalHandler) + + assertCode(t, res.StatusCode, http.StatusOK) + }) +} + +func prepareMiddlewareTest( + t testing.TB, + mid framework.ControllerHandler, + in framework.ControllerHandler, +) *http.Response { + t.Helper() request := httptest.NewRequest(http.MethodGet, "/", nil) response := httptest.NewRecorder() c := framework.NewContext(response, request) - panicHandler := func(c *framework.Context) error { - log.Println("TEST") - panic("panic") - } + c.SetHandlers([]framework.ControllerHandler{in}) - c.SetHandlers([]framework.ControllerHandler{panicHandler}) - - err := recoverer(c) + err := mid(c) if err != nil { t.Fatal(err) } + res := response.Result() - if res.StatusCode != http.StatusInternalServerError { - t.Errorf("status code got %d, want %d", res.StatusCode, http.StatusInternalServerError) + return res +} + +func assertCode(t testing.TB, got int, want int) { + t.Helper() + if got != want { + t.Errorf("status code got %d, want %d", got, want) + } +} + +func assertBody(t testing.TB, got io.Reader, want string) { + t.Helper() + buf, _ := io.ReadAll(got) + if cmp := bytes.Compare(buf, []byte(want)); cmp != 0 { + t.Errorf("got %q, want %q", string(buf), want) } }