package context import ( "context" "errors" "log" "net/http" "net/http/httptest" "testing" "time" ) type SpyStore struct { response string t *testing.T } func (s *SpyStore) Fetch(ctx context.Context) (string, error) { data := make(chan string, 1) go func() { var result string for _, c := range s.response { select { case <-ctx.Done(): log.Println("spy store got cancelled") return default: time.Sleep(10 * time.Millisecond) result += string(c) } } data <- result }() select { case <-ctx.Done(): return "", ctx.Err() case res := <-data: return res, nil } } type SpyResponseWriter struct { written bool } func (s *SpyResponseWriter) Header() http.Header { s.written = true return nil } func (s *SpyResponseWriter) Write([]byte) (int, error) { s.written = true return 0, errors.New("not implemented") } func (s *SpyResponseWriter) WriteHeader(statusCode int) { s.written = true } func TestServer(t *testing.T) { t.Run("basic get store", func(t *testing.T) { data := "hello, world" store := &SpyStore{response: data, t: t} srv := Server(store) request := httptest.NewRequest(http.MethodGet, "/", nil) response := httptest.NewRecorder() srv.ServeHTTP(response, request) if response.Body.String() != data { t.Errorf(`got "%s", want "%s"`, response.Body.String(), data) } }) t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { data := "hello, world" store := &SpyStore{response: data, t: t} srv := Server(store) request := httptest.NewRequest(http.MethodGet, "/", nil) cancellingCtx, cancel := context.WithCancel(request.Context()) // Wait 5ms to call cancel time.AfterFunc(5*time.Millisecond, cancel) request = request.WithContext(cancellingCtx) response := &SpyResponseWriter{} srv.ServeHTTP(response, request) if response.written { t.Error("a response should not have been written") } }) }