diff --git a/api/api.go b/api/api.go index aa3a56df2..235b4e8e7 100644 --- a/api/api.go +++ b/api/api.go @@ -93,6 +93,7 @@ type FullNode interface { WalletSign(context.Context, address.Address, []byte) (*types.Signature, error) WalletSignMessage(context.Context, address.Address, *types.Message) (*types.SignedMessage, error) WalletDefaultAddress(context.Context) (address.Address, error) + WalletSetDefault(context.Context, address.Address) error WalletExport(context.Context, address.Address) (*types.KeyInfo, error) WalletImport(context.Context, *types.KeyInfo) (address.Address, error) diff --git a/api/struct.go b/api/struct.go index 93d670676..3db672e46 100644 --- a/api/struct.go +++ b/api/struct.go @@ -70,6 +70,7 @@ type FullNodeStruct struct { WalletSign func(context.Context, address.Address, []byte) (*types.Signature, error) `perm:"sign"` WalletSignMessage func(context.Context, address.Address, *types.Message) (*types.SignedMessage, error) `perm:"sign"` WalletDefaultAddress func(context.Context) (address.Address, error) `perm:"write"` + WalletSetDefault func(context.Context, address.Address) error `perm:"admin"` WalletExport func(context.Context, address.Address) (*types.KeyInfo, error) `perm:"admin"` WalletImport func(context.Context, *types.KeyInfo) (address.Address, error) `perm:"admin"` @@ -269,6 +270,10 @@ func (c *FullNodeStruct) WalletDefaultAddress(ctx context.Context) (address.Addr return c.Internal.WalletDefaultAddress(ctx) } +func (c *FullNodeStruct) WalletSetDefault(ctx context.Context, a address.Address) error { + return c.Internal.WalletSetDefault(ctx, a) +} + func (c *FullNodeStruct) WalletExport(ctx context.Context, a address.Address) (*types.KeyInfo, error) { return c.Internal.WalletExport(ctx, a) } diff --git a/chain/types/keystore.go b/chain/types/keystore.go index 54cf05802..994e6d960 100644 --- a/chain/types/keystore.go +++ b/chain/types/keystore.go @@ -1,5 +1,11 @@ package types +import ( + "fmt" +) + +var ErrKeyInfoNotFound = fmt.Errorf("key info not found") + // KeyInfo is used for storing keys in KeyStore type KeyInfo struct { Type string diff --git a/chain/wallet/memkeystore.go b/chain/wallet/memkeystore.go index 14e6a157f..bd6e3dc89 100644 --- a/chain/wallet/memkeystore.go +++ b/chain/wallet/memkeystore.go @@ -2,7 +2,6 @@ package wallet import ( "github.com/filecoin-project/go-lotus/chain/types" - "github.com/filecoin-project/go-lotus/node/repo" ) type MemKeyStore struct { @@ -28,7 +27,7 @@ func (mks *MemKeyStore) List() ([]string, error) { func (mks *MemKeyStore) Get(k string) (types.KeyInfo, error) { ki, ok := mks.m[k] if !ok { - return types.KeyInfo{}, repo.ErrKeyNotFound + return types.KeyInfo{}, types.ErrKeyInfoNotFound } return ki, nil diff --git a/chain/wallet/wallet.go b/chain/wallet/wallet.go index 3f3ee86c2..dad6b7db6 100644 --- a/chain/wallet/wallet.go +++ b/chain/wallet/wallet.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/filecoin-project/go-bls-sigs" - "github.com/filecoin-project/go-lotus/node/repo" "github.com/minio/blake2b-simd" "golang.org/x/xerrors" @@ -19,6 +18,7 @@ import ( const ( KNamePrefix = "wallet-" + KDefault = "default" ) type Wallet struct { @@ -43,7 +43,7 @@ func (w *Wallet) Sign(ctx context.Context, addr address.Address, msg []byte) (*t return nil, err } if ki == nil { - return nil, xerrors.Errorf("signing using key '%s': %w", addr.String(), repo.ErrKeyNotFound) + return nil, xerrors.Errorf("signing using key '%s': %w", addr.String(), types.ErrKeyInfoNotFound) } switch ki.Type { @@ -83,7 +83,7 @@ func (w *Wallet) findKey(addr address.Address) (*Key, error) { } ki, err := w.keystore.Get(KNamePrefix + addr.String()) if err != nil { - if xerrors.Is(err, repo.ErrKeyNotFound) { + if xerrors.Is(err, types.ErrKeyInfoNotFound) { return nil, nil } return nil, xerrors.Errorf("getting from keystore: %w", err) @@ -144,6 +144,39 @@ func (w *Wallet) ListAddrs() ([]address.Address, error) { return out, nil } +func (w *Wallet) GetDefault() (address.Address, error) { + w.lk.Lock() + defer w.lk.Unlock() + + ki, err := w.keystore.Get(KDefault) + if err != nil { + return address.Undef, xerrors.Errorf("failed to get default key: %w", err) + } + + k, err := NewKey(ki) + if err != nil { + return address.Undef, xerrors.Errorf("failed to read default key from keystore: %w", err) + } + + return k.Address, nil +} + +func (w *Wallet) SetDefault(a address.Address) error { + w.lk.Lock() + defer w.lk.Unlock() + + ki, err := w.keystore.Get(KNamePrefix + a.String()) + if err != nil { + return err + } + + if err := w.keystore.Put(KDefault, ki); err != nil { + return err + } + + return nil +} + func GenerateKey(typ string) (*Key, error) { switch typ { case types.KTSecp256k1: @@ -183,6 +216,18 @@ func (w *Wallet) GenerateKey(typ string) (address.Address, error) { return address.Undef, xerrors.Errorf("saving to keystore: %w", err) } w.keys[k.Address] = k + + _, err = w.keystore.Get(KDefault) + if err != nil { + if !xerrors.Is(err, types.ErrKeyInfoNotFound) { + return address.Undef, err + } + + if err := w.keystore.Put(KDefault, k.KeyInfo); err != nil { + return address.Undef, xerrors.Errorf("failed to set new key as default: %w", err) + } + } + return k.Address, nil } diff --git a/cli/wallet.go b/cli/wallet.go index 6361f5c69..6394526ed 100644 --- a/cli/wallet.go +++ b/cli/wallet.go @@ -22,6 +22,8 @@ var walletCmd = &cli.Command{ walletBalance, walletExport, walletImport, + walletGetDefault, + walletSetDefault, }, } @@ -108,6 +110,51 @@ var walletBalance = &cli.Command{ }, } +var walletGetDefault = &cli.Command{ + Name: "default", + Usage: "Get default wallet address", + Action: func(cctx *cli.Context) error { + api, closer, err := GetFullNodeAPI(cctx) + if err != nil { + return err + } + defer closer() + ctx := ReqContext(cctx) + + addr, err := api.WalletDefaultAddress(ctx) + if err != nil { + return err + } + + fmt.Printf("%s\n", addr.String()) + return nil + }, +} + +var walletSetDefault = &cli.Command{ + Name: "set-default", + Usage: "Set default wallet address", + Action: func(cctx *cli.Context) error { + api, closer, err := GetFullNodeAPI(cctx) + if err != nil { + return err + } + defer closer() + ctx := ReqContext(cctx) + + if !cctx.Args().Present() { + return fmt.Errorf("must pass address to set as default") + } + + addr, err := address.NewFromString(cctx.Args().First()) + if err != nil { + return err + } + + return api.WalletSetDefault(ctx, addr) + }, +} + var walletExport = &cli.Command{ Name: "export", Usage: "export keys", diff --git a/node/impl/full/wallet.go b/node/impl/full/wallet.go index 4482118fe..01d0ee79c 100644 --- a/node/impl/full/wallet.go +++ b/node/impl/full/wallet.go @@ -54,16 +54,11 @@ func (a *WalletAPI) WalletSignMessage(ctx context.Context, k address.Address, ms } func (a *WalletAPI) WalletDefaultAddress(ctx context.Context) (address.Address, error) { - addrs, err := a.Wallet.ListAddrs() - if err != nil { - return address.Undef, err - } - if len(addrs) == 0 { - return address.Undef, xerrors.New("no addresses in wallet") - } + return a.Wallet.GetDefault() +} - // TODO: store a default address in the config or 'wallet' portion of the repo - return addrs[0], nil +func (a *WalletAPI) WalletSetDefault(ctx context.Context, addr address.Address) error { + return a.Wallet.SetDefault(addr) } func (a *WalletAPI) WalletExport(ctx context.Context, addr address.Address) (*types.KeyInfo, error) { diff --git a/node/modules/lp2p/libp2p.go b/node/modules/lp2p/libp2p.go index c68796f57..74fbc516d 100644 --- a/node/modules/lp2p/libp2p.go +++ b/node/modules/lp2p/libp2p.go @@ -2,11 +2,11 @@ package lp2p import ( "crypto/rand" - "github.com/filecoin-project/go-lotus/chain/types" - "github.com/filecoin-project/go-lotus/node/repo" - "golang.org/x/xerrors" "time" + "github.com/filecoin-project/go-lotus/chain/types" + "golang.org/x/xerrors" + logging "github.com/ipfs/go-log" "github.com/libp2p/go-libp2p" connmgr "github.com/libp2p/go-libp2p-connmgr" @@ -31,7 +31,7 @@ func PrivKey(ks types.KeyStore) (crypto.PrivKey, error) { if err == nil { return crypto.UnmarshalPrivateKey(k.PrivateKey) } - if !xerrors.Is(err, repo.ErrKeyNotFound) { + if !xerrors.Is(err, types.ErrKeyInfoNotFound) { return nil, err } pk, err := genLibp2pKey() diff --git a/node/repo/fsrepo.go b/node/repo/fsrepo.go index 611f57d7b..ab33eda95 100644 --- a/node/repo/fsrepo.go +++ b/node/repo/fsrepo.go @@ -287,7 +287,7 @@ func (fsr *fsLockedRepo) Get(name string) (types.KeyInfo, error) { fstat, err := os.Stat(keyPath) if os.IsNotExist(err) { - return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, ErrKeyNotFound) + return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, types.ErrKeyInfoNotFound) } else if err != nil { return types.KeyInfo{}, xerrors.Errorf("opening key '%s': %w", name, err) } @@ -354,7 +354,7 @@ func (fsr *fsLockedRepo) Delete(name string) error { _, err := os.Stat(keyPath) if os.IsNotExist(err) { - return xerrors.Errorf("checking key before delete '%s': %w", name, ErrKeyNotFound) + return xerrors.Errorf("checking key before delete '%s': %w", name, types.ErrKeyInfoNotFound) } else if err != nil { return xerrors.Errorf("checking key before delete '%s': %w", name, err) } diff --git a/node/repo/interface.go b/node/repo/interface.go index 9ec3a263a..c769e230c 100644 --- a/node/repo/interface.go +++ b/node/repo/interface.go @@ -16,8 +16,7 @@ var ( ErrRepoAlreadyLocked = errors.New("repo is already locked") ErrClosedRepo = errors.New("repo is no longer open") - ErrKeyExists = errors.New("key already exists") - ErrKeyNotFound = errors.New("key not found") + ErrKeyExists = errors.New("key already exists") ) type Repo interface { diff --git a/node/repo/memrepo.go b/node/repo/memrepo.go index 8441b9013..4fd64e30b 100644 --- a/node/repo/memrepo.go +++ b/node/repo/memrepo.go @@ -204,7 +204,7 @@ func (lmem *lockedMemRepo) Get(name string) (types.KeyInfo, error) { key, ok := lmem.mem.keystore[name] if !ok { - return types.KeyInfo{}, xerrors.Errorf("getting key '%s': %w", name, ErrKeyNotFound) + return types.KeyInfo{}, xerrors.Errorf("getting key '%s': %w", name, types.ErrKeyInfoNotFound) } return key, nil } @@ -235,7 +235,7 @@ func (lmem *lockedMemRepo) Delete(name string) error { _, isThere := lmem.mem.keystore[name] if !isThere { - return xerrors.Errorf("deleting key '%s': %w", name, ErrKeyNotFound) + return xerrors.Errorf("deleting key '%s': %w", name, types.ErrKeyInfoNotFound) } delete(lmem.mem.keystore, name) return nil diff --git a/node/repo/repo_test.go b/node/repo/repo_test.go index a939ef462..f05118c4c 100644 --- a/node/repo/repo_test.go +++ b/node/repo/repo_test.go @@ -90,7 +90,7 @@ func basicTest(t *testing.T, repo Repo) { k2prim, err := kstr.Get("k2") if assert.Error(t, err, "should not be able to get k2") { - assert.True(t, xerrors.Is(err, ErrKeyNotFound), "returned error is ErrKeyNotFound") + assert.True(t, xerrors.Is(err, types.ErrKeyInfoNotFound), "returned error is ErrKeyNotFound") } assert.Empty(t, k2prim, "there should be no output for k2") @@ -110,6 +110,6 @@ func basicTest(t *testing.T, repo Repo) { err = kstr.Delete("k2") if assert.Error(t, err) { - assert.True(t, xerrors.Is(err, ErrKeyNotFound), "returned errror is ErrKeyNotFound") + assert.True(t, xerrors.Is(err, types.ErrKeyInfoNotFound), "returned errror is ErrKeyNotFound") } }