Wallets record default address in keystore

This commit is contained in:
whyrusleeping 2019-10-17 19:18:40 +09:00
parent 75db4576f5
commit d818e20df5
12 changed files with 123 additions and 26 deletions

View File

@ -93,6 +93,7 @@ type FullNode interface {
WalletSign(context.Context, address.Address, []byte) (*types.Signature, error) WalletSign(context.Context, address.Address, []byte) (*types.Signature, error)
WalletSignMessage(context.Context, address.Address, *types.Message) (*types.SignedMessage, error) WalletSignMessage(context.Context, address.Address, *types.Message) (*types.SignedMessage, error)
WalletDefaultAddress(context.Context) (address.Address, error) WalletDefaultAddress(context.Context) (address.Address, error)
WalletSetDefault(context.Context, address.Address) error
WalletExport(context.Context, address.Address) (*types.KeyInfo, error) WalletExport(context.Context, address.Address) (*types.KeyInfo, error)
WalletImport(context.Context, *types.KeyInfo) (address.Address, error) WalletImport(context.Context, *types.KeyInfo) (address.Address, error)

View File

@ -70,6 +70,7 @@ type FullNodeStruct struct {
WalletSign func(context.Context, address.Address, []byte) (*types.Signature, error) `perm:"sign"` WalletSign func(context.Context, address.Address, []byte) (*types.Signature, error) `perm:"sign"`
WalletSignMessage func(context.Context, address.Address, *types.Message) (*types.SignedMessage, error) `perm:"sign"` WalletSignMessage func(context.Context, address.Address, *types.Message) (*types.SignedMessage, error) `perm:"sign"`
WalletDefaultAddress func(context.Context) (address.Address, error) `perm:"write"` WalletDefaultAddress func(context.Context) (address.Address, error) `perm:"write"`
WalletSetDefault func(context.Context, address.Address) error `perm:"admin"`
WalletExport func(context.Context, address.Address) (*types.KeyInfo, error) `perm:"admin"` WalletExport func(context.Context, address.Address) (*types.KeyInfo, error) `perm:"admin"`
WalletImport func(context.Context, *types.KeyInfo) (address.Address, error) `perm:"admin"` WalletImport func(context.Context, *types.KeyInfo) (address.Address, error) `perm:"admin"`
@ -269,6 +270,10 @@ func (c *FullNodeStruct) WalletDefaultAddress(ctx context.Context) (address.Addr
return c.Internal.WalletDefaultAddress(ctx) return c.Internal.WalletDefaultAddress(ctx)
} }
func (c *FullNodeStruct) WalletSetDefault(ctx context.Context, a address.Address) error {
return c.Internal.WalletSetDefault(ctx, a)
}
func (c *FullNodeStruct) WalletExport(ctx context.Context, a address.Address) (*types.KeyInfo, error) { func (c *FullNodeStruct) WalletExport(ctx context.Context, a address.Address) (*types.KeyInfo, error) {
return c.Internal.WalletExport(ctx, a) return c.Internal.WalletExport(ctx, a)
} }

View File

@ -1,5 +1,11 @@
package types package types
import (
"fmt"
)
var ErrKeyInfoNotFound = fmt.Errorf("key info not found")
// KeyInfo is used for storing keys in KeyStore // KeyInfo is used for storing keys in KeyStore
type KeyInfo struct { type KeyInfo struct {
Type string Type string

View File

@ -2,7 +2,6 @@ package wallet
import ( import (
"github.com/filecoin-project/go-lotus/chain/types" "github.com/filecoin-project/go-lotus/chain/types"
"github.com/filecoin-project/go-lotus/node/repo"
) )
type MemKeyStore struct { type MemKeyStore struct {
@ -28,7 +27,7 @@ func (mks *MemKeyStore) List() ([]string, error) {
func (mks *MemKeyStore) Get(k string) (types.KeyInfo, error) { func (mks *MemKeyStore) Get(k string) (types.KeyInfo, error) {
ki, ok := mks.m[k] ki, ok := mks.m[k]
if !ok { if !ok {
return types.KeyInfo{}, repo.ErrKeyNotFound return types.KeyInfo{}, types.ErrKeyInfoNotFound
} }
return ki, nil return ki, nil

View File

@ -7,7 +7,6 @@ import (
"sync" "sync"
"github.com/filecoin-project/go-bls-sigs" "github.com/filecoin-project/go-bls-sigs"
"github.com/filecoin-project/go-lotus/node/repo"
"github.com/minio/blake2b-simd" "github.com/minio/blake2b-simd"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@ -19,6 +18,7 @@ import (
const ( const (
KNamePrefix = "wallet-" KNamePrefix = "wallet-"
KDefault = "default"
) )
type Wallet struct { type Wallet struct {
@ -43,7 +43,7 @@ func (w *Wallet) Sign(ctx context.Context, addr address.Address, msg []byte) (*t
return nil, err return nil, err
} }
if ki == nil { if ki == nil {
return nil, xerrors.Errorf("signing using key '%s': %w", addr.String(), repo.ErrKeyNotFound) return nil, xerrors.Errorf("signing using key '%s': %w", addr.String(), types.ErrKeyInfoNotFound)
} }
switch ki.Type { switch ki.Type {
@ -83,7 +83,7 @@ func (w *Wallet) findKey(addr address.Address) (*Key, error) {
} }
ki, err := w.keystore.Get(KNamePrefix + addr.String()) ki, err := w.keystore.Get(KNamePrefix + addr.String())
if err != nil { if err != nil {
if xerrors.Is(err, repo.ErrKeyNotFound) { if xerrors.Is(err, types.ErrKeyInfoNotFound) {
return nil, nil return nil, nil
} }
return nil, xerrors.Errorf("getting from keystore: %w", err) return nil, xerrors.Errorf("getting from keystore: %w", err)
@ -144,6 +144,39 @@ func (w *Wallet) ListAddrs() ([]address.Address, error) {
return out, nil return out, nil
} }
func (w *Wallet) GetDefault() (address.Address, error) {
w.lk.Lock()
defer w.lk.Unlock()
ki, err := w.keystore.Get(KDefault)
if err != nil {
return address.Undef, xerrors.Errorf("failed to get default key: %w", err)
}
k, err := NewKey(ki)
if err != nil {
return address.Undef, xerrors.Errorf("failed to read default key from keystore: %w", err)
}
return k.Address, nil
}
func (w *Wallet) SetDefault(a address.Address) error {
w.lk.Lock()
defer w.lk.Unlock()
ki, err := w.keystore.Get(KNamePrefix + a.String())
if err != nil {
return err
}
if err := w.keystore.Put(KDefault, ki); err != nil {
return err
}
return nil
}
func GenerateKey(typ string) (*Key, error) { func GenerateKey(typ string) (*Key, error) {
switch typ { switch typ {
case types.KTSecp256k1: case types.KTSecp256k1:
@ -183,6 +216,18 @@ func (w *Wallet) GenerateKey(typ string) (address.Address, error) {
return address.Undef, xerrors.Errorf("saving to keystore: %w", err) return address.Undef, xerrors.Errorf("saving to keystore: %w", err)
} }
w.keys[k.Address] = k w.keys[k.Address] = k
_, err = w.keystore.Get(KDefault)
if err != nil {
if !xerrors.Is(err, types.ErrKeyInfoNotFound) {
return address.Undef, err
}
if err := w.keystore.Put(KDefault, k.KeyInfo); err != nil {
return address.Undef, xerrors.Errorf("failed to set new key as default: %w", err)
}
}
return k.Address, nil return k.Address, nil
} }

View File

@ -22,6 +22,8 @@ var walletCmd = &cli.Command{
walletBalance, walletBalance,
walletExport, walletExport,
walletImport, walletImport,
walletGetDefault,
walletSetDefault,
}, },
} }
@ -108,6 +110,51 @@ var walletBalance = &cli.Command{
}, },
} }
var walletGetDefault = &cli.Command{
Name: "default",
Usage: "Get default wallet address",
Action: func(cctx *cli.Context) error {
api, closer, err := GetFullNodeAPI(cctx)
if err != nil {
return err
}
defer closer()
ctx := ReqContext(cctx)
addr, err := api.WalletDefaultAddress(ctx)
if err != nil {
return err
}
fmt.Printf("%s\n", addr.String())
return nil
},
}
var walletSetDefault = &cli.Command{
Name: "set-default",
Usage: "Set default wallet address",
Action: func(cctx *cli.Context) error {
api, closer, err := GetFullNodeAPI(cctx)
if err != nil {
return err
}
defer closer()
ctx := ReqContext(cctx)
if !cctx.Args().Present() {
return fmt.Errorf("must pass address to set as default")
}
addr, err := address.NewFromString(cctx.Args().First())
if err != nil {
return err
}
return api.WalletSetDefault(ctx, addr)
},
}
var walletExport = &cli.Command{ var walletExport = &cli.Command{
Name: "export", Name: "export",
Usage: "export keys", Usage: "export keys",

View File

@ -54,16 +54,11 @@ func (a *WalletAPI) WalletSignMessage(ctx context.Context, k address.Address, ms
} }
func (a *WalletAPI) WalletDefaultAddress(ctx context.Context) (address.Address, error) { func (a *WalletAPI) WalletDefaultAddress(ctx context.Context) (address.Address, error) {
addrs, err := a.Wallet.ListAddrs() return a.Wallet.GetDefault()
if err != nil { }
return address.Undef, err
}
if len(addrs) == 0 {
return address.Undef, xerrors.New("no addresses in wallet")
}
// TODO: store a default address in the config or 'wallet' portion of the repo func (a *WalletAPI) WalletSetDefault(ctx context.Context, addr address.Address) error {
return addrs[0], nil return a.Wallet.SetDefault(addr)
} }
func (a *WalletAPI) WalletExport(ctx context.Context, addr address.Address) (*types.KeyInfo, error) { func (a *WalletAPI) WalletExport(ctx context.Context, addr address.Address) (*types.KeyInfo, error) {

View File

@ -2,11 +2,11 @@ package lp2p
import ( import (
"crypto/rand" "crypto/rand"
"github.com/filecoin-project/go-lotus/chain/types"
"github.com/filecoin-project/go-lotus/node/repo"
"golang.org/x/xerrors"
"time" "time"
"github.com/filecoin-project/go-lotus/chain/types"
"golang.org/x/xerrors"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p"
connmgr "github.com/libp2p/go-libp2p-connmgr" connmgr "github.com/libp2p/go-libp2p-connmgr"
@ -31,7 +31,7 @@ func PrivKey(ks types.KeyStore) (crypto.PrivKey, error) {
if err == nil { if err == nil {
return crypto.UnmarshalPrivateKey(k.PrivateKey) return crypto.UnmarshalPrivateKey(k.PrivateKey)
} }
if !xerrors.Is(err, repo.ErrKeyNotFound) { if !xerrors.Is(err, types.ErrKeyInfoNotFound) {
return nil, err return nil, err
} }
pk, err := genLibp2pKey() pk, err := genLibp2pKey()

View File

@ -287,7 +287,7 @@ func (fsr *fsLockedRepo) Get(name string) (types.KeyInfo, error) {
fstat, err := os.Stat(keyPath) fstat, err := os.Stat(keyPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, ErrKeyNotFound) return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, types.ErrKeyInfoNotFound)
} else if err != nil { } else if err != nil {
return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, err) return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, err)
} }
@ -354,7 +354,7 @@ func (fsr *fsLockedRepo) Delete(name string) error {
_, err := os.Stat(keyPath) _, err := os.Stat(keyPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return xerrors.Errorf("checking key before delete '%s': %w", name, ErrKeyNotFound) return xerrors.Errorf("checking key before delete '%s': %w", name, types.ErrKeyInfoNotFound)
} else if err != nil { } else if err != nil {
return xerrors.Errorf("checking key before delete '%s': %w", name, err) return xerrors.Errorf("checking key before delete '%s': %w", name, err)
} }

View File

@ -16,8 +16,7 @@ var (
ErrRepoAlreadyLocked = errors.New("repo is already locked") ErrRepoAlreadyLocked = errors.New("repo is already locked")
ErrClosedRepo = errors.New("repo is no longer open") ErrClosedRepo = errors.New("repo is no longer open")
ErrKeyExists = errors.New("key already exists") ErrKeyExists = errors.New("key already exists")
ErrKeyNotFound = errors.New("key not found")
) )
type Repo interface { type Repo interface {

View File

@ -204,7 +204,7 @@ func (lmem *lockedMemRepo) Get(name string) (types.KeyInfo, error) {
key, ok := lmem.mem.keystore[name] key, ok := lmem.mem.keystore[name]
if !ok { if !ok {
return types.KeyInfo{}, xerrors.Errorf("getting key '%s': %w", name, ErrKeyNotFound) return types.KeyInfo{}, xerrors.Errorf("getting key '%s': %w", name, types.ErrKeyInfoNotFound)
} }
return key, nil return key, nil
} }
@ -235,7 +235,7 @@ func (lmem *lockedMemRepo) Delete(name string) error {
_, isThere := lmem.mem.keystore[name] _, isThere := lmem.mem.keystore[name]
if !isThere { if !isThere {
return xerrors.Errorf("deleting key '%s': %w", name, ErrKeyNotFound) return xerrors.Errorf("deleting key '%s': %w", name, types.ErrKeyInfoNotFound)
} }
delete(lmem.mem.keystore, name) delete(lmem.mem.keystore, name)
return nil return nil

View File

@ -90,7 +90,7 @@ func basicTest(t *testing.T, repo Repo) {
k2prim, err := kstr.Get("k2") k2prim, err := kstr.Get("k2")
if assert.Error(t, err, "should not be able to get k2") { if assert.Error(t, err, "should not be able to get k2") {
assert.True(t, xerrors.Is(err, ErrKeyNotFound), "returned error is ErrKeyNotFound") assert.True(t, xerrors.Is(err, types.ErrKeyInfoNotFound), "returned error is ErrKeyNotFound")
} }
assert.Empty(t, k2prim, "there should be no output for k2") assert.Empty(t, k2prim, "there should be no output for k2")
@ -110,6 +110,6 @@ func basicTest(t *testing.T, repo Repo) {
err = kstr.Delete("k2") err = kstr.Delete("k2")
if assert.Error(t, err) { if assert.Error(t, err) {
assert.True(t, xerrors.Is(err, ErrKeyNotFound), "returned errror is ErrKeyNotFound") assert.True(t, xerrors.Is(err, types.ErrKeyInfoNotFound), "returned errror is ErrKeyNotFound")
} }
} }