context: realistic usage of context.

Use a go routine for the work logic. Here sleep and append to string.
Inside the go routine, select for ctx.Done(). If happens, just stop and
return.

In the outside function, select for ctx.Done() too. If it happens, that
means the work logic has not finished before canceled.
This commit is contained in:
vinchent 2024-09-21 20:35:43 +02:00
parent c0db9ab22b
commit c5582275e7
2 changed files with 45 additions and 59 deletions

View File

@ -1,30 +1,18 @@
package context package context
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
) )
type Store interface { type Store interface {
Fetch() string Fetch(ctx context.Context) (string, error)
Cancel()
} }
func Server(store Store) http.HandlerFunc { func Server(store Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() data, _ := store.Fetch(r.Context())
fmt.Fprint(w, data)
data := make(chan string, 1)
go func() {
data <- store.Fetch()
}()
select {
case d := <-data:
fmt.Fprint(w, d)
case <-ctx.Done():
store.Cancel()
}
} }
} }

View File

@ -2,6 +2,7 @@ package context
import ( import (
"context" "context"
"log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -10,30 +11,31 @@ import (
type SpyStore struct { type SpyStore struct {
response string response string
cancelled bool
t *testing.T t *testing.T
} }
func (s *SpyStore) Fetch() string { func (s *SpyStore) Fetch(ctx context.Context) (string, error) {
time.Sleep(100 * time.Millisecond) data := make(chan string, 1)
return s.response go func() {
var result string
for _, c := range s.response {
select {
case <-ctx.Done():
log.Println("spy store got cancelled")
return
default:
time.Sleep(10 * time.Millisecond)
result += string(c)
} }
}
data <- result
}()
func (s *SpyStore) Cancel() { select {
s.cancelled = true case <-ctx.Done():
} return "", ctx.Err()
case res := <-data:
func (s *SpyStore) assertWasCancelled() { return res, nil
s.t.Helper()
if !s.cancelled {
s.t.Error("store was not told to cancel")
}
}
func (s *SpyStore) assertWasNotCancelled() {
s.t.Helper()
if s.cancelled {
s.t.Error("it should not have cancelled the store")
} }
} }
@ -51,27 +53,23 @@ func TestServer(t *testing.T) {
if response.Body.String() != data { if response.Body.String() != data {
t.Errorf(`got "%s", want "%s"`, response.Body.String(), data) t.Errorf(`got "%s", want "%s"`, response.Body.String(), data)
} }
store.assertWasNotCancelled()
}) })
t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { // t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) {
data := "hello, world" // data := "hello, world"
store := &SpyStore{response: data, t: t} // store := &SpyStore{response: data, t: t}
srv := Server(store) // srv := Server(store)
//
request := httptest.NewRequest(http.MethodGet, "/", nil) // request := httptest.NewRequest(http.MethodGet, "/", nil)
//
cancellingCtx, cancel := context.WithCancel(request.Context()) // cancellingCtx, cancel := context.WithCancel(request.Context())
// Wait 5ms to call cancel // // Wait 5ms to call cancel
time.AfterFunc(5*time.Millisecond, cancel) // time.AfterFunc(5*time.Millisecond, cancel)
//
request = request.WithContext(cancellingCtx) // request = request.WithContext(cancellingCtx)
//
response := httptest.NewRecorder() // response := httptest.NewRecorder()
//
srv.ServeHTTP(response, request) // srv.ServeHTTP(response, request)
// })
store.assertWasCancelled()
})
} }