accounts/keystore: scan key directory without locks held (#15171)

The accountCache contains a file cache, and remembers from
scan to scan what files were present earlier. Thus, whenever
there's a change, the scan phase only bothers processing new
and removed files.
This commit is contained in:
Martin Holst Swende 2017-10-09 12:40:50 +02:00 committed by Felix Lange
parent 7a045af05b
commit 88b1db7288
4 changed files with 302 additions and 107 deletions

View File

@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"gopkg.in/fatih/set.v0"
) )
// Minimum amount of time between cache reloads. This limit applies if the platform does // Minimum amount of time between cache reloads. This limit applies if the platform does
@ -71,6 +72,14 @@ type accountCache struct {
byAddr map[common.Address][]accounts.Account byAddr map[common.Address][]accounts.Account
throttle *time.Timer throttle *time.Timer
notify chan struct{} notify chan struct{}
fileC fileCache
}
// fileCache is a cache of files seen during scan of keystore
type fileCache struct {
all *set.SetNonTS // list of all files
mtime time.Time // latest mtime seen
mu sync.RWMutex
} }
func newAccountCache(keydir string) (*accountCache, chan struct{}) { func newAccountCache(keydir string) (*accountCache, chan struct{}) {
@ -78,6 +87,7 @@ func newAccountCache(keydir string) (*accountCache, chan struct{}) {
keydir: keydir, keydir: keydir,
byAddr: make(map[common.Address][]accounts.Account), byAddr: make(map[common.Address][]accounts.Account),
notify: make(chan struct{}, 1), notify: make(chan struct{}, 1),
fileC: fileCache{all: set.NewNonTS()},
} }
ac.watcher = newWatcher(ac) ac.watcher = newWatcher(ac)
return ac, ac.notify return ac, ac.notify
@ -127,6 +137,23 @@ func (ac *accountCache) delete(removed accounts.Account) {
} }
} }
// deleteByFile removes an account referenced by the given path.
func (ac *accountCache) deleteByFile(path string) {
ac.mu.Lock()
defer ac.mu.Unlock()
i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Path >= path })
if i < len(ac.all) && ac.all[i].URL.Path == path {
removed := ac.all[i]
ac.all = append(ac.all[:i], ac.all[i+1:]...)
if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
delete(ac.byAddr, removed.Address)
} else {
ac.byAddr[removed.Address] = ba
}
}
}
func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.Account { func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.Account {
for i := range slice { for i := range slice {
if slice[i] == elem { if slice[i] == elem {
@ -167,15 +194,16 @@ func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) {
default: default:
err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))} err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))}
copy(err.Matches, matches) copy(err.Matches, matches)
sort.Sort(accountsByURL(err.Matches))
return accounts.Account{}, err return accounts.Account{}, err
} }
} }
func (ac *accountCache) maybeReload() { func (ac *accountCache) maybeReload() {
ac.mu.Lock() ac.mu.Lock()
defer ac.mu.Unlock()
if ac.watcher.running { if ac.watcher.running {
ac.mu.Unlock()
return // A watcher is running and will keep the cache up-to-date. return // A watcher is running and will keep the cache up-to-date.
} }
if ac.throttle == nil { if ac.throttle == nil {
@ -184,12 +212,15 @@ func (ac *accountCache) maybeReload() {
select { select {
case <-ac.throttle.C: case <-ac.throttle.C:
default: default:
ac.mu.Unlock()
return // The cache was reloaded recently. return // The cache was reloaded recently.
} }
} }
// No watcher running, start it.
ac.watcher.start() ac.watcher.start()
ac.reload()
ac.throttle.Reset(minReloadInterval) ac.throttle.Reset(minReloadInterval)
ac.mu.Unlock()
ac.scanAccounts()
} }
func (ac *accountCache) close() { func (ac *accountCache) close() {
@ -205,54 +236,76 @@ func (ac *accountCache) close() {
ac.mu.Unlock() ac.mu.Unlock()
} }
// reload caches addresses of existing accounts. // scanFiles performs a new scan on the given directory, compares against the already
// Callers must hold ac.mu. // cached filenames, and returns file sets: new, missing , modified
func (ac *accountCache) reload() { func (fc *fileCache) scanFiles(keyDir string) (set.Interface, set.Interface, set.Interface, error) {
accounts, err := ac.scan() t0 := time.Now()
files, err := ioutil.ReadDir(keyDir)
t1 := time.Now()
if err != nil { if err != nil {
log.Debug("Failed to reload keystore contents", "err", err) return nil, nil, nil, err
} }
ac.all = accounts fc.mu.RLock()
sort.Sort(ac.all) prevMtime := fc.mtime
for k := range ac.byAddr { fc.mu.RUnlock()
delete(ac.byAddr, k)
}
for _, a := range accounts {
ac.byAddr[a.Address] = append(ac.byAddr[a.Address], a)
}
select {
case ac.notify <- struct{}{}:
default:
}
log.Debug("Reloaded keystore contents", "accounts", len(ac.all))
}
func (ac *accountCache) scan() ([]accounts.Account, error) { filesNow := set.NewNonTS()
files, err := ioutil.ReadDir(ac.keydir) moddedFiles := set.NewNonTS()
if err != nil { var newMtime time.Time
return nil, err
}
var (
buf = new(bufio.Reader)
addrs []accounts.Account
keyJSON struct {
Address string `json:"address"`
}
)
for _, fi := range files { for _, fi := range files {
path := filepath.Join(ac.keydir, fi.Name()) modTime := fi.ModTime()
path := filepath.Join(keyDir, fi.Name())
if skipKeyFile(fi) { if skipKeyFile(fi) {
log.Trace("Ignoring file on account scan", "path", path) log.Trace("Ignoring file on account scan", "path", path)
continue continue
} }
logger := log.New("path", path) filesNow.Add(path)
if modTime.After(prevMtime) {
moddedFiles.Add(path)
}
if modTime.After(newMtime) {
newMtime = modTime
}
}
t2 := time.Now()
fc.mu.Lock()
// Missing = previous - current
missing := set.Difference(fc.all, filesNow)
// New = current - previous
newFiles := set.Difference(filesNow, fc.all)
// Modified = modified - new
modified := set.Difference(moddedFiles, newFiles)
fc.all = filesNow
fc.mtime = newMtime
fc.mu.Unlock()
t3 := time.Now()
log.Debug("FS scan times", "list", t1.Sub(t0), "set", t2.Sub(t1), "diff", t3.Sub(t2))
return newFiles, missing, modified, nil
}
// scanAccounts checks if any changes have occurred on the filesystem, and
// updates the account cache accordingly
func (ac *accountCache) scanAccounts() error {
newFiles, missingFiles, modified, err := ac.fileC.scanFiles(ac.keydir)
t1 := time.Now()
if err != nil {
log.Debug("Failed to reload keystore contents", "err", err)
return err
}
var (
buf = new(bufio.Reader)
keyJSON struct {
Address string `json:"address"`
}
)
readAccount := func(path string) *accounts.Account {
fd, err := os.Open(path) fd, err := os.Open(path)
if err != nil { if err != nil {
logger.Trace("Failed to open keystore file", "err", err) log.Trace("Failed to open keystore file", "path", path, "err", err)
continue return nil
} }
defer fd.Close()
buf.Reset(fd) buf.Reset(fd)
// Parse the address. // Parse the address.
keyJSON.Address = "" keyJSON.Address = ""
@ -260,15 +313,45 @@ func (ac *accountCache) scan() ([]accounts.Account, error) {
addr := common.HexToAddress(keyJSON.Address) addr := common.HexToAddress(keyJSON.Address)
switch { switch {
case err != nil: case err != nil:
logger.Debug("Failed to decode keystore key", "err", err) log.Debug("Failed to decode keystore key", "path", path, "err", err)
case (addr == common.Address{}): case (addr == common.Address{}):
logger.Debug("Failed to decode keystore key", "err", "missing or zero address") log.Debug("Failed to decode keystore key", "path", path, "err", "missing or zero address")
default: default:
addrs = append(addrs, accounts.Account{Address: addr, URL: accounts.URL{Scheme: KeyStoreScheme, Path: path}}) return &accounts.Account{Address: addr, URL: accounts.URL{Scheme: KeyStoreScheme, Path: path}}
} }
fd.Close() return nil
} }
return addrs, err
for _, p := range newFiles.List() {
path, _ := p.(string)
a := readAccount(path)
if a != nil {
ac.add(*a)
}
}
for _, p := range missingFiles.List() {
path, _ := p.(string)
ac.deleteByFile(path)
}
for _, p := range modified.List() {
path, _ := p.(string)
a := readAccount(path)
ac.deleteByFile(path)
if a != nil {
ac.add(*a)
}
}
t2 := time.Now()
select {
case ac.notify <- struct{}{}:
default:
}
log.Trace("Handled keystore changes", "time", t2.Sub(t1))
return nil
} }
func skipKeyFile(fi os.FileInfo) bool { func skipKeyFile(fi os.FileInfo) bool {

View File

@ -18,6 +18,7 @@ package keystore
import ( import (
"fmt" "fmt"
"io/ioutil"
"math/rand" "math/rand"
"os" "os"
"path/filepath" "path/filepath"
@ -295,3 +296,101 @@ func TestCacheFind(t *testing.T) {
} }
} }
} }
func waitForAccounts(wantAccounts []accounts.Account, ks *KeyStore) error {
var list []accounts.Account
for d := 200 * time.Millisecond; d < 8*time.Second; d *= 2 {
list = ks.Accounts()
if reflect.DeepEqual(list, wantAccounts) {
// ks should have also received change notifications
select {
case <-ks.changes:
default:
return fmt.Errorf("wasn't notified of new accounts")
}
return nil
}
time.Sleep(d)
}
return fmt.Errorf("\ngot %v\nwant %v", list, wantAccounts)
}
// TestUpdatedKeyfileContents tests that updating the contents of a keystore file
// is noticed by the watcher, and the account cache is updated accordingly
func TestUpdatedKeyfileContents(t *testing.T) {
t.Parallel()
// Create a temporary kesytore to test with
rand.Seed(time.Now().UnixNano())
dir := filepath.Join(os.TempDir(), fmt.Sprintf("eth-keystore-watch-test-%d-%d", os.Getpid(), rand.Int()))
ks := NewKeyStore(dir, LightScryptN, LightScryptP)
list := ks.Accounts()
if len(list) > 0 {
t.Error("initial account list not empty:", list)
}
time.Sleep(100 * time.Millisecond)
// Create the directory and copy a key file into it.
os.MkdirAll(dir, 0700)
defer os.RemoveAll(dir)
file := filepath.Join(dir, "aaa")
// Place one of our testfiles in there
if err := cp.CopyFile(file, cachetestAccounts[0].URL.Path); err != nil {
t.Fatal(err)
}
// ks should see the account.
wantAccounts := []accounts.Account{cachetestAccounts[0]}
wantAccounts[0].URL = accounts.URL{Scheme: KeyStoreScheme, Path: file}
if err := waitForAccounts(wantAccounts, ks); err != nil {
t.Error(err)
return
}
// Now replace file contents
if err := forceCopyFile(file, cachetestAccounts[1].URL.Path); err != nil {
t.Fatal(err)
return
}
wantAccounts = []accounts.Account{cachetestAccounts[1]}
wantAccounts[0].URL = accounts.URL{Scheme: KeyStoreScheme, Path: file}
if err := waitForAccounts(wantAccounts, ks); err != nil {
t.Errorf("First replacement failed")
t.Error(err)
return
}
// Now replace file contents again
if err := forceCopyFile(file, cachetestAccounts[2].URL.Path); err != nil {
t.Fatal(err)
return
}
wantAccounts = []accounts.Account{cachetestAccounts[2]}
wantAccounts[0].URL = accounts.URL{Scheme: KeyStoreScheme, Path: file}
if err := waitForAccounts(wantAccounts, ks); err != nil {
t.Errorf("Second replacement failed")
t.Error(err)
return
}
// Now replace file contents with crap
if err := ioutil.WriteFile(file, []byte("foo"), 0644); err != nil {
t.Fatal(err)
return
}
if err := waitForAccounts([]accounts.Account{}, ks); err != nil {
t.Errorf("Emptying account file failed")
t.Error(err)
return
}
}
// forceCopyFile is like cp.CopyFile, but doesn't complain if the destination exists.
func forceCopyFile(dst, src string) error {
data, err := ioutil.ReadFile(src)
if err != nil {
return err
}
return ioutil.WriteFile(dst, data, 0644)
}

View File

@ -272,82 +272,104 @@ func TestWalletNotifierLifecycle(t *testing.T) {
t.Errorf("wallet notifier didn't terminate after unsubscribe") t.Errorf("wallet notifier didn't terminate after unsubscribe")
} }
type walletEvent struct {
accounts.WalletEvent
a accounts.Account
}
// Tests that wallet notifications and correctly fired when accounts are added // Tests that wallet notifications and correctly fired when accounts are added
// or deleted from the keystore. // or deleted from the keystore.
func TestWalletNotifications(t *testing.T) { func TestWalletNotifications(t *testing.T) {
// Create a temporary kesytore to test with
dir, ks := tmpKeyStore(t, false) dir, ks := tmpKeyStore(t, false)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
// Subscribe to the wallet feed // Subscribe to the wallet feed and collect events.
updates := make(chan accounts.WalletEvent, 1) var (
sub := ks.Subscribe(updates) events []walletEvent
updates = make(chan accounts.WalletEvent)
sub = ks.Subscribe(updates)
)
defer sub.Unsubscribe() defer sub.Unsubscribe()
go func() {
for {
select {
case ev := <-updates:
events = append(events, walletEvent{ev, ev.Wallet.Accounts()[0]})
case <-sub.Err():
close(updates)
return
}
}
}()
// Randomly add and remove account and make sure events and wallets are in sync // Randomly add and remove accounts.
live := make(map[common.Address]accounts.Account) var (
live = make(map[common.Address]accounts.Account)
wantEvents []walletEvent
)
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
// Execute a creation or deletion and ensure event arrival
if create := len(live) == 0 || rand.Int()%4 > 0; create { if create := len(live) == 0 || rand.Int()%4 > 0; create {
// Add a new account and ensure wallet notifications arrives // Add a new account and ensure wallet notifications arrives
account, err := ks.NewAccount("") account, err := ks.NewAccount("")
if err != nil { if err != nil {
t.Fatalf("failed to create test account: %v", err) t.Fatalf("failed to create test account: %v", err)
} }
select {
case event := <-updates:
if event.Kind != accounts.WalletArrived {
t.Errorf("non-arrival event on account creation")
}
if event.Wallet.Accounts()[0] != account {
t.Errorf("account mismatch on created wallet: have %v, want %v", event.Wallet.Accounts()[0], account)
}
default:
t.Errorf("wallet arrival event not fired on account creation")
}
live[account.Address] = account live[account.Address] = account
wantEvents = append(wantEvents, walletEvent{accounts.WalletEvent{Kind: accounts.WalletArrived}, account})
} else { } else {
// Select a random account to delete (crude, but works) // Delete a random account.
var account accounts.Account var account accounts.Account
for _, a := range live { for _, a := range live {
account = a account = a
break break
} }
// Remove an account and ensure wallet notifiaction arrives
if err := ks.Delete(account, ""); err != nil { if err := ks.Delete(account, ""); err != nil {
t.Fatalf("failed to delete test account: %v", err) t.Fatalf("failed to delete test account: %v", err)
} }
select {
case event := <-updates:
if event.Kind != accounts.WalletDropped {
t.Errorf("non-drop event on account deletion")
}
if event.Wallet.Accounts()[0] != account {
t.Errorf("account mismatch on deleted wallet: have %v, want %v", event.Wallet.Accounts()[0], account)
}
default:
t.Errorf("wallet departure event not fired on account creation")
}
delete(live, account.Address) delete(live, account.Address)
wantEvents = append(wantEvents, walletEvent{accounts.WalletEvent{Kind: accounts.WalletDropped}, account})
} }
// Retrieve the list of wallets and ensure it matches with our required live set }
liveList := make([]accounts.Account, 0, len(live))
for _, account := range live {
liveList = append(liveList, account)
}
sort.Sort(accountsByURL(liveList))
wallets := ks.Wallets() // Shut down the event collector and check events.
if len(liveList) != len(wallets) { sub.Unsubscribe()
t.Errorf("wallet list doesn't match required accounts: have %v, want %v", wallets, liveList) <-updates
} else { checkAccounts(t, live, ks.Wallets())
for j, wallet := range wallets { checkEvents(t, wantEvents, events)
if accs := wallet.Accounts(); len(accs) != 1 { }
t.Errorf("wallet %d: contains invalid number of accounts: have %d, want 1", j, len(accs))
} else if accs[0] != liveList[j] { // checkAccounts checks that all known live accounts are present in the wallet list.
t.Errorf("wallet %d: account mismatch: have %v, want %v", j, accs[0], liveList[j]) func checkAccounts(t *testing.T, live map[common.Address]accounts.Account, wallets []accounts.Wallet) {
} if len(live) != len(wallets) {
t.Errorf("wallet list doesn't match required accounts: have %d, want %d", len(wallets), len(live))
return
}
liveList := make([]accounts.Account, 0, len(live))
for _, account := range live {
liveList = append(liveList, account)
}
sort.Sort(accountsByURL(liveList))
for j, wallet := range wallets {
if accs := wallet.Accounts(); len(accs) != 1 {
t.Errorf("wallet %d: contains invalid number of accounts: have %d, want 1", j, len(accs))
} else if accs[0] != liveList[j] {
t.Errorf("wallet %d: account mismatch: have %v, want %v", j, accs[0], liveList[j])
}
}
}
// checkEvents checks that all events in 'want' are present in 'have'. Events may be present multiple times.
func checkEvents(t *testing.T, want []walletEvent, have []walletEvent) {
for _, wantEv := range want {
nmatch := 0
for ; len(have) > 0; nmatch++ {
if have[0].Kind != wantEv.Kind || have[0].a != wantEv.a {
break
} }
have = have[1:]
}
if nmatch == 0 {
t.Fatalf("can't find event with Kind=%v for %x", wantEv.Kind, wantEv.a.Address)
} }
} }
} }

View File

@ -70,7 +70,6 @@ func (w *watcher) loop() {
return return
} }
defer notify.Stop(w.ev) defer notify.Stop(w.ev)
logger.Trace("Started watching keystore folder") logger.Trace("Started watching keystore folder")
defer logger.Trace("Stopped watching keystore folder") defer logger.Trace("Stopped watching keystore folder")
@ -82,9 +81,9 @@ func (w *watcher) loop() {
// When an event occurs, the reload call is delayed a bit so that // When an event occurs, the reload call is delayed a bit so that
// multiple events arriving quickly only cause a single reload. // multiple events arriving quickly only cause a single reload.
var ( var (
debounce = time.NewTimer(0) debounce = time.NewTimer(0)
debounceDuration = 500 * time.Millisecond debounceDuration = 500 * time.Millisecond
inCycle, hadEvent bool rescanTriggered = false
) )
defer debounce.Stop() defer debounce.Stop()
for { for {
@ -92,22 +91,14 @@ func (w *watcher) loop() {
case <-w.quit: case <-w.quit:
return return
case <-w.ev: case <-w.ev:
if !inCycle { // Trigger the scan (with delay), if not already triggered
if !rescanTriggered {
debounce.Reset(debounceDuration) debounce.Reset(debounceDuration)
inCycle = true rescanTriggered = true
} else {
hadEvent = true
} }
case <-debounce.C: case <-debounce.C:
w.ac.mu.Lock() w.ac.scanAccounts()
w.ac.reload() rescanTriggered = false
w.ac.mu.Unlock()
if hadEvent {
debounce.Reset(debounceDuration)
inCycle, hadEvent = true, false
} else {
inCycle, hadEvent = false, false
}
} }
} }
} }