2023-05-19 17:43:56 +00:00
|
|
|
package shardedmutex
|
|
|
|
|
|
|
|
import (
|
|
|
|
"hash/maphash"
|
|
|
|
"sync"
|
|
|
|
)
|
|
|
|
|
|
|
|
const cacheline = 64
|
|
|
|
|
|
|
|
// padding a mutex to a cacheline improves performance as the cachelines are not contested
|
|
|
|
// name old time/op new time/op delta
|
|
|
|
// Locks-8 74.6ns ± 7% 12.3ns ± 2% -83.54% (p=0.000 n=20+18)
|
|
|
|
type paddedMutex struct {
|
2023-09-21 15:37:02 +00:00
|
|
|
mt sync.Mutex
|
|
|
|
_ [cacheline - 8]uint8
|
2023-05-19 17:43:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
type ShardedMutex struct {
|
|
|
|
shards []paddedMutex
|
|
|
|
}
|
|
|
|
|
|
|
|
// New creates a new ShardedMutex with N shards
|
2023-05-19 18:02:47 +00:00
|
|
|
func New(nShards int) ShardedMutex {
|
|
|
|
if nShards < 1 {
|
2023-05-19 17:43:56 +00:00
|
|
|
panic("n_shards cannot be less than 1")
|
|
|
|
}
|
|
|
|
return ShardedMutex{
|
2023-05-19 18:02:47 +00:00
|
|
|
shards: make([]paddedMutex, nShards),
|
2023-05-19 17:43:56 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (sm ShardedMutex) Shards() int {
|
|
|
|
return len(sm.shards)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (sm ShardedMutex) Lock(shard int) {
|
|
|
|
sm.shards[shard].mt.Lock()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (sm ShardedMutex) Unlock(shard int) {
|
|
|
|
sm.shards[shard].mt.Unlock()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (sm ShardedMutex) GetLock(shard int) sync.Locker {
|
|
|
|
return &sm.shards[shard].mt
|
|
|
|
}
|
|
|
|
|
|
|
|
type ShardedMutexFor[K any] struct {
|
|
|
|
inner ShardedMutex
|
|
|
|
|
|
|
|
hasher func(maphash.Seed, K) uint64
|
|
|
|
seed maphash.Seed
|
|
|
|
}
|
|
|
|
|
2023-05-19 18:02:47 +00:00
|
|
|
func NewFor[K any](hasher func(maphash.Seed, K) uint64, nShards int) ShardedMutexFor[K] {
|
2023-05-19 17:43:56 +00:00
|
|
|
return ShardedMutexFor[K]{
|
2023-05-19 18:02:47 +00:00
|
|
|
inner: New(nShards),
|
2023-05-19 17:43:56 +00:00
|
|
|
hasher: hasher,
|
|
|
|
seed: maphash.MakeSeed(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (sm ShardedMutexFor[K]) shardFor(key K) int {
|
|
|
|
return int(sm.hasher(sm.seed, key) % uint64(len(sm.inner.shards)))
|
|
|
|
}
|
|
|
|
|
|
|
|
func (sm ShardedMutexFor[K]) Lock(key K) {
|
|
|
|
sm.inner.Lock(sm.shardFor(key))
|
|
|
|
}
|
|
|
|
func (sm ShardedMutexFor[K]) Unlock(key K) {
|
|
|
|
sm.inner.Unlock(sm.shardFor(key))
|
|
|
|
}
|
|
|
|
func (sm ShardedMutexFor[K]) GetLock(key K) sync.Locker {
|
|
|
|
return sm.inner.GetLock(sm.shardFor(key))
|
|
|
|
}
|