diff --git a/context/context.go b/context/context.go index a4671f5..b0ab41c 100644 --- a/context/context.go +++ b/context/context.go @@ -3,6 +3,7 @@ package context import ( "context" "fmt" + "log" "net/http" ) @@ -12,7 +13,11 @@ type Store interface { func Server(store Store) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - data, _ := store.Fetch(r.Context()) + data, err := store.Fetch(r.Context()) + if err != nil { + log.Println(err) + return + } fmt.Fprint(w, data) } } diff --git a/context/context_test.go b/context/context_test.go index 47fa206..0fc447b 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -2,6 +2,7 @@ package context import ( "context" + "errors" "log" "net/http" "net/http/httptest" @@ -39,6 +40,24 @@ func (s *SpyStore) Fetch(ctx context.Context) (string, error) { } } +type SpyResponseWriter struct { + written bool +} + +func (s *SpyResponseWriter) Header() http.Header { + s.written = true + return nil +} + +func (s *SpyResponseWriter) Write([]byte) (int, error) { + s.written = true + return 0, errors.New("not implemented") +} + +func (s *SpyResponseWriter) WriteHeader(statusCode int) { + s.written = true +} + func TestServer(t *testing.T) { t.Run("basic get store", func(t *testing.T) { data := "hello, world" @@ -55,21 +74,25 @@ func TestServer(t *testing.T) { } }) - // t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { - // data := "hello, world" - // store := &SpyStore{response: data, t: t} - // srv := Server(store) - // - // request := httptest.NewRequest(http.MethodGet, "/", nil) - // - // cancellingCtx, cancel := context.WithCancel(request.Context()) - // // Wait 5ms to call cancel - // time.AfterFunc(5*time.Millisecond, cancel) - // - // request = request.WithContext(cancellingCtx) - // - // response := httptest.NewRecorder() - // - // srv.ServeHTTP(response, request) - // }) + t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { + data := "hello, world" + store := &SpyStore{response: data, t: t} + srv := Server(store) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + + cancellingCtx, cancel := context.WithCancel(request.Context()) + // Wait 5ms to call cancel + time.AfterFunc(5*time.Millisecond, cancel) + + request = request.WithContext(cancellingCtx) + + response := &SpyResponseWriter{} + + srv.ServeHTTP(response, request) + + if response.written { + t.Error("a response should not have been written") + } + }) }