accounts/scwallet: ordered wallets, tighter events, derivation logs

This commit is contained in:
Péter Szilágyi 2018-04-19 13:39:33 +03:00 committed by Guillaume Ballet
parent 114de0fe2a
commit 386943943f
2 changed files with 82 additions and 64 deletions

View File

@ -36,14 +36,12 @@ import (
"encoding/json"
"io/ioutil"
"os"
"reflect"
"sync"
"time"
"github.com/ebfe/scard"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
)
@ -74,6 +72,7 @@ type Hub struct {
context *scard.Context
datadir string
pairings map[string]smartcardPairing
refreshed time.Time // Time instance when the list of wallets was last refreshed
wallets map[string]*Wallet // Mapping from reader names to wallet instances
updateFeed event.Feed // Event feed to notify wallet additions/removals
@ -82,11 +81,9 @@ type Hub struct {
quit chan chan error
stateLock sync.Mutex // Protects the internals of the hub from racey access
stateLock sync.RWMutex // Protects the internals of the hub from racey access
}
var HubType = reflect.TypeOf(&Hub{})
func (hub *Hub) readPairings() error {
hub.pairings = make(map[string]smartcardPairing)
pairingFile, err := os.Open(hub.datadir + "/smartcards.json")
@ -158,7 +155,6 @@ func NewHub(scheme string, datadir string) (*Hub, error) {
if err != nil {
return nil, err
}
hub := &Hub{
scheme: scheme,
context: context,
@ -166,29 +162,32 @@ func NewHub(scheme string, datadir string) (*Hub, error) {
wallets: make(map[string]*Wallet),
quit: make(chan chan error),
}
if err := hub.readPairings(); err != nil {
return nil, err
}
hub.refreshWallets()
return hub, nil
}
// Wallets implements accounts.Backend, returning all the currently tracked
// devices that appear to be hardware wallets.
// Wallets implements accounts.Backend, returning all the currently tracked smart
// cards that appear to be hardware wallets.
func (hub *Hub) Wallets() []accounts.Wallet {
// Make sure the list of wallets is up to date
hub.stateLock.Lock()
defer hub.stateLock.Unlock()
hub.refreshWallets()
hub.stateLock.RLock()
defer hub.stateLock.RUnlock()
cpy := make([]accounts.Wallet, 0, len(hub.wallets))
for _, wallet := range hub.wallets {
if wallet != nil {
cpy = append(cpy, wallet)
}
for i := 0; i < len(cpy); i++ {
for j := i + 1; j < len(cpy); j++ {
if cpy[i].URL().Cmp(cpy[j].URL()) > 0 {
cpy[i], cpy[j] = cpy[j], cpy[i]
}
}
}
return cpy
}
@ -196,67 +195,77 @@ func (hub *Hub) Wallets() []accounts.Wallet {
// refreshWallets scans the devices attached to the machine and updates the
// list of wallets based on the found devices.
func (hub *Hub) refreshWallets() {
// Don't scan the USB like crazy it the user fetches wallets in a loop
hub.stateLock.RLock()
elapsed := time.Since(hub.refreshed)
hub.stateLock.RUnlock()
if elapsed < refreshThrottling {
return
}
// Retrieve all the smart card reader to check for cards
readers, err := hub.context.ListReaders()
if err != nil {
log.Error("Error listing readers", "err", err)
// This is a perverted hack, the scard library returns an error if no card
// readers are present instead of simply returning an empty list. We don't
// want to fill the user's log with errors, so filter those out.
if err.Error() != "scard: Cannot find a smart card reader." {
log.Error("Failed to enumerate smart card readers", "err", err)
}
}
// Transform the current list of wallets into the new one
hub.stateLock.Lock()
events := []accounts.WalletEvent{}
seen := make(map[string]struct{})
for _, reader := range readers {
if wallet, ok := hub.wallets[reader]; ok {
// We already know about this card; check it's still present
if err := wallet.ping(); err != nil {
log.Debug("Got error pinging wallet", "reader", reader, "err", err)
} else {
seen[reader] = struct{}{}
}
continue
}
// Mark the reader as present
seen[reader] = struct{}{}
// If we alreay know about this card, skip to the next reader, otherwise clean up
if wallet, ok := hub.wallets[reader]; ok {
if err := wallet.ping(); err == nil {
continue
}
wallet.Close()
events = append(events, accounts.WalletEvent{Wallet: wallet, Kind: accounts.WalletDropped})
delete(hub.wallets, reader)
}
// New card detected, try to connect to it
card, err := hub.context.Connect(reader, scard.ShareShared, scard.ProtocolAny)
if err != nil {
log.Debug("Error opening card", "reader", reader, "err", err)
log.Debug("Failed to open smart card", "reader", reader, "err", err)
continue
}
wallet := NewWallet(hub, card)
err = wallet.connect()
if err != nil {
log.Debug("Error connecting to wallet", "reader", reader, "err", err)
if err = wallet.connect(); err != nil {
log.Debug("Failed to connect to smart card", "reader", reader, "err", err)
card.Disconnect(scard.LeaveCard)
continue
}
// Card connected, start tracking in amongs the wallets
hub.wallets[reader] = wallet
events = append(events, accounts.WalletEvent{Wallet: wallet, Kind: accounts.WalletArrived})
log.Info("Found new smartcard wallet", "reader", reader, "publicKey", hexutil.Encode(wallet.PublicKey[:4]))
}
// Remove any wallets we no longer see
for k, wallet := range hub.wallets {
if _, ok := seen[k]; !ok {
log.Info("Wallet disconnected", "pubkey", hexutil.Encode(wallet.PublicKey[:4]), "reader", k)
// Remove any wallets no longer present
for reader, wallet := range hub.wallets {
if _, ok := seen[reader]; !ok {
wallet.Close()
events = append(events, accounts.WalletEvent{Wallet: wallet, Kind: accounts.WalletDropped})
delete(hub.wallets, k)
delete(hub.wallets, reader)
}
}
hub.refreshed = time.Now()
hub.stateLock.Unlock()
for _, event := range events {
hub.updateFeed.Send(event)
}
hub.refreshed = time.Now()
}
// Subscribe implements accounts.Backend, creating an async subscription to
// receive notifications on the addition or removal of wallets.
// receive notifications on the addition or removal of smart card wallets.
func (hub *Hub) Subscribe(sink chan<- accounts.WalletEvent) event.Subscription {
// We need the mutex to reliably start/stop the update loop
hub.stateLock.Lock()
@ -274,16 +283,18 @@ func (hub *Hub) Subscribe(sink chan<- accounts.WalletEvent) event.Subscription {
}
// updater is responsible for maintaining an up-to-date list of wallets managed
// by the hub, and for firing wallet addition/removal events.
// by the smart card hub, and for firing wallet addition/removal events.
func (hub *Hub) updater() {
for {
// TODO: Wait for a USB hotplug event (not supported yet) or a refresh timeout
// <-hub.changes
time.Sleep(refreshCycle)
// Run the wallet refresher
hub.stateLock.Lock()
hub.refreshWallets()
// If all our subscribers left, stop the updater
hub.stateLock.Lock()
if hub.updateScope.Count() == 0 {
hub.updating = false
hub.stateLock.Unlock()

View File

@ -97,6 +97,7 @@ type Wallet struct {
card *scard.Card // A handle to the smartcard interface for the wallet.
session *Session // The secure communication session with the card
log log.Logger // Contextual logger to tag the base with its id
deriveNextPath accounts.DerivationPath // Next derivation path for account auto-discovery
deriveNextAddr common.Address // Next derived account address for auto-discovery
deriveChain ethereum.ChainStateReader // Blockchain state reader to discover used account with
@ -273,7 +274,7 @@ func (w *Wallet) Unpair(pin []byte) error {
func (w *Wallet) URL() accounts.URL {
return accounts.URL{
Scheme: w.Hub.scheme,
Path: fmt.Sprintf("%x", w.PublicKey[1:3]),
Path: fmt.Sprintf("%x", w.PublicKey[1:5]), // Byte #0 isn't unique; 1:5 covers << 64K cards, bump to 1:9 for << 4M
}
}
@ -327,7 +328,7 @@ func (w *Wallet) Open(passphrase string) error {
if err := w.session.authenticate(*pairing); err != nil {
return fmt.Errorf("failed to authenticate card %x: %s", w.PublicKey[:4], err)
}
return nil
return ErrPINNeeded
}
// If no passphrase was supplied, request the PUK from the user
if passphrase == "" {
@ -389,8 +390,8 @@ func (w *Wallet) Close() error {
// selfDerive is an account derivation loop that upon request attempts to find
// new non-zero accounts.
func (w *Wallet) selfDerive() {
w.log.Debug("Smartcard wallet self-derivation started")
defer w.log.Debug("Smartcard wallet self-derivation stopped")
w.log.Debug("Smart card wallet self-derivation started")
defer w.log.Debug("Smart card wallet self-derivation stopped")
// Execute self-derivations until termination or error
var (
@ -419,7 +420,8 @@ func (w *Wallet) selfDerive() {
// Device lock obtained, derive the next batch of accounts
var (
paths []accounts.DerivationPath
nextAccount accounts.Account
nextAcc accounts.Account
nextAddr = w.deriveNextAddr
nextPath = w.deriveNextPath
@ -428,11 +430,11 @@ func (w *Wallet) selfDerive() {
for empty := false; !empty; {
// Retrieve the next derived Ethereum account
if nextAddr == (common.Address{}) {
if nextAccount, err = w.session.derive(nextPath); err != nil {
if nextAcc, err = w.session.derive(nextPath); err != nil {
w.log.Warn("Smartcard wallet account derivation failed", "err", err)
break
}
nextAddr = nextAccount.Address
nextAddr = nextAcc.Address
}
// Check the account's status against the current chain state
var (
@ -459,8 +461,8 @@ func (w *Wallet) selfDerive() {
paths = append(paths, path)
// Display a log message to the user for new (or previously empty accounts)
if _, known := pairing.Accounts[nextAddr]; !known || (!empty && nextAddr == w.deriveNextAddr) {
w.log.Info("Smartcard wallet discovered new account", "address", nextAccount.Address, "path", path, "balance", balance, "nonce", nonce)
if _, known := pairing.Accounts[nextAddr]; !known || !empty || nextAddr != w.deriveNextAddr {
w.log.Info("Smartcard wallet discovered new account", "address", nextAddr, "path", path, "balance", balance, "nonce", nonce)
}
pairing.Accounts[nextAddr] = path
@ -470,12 +472,10 @@ func (w *Wallet) selfDerive() {
nextPath[len(nextPath)-1]++
}
}
// If there are new accounts, write them out
if len(paths) > 0 {
err = w.Hub.setPairing(w, pairing)
}
// Shift the self-derivation forward
w.deriveNextAddr = nextAddr
w.deriveNextPath = nextPath
@ -524,6 +524,13 @@ func (w *Wallet) Accounts() []accounts.Account {
for address, path := range pairing.Accounts {
ret = append(ret, w.makeAccount(address, path))
}
for i := 0; i < len(ret); i++ {
for j := i + 1; j < len(ret); j++ {
if ret[i].URL.Cmp(ret[j].URL) > 0 {
ret[i], ret[j] = ret[j], ret[i]
}
}
}
return ret
}
return nil