diff --git a/racer/racer.go b/racer/racer.go index f287b0b..7cf2444 100644 --- a/racer/racer.go +++ b/racer/racer.go @@ -8,7 +8,9 @@ import ( var ErrRacerTimeout = errors.New("timeout") -func Racer(a, b string, timeout time.Duration) (string, error) { +const tenSecondTimeout = 10 * time.Second + +func ConfigurableRacer(a, b string, timeout time.Duration) (string, error) { select { // wait until channel closed case <-ping(a): @@ -20,6 +22,10 @@ func Racer(a, b string, timeout time.Duration) (string, error) { } } +func Racer(a, b string) (string, error) { + return ConfigurableRacer(a, b, tenSecondTimeout) +} + func ping(url string) chan struct{} { ch := make(chan struct{}) go func() { diff --git a/racer/racer_test.go b/racer/racer_test.go index 56d2bdc..e552eea 100644 --- a/racer/racer_test.go +++ b/racer/racer_test.go @@ -7,6 +7,12 @@ import ( "time" ) +type SpyRacerTimeout struct{} + +func (s *SpyRacerTimeout) Timeout() <-chan time.Time { + return time.After(1 * time.Millisecond) +} + func TestRacer(t *testing.T) { t.Run("compares speeds of servers, returning the url of the fasted one", func(t *testing.T) { slowServer := makeDelayedServer(20 * time.Millisecond) @@ -19,7 +25,7 @@ func TestRacer(t *testing.T) { fastURL := fastServer.URL want := fastURL - got, _ := Racer(slowURL, fastURL, 10*time.Second) + got, _ := Racer(slowURL, fastURL) if got != want { t.Errorf("got %q, want %q", got, want) @@ -27,13 +33,13 @@ func TestRacer(t *testing.T) { }) t.Run("returns an error if a server doesn't respond within 10s", func(t *testing.T) { - serverA := makeDelayedServer(11 * time.Second) + serverA := makeDelayedServer(25 * time.Millisecond) defer serverA.Close() - serverB := makeDelayedServer(11 * time.Second) + serverB := makeDelayedServer(25 * time.Millisecond) defer serverB.Close() - _, err := Racer(serverA.URL, serverB.URL, 10*time.Second) + _, err := ConfigurableRacer(serverA.URL, serverB.URL, 20*time.Millisecond) if err == nil { t.Error("expected an error but didn't got one") }