diff --git a/accounts/keystore/keystore.go b/accounts/keystore/keystore.go index 5b55175b1..47fa376ed 100644 --- a/accounts/keystore/keystore.go +++ b/accounts/keystore/keystore.go @@ -24,7 +24,6 @@ import ( "crypto/ecdsa" crand "crypto/rand" "errors" - "fmt" "math/big" "os" "path/filepath" @@ -67,7 +66,8 @@ type KeyStore struct { updateScope event.SubscriptionScope // Subscription scope tracking current live listeners updating bool // Whether the event notification loop is running - mu sync.RWMutex + mu sync.RWMutex + importMu sync.Mutex // Import Mutex locks the import to prevent two insertions from racing } type unlocked struct { @@ -443,14 +443,21 @@ func (ks *KeyStore) Import(keyJSON []byte, passphrase, newPassphrase string) (ac if err != nil { return accounts.Account{}, err } + ks.importMu.Lock() + defer ks.importMu.Unlock() + if ks.cache.hasAddress(key.Address) { + return accounts.Account{}, errors.New("account already exists") + } return ks.importKey(key, newPassphrase) } // ImportECDSA stores the given key into the key directory, encrypting it with the passphrase. func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (accounts.Account, error) { key := newKeyFromECDSA(priv) + ks.importMu.Lock() + defer ks.importMu.Unlock() if ks.cache.hasAddress(key.Address) { - return accounts.Account{}, fmt.Errorf("account already exists") + return accounts.Account{}, errors.New("account already exists") } return ks.importKey(key, passphrase) } diff --git a/accounts/keystore/keystore_test.go b/accounts/keystore/keystore_test.go index a691c5062..29c251d7c 100644 --- a/accounts/keystore/keystore_test.go +++ b/accounts/keystore/keystore_test.go @@ -23,11 +23,14 @@ import ( "runtime" "sort" "strings" + "sync" + "sync/atomic" "testing" "time" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/event" ) @@ -338,6 +341,88 @@ func TestWalletNotifications(t *testing.T) { checkEvents(t, wantEvents, events) } +// TestImportExport tests the import functionality of a keystore. +func TestImportECDSA(t *testing.T) { + dir, ks := tmpKeyStore(t, true) + defer os.RemoveAll(dir) + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %v", key) + } + if _, err = ks.ImportECDSA(key, "old"); err != nil { + t.Errorf("importing failed: %v", err) + } + if _, err = ks.ImportECDSA(key, "old"); err == nil { + t.Errorf("importing same key twice succeeded") + } + if _, err = ks.ImportECDSA(key, "new"); err == nil { + t.Errorf("importing same key twice succeeded") + } +} + +// TestImportECDSA tests the import and export functionality of a keystore. +func TestImportExport(t *testing.T) { + dir, ks := tmpKeyStore(t, true) + defer os.RemoveAll(dir) + acc, err := ks.NewAccount("old") + if err != nil { + t.Fatalf("failed to create account: %v", acc) + } + json, err := ks.Export(acc, "old", "new") + if err != nil { + t.Fatalf("failed to export account: %v", acc) + } + dir2, ks2 := tmpKeyStore(t, true) + defer os.RemoveAll(dir2) + if _, err = ks2.Import(json, "old", "old"); err == nil { + t.Errorf("importing with invalid password succeeded") + } + acc2, err := ks2.Import(json, "new", "new") + if err != nil { + t.Errorf("importing failed: %v", err) + } + if acc.Address != acc2.Address { + t.Error("imported account does not match exported account") + } + if _, err = ks2.Import(json, "new", "new"); err == nil { + t.Errorf("importing a key twice succeeded") + } + +} + +// TestImportRace tests the keystore on races. +// This test should fail under -race if importing races. +func TestImportRace(t *testing.T) { + dir, ks := tmpKeyStore(t, true) + defer os.RemoveAll(dir) + acc, err := ks.NewAccount("old") + if err != nil { + t.Fatalf("failed to create account: %v", acc) + } + json, err := ks.Export(acc, "old", "new") + if err != nil { + t.Fatalf("failed to export account: %v", acc) + } + dir2, ks2 := tmpKeyStore(t, true) + defer os.RemoveAll(dir2) + var atom uint32 + var wg sync.WaitGroup + wg.Add(2) + for i := 0; i < 2; i++ { + go func() { + defer wg.Done() + if _, err := ks2.Import(json, "new", "new"); err != nil { + atomic.AddUint32(&atom, 1) + } + + }() + } + wg.Wait() + if atom != 1 { + t.Errorf("Import is racy") + } +} + // checkAccounts checks that all known live accounts are present in the wallet list. func checkAccounts(t *testing.T, live map[common.Address]accounts.Account, wallets []accounts.Wallet) { if len(live) != len(wallets) { diff --git a/cmd/geth/accountcmd_test.go b/cmd/geth/accountcmd_test.go index 3ee4278e5..6213e5195 100644 --- a/cmd/geth/accountcmd_test.go +++ b/cmd/geth/accountcmd_test.go @@ -89,18 +89,23 @@ Path of the secret key file: .*UTC--.+--[0-9a-f]{40} } func TestAccountImport(t *testing.T) { - tests := []struct{ key, output string }{ + tests := []struct{ name, key, output string }{ { + name: "correct account", key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", output: "Address: {fcad0b19bb29d4674531d6f115237e16afce377c}\n", }, { + name: "invalid character", key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef1", output: "Fatal: Failed to load the private key: invalid character '1' at end of key file\n", }, } for _, test := range tests { - importAccountWithExpect(t, test.key, test.output) + t.Run(test.name, func(t *testing.T) { + t.Parallel() + importAccountWithExpect(t, test.key, test.output) + }) } }