simplify account unlocking

This commit is contained in:
zelig 2015-06-18 15:12:39 +01:00
parent 9f6016e877
commit 1d72aaa0cd
2 changed files with 85 additions and 56 deletions

View File

@ -49,11 +49,6 @@ var (
ErrNoKeys = errors.New("no keys in store") ErrNoKeys = errors.New("no keys in store")
) )
const (
// Default unlock duration (in seconds) when an account is unlocked from the console
DefaultAccountUnlockDuration = 300
)
type Account struct { type Account struct {
Address common.Address Address common.Address
} }
@ -114,28 +109,58 @@ func (am *Manager) Sign(a Account, toSign []byte) (signature []byte, err error)
return signature, err return signature, err
} }
// TimedUnlock unlocks the account with the given address. // unlock indefinitely
// When timeout has passed, the account will be locked again. func (am *Manager) Unlock(addr common.Address, keyAuth string) error {
return am.TimedUnlock(addr, keyAuth, 0)
}
// Unlock unlocks the account with the given address. The account
// stays unlocked for the duration of timeout
// it timeout is 0 the account is unlocked for the entire session
func (am *Manager) TimedUnlock(addr common.Address, keyAuth string, timeout time.Duration) error { func (am *Manager) TimedUnlock(addr common.Address, keyAuth string, timeout time.Duration) error {
key, err := am.keyStore.GetKey(addr, keyAuth) key, err := am.keyStore.GetKey(addr, keyAuth)
if err != nil { if err != nil {
return err return err
} }
u := am.addUnlocked(addr, key) var u *unlocked
go am.dropLater(addr, u, timeout) am.mutex.Lock()
defer am.mutex.Unlock()
var found bool
u, found = am.unlocked[addr]
if found {
// terminate dropLater for this key to avoid unexpected drops.
if u.abort != nil {
close(u.abort)
}
}
if timeout > 0 {
u = &unlocked{Key: key, abort: make(chan struct{})}
go am.expire(addr, u, timeout)
} else {
u = &unlocked{Key: key}
}
am.unlocked[addr] = u
return nil return nil
} }
// Unlock unlocks the account with the given address. The account func (am *Manager) expire(addr common.Address, u *unlocked, timeout time.Duration) {
// stays unlocked until the program exits or until a TimedUnlock t := time.NewTimer(timeout)
// timeout (started after the call to Unlock) expires. defer t.Stop()
func (am *Manager) Unlock(addr common.Address, keyAuth string) error { select {
key, err := am.keyStore.GetKey(addr, keyAuth) case <-u.abort:
if err != nil { // just quit
return err case <-t.C:
am.mutex.Lock()
// only drop if it's still the same key instance that dropLater
// was launched with. we can check that using pointer equality
// because the map stores a new pointer every time the key is
// unlocked.
if am.unlocked[addr] == u {
zeroKey(u.PrivateKey)
delete(am.unlocked, addr)
}
am.mutex.Unlock()
} }
am.addUnlocked(addr, key)
return nil
} }
func (am *Manager) NewAccount(auth string) (Account, error) { func (am *Manager) NewAccount(auth string) (Account, error) {
@ -162,43 +187,6 @@ func (am *Manager) Accounts() ([]Account, error) {
return accounts, err return accounts, err
} }
func (am *Manager) addUnlocked(addr common.Address, key *crypto.Key) *unlocked {
u := &unlocked{Key: key, abort: make(chan struct{})}
am.mutex.Lock()
prev, found := am.unlocked[addr]
if found {
// terminate dropLater for this key to avoid unexpected drops.
close(prev.abort)
// the key is zeroed here instead of in dropLater because
// there might not actually be a dropLater running for this
// key, i.e. when Unlock was used.
zeroKey(prev.PrivateKey)
}
am.unlocked[addr] = u
am.mutex.Unlock()
return u
}
func (am *Manager) dropLater(addr common.Address, u *unlocked, timeout time.Duration) {
t := time.NewTimer(timeout)
defer t.Stop()
select {
case <-u.abort:
// just quit
case <-t.C:
am.mutex.Lock()
// only drop if it's still the same key instance that dropLater
// was launched with. we can check that using pointer equality
// because the map stores a new pointer every time the key is
// unlocked.
if am.unlocked[addr] == u {
zeroKey(u.PrivateKey)
delete(am.unlocked, addr)
}
am.mutex.Unlock()
}
}
// zeroKey zeroes a private key in memory. // zeroKey zeroes a private key in memory.
func zeroKey(k *ecdsa.PrivateKey) { func zeroKey(k *ecdsa.PrivateKey) {
b := k.D.Bits() b := k.D.Bits()

View File

@ -18,7 +18,7 @@ func TestSign(t *testing.T) {
pass := "" // not used but required by API pass := "" // not used but required by API
a1, err := am.NewAccount(pass) a1, err := am.NewAccount(pass)
toSign := randentropy.GetEntropyCSPRNG(32) toSign := randentropy.GetEntropyCSPRNG(32)
am.Unlock(a1.Address, "") am.Unlock(a1.Address, "", 0)
_, err = am.Sign(a1, toSign) _, err = am.Sign(a1, toSign)
if err != nil { if err != nil {
@ -58,6 +58,47 @@ func TestTimedUnlock(t *testing.T) {
if err != ErrLocked { if err != ErrLocked {
t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
} }
}
func TestOverrideUnlock(t *testing.T) {
dir, ks := tmpKeyStore(t, crypto.NewKeyStorePassphrase)
defer os.RemoveAll(dir)
am := NewManager(ks)
pass := "foo"
a1, err := am.NewAccount(pass)
toSign := randentropy.GetEntropyCSPRNG(32)
// Unlock indefinitely
if err = am.Unlock(a1.Address, pass); err != nil {
t.Fatal(err)
}
// Signing without passphrase works because account is temp unlocked
_, err = am.Sign(a1, toSign)
if err != nil {
t.Fatal("Signing shouldn't return an error after unlocking, got ", err)
}
// reset unlock to a shorter period, invalidates the previous unlock
if err = am.TimedUnlock(a1.Address, pass, 100*time.Millisecond); err != nil {
t.Fatal(err)
}
// Signing without passphrase still works because account is temp unlocked
_, err = am.Sign(a1, toSign)
if err != nil {
t.Fatal("Signing shouldn't return an error after unlocking, got ", err)
}
// Signing fails again after automatic locking
time.Sleep(150 * time.Millisecond)
_, err = am.Sign(a1, toSign)
if err != ErrLocked {
t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
}
} }
func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore2) (string, crypto.KeyStore2) { func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore2) (string, crypto.KeyStore2) {