diff --git a/countdown/countdown_test.go b/countdown/countdown_test.go index c55330e..1663184 100644 --- a/countdown/countdown_test.go +++ b/countdown/countdown_test.go @@ -2,34 +2,55 @@ package countdown import ( "bytes" + "reflect" "testing" ) -type SpySleeper struct { - Calls int +const ( + write = "write" + sleep = "sleep" +) + +type SpyCountdownOperations struct { + Calls []string } -func (s *SpySleeper) Sleep() { - s.Calls++ +func (s *SpyCountdownOperations) Sleep() { + s.Calls = append(s.Calls, sleep) +} + +func (s *SpyCountdownOperations) Write(p []byte) (n int, err error) { + s.Calls = append(s.Calls, write) + return 0, nil } func TestCountdown(t *testing.T) { - buffer := &bytes.Buffer{} - sleeper := &SpySleeper{} + t.Run("print 3 to Go!", func(t *testing.T) { + buffer := &bytes.Buffer{} + Countdown(buffer, &SpyCountdownOperations{}) - Countdown(buffer, sleeper) - - got := buffer.String() - want := `3 + got := buffer.String() + want := `3 2 1 Go!` - if got != want { - t.Errorf("got %q want %q", got, want) - } + if got != want { + t.Errorf("got %q want %q", got, want) + } + }) - if sleeper.Calls != 3 { - t.Errorf("not enough calls to sleeper, want 3 got %d", sleeper.Calls) - } + t.Run("sleep before every print", func(t *testing.T) { + spySleepPrinter := &SpyCountdownOperations{} + + Countdown(spySleepPrinter, spySleepPrinter) + + want := []string{ + write, sleep, write, sleep, write, sleep, write, + } + + if !reflect.DeepEqual(want, spySleepPrinter.Calls) { + t.Errorf("wanted %v got %v", want, spySleepPrinter.Calls) + } + }) }