middleware: refactor tests
This commit is contained in:
parent
fc4103cff3
commit
3b16d6b16a
@ -41,6 +41,7 @@ func Timeout(d time.Duration) framework.ControllerHandler {
|
|||||||
log.Println("finish")
|
log.Println("finish")
|
||||||
case <-durationCtx.Done():
|
case <-durationCtx.Done():
|
||||||
c.SetHasTimeout()
|
c.SetHasTimeout()
|
||||||
|
c.GetResponseWriter().WriteHeader(http.StatusRequestTimeout)
|
||||||
c.GetResponseWriter().Write([]byte("time out"))
|
c.GetResponseWriter().Write([]byte("time out"))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -3,7 +3,6 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -16,50 +15,89 @@ func TestTimeout(t *testing.T) {
|
|||||||
t.Run("Test timeout handler", func(t *testing.T) {
|
t.Run("Test timeout handler", func(t *testing.T) {
|
||||||
timeoutHandler := Timeout(1 * time.Millisecond)
|
timeoutHandler := Timeout(1 * time.Millisecond)
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
response := httptest.NewRecorder()
|
|
||||||
c := framework.NewContext(response, request)
|
|
||||||
|
|
||||||
longHandler := func(c *framework.Context) error {
|
longHandler := func(c *framework.Context) error {
|
||||||
log.Println("TEST")
|
|
||||||
time.Sleep(2 * time.Millisecond)
|
time.Sleep(2 * time.Millisecond)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetHandlers([]framework.ControllerHandler{longHandler})
|
res := prepareMiddlewareTest(t, timeoutHandler, longHandler)
|
||||||
|
|
||||||
err := timeoutHandler(c)
|
assertCode(t, res.StatusCode, http.StatusRequestTimeout)
|
||||||
if err != nil {
|
assertBody(t, res.Body, "time out")
|
||||||
t.Fatal(err)
|
})
|
||||||
}
|
t.Run("Test no timeout", func(t *testing.T) {
|
||||||
res := response.Result()
|
timeoutHandler := Timeout(2 * time.Millisecond)
|
||||||
buf, _ := io.ReadAll(res.Body)
|
|
||||||
if cmp := bytes.Compare(buf, []byte("time out")); cmp != 0 {
|
quickHandler := func(c *framework.Context) error {
|
||||||
t.Errorf("got %q, want time out", string(buf))
|
// time.Sleep(1 * time.Millisecond)
|
||||||
|
c.WriteJSON(http.StatusOK, "ok")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res := prepareMiddlewareTest(t, timeoutHandler, quickHandler)
|
||||||
|
|
||||||
|
assertCode(t, res.StatusCode, http.StatusOK)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRecover(t *testing.T) {
|
func TestRecover(t *testing.T) {
|
||||||
|
t.Run("Test panic", func(t *testing.T) {
|
||||||
recoverer := Recovery()
|
recoverer := Recovery()
|
||||||
|
|
||||||
|
panicHandler := func(c *framework.Context) error {
|
||||||
|
panic("panic")
|
||||||
|
}
|
||||||
|
|
||||||
|
res := prepareMiddlewareTest(t, recoverer, panicHandler)
|
||||||
|
|
||||||
|
assertCode(t, res.StatusCode, http.StatusInternalServerError)
|
||||||
|
})
|
||||||
|
t.Run("Test no panic", func(t *testing.T) {
|
||||||
|
recoverer := Recovery()
|
||||||
|
|
||||||
|
normalHandler := func(c *framework.Context) error {
|
||||||
|
c.WriteJSON(http.StatusOK, "ok")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
res := prepareMiddlewareTest(t, recoverer, normalHandler)
|
||||||
|
|
||||||
|
assertCode(t, res.StatusCode, http.StatusOK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareMiddlewareTest(
|
||||||
|
t testing.TB,
|
||||||
|
mid framework.ControllerHandler,
|
||||||
|
in framework.ControllerHandler,
|
||||||
|
) *http.Response {
|
||||||
|
t.Helper()
|
||||||
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
response := httptest.NewRecorder()
|
response := httptest.NewRecorder()
|
||||||
c := framework.NewContext(response, request)
|
c := framework.NewContext(response, request)
|
||||||
|
|
||||||
panicHandler := func(c *framework.Context) error {
|
c.SetHandlers([]framework.ControllerHandler{in})
|
||||||
log.Println("TEST")
|
|
||||||
panic("panic")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetHandlers([]framework.ControllerHandler{panicHandler})
|
err := mid(c)
|
||||||
|
|
||||||
err := recoverer(c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res := response.Result()
|
res := response.Result()
|
||||||
if res.StatusCode != http.StatusInternalServerError {
|
return res
|
||||||
t.Errorf("status code got %d, want %d", res.StatusCode, http.StatusInternalServerError)
|
}
|
||||||
|
|
||||||
|
func assertCode(t testing.TB, got int, want int) {
|
||||||
|
t.Helper()
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("status code got %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertBody(t testing.TB, got io.Reader, want string) {
|
||||||
|
t.Helper()
|
||||||
|
buf, _ := io.ReadAll(got)
|
||||||
|
if cmp := bytes.Compare(buf, []byte(want)); cmp != 0 {
|
||||||
|
t.Errorf("got %q, want %q", string(buf), want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user