diff --git a/eth/api_backend.go b/eth/api_backend.go index 00424caed..9ac06ffa4 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -18,6 +18,7 @@ package eth import ( "context" + "errors" "math/big" "github.com/ethereum/go-ethereum/accounts" @@ -95,9 +96,12 @@ func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc. } // Otherwise resolve the block number and return its state header, err := b.HeaderByNumber(ctx, blockNr) - if header == nil || err != nil { + if err != nil { return nil, nil, err } + if header == nil { + return nil, nil, errors.New("header not found") + } stateDb, err := b.eth.BlockChain().StateAt(header.Root) return stateDb, header, err } diff --git a/ethclient/ethclient_test.go b/ethclient/ethclient_test.go index 3e8bf974c..74711bd39 100644 --- a/ethclient/ethclient_test.go +++ b/ethclient/ethclient_test.go @@ -17,13 +17,24 @@ package ethclient import ( + "context" + "errors" "fmt" "math/big" "reflect" "testing" + "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/params" ) // Verify that Client implements the ethereum interfaces. @@ -150,3 +161,143 @@ func TestToFilterArg(t *testing.T) { }) } } + +var ( + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + testAddr = crypto.PubkeyToAddress(testKey.PublicKey) + testBalance = big.NewInt(2e10) +) + +func newTestBackend(t *testing.T) (*node.Node, []*types.Block) { + // Generate test chain. + genesis, blocks := generateTestChain() + + // Start Ethereum service. + var ethservice *eth.Ethereum + n, err := node.New(&node.Config{}) + n.Register(func(ctx *node.ServiceContext) (node.Service, error) { + config := ð.Config{Genesis: genesis} + config.Ethash.PowMode = ethash.ModeFake + ethservice, err = eth.New(ctx, config) + return ethservice, err + }) + + // Import the test chain. + if err := n.Start(); err != nil { + t.Fatalf("can't start test node: %v", err) + } + if _, err := ethservice.BlockChain().InsertChain(blocks[1:]); err != nil { + t.Fatalf("can't import test blocks: %v", err) + } + return n, blocks +} + +func generateTestChain() (*core.Genesis, []*types.Block) { + db := rawdb.NewMemoryDatabase() + config := params.AllEthashProtocolChanges + genesis := &core.Genesis{ + Config: config, + Alloc: core.GenesisAlloc{testAddr: {Balance: testBalance}}, + ExtraData: []byte("test genesis"), + Timestamp: 9000, + } + generate := func(i int, g *core.BlockGen) { + g.OffsetTime(5) + g.SetExtra([]byte("test")) + } + gblock := genesis.ToBlock(db) + engine := ethash.NewFaker() + blocks, _ := core.GenerateChain(config, gblock, engine, db, 1, generate) + blocks = append([]*types.Block{gblock}, blocks...) + return genesis, blocks +} + +func TestHeader(t *testing.T) { + backend, chain := newTestBackend(t) + client, _ := backend.Attach() + defer backend.Stop() + defer client.Close() + + tests := map[string]struct { + block *big.Int + want *types.Header + wantErr error + }{ + "genesis": { + block: big.NewInt(0), + want: chain[0].Header(), + }, + "first_block": { + block: big.NewInt(1), + want: chain[1].Header(), + }, + "future_block": { + block: big.NewInt(1000000000), + want: nil, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ec := NewClient(client) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + got, err := ec.HeaderByNumber(ctx, tt.block) + if tt.wantErr != nil && (err == nil || err.Error() != tt.wantErr.Error()) { + t.Fatalf("HeaderByNumber(%v) error = %q, want %q", tt.block, err, tt.wantErr) + } + if got != nil && got.Number.Sign() == 0 { + got.Number = big.NewInt(0) // hack to make DeepEqual work + } + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("HeaderByNumber(%v)\n = %v\nwant %v", tt.block, got, tt.want) + } + }) + } +} + +func TestBalanceAt(t *testing.T) { + backend, _ := newTestBackend(t) + client, _ := backend.Attach() + defer backend.Stop() + defer client.Close() + + tests := map[string]struct { + account common.Address + block *big.Int + want *big.Int + wantErr error + }{ + "valid_account": { + account: testAddr, + block: big.NewInt(1), + want: testBalance, + }, + "non_existent_account": { + account: common.Address{1}, + block: big.NewInt(1), + want: big.NewInt(0), + }, + "future_block": { + account: testAddr, + block: big.NewInt(1000000000), + want: big.NewInt(0), + wantErr: errors.New("header not found"), + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ec := NewClient(client) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + got, err := ec.BalanceAt(ctx, tt.account, tt.block) + if tt.wantErr != nil && (err == nil || err.Error() != tt.wantErr.Error()) { + t.Fatalf("BalanceAt(%x, %v) error = %q, want %q", tt.account, tt.block, err, tt.wantErr) + } + if got.Cmp(tt.want) != 0 { + t.Fatalf("BalanceAt(%x, %v) = %v, want %v", tt.account, tt.block, got, tt.want) + } + }) + } +} diff --git a/les/api_backend.go b/les/api_backend.go index 4fe352136..589cf572d 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -18,6 +18,7 @@ package les import ( "context" + "errors" "math/big" "github.com/ethereum/go-ethereum/accounts" @@ -78,9 +79,12 @@ func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) { header, err := b.HeaderByNumber(ctx, blockNr) - if header == nil || err != nil { + if err != nil { return nil, nil, err } + if header == nil { + return nil, nil, errors.New("header not found") + } return light.NewState(ctx, header, b.eth.odr), header, nil }