middleware: refactor tests

This commit is contained in:
Muyao CHEN 2024-09-25 14:09:29 +02:00
parent fc4103cff3
commit 3b16d6b16a
2 changed files with 64 additions and 25 deletions

View File

@ -41,6 +41,7 @@ func Timeout(d time.Duration) framework.ControllerHandler {
log.Println("finish") log.Println("finish")
case <-durationCtx.Done(): case <-durationCtx.Done():
c.SetHasTimeout() c.SetHasTimeout()
c.GetResponseWriter().WriteHeader(http.StatusRequestTimeout)
c.GetResponseWriter().Write([]byte("time out")) c.GetResponseWriter().Write([]byte("time out"))
} }
return nil return nil

View File

@ -3,7 +3,6 @@ package middleware
import ( import (
"bytes" "bytes"
"io" "io"
"log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -16,50 +15,89 @@ func TestTimeout(t *testing.T) {
t.Run("Test timeout handler", func(t *testing.T) { t.Run("Test timeout handler", func(t *testing.T) {
timeoutHandler := Timeout(1 * time.Millisecond) timeoutHandler := Timeout(1 * time.Millisecond)
request := httptest.NewRequest(http.MethodGet, "/", nil)
response := httptest.NewRecorder()
c := framework.NewContext(response, request)
longHandler := func(c *framework.Context) error { longHandler := func(c *framework.Context) error {
log.Println("TEST")
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
return nil return nil
} }
c.SetHandlers([]framework.ControllerHandler{longHandler}) res := prepareMiddlewareTest(t, timeoutHandler, longHandler)
err := timeoutHandler(c) assertCode(t, res.StatusCode, http.StatusRequestTimeout)
if err != nil { assertBody(t, res.Body, "time out")
t.Fatal(err) })
} t.Run("Test no timeout", func(t *testing.T) {
res := response.Result() timeoutHandler := Timeout(2 * time.Millisecond)
buf, _ := io.ReadAll(res.Body)
if cmp := bytes.Compare(buf, []byte("time out")); cmp != 0 { quickHandler := func(c *framework.Context) error {
t.Errorf("got %q, want time out", string(buf)) // time.Sleep(1 * time.Millisecond)
c.WriteJSON(http.StatusOK, "ok")
return nil
} }
res := prepareMiddlewareTest(t, timeoutHandler, quickHandler)
assertCode(t, res.StatusCode, http.StatusOK)
}) })
} }
func TestRecover(t *testing.T) { func TestRecover(t *testing.T) {
t.Run("Test panic", func(t *testing.T) {
recoverer := Recovery() recoverer := Recovery()
panicHandler := func(c *framework.Context) error {
panic("panic")
}
res := prepareMiddlewareTest(t, recoverer, panicHandler)
assertCode(t, res.StatusCode, http.StatusInternalServerError)
})
t.Run("Test no panic", func(t *testing.T) {
recoverer := Recovery()
normalHandler := func(c *framework.Context) error {
c.WriteJSON(http.StatusOK, "ok")
return nil
}
res := prepareMiddlewareTest(t, recoverer, normalHandler)
assertCode(t, res.StatusCode, http.StatusOK)
})
}
func prepareMiddlewareTest(
t testing.TB,
mid framework.ControllerHandler,
in framework.ControllerHandler,
) *http.Response {
t.Helper()
request := httptest.NewRequest(http.MethodGet, "/", nil) request := httptest.NewRequest(http.MethodGet, "/", nil)
response := httptest.NewRecorder() response := httptest.NewRecorder()
c := framework.NewContext(response, request) c := framework.NewContext(response, request)
panicHandler := func(c *framework.Context) error { c.SetHandlers([]framework.ControllerHandler{in})
log.Println("TEST")
panic("panic")
}
c.SetHandlers([]framework.ControllerHandler{panicHandler}) err := mid(c)
err := recoverer(c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
res := response.Result() res := response.Result()
if res.StatusCode != http.StatusInternalServerError { return res
t.Errorf("status code got %d, want %d", res.StatusCode, http.StatusInternalServerError) }
func assertCode(t testing.TB, got int, want int) {
t.Helper()
if got != want {
t.Errorf("status code got %d, want %d", got, want)
}
}
func assertBody(t testing.TB, got io.Reader, want string) {
t.Helper()
buf, _ := io.ReadAll(got)
if cmp := bytes.Compare(buf, []byte(want)); cmp != 0 {
t.Errorf("got %q, want %q", string(buf), want)
} }
} }