context: realistic usage of context.

Use a go routine for the work logic. Here sleep and append to string.
Inside the go routine, select for ctx.Done(). If happens, just stop and
return.

In the outside function, select for ctx.Done() too. If it happens, that
means the work logic has not finished before canceled.
This commit is contained in:
vinchent 2024-09-21 20:35:43 +02:00
parent c0db9ab22b
commit c5582275e7
2 changed files with 45 additions and 59 deletions

View File

@ -1,30 +1,18 @@
package context
import (
"context"
"fmt"
"net/http"
)
type Store interface {
Fetch() string
Cancel()
Fetch(ctx context.Context) (string, error)
}
func Server(store Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := make(chan string, 1)
go func() {
data <- store.Fetch()
}()
select {
case d := <-data:
fmt.Fprint(w, d)
case <-ctx.Done():
store.Cancel()
}
data, _ := store.Fetch(r.Context())
fmt.Fprint(w, data)
}
}

View File

@ -2,6 +2,7 @@ package context
import (
"context"
"log"
"net/http"
"net/http/httptest"
"testing"
@ -9,31 +10,32 @@ import (
)
type SpyStore struct {
response string
cancelled bool
t *testing.T
response string
t *testing.T
}
func (s *SpyStore) Fetch() string {
time.Sleep(100 * time.Millisecond)
return s.response
}
func (s *SpyStore) Fetch(ctx context.Context) (string, error) {
data := make(chan string, 1)
go func() {
var result string
for _, c := range s.response {
select {
case <-ctx.Done():
log.Println("spy store got cancelled")
return
default:
time.Sleep(10 * time.Millisecond)
result += string(c)
}
}
data <- result
}()
func (s *SpyStore) Cancel() {
s.cancelled = true
}
func (s *SpyStore) assertWasCancelled() {
s.t.Helper()
if !s.cancelled {
s.t.Error("store was not told to cancel")
}
}
func (s *SpyStore) assertWasNotCancelled() {
s.t.Helper()
if s.cancelled {
s.t.Error("it should not have cancelled the store")
select {
case <-ctx.Done():
return "", ctx.Err()
case res := <-data:
return res, nil
}
}
@ -51,27 +53,23 @@ func TestServer(t *testing.T) {
if response.Body.String() != data {
t.Errorf(`got "%s", want "%s"`, response.Body.String(), data)
}
store.assertWasNotCancelled()
})
t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) {
data := "hello, world"
store := &SpyStore{response: data, t: t}
srv := Server(store)
request := httptest.NewRequest(http.MethodGet, "/", nil)
cancellingCtx, cancel := context.WithCancel(request.Context())
// Wait 5ms to call cancel
time.AfterFunc(5*time.Millisecond, cancel)
request = request.WithContext(cancellingCtx)
response := httptest.NewRecorder()
srv.ServeHTTP(response, request)
store.assertWasCancelled()
})
// t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) {
// data := "hello, world"
// store := &SpyStore{response: data, t: t}
// srv := Server(store)
//
// request := httptest.NewRequest(http.MethodGet, "/", nil)
//
// cancellingCtx, cancel := context.WithCancel(request.Context())
// // Wait 5ms to call cancel
// time.AfterFunc(5*time.Millisecond, cancel)
//
// request = request.WithContext(cancellingCtx)
//
// response := httptest.NewRecorder()
//
// srv.ServeHTTP(response, request)
// })
}