diff --git a/common/mclock/mclock.go b/common/mclock/mclock.go index d0e0cd78b..3aca257cb 100644 --- a/common/mclock/mclock.go +++ b/common/mclock/mclock.go @@ -31,44 +31,93 @@ func Now() AbsTime { return AbsTime(monotime.Now()) } -// Add returns t + d. +// Add returns t + d as absolute time. func (t AbsTime) Add(d time.Duration) AbsTime { return t + AbsTime(d) } +// Sub returns t - t2 as a duration. +func (t AbsTime) Sub(t2 AbsTime) time.Duration { + return time.Duration(t - t2) +} + // The Clock interface makes it possible to replace the monotonic system clock with // a simulated clock. type Clock interface { Now() AbsTime Sleep(time.Duration) - After(time.Duration) <-chan time.Time + NewTimer(time.Duration) ChanTimer + After(time.Duration) <-chan AbsTime AfterFunc(d time.Duration, f func()) Timer } -// Timer represents a cancellable event returned by AfterFunc +// Timer is a cancellable event created by AfterFunc. type Timer interface { + // Stop cancels the timer. It returns false if the timer has already + // expired or been stopped. Stop() bool } +// ChanTimer is a cancellable event created by NewTimer. +type ChanTimer interface { + Timer + + // The channel returned by C receives a value when the timer expires. + C() <-chan AbsTime + // Reset reschedules the timer with a new timeout. + // It should be invoked only on stopped or expired timers with drained channels. + Reset(time.Duration) +} + // System implements Clock using the system clock. type System struct{} // Now returns the current monotonic time. -func (System) Now() AbsTime { +func (c System) Now() AbsTime { return AbsTime(monotime.Now()) } // Sleep blocks for the given duration. -func (System) Sleep(d time.Duration) { +func (c System) Sleep(d time.Duration) { time.Sleep(d) } +// NewTimer creates a timer which can be rescheduled. +func (c System) NewTimer(d time.Duration) ChanTimer { + ch := make(chan AbsTime, 1) + t := time.AfterFunc(d, func() { + // This send is non-blocking because that's how time.Timer + // behaves. It doesn't matter in the happy case, but does + // when Reset is misused. + select { + case ch <- c.Now(): + default: + } + }) + return &systemTimer{t, ch} +} + // After returns a channel which receives the current time after d has elapsed. -func (System) After(d time.Duration) <-chan time.Time { - return time.After(d) +func (c System) After(d time.Duration) <-chan AbsTime { + ch := make(chan AbsTime, 1) + time.AfterFunc(d, func() { ch <- c.Now() }) + return ch } // AfterFunc runs f on a new goroutine after the duration has elapsed. -func (System) AfterFunc(d time.Duration, f func()) Timer { +func (c System) AfterFunc(d time.Duration, f func()) Timer { return time.AfterFunc(d, f) } + +type systemTimer struct { + *time.Timer + ch <-chan AbsTime +} + +func (st *systemTimer) Reset(d time.Duration) { + st.Timer.Reset(d) +} + +func (st *systemTimer) C() <-chan AbsTime { + return st.ch +} diff --git a/common/mclock/simclock.go b/common/mclock/simclock.go index 4d351252f..766ca0f87 100644 --- a/common/mclock/simclock.go +++ b/common/mclock/simclock.go @@ -17,6 +17,7 @@ package mclock import ( + "container/heap" "sync" "time" ) @@ -32,18 +33,24 @@ import ( // the timeout using a channel or semaphore. type Simulated struct { now AbsTime - scheduled []*simTimer + scheduled simTimerHeap mu sync.RWMutex cond *sync.Cond - lastId uint64 } -// simTimer implements Timer on the virtual clock. +// simTimer implements ChanTimer on the virtual clock. type simTimer struct { - do func() - at AbsTime - id uint64 - s *Simulated + at AbsTime + index int // position in s.scheduled + s *Simulated + do func() + ch <-chan AbsTime +} + +func (s *Simulated) init() { + if s.cond == nil { + s.cond = sync.NewCond(&s.mu) + } } // Run moves the clock by the given duration, executing all timers before that duration. @@ -53,14 +60,9 @@ func (s *Simulated) Run(d time.Duration) { end := s.now + AbsTime(d) var do []func() - for len(s.scheduled) > 0 { - ev := s.scheduled[0] - if ev.at > end { - break - } - s.now = ev.at + for len(s.scheduled) > 0 && s.scheduled[0].at <= end { + ev := heap.Pop(&s.scheduled).(*simTimer) do = append(do, ev.do) - s.scheduled = s.scheduled[1:] } s.now = end s.mu.Unlock() @@ -102,14 +104,22 @@ func (s *Simulated) Sleep(d time.Duration) { <-s.After(d) } +// NewTimer creates a timer which fires when the clock has advanced by d. +func (s *Simulated) NewTimer(d time.Duration) ChanTimer { + s.mu.Lock() + defer s.mu.Unlock() + + ch := make(chan AbsTime, 1) + var timer *simTimer + timer = s.schedule(d, func() { ch <- timer.at }) + timer.ch = ch + return timer +} + // After returns a channel which receives the current time after the clock // has advanced by d. -func (s *Simulated) After(d time.Duration) <-chan time.Time { - after := make(chan time.Time, 1) - s.AfterFunc(d, func() { - after <- (time.Time{}).Add(time.Duration(s.now)) - }) - return after +func (s *Simulated) After(d time.Duration) <-chan AbsTime { + return s.NewTimer(d).C() } // AfterFunc runs fn after the clock has advanced by d. Unlike with the system @@ -117,46 +127,83 @@ func (s *Simulated) After(d time.Duration) <-chan time.Time { func (s *Simulated) AfterFunc(d time.Duration, fn func()) Timer { s.mu.Lock() defer s.mu.Unlock() + + return s.schedule(d, fn) +} + +func (s *Simulated) schedule(d time.Duration, fn func()) *simTimer { s.init() at := s.now + AbsTime(d) - s.lastId++ - id := s.lastId - l, h := 0, len(s.scheduled) - ll := h - for l != h { - m := (l + h) / 2 - if (at < s.scheduled[m].at) || ((at == s.scheduled[m].at) && (id < s.scheduled[m].id)) { - h = m - } else { - l = m + 1 - } - } ev := &simTimer{do: fn, at: at, s: s} - s.scheduled = append(s.scheduled, nil) - copy(s.scheduled[l+1:], s.scheduled[l:ll]) - s.scheduled[l] = ev + heap.Push(&s.scheduled, ev) s.cond.Broadcast() return ev } func (ev *simTimer) Stop() bool { - s := ev.s - s.mu.Lock() - defer s.mu.Unlock() + ev.s.mu.Lock() + defer ev.s.mu.Unlock() - for i := 0; i < len(s.scheduled); i++ { - if s.scheduled[i] == ev { - s.scheduled = append(s.scheduled[:i], s.scheduled[i+1:]...) - s.cond.Broadcast() - return true - } + if ev.index < 0 { + return false } - return false + heap.Remove(&ev.s.scheduled, ev.index) + ev.s.cond.Broadcast() + ev.index = -1 + return true } -func (s *Simulated) init() { - if s.cond == nil { - s.cond = sync.NewCond(&s.mu) +func (ev *simTimer) Reset(d time.Duration) { + if ev.ch == nil { + panic("mclock: Reset() on timer created by AfterFunc") } + + ev.s.mu.Lock() + defer ev.s.mu.Unlock() + ev.at = ev.s.now.Add(d) + if ev.index < 0 { + heap.Push(&ev.s.scheduled, ev) // already expired + } else { + heap.Fix(&ev.s.scheduled, ev.index) // hasn't fired yet, reschedule + } + ev.s.cond.Broadcast() +} + +func (ev *simTimer) C() <-chan AbsTime { + if ev.ch == nil { + panic("mclock: C() on timer created by AfterFunc") + } + return ev.ch +} + +type simTimerHeap []*simTimer + +func (h *simTimerHeap) Len() int { + return len(*h) +} + +func (h *simTimerHeap) Less(i, j int) bool { + return (*h)[i].at < (*h)[j].at +} + +func (h *simTimerHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] + (*h)[i].index = i + (*h)[j].index = j +} + +func (h *simTimerHeap) Push(x interface{}) { + t := x.(*simTimer) + t.index = len(*h) + *h = append(*h, t) +} + +func (h *simTimerHeap) Pop() interface{} { + end := len(*h) - 1 + t := (*h)[end] + t.index = -1 + (*h)[end] = nil + *h = (*h)[:end] + return t } diff --git a/common/mclock/simclock_test.go b/common/mclock/simclock_test.go index 09e4391c1..94aa4f2b3 100644 --- a/common/mclock/simclock_test.go +++ b/common/mclock/simclock_test.go @@ -25,14 +25,16 @@ var _ Clock = System{} var _ Clock = new(Simulated) func TestSimulatedAfter(t *testing.T) { - const timeout = 30 * time.Minute - const adv = time.Minute - var ( - c Simulated - end = c.Now().Add(timeout) - ch = c.After(timeout) + timeout = 30 * time.Minute + offset = 99 * time.Hour + adv = 11 * time.Minute + c Simulated ) + c.Run(offset) + + end := c.Now().Add(timeout) + ch := c.After(timeout) for c.Now() < end.Add(-adv) { c.Run(adv) select { @@ -45,8 +47,8 @@ func TestSimulatedAfter(t *testing.T) { c.Run(adv) select { case stamp := <-ch: - want := time.Time{}.Add(timeout) - if !stamp.Equal(want) { + want := AbsTime(0).Add(offset).Add(timeout) + if stamp != want { t.Errorf("Wrong time sent on timer channel: got %v, want %v", stamp, want) } default: @@ -113,3 +115,48 @@ func TestSimulatedSleep(t *testing.T) { t.Fatal("Sleep didn't return in time") } } + +func TestSimulatedTimerReset(t *testing.T) { + var ( + c Simulated + timeout = 1 * time.Hour + ) + timer := c.NewTimer(timeout) + c.Run(2 * timeout) + select { + case ftime := <-timer.C(): + if ftime != AbsTime(timeout) { + t.Fatalf("wrong time %v sent on timer channel, want %v", ftime, AbsTime(timeout)) + } + default: + t.Fatal("timer didn't fire") + } + + timer.Reset(timeout) + c.Run(2 * timeout) + select { + case ftime := <-timer.C(): + if ftime != AbsTime(3*timeout) { + t.Fatalf("wrong time %v sent on timer channel, want %v", ftime, AbsTime(3*timeout)) + } + default: + t.Fatal("timer didn't fire again") + } +} + +func TestSimulatedTimerStop(t *testing.T) { + var ( + c Simulated + timeout = 1 * time.Hour + ) + timer := c.NewTimer(timeout) + c.Run(2 * timeout) + if timer.Stop() { + t.Errorf("Stop returned true for fired timer") + } + select { + case <-timer.C(): + default: + t.Fatal("timer didn't fire") + } +}