eth/filters: fix getLogs for pending block (#24949)

* eth/filters: fix pending for getLogs

* add pending method to test backend

* fix block range validation
This commit is contained in:
Sina Mahmoodi 2022-06-07 08:31:19 +02:00 committed by GitHub
parent 7e9514b8c3
commit d9566e39bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 17 deletions

View File

@ -66,6 +66,7 @@ type SimulatedBackend struct {
mu sync.Mutex mu sync.Mutex
pendingBlock *types.Block // Currently pending block that will be imported on request pendingBlock *types.Block // Currently pending block that will be imported on request
pendingState *state.StateDB // Currently pending state that will be the active on request pendingState *state.StateDB // Currently pending state that will be the active on request
pendingReceipts types.Receipts // Currently receipts for the pending block
events *filters.EventSystem // Event system for filtering log events live events *filters.EventSystem // Event system for filtering log events live
@ -84,8 +85,8 @@ func NewSimulatedBackendWithDatabase(database ethdb.Database, alloc core.Genesis
database: database, database: database,
blockchain: blockchain, blockchain: blockchain,
config: genesis.Config, config: genesis.Config,
events: filters.NewEventSystem(&filterBackend{database, blockchain}, false),
} }
backend.events = filters.NewEventSystem(&filterBackend{database, blockchain, backend}, false)
backend.rollback(blockchain.CurrentBlock()) backend.rollback(blockchain.CurrentBlock())
return backend return backend
} }
@ -662,7 +663,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
return fmt.Errorf("invalid transaction nonce: got %d, want %d", tx.Nonce(), nonce) return fmt.Errorf("invalid transaction nonce: got %d, want %d", tx.Nonce(), nonce)
} }
// Include tx in chain // Include tx in chain
blocks, _ := core.GenerateChain(b.config, block, ethash.NewFaker(), b.database, 1, func(number int, block *core.BlockGen) { blocks, receipts := core.GenerateChain(b.config, block, ethash.NewFaker(), b.database, 1, func(number int, block *core.BlockGen) {
for _, tx := range b.pendingBlock.Transactions() { for _, tx := range b.pendingBlock.Transactions() {
block.AddTxWithChain(b.blockchain, tx) block.AddTxWithChain(b.blockchain, tx)
} }
@ -672,6 +673,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
b.pendingBlock = blocks[0] b.pendingBlock = blocks[0]
b.pendingState, _ = state.New(b.pendingBlock.Root(), stateDB.Database(), nil) b.pendingState, _ = state.New(b.pendingBlock.Root(), stateDB.Database(), nil)
b.pendingReceipts = receipts[0]
return nil return nil
} }
@ -683,7 +685,7 @@ func (b *SimulatedBackend) FilterLogs(ctx context.Context, query ethereum.Filter
var filter *filters.Filter var filter *filters.Filter
if query.BlockHash != nil { if query.BlockHash != nil {
// Block filter requested, construct a single-shot filter // Block filter requested, construct a single-shot filter
filter = filters.NewBlockFilter(&filterBackend{b.database, b.blockchain}, *query.BlockHash, query.Addresses, query.Topics) filter = filters.NewBlockFilter(&filterBackend{b.database, b.blockchain, b}, *query.BlockHash, query.Addresses, query.Topics)
} else { } else {
// Initialize unset filter boundaries to run from genesis to chain head // Initialize unset filter boundaries to run from genesis to chain head
from := int64(0) from := int64(0)
@ -695,7 +697,7 @@ func (b *SimulatedBackend) FilterLogs(ctx context.Context, query ethereum.Filter
to = query.ToBlock.Int64() to = query.ToBlock.Int64()
} }
// Construct the range filter // Construct the range filter
filter = filters.NewRangeFilter(&filterBackend{b.database, b.blockchain}, from, to, query.Addresses, query.Topics) filter = filters.NewRangeFilter(&filterBackend{b.database, b.blockchain, b}, from, to, query.Addresses, query.Topics)
} }
// Run the filter and return all the logs // Run the filter and return all the logs
logs, err := filter.Logs(ctx) logs, err := filter.Logs(ctx)
@ -818,6 +820,7 @@ func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList }
type filterBackend struct { type filterBackend struct {
db ethdb.Database db ethdb.Database
bc *core.BlockChain bc *core.BlockChain
backend *SimulatedBackend
} }
func (fb *filterBackend) ChainDb() ethdb.Database { return fb.db } func (fb *filterBackend) ChainDb() ethdb.Database { return fb.db }
@ -834,6 +837,10 @@ func (fb *filterBackend) HeaderByHash(ctx context.Context, hash common.Hash) (*t
return fb.bc.GetHeaderByHash(hash), nil return fb.bc.GetHeaderByHash(hash), nil
} }
func (fb *filterBackend) PendingBlockAndReceipts() (*types.Block, types.Receipts) {
return fb.backend.pendingBlock, fb.backend.pendingReceipts
}
func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) {
number := rawdb.ReadHeaderNumber(fb.db, hash) number := rawdb.ReadHeaderNumber(fb.db, hash)
if number == nil { if number == nil {

View File

@ -36,6 +36,7 @@ type Backend interface {
HeaderByHash(ctx context.Context, blockHash common.Hash) (*types.Header, error) HeaderByHash(ctx context.Context, blockHash common.Hash) (*types.Header, error)
GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error)
GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error)
PendingBlockAndReceipts() (*types.Block, types.Receipts)
SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription
SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription
@ -128,26 +129,35 @@ func (f *Filter) Logs(ctx context.Context) ([]*types.Log, error) {
} }
return f.blockLogs(ctx, header) return f.blockLogs(ctx, header)
} }
// Short-cut if all we care about is pending logs
if f.begin == rpc.PendingBlockNumber.Int64() {
if f.end != rpc.PendingBlockNumber.Int64() {
return nil, errors.New("invalid block range")
}
return f.pendingLogs()
}
// Figure out the limits of the filter range // Figure out the limits of the filter range
header, _ := f.backend.HeaderByNumber(ctx, rpc.LatestBlockNumber) header, _ := f.backend.HeaderByNumber(ctx, rpc.LatestBlockNumber)
if header == nil { if header == nil {
return nil, nil return nil, nil
} }
head := header.Number.Uint64() var (
head = header.Number.Uint64()
if f.begin == -1 { end = uint64(f.end)
pending = f.end == rpc.PendingBlockNumber.Int64()
)
if f.begin == rpc.LatestBlockNumber.Int64() {
f.begin = int64(head) f.begin = int64(head)
} }
end := uint64(f.end) if f.end == rpc.LatestBlockNumber.Int64() || f.end == rpc.PendingBlockNumber.Int64() {
if f.end == -1 {
end = head end = head
} }
// Gather all indexed logs, and finish with non indexed ones // Gather all indexed logs, and finish with non indexed ones
var ( var (
logs []*types.Log logs []*types.Log
err error err error
size, sections = f.backend.BloomStatus()
) )
size, sections := f.backend.BloomStatus()
if indexed := sections * size; indexed > uint64(f.begin) { if indexed := sections * size; indexed > uint64(f.begin) {
if indexed > end { if indexed > end {
logs, err = f.indexedLogs(ctx, end) logs, err = f.indexedLogs(ctx, end)
@ -160,6 +170,13 @@ func (f *Filter) Logs(ctx context.Context) ([]*types.Log, error) {
} }
rest, err := f.unindexedLogs(ctx, end) rest, err := f.unindexedLogs(ctx, end)
logs = append(logs, rest...) logs = append(logs, rest...)
if pending {
pendingLogs, err := f.pendingLogs()
if err != nil {
return nil, err
}
logs = append(logs, pendingLogs...)
}
return logs, err return logs, err
} }
@ -272,6 +289,19 @@ func (f *Filter) checkMatches(ctx context.Context, header *types.Header) (logs [
return nil, nil return nil, nil
} }
// pendingLogs returns the logs matching the filter criteria within the pending block.
func (f *Filter) pendingLogs() ([]*types.Log, error) {
block, receipts := f.backend.PendingBlockAndReceipts()
if bloomFilter(block.Bloom(), f.addresses, f.topics) {
var unfiltered []*types.Log
for _, r := range receipts {
unfiltered = append(unfiltered, r.Logs...)
}
return filterLogs(unfiltered, nil, nil, f.addresses, f.topics), nil
}
return nil, nil
}
func includes(addresses []common.Address, a common.Address) bool { func includes(addresses []common.Address, a common.Address) bool {
for _, addr := range addresses { for _, addr := range addresses {
if addr == a { if addr == a {

View File

@ -106,6 +106,10 @@ func (b *testBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*types
return logs, nil return logs, nil
} }
func (b *testBackend) PendingBlockAndReceipts() (*types.Block, types.Receipts) {
return nil, nil
}
func (b *testBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { func (b *testBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription {
return b.txFeed.Subscribe(ch) return b.txFeed.Subscribe(ch)
} }

View File

@ -65,6 +65,7 @@ type Backend interface {
BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error)
StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error) StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error)
StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error)
PendingBlockAndReceipts() (*types.Block, types.Receipts)
GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error)
GetTd(ctx context.Context, hash common.Hash) *big.Int GetTd(ctx context.Context, hash common.Hash) *big.Int
GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error)