package shardedmutex import ( "fmt" "hash/maphash" "runtime" "sync" "sync/atomic" "testing" "time" ) func TestLockingDifferentShardsDoesNotBlock(t *testing.T) { shards := 16 sm := New(shards) done := make(chan struct{}) go func() { select { case <-done: return case <-time.After(5 * time.Second): panic("test locked up") } }() for i := 0; i < shards; i++ { sm.Lock(i) } close(done) } func TestLockingSameShardsBlocks(t *testing.T) { shards := 16 sm := New(shards) wg := sync.WaitGroup{} wg.Add(shards) ch := make(chan int, shards) for i := 0; i < shards; i++ { go func(i int) { if i != 15 { sm.Lock(i) } wg.Done() wg.Wait() sm.Lock((15 + i) % shards) ch <- i sm.Unlock(i) }(i) } wg.Wait() for i := 0; i < 2*shards; i++ { runtime.Gosched() } for i := 0; i < shards; i++ { if a := <-ch; a != i { t.Errorf("got %d instead of %d", a, i) } } } func TestShardedByString(t *testing.T) { shards := 16 sm := NewFor(maphash.String, shards) wg1 := sync.WaitGroup{} wg1.Add(shards * 20) wg2 := sync.WaitGroup{} wg2.Add(shards * 20) active := atomic.Int32{} max := atomic.Int32{} for i := 0; i < shards*20; i++ { go func(i int) { wg1.Done() wg1.Wait() sm.Lock(fmt.Sprintf("goroutine %d", i)) activeNew := active.Add(1) for { curMax := max.Load() if curMax >= activeNew { break } if max.CompareAndSwap(curMax, activeNew) { break } } for j := 0; j < 100; j++ { runtime.Gosched() } active.Add(-1) sm.Unlock(fmt.Sprintf("goroutine %d", i)) wg2.Done() }(i) } wg2.Wait() if max.Load() != 16 { t.Fatal("max load not achieved", max.Load()) } } func BenchmarkShardedMutex(b *testing.B) { shards := 16 sm := New(shards) done := atomic.Int32{} go func() { for { sm.Lock(0) sm.Unlock(0) if done.Load() != 0 { return } } }() for i := 0; i < 100; i++ { runtime.Gosched() } b.ResetTimer() for i := 0; i < b.N; i++ { sm.Lock(1) sm.Unlock(1) } done.Add(1) } func BenchmarkShardedMutexOf(b *testing.B) { shards := 16 sm := NewFor(maphash.String, shards) str1 := "string1" str2 := "string2" done := atomic.Int32{} go func() { for { sm.Lock(str1) sm.Unlock(str1) if done.Load() != 0 { return } } }() for i := 0; i < 100; i++ { runtime.Gosched() } b.ResetTimer() for i := 0; i < b.N; i++ { sm.Lock(str2) sm.Unlock(str2) } done.Add(1) }