diff --git a/framework/middleware/timeout_test.go b/framework/middleware/timeout_test.go index 34632af..42e1d67 100644 --- a/framework/middleware/timeout_test.go +++ b/framework/middleware/timeout_test.go @@ -1,6 +1,9 @@ package middleware import ( + "bytes" + "io" + "log" "net/http" "net/http/httptest" "testing" @@ -17,9 +20,46 @@ func TestTimeout(t *testing.T) { response := httptest.NewRecorder() c := framework.NewContext(response, request) + longHandler := func(c *framework.Context) error { + log.Println("TEST") + time.Sleep(2 * time.Millisecond) + return nil + } + + c.SetHandlers([]framework.ControllerHandler{longHandler}) + err := timeoutHandler(c) if err != nil { t.Fatal(err) } + res := response.Result() + buf, _ := io.ReadAll(res.Body) + if cmp := bytes.Compare(buf, []byte("time out")); cmp != 0 { + t.Errorf("got %q, want time out", string(buf)) + } }) } + +func TestRecover(t *testing.T) { + recoverer := Recovery() + + request := httptest.NewRequest(http.MethodGet, "/", nil) + response := httptest.NewRecorder() + c := framework.NewContext(response, request) + + panicHandler := func(c *framework.Context) error { + log.Println("TEST") + panic("panic") + } + + c.SetHandlers([]framework.ControllerHandler{panicHandler}) + + err := recoverer(c) + if err != nil { + t.Fatal(err) + } + res := response.Result() + if res.StatusCode != http.StatusInternalServerError { + t.Errorf("status code got %d, want %d", res.StatusCode, http.StatusInternalServerError) + } +}