common/mclock: add NewTimer and Timer.Reset (#20634)

These methods can be helpful when migrating existing timer code.
This commit is contained in:
Felix Lange 2020-02-11 16:36:49 +01:00 committed by GitHub
parent dcffb7777f
commit c22fdec3c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 207 additions and 64 deletions

View File

@ -31,44 +31,93 @@ func Now() AbsTime {
return AbsTime(monotime.Now()) return AbsTime(monotime.Now())
} }
// Add returns t + d. // Add returns t + d as absolute time.
func (t AbsTime) Add(d time.Duration) AbsTime { func (t AbsTime) Add(d time.Duration) AbsTime {
return t + AbsTime(d) return t + AbsTime(d)
} }
// Sub returns t - t2 as a duration.
func (t AbsTime) Sub(t2 AbsTime) time.Duration {
return time.Duration(t - t2)
}
// The Clock interface makes it possible to replace the monotonic system clock with // The Clock interface makes it possible to replace the monotonic system clock with
// a simulated clock. // a simulated clock.
type Clock interface { type Clock interface {
Now() AbsTime Now() AbsTime
Sleep(time.Duration) Sleep(time.Duration)
After(time.Duration) <-chan time.Time NewTimer(time.Duration) ChanTimer
After(time.Duration) <-chan AbsTime
AfterFunc(d time.Duration, f func()) Timer AfterFunc(d time.Duration, f func()) Timer
} }
// Timer represents a cancellable event returned by AfterFunc // Timer is a cancellable event created by AfterFunc.
type Timer interface { type Timer interface {
// Stop cancels the timer. It returns false if the timer has already
// expired or been stopped.
Stop() bool Stop() bool
} }
// ChanTimer is a cancellable event created by NewTimer.
type ChanTimer interface {
Timer
// The channel returned by C receives a value when the timer expires.
C() <-chan AbsTime
// Reset reschedules the timer with a new timeout.
// It should be invoked only on stopped or expired timers with drained channels.
Reset(time.Duration)
}
// System implements Clock using the system clock. // System implements Clock using the system clock.
type System struct{} type System struct{}
// Now returns the current monotonic time. // Now returns the current monotonic time.
func (System) Now() AbsTime { func (c System) Now() AbsTime {
return AbsTime(monotime.Now()) return AbsTime(monotime.Now())
} }
// Sleep blocks for the given duration. // Sleep blocks for the given duration.
func (System) Sleep(d time.Duration) { func (c System) Sleep(d time.Duration) {
time.Sleep(d) time.Sleep(d)
} }
// NewTimer creates a timer which can be rescheduled.
func (c System) NewTimer(d time.Duration) ChanTimer {
ch := make(chan AbsTime, 1)
t := time.AfterFunc(d, func() {
// This send is non-blocking because that's how time.Timer
// behaves. It doesn't matter in the happy case, but does
// when Reset is misused.
select {
case ch <- c.Now():
default:
}
})
return &systemTimer{t, ch}
}
// After returns a channel which receives the current time after d has elapsed. // After returns a channel which receives the current time after d has elapsed.
func (System) After(d time.Duration) <-chan time.Time { func (c System) After(d time.Duration) <-chan AbsTime {
return time.After(d) ch := make(chan AbsTime, 1)
time.AfterFunc(d, func() { ch <- c.Now() })
return ch
} }
// AfterFunc runs f on a new goroutine after the duration has elapsed. // AfterFunc runs f on a new goroutine after the duration has elapsed.
func (System) AfterFunc(d time.Duration, f func()) Timer { func (c System) AfterFunc(d time.Duration, f func()) Timer {
return time.AfterFunc(d, f) return time.AfterFunc(d, f)
} }
type systemTimer struct {
*time.Timer
ch <-chan AbsTime
}
func (st *systemTimer) Reset(d time.Duration) {
st.Timer.Reset(d)
}
func (st *systemTimer) C() <-chan AbsTime {
return st.ch
}

View File

@ -17,6 +17,7 @@
package mclock package mclock
import ( import (
"container/heap"
"sync" "sync"
"time" "time"
) )
@ -32,18 +33,24 @@ import (
// the timeout using a channel or semaphore. // the timeout using a channel or semaphore.
type Simulated struct { type Simulated struct {
now AbsTime now AbsTime
scheduled []*simTimer scheduled simTimerHeap
mu sync.RWMutex mu sync.RWMutex
cond *sync.Cond cond *sync.Cond
lastId uint64
} }
// simTimer implements Timer on the virtual clock. // simTimer implements ChanTimer on the virtual clock.
type simTimer struct { type simTimer struct {
do func() at AbsTime
at AbsTime index int // position in s.scheduled
id uint64 s *Simulated
s *Simulated do func()
ch <-chan AbsTime
}
func (s *Simulated) init() {
if s.cond == nil {
s.cond = sync.NewCond(&s.mu)
}
} }
// Run moves the clock by the given duration, executing all timers before that duration. // Run moves the clock by the given duration, executing all timers before that duration.
@ -53,14 +60,9 @@ func (s *Simulated) Run(d time.Duration) {
end := s.now + AbsTime(d) end := s.now + AbsTime(d)
var do []func() var do []func()
for len(s.scheduled) > 0 { for len(s.scheduled) > 0 && s.scheduled[0].at <= end {
ev := s.scheduled[0] ev := heap.Pop(&s.scheduled).(*simTimer)
if ev.at > end {
break
}
s.now = ev.at
do = append(do, ev.do) do = append(do, ev.do)
s.scheduled = s.scheduled[1:]
} }
s.now = end s.now = end
s.mu.Unlock() s.mu.Unlock()
@ -102,14 +104,22 @@ func (s *Simulated) Sleep(d time.Duration) {
<-s.After(d) <-s.After(d)
} }
// NewTimer creates a timer which fires when the clock has advanced by d.
func (s *Simulated) NewTimer(d time.Duration) ChanTimer {
s.mu.Lock()
defer s.mu.Unlock()
ch := make(chan AbsTime, 1)
var timer *simTimer
timer = s.schedule(d, func() { ch <- timer.at })
timer.ch = ch
return timer
}
// After returns a channel which receives the current time after the clock // After returns a channel which receives the current time after the clock
// has advanced by d. // has advanced by d.
func (s *Simulated) After(d time.Duration) <-chan time.Time { func (s *Simulated) After(d time.Duration) <-chan AbsTime {
after := make(chan time.Time, 1) return s.NewTimer(d).C()
s.AfterFunc(d, func() {
after <- (time.Time{}).Add(time.Duration(s.now))
})
return after
} }
// AfterFunc runs fn after the clock has advanced by d. Unlike with the system // AfterFunc runs fn after the clock has advanced by d. Unlike with the system
@ -117,46 +127,83 @@ func (s *Simulated) After(d time.Duration) <-chan time.Time {
func (s *Simulated) AfterFunc(d time.Duration, fn func()) Timer { func (s *Simulated) AfterFunc(d time.Duration, fn func()) Timer {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.schedule(d, fn)
}
func (s *Simulated) schedule(d time.Duration, fn func()) *simTimer {
s.init() s.init()
at := s.now + AbsTime(d) at := s.now + AbsTime(d)
s.lastId++
id := s.lastId
l, h := 0, len(s.scheduled)
ll := h
for l != h {
m := (l + h) / 2
if (at < s.scheduled[m].at) || ((at == s.scheduled[m].at) && (id < s.scheduled[m].id)) {
h = m
} else {
l = m + 1
}
}
ev := &simTimer{do: fn, at: at, s: s} ev := &simTimer{do: fn, at: at, s: s}
s.scheduled = append(s.scheduled, nil) heap.Push(&s.scheduled, ev)
copy(s.scheduled[l+1:], s.scheduled[l:ll])
s.scheduled[l] = ev
s.cond.Broadcast() s.cond.Broadcast()
return ev return ev
} }
func (ev *simTimer) Stop() bool { func (ev *simTimer) Stop() bool {
s := ev.s ev.s.mu.Lock()
s.mu.Lock() defer ev.s.mu.Unlock()
defer s.mu.Unlock()
for i := 0; i < len(s.scheduled); i++ { if ev.index < 0 {
if s.scheduled[i] == ev { return false
s.scheduled = append(s.scheduled[:i], s.scheduled[i+1:]...)
s.cond.Broadcast()
return true
}
} }
return false heap.Remove(&ev.s.scheduled, ev.index)
ev.s.cond.Broadcast()
ev.index = -1
return true
} }
func (s *Simulated) init() { func (ev *simTimer) Reset(d time.Duration) {
if s.cond == nil { if ev.ch == nil {
s.cond = sync.NewCond(&s.mu) panic("mclock: Reset() on timer created by AfterFunc")
} }
ev.s.mu.Lock()
defer ev.s.mu.Unlock()
ev.at = ev.s.now.Add(d)
if ev.index < 0 {
heap.Push(&ev.s.scheduled, ev) // already expired
} else {
heap.Fix(&ev.s.scheduled, ev.index) // hasn't fired yet, reschedule
}
ev.s.cond.Broadcast()
}
func (ev *simTimer) C() <-chan AbsTime {
if ev.ch == nil {
panic("mclock: C() on timer created by AfterFunc")
}
return ev.ch
}
type simTimerHeap []*simTimer
func (h *simTimerHeap) Len() int {
return len(*h)
}
func (h *simTimerHeap) Less(i, j int) bool {
return (*h)[i].at < (*h)[j].at
}
func (h *simTimerHeap) Swap(i, j int) {
(*h)[i], (*h)[j] = (*h)[j], (*h)[i]
(*h)[i].index = i
(*h)[j].index = j
}
func (h *simTimerHeap) Push(x interface{}) {
t := x.(*simTimer)
t.index = len(*h)
*h = append(*h, t)
}
func (h *simTimerHeap) Pop() interface{} {
end := len(*h) - 1
t := (*h)[end]
t.index = -1
(*h)[end] = nil
*h = (*h)[:end]
return t
} }

View File

@ -25,14 +25,16 @@ var _ Clock = System{}
var _ Clock = new(Simulated) var _ Clock = new(Simulated)
func TestSimulatedAfter(t *testing.T) { func TestSimulatedAfter(t *testing.T) {
const timeout = 30 * time.Minute
const adv = time.Minute
var ( var (
c Simulated timeout = 30 * time.Minute
end = c.Now().Add(timeout) offset = 99 * time.Hour
ch = c.After(timeout) adv = 11 * time.Minute
c Simulated
) )
c.Run(offset)
end := c.Now().Add(timeout)
ch := c.After(timeout)
for c.Now() < end.Add(-adv) { for c.Now() < end.Add(-adv) {
c.Run(adv) c.Run(adv)
select { select {
@ -45,8 +47,8 @@ func TestSimulatedAfter(t *testing.T) {
c.Run(adv) c.Run(adv)
select { select {
case stamp := <-ch: case stamp := <-ch:
want := time.Time{}.Add(timeout) want := AbsTime(0).Add(offset).Add(timeout)
if !stamp.Equal(want) { if stamp != want {
t.Errorf("Wrong time sent on timer channel: got %v, want %v", stamp, want) t.Errorf("Wrong time sent on timer channel: got %v, want %v", stamp, want)
} }
default: default:
@ -113,3 +115,48 @@ func TestSimulatedSleep(t *testing.T) {
t.Fatal("Sleep didn't return in time") t.Fatal("Sleep didn't return in time")
} }
} }
func TestSimulatedTimerReset(t *testing.T) {
var (
c Simulated
timeout = 1 * time.Hour
)
timer := c.NewTimer(timeout)
c.Run(2 * timeout)
select {
case ftime := <-timer.C():
if ftime != AbsTime(timeout) {
t.Fatalf("wrong time %v sent on timer channel, want %v", ftime, AbsTime(timeout))
}
default:
t.Fatal("timer didn't fire")
}
timer.Reset(timeout)
c.Run(2 * timeout)
select {
case ftime := <-timer.C():
if ftime != AbsTime(3*timeout) {
t.Fatalf("wrong time %v sent on timer channel, want %v", ftime, AbsTime(3*timeout))
}
default:
t.Fatal("timer didn't fire again")
}
}
func TestSimulatedTimerStop(t *testing.T) {
var (
c Simulated
timeout = 1 * time.Hour
)
timer := c.NewTimer(timeout)
c.Run(2 * timeout)
if timer.Stop() {
t.Errorf("Stop returned true for fired timer")
}
select {
case <-timer.C():
default:
t.Fatal("timer didn't fire")
}
}