package middleware import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "time" "git.vinchent.xyz/vinchent/go-web/framework" ) func TestTimeout(t *testing.T) { t.Run("Test timeout handler", func(t *testing.T) { timeoutHandler := Timeout(1 * time.Millisecond) longHandler := func(c *framework.Context) error { time.Sleep(2 * time.Millisecond) return nil } res := prepareMiddlewareTest(t, timeoutHandler, longHandler) assertCode(t, res.StatusCode, http.StatusRequestTimeout) assertBody(t, res.Body, "time out") }) t.Run("Test no timeout", func(t *testing.T) { timeoutHandler := Timeout(2 * time.Millisecond) quickHandler := func(c *framework.Context) error { // time.Sleep(1 * time.Millisecond) c.WriteJSON(http.StatusOK, "ok") return nil } res := prepareMiddlewareTest(t, timeoutHandler, quickHandler) assertCode(t, res.StatusCode, http.StatusOK) }) } func TestRecover(t *testing.T) { t.Run("Test panic", func(t *testing.T) { recoverer := Recovery() panicHandler := func(c *framework.Context) error { panic("panic") } res := prepareMiddlewareTest(t, recoverer, panicHandler) assertCode(t, res.StatusCode, http.StatusInternalServerError) }) t.Run("Test no panic", func(t *testing.T) { recoverer := Recovery() normalHandler := func(c *framework.Context) error { c.WriteJSON(http.StatusOK, "ok") return nil } res := prepareMiddlewareTest(t, recoverer, normalHandler) assertCode(t, res.StatusCode, http.StatusOK) }) } func prepareMiddlewareTest( t testing.TB, mid framework.ControllerHandler, in framework.ControllerHandler, ) *http.Response { t.Helper() request := httptest.NewRequest(http.MethodGet, "/", nil) response := httptest.NewRecorder() c := framework.NewContext(response, request) c.SetHandlers([]framework.ControllerHandler{in}) err := mid(c) if err != nil { t.Fatal(err) } res := response.Result() return res } func assertCode(t testing.TB, got int, want int) { t.Helper() if got != want { t.Errorf("status code got %d, want %d", got, want) } } func assertBody(t testing.TB, got io.Reader, want string) { t.Helper() buf, _ := io.ReadAll(got) if cmp := bytes.Compare(buf, []byte(want)); cmp != 0 { t.Errorf("got %q, want %q", string(buf), want) } }