middleware: refactor tests
This commit is contained in:
		@ -41,6 +41,7 @@ func Timeout(d time.Duration) framework.ControllerHandler {
 | 
			
		||||
			log.Println("finish")
 | 
			
		||||
		case <-durationCtx.Done():
 | 
			
		||||
			c.SetHasTimeout()
 | 
			
		||||
			c.GetResponseWriter().WriteHeader(http.StatusRequestTimeout)
 | 
			
		||||
			c.GetResponseWriter().Write([]byte("time out"))
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@ package middleware
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"testing"
 | 
			
		||||
@ -16,50 +15,89 @@ func TestTimeout(t *testing.T) {
 | 
			
		||||
	t.Run("Test timeout handler", func(t *testing.T) {
 | 
			
		||||
		timeoutHandler := Timeout(1 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
		request := httptest.NewRequest(http.MethodGet, "/", nil)
 | 
			
		||||
		response := httptest.NewRecorder()
 | 
			
		||||
		c := framework.NewContext(response, request)
 | 
			
		||||
 | 
			
		||||
		longHandler := func(c *framework.Context) error {
 | 
			
		||||
			log.Println("TEST")
 | 
			
		||||
			time.Sleep(2 * time.Millisecond)
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.SetHandlers([]framework.ControllerHandler{longHandler})
 | 
			
		||||
		res := prepareMiddlewareTest(t, timeoutHandler, longHandler)
 | 
			
		||||
 | 
			
		||||
		err := timeoutHandler(c)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		res := response.Result()
 | 
			
		||||
		buf, _ := io.ReadAll(res.Body)
 | 
			
		||||
		if cmp := bytes.Compare(buf, []byte("time out")); cmp != 0 {
 | 
			
		||||
			t.Errorf("got %q, want time out", string(buf))
 | 
			
		||||
		assertCode(t, res.StatusCode, http.StatusRequestTimeout)
 | 
			
		||||
		assertBody(t, res.Body, "time out")
 | 
			
		||||
	})
 | 
			
		||||
	t.Run("Test no timeout", func(t *testing.T) {
 | 
			
		||||
		timeoutHandler := Timeout(2 * time.Millisecond)
 | 
			
		||||
 | 
			
		||||
		quickHandler := func(c *framework.Context) error {
 | 
			
		||||
			// time.Sleep(1 * time.Millisecond)
 | 
			
		||||
			c.WriteJSON(http.StatusOK, "ok")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res := prepareMiddlewareTest(t, timeoutHandler, quickHandler)
 | 
			
		||||
 | 
			
		||||
		assertCode(t, res.StatusCode, http.StatusOK)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRecover(t *testing.T) {
 | 
			
		||||
	recoverer := Recovery()
 | 
			
		||||
	t.Run("Test panic", func(t *testing.T) {
 | 
			
		||||
		recoverer := Recovery()
 | 
			
		||||
 | 
			
		||||
		panicHandler := func(c *framework.Context) error {
 | 
			
		||||
			panic("panic")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res := prepareMiddlewareTest(t, recoverer, panicHandler)
 | 
			
		||||
 | 
			
		||||
		assertCode(t, res.StatusCode, http.StatusInternalServerError)
 | 
			
		||||
	})
 | 
			
		||||
	t.Run("Test no panic", func(t *testing.T) {
 | 
			
		||||
		recoverer := Recovery()
 | 
			
		||||
 | 
			
		||||
		normalHandler := func(c *framework.Context) error {
 | 
			
		||||
			c.WriteJSON(http.StatusOK, "ok")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res := prepareMiddlewareTest(t, recoverer, normalHandler)
 | 
			
		||||
 | 
			
		||||
		assertCode(t, res.StatusCode, http.StatusOK)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func prepareMiddlewareTest(
 | 
			
		||||
	t testing.TB,
 | 
			
		||||
	mid framework.ControllerHandler,
 | 
			
		||||
	in framework.ControllerHandler,
 | 
			
		||||
) *http.Response {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	request := httptest.NewRequest(http.MethodGet, "/", nil)
 | 
			
		||||
	response := httptest.NewRecorder()
 | 
			
		||||
	c := framework.NewContext(response, request)
 | 
			
		||||
 | 
			
		||||
	panicHandler := func(c *framework.Context) error {
 | 
			
		||||
		log.Println("TEST")
 | 
			
		||||
		panic("panic")
 | 
			
		||||
	}
 | 
			
		||||
	c.SetHandlers([]framework.ControllerHandler{in})
 | 
			
		||||
 | 
			
		||||
	c.SetHandlers([]framework.ControllerHandler{panicHandler})
 | 
			
		||||
 | 
			
		||||
	err := recoverer(c)
 | 
			
		||||
	err := mid(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := response.Result()
 | 
			
		||||
	if res.StatusCode != http.StatusInternalServerError {
 | 
			
		||||
		t.Errorf("status code got %d, want %d", res.StatusCode, http.StatusInternalServerError)
 | 
			
		||||
	return res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func assertCode(t testing.TB, got int, want int) {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	if got != want {
 | 
			
		||||
		t.Errorf("status code got %d, want %d", got, want)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func assertBody(t testing.TB, got io.Reader, want string) {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	buf, _ := io.ReadAll(got)
 | 
			
		||||
	if cmp := bytes.Compare(buf, []byte(want)); cmp != 0 {
 | 
			
		||||
		t.Errorf("got %q, want %q", string(buf), want)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user