diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a502679..fd2b8fb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (evm) [tharsis#342](https://github.com/tharsis/ethermint/issues/342) Don't clear balance when resetting the account. * (evm) [tharsis#334](https://github.com/tharsis/ethermint/pull/334) Log index changed to the index in block rather than tx. +* (evm) [tharsis#399](https://github.com/tharsis/ethermint/pull/399) Exception in sub-message call reverts the call if it's not propagated. ### API Breaking diff --git a/tests/solidity/suites/exception/contracts/TestRevert.sol b/tests/solidity/suites/exception/contracts/TestRevert.sol new file mode 100644 index 00000000..57d5807c --- /dev/null +++ b/tests/solidity/suites/exception/contracts/TestRevert.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +pragma solidity >=0.8.0; + +contract State { + uint256 a = 0; + function set(uint256 input) public { + a = input; + require(a < 10); + } + function force_set(uint256 input) public { + a = input; + } + function query() public view returns(uint256) { + return a; + } +} + +contract TestRevert { + State state; + uint256 b = 0; + uint256 c = 0; + constructor() { + state = new State(); + } + function try_set(uint256 input) public { + b = input; + try state.set(input) { + } catch (bytes memory) { + } + c = input; + } + function set(uint256 input) public { + state.force_set(input); + } + function query_a() public view returns(uint256) { + return state.query(); + } + function query_b() public view returns(uint256) { + return b; + } + function query_c() public view returns(uint256) { + return c; + } +} diff --git a/tests/solidity/suites/exception/contracts/test/Migrations.sol b/tests/solidity/suites/exception/contracts/test/Migrations.sol new file mode 100644 index 00000000..ef49fe5a --- /dev/null +++ b/tests/solidity/suites/exception/contracts/test/Migrations.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +pragma solidity >=0.8.0; + +contract Migrations { + address public owner = msg.sender; + uint public last_completed_migration; + + modifier restricted() { + require( + msg.sender == owner, + "This function is restricted to the contract's owner" + ); + _; + } + + function setCompleted(uint completed) public restricted { + last_completed_migration = completed; + } +} diff --git a/tests/solidity/suites/exception/migrations/1_initial_migration.js b/tests/solidity/suites/exception/migrations/1_initial_migration.js new file mode 100644 index 00000000..16a7ba52 --- /dev/null +++ b/tests/solidity/suites/exception/migrations/1_initial_migration.js @@ -0,0 +1,5 @@ +const Migrations = artifacts.require("Migrations"); + +module.exports = function (deployer) { + deployer.deploy(Migrations); +}; diff --git a/tests/solidity/suites/exception/package.json b/tests/solidity/suites/exception/package.json new file mode 100644 index 00000000..63e94e15 --- /dev/null +++ b/tests/solidity/suites/exception/package.json @@ -0,0 +1,15 @@ +{ + "name": "exception", + "version": "1.0.0", + "author": "huangyi ", + "license": "GPL-3.0-or-later", + "scripts": { + "test-ganache": "yarn truffle test", + "test-ethermint": "yarn truffle test --network ethermint" + }, + "devDependencies": { + "truffle": "^5.1.42", + "truffle-assertions": "^0.9.2", + "web3": "^1.2.11" + } +} diff --git a/tests/solidity/suites/exception/test/.gitkeep b/tests/solidity/suites/exception/test/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/tests/solidity/suites/exception/test/revert.js b/tests/solidity/suites/exception/test/revert.js new file mode 100644 index 00000000..73d53e4d --- /dev/null +++ b/tests/solidity/suites/exception/test/revert.js @@ -0,0 +1,35 @@ +const TestRevert = artifacts.require("TestRevert") +const truffleAssert = require('truffle-assertions'); + +async function expectRevert(promise) { + try { + await promise; + } catch (error) { + if (error.message.indexOf('revert') === -1) { + expect('revert').to.equal(error.message, 'Wrong kind of exception received'); + } + return; + } + expect.fail('Expected an exception but none was received'); +} + +contract('TestRevert', (accounts) => { + let revert + + beforeEach(async () => { + revert = await TestRevert.new() + }) + it('should revert', async () => { + await revert.try_set(10) + no = await revert.query_a() + assert.equal(no, '0', 'The modification on a should be reverted') + no = await revert.query_b() + assert.equal(no, '10', 'The modification on b should not be reverted') + no = await revert.query_c() + assert.equal(no, '10', 'The modification on c should not be reverted') + + await revert.set(10) + no = await revert.query_a() + assert.equal(no, '10', 'The force set should not be reverted') + }) +}) diff --git a/tests/solidity/suites/exception/truffle-config.js b/tests/solidity/suites/exception/truffle-config.js new file mode 100644 index 00000000..b0dbcc9d --- /dev/null +++ b/tests/solidity/suites/exception/truffle-config.js @@ -0,0 +1,17 @@ +module.exports = { + networks: { + // Development network is just left as truffle's default settings + ethermint: { + host: "127.0.0.1", // Localhost (default: none) + port: 8545, // Standard Ethereum port (default: none) + network_id: "*", // Any network (default: none) + gas: 5000000, // Gas sent with each transaction + gasPrice: 1000000000, // 1 gwei (in wei) + }, + }, + compilers: { + solc: { + version: "0.8.6", + }, + }, +} diff --git a/x/evm/keeper/context_stack.go b/x/evm/keeper/context_stack.go new file mode 100644 index 00000000..9d1bc68c --- /dev/null +++ b/x/evm/keeper/context_stack.go @@ -0,0 +1,87 @@ +package keeper + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// cachedContext is a pair of cache context and its corresponding commit method. +// They are obtained from the return value of `context.CacheContext()`. +type cachedContext struct { + ctx sdk.Context + commit func() +} + +// ContextStack manages the initial context and a stack of cached contexts, +// to support the `StateDB.Snapshot` and `StateDB.RevertToSnapshot` methods. +type ContextStack struct { + // Context of the initial state before transaction execution. + // It's the context used by `StateDB.CommitedState`. + initialCtx sdk.Context + cachedContexts []cachedContext +} + +// CurrentContext returns the top context of cached stack, +// if the stack is empty, returns the initial context. +func (cs *ContextStack) CurrentContext() sdk.Context { + l := len(cs.cachedContexts) + if l == 0 { + return cs.initialCtx + } + return cs.cachedContexts[l-1].ctx +} + +// Reset sets the initial context and clear the cache context stack. +func (cs *ContextStack) Reset(ctx sdk.Context) { + cs.initialCtx = ctx + if len(cs.cachedContexts) > 0 { + cs.cachedContexts = []cachedContext{} + } +} + +// IsEmpty returns true if the cache context stack is empty. +func (cs *ContextStack) IsEmpty() bool { + return len(cs.cachedContexts) == 0 +} + +// Commit commits all the cached contexts from top to bottom in order and clears the stack by setting an empty slice of cache contexts. +func (cs *ContextStack) Commit() { + // commit in order from top to bottom + for i := len(cs.cachedContexts) - 1; i >= 0; i-- { + // keep all the cosmos events + cs.initialCtx.EventManager().EmitEvents(cs.cachedContexts[i].ctx.EventManager().Events()) + if cs.cachedContexts[i].commit == nil { + panic(fmt.Sprintf("commit function at index %d should not be nil", i)) + } else { + cs.cachedContexts[i].commit() + } + } + cs.cachedContexts = []cachedContext{} +} + +// Snapshot pushes a new cached context to the stack, +// and returns the index of it. +func (cs *ContextStack) Snapshot() int { + i := len(cs.cachedContexts) + ctx, commit := cs.CurrentContext().CacheContext() + cs.cachedContexts = append(cs.cachedContexts, cachedContext{ctx: ctx, commit: commit}) + return i +} + +// RevertToSnapshot pops all the cached contexts after the target index (inclusive). +// the target should be snapshot index returned by `Snapshot`. +// This function panics if the index is out of bounds. +func (cs *ContextStack) RevertToSnapshot(target int) { + if target < 0 || target >= len(cs.cachedContexts) { + panic(fmt.Errorf("snapshot index %d out of bound [%d..%d)", target, 0, len(cs.cachedContexts))) + } + cs.cachedContexts = cs.cachedContexts[:target] +} + +// RevertAll discards all the cache contexts. +func (cs *ContextStack) RevertAll() { + if len(cs.cachedContexts) > 0 { + cs.RevertToSnapshot(0) + } +} diff --git a/x/evm/keeper/grpc_query.go b/x/evm/keeper/grpc_query.go index 5ecee5f7..db00b348 100644 --- a/x/evm/keeper/grpc_query.go +++ b/x/evm/keeper/grpc_query.go @@ -265,7 +265,7 @@ func (k Keeper) BlockBloom(c context.Context, req *types.QueryBlockBloomRequest) bloom, found := k.GetBlockBloom(ctx, req.Height) if !found { // if the bloom is not found, query the transient store at the current height - k.ctx = ctx + k.WithContext(ctx) bloomInt := k.GetBlockBloomTransient() if bloomInt.Sign() == 0 { @@ -381,6 +381,7 @@ func (k Keeper) EthCall(c context.Context, req *types.EthCallRequest) (*types.Ms evm := k.NewEVM(msg, ethCfg, params, coinbase) // pass true means execute in query mode, which don't do actual gas refund. res, err := k.ApplyMessage(evm, msg, ethCfg, true) + k.ctxStack.RevertAll() if err != nil { return nil, status.Error(codes.Internal, err.Error()) } @@ -443,15 +444,15 @@ func (k Keeper) EstimateGas(c context.Context, req *types.EthCallRequest) (*type executable := func(gas uint64) (bool, *types.MsgEthereumTxResponse, error) { args.Gas = (*hexutil.Uint64)(&gas) - // Execute the call in an isolated context - k.BeginCachedContext() + // Reset to the initial context + k.WithContext(ctx) msg := args.ToMessage(req.GasCap) evm := k.NewEVM(msg, ethCfg, params, coinbase) // pass true means execute in query mode, which don't do actual gas refund. rsp, err := k.ApplyMessage(evm, msg, ethCfg, true) - k.EndCachedContext() + k.ctxStack.RevertAll() if err != nil { if errors.Is(stacktrace.RootCause(err), core.ErrIntrinsicGas) { diff --git a/x/evm/keeper/keeper.go b/x/evm/keeper/keeper.go index 42df738b..c91a1cfa 100644 --- a/x/evm/keeper/keeper.go +++ b/x/evm/keeper/keeper.go @@ -40,13 +40,11 @@ type Keeper struct { // access historical headers for EVM state transition execution stakingKeeper types.StakingKeeper - // Context for accessing the store, emit events and log info. + // Manage the initial context and cache context stack for accessing the store, + // emit events and log info. // It is kept as a field to make is accessible by the StateDb // functions. Resets on every transaction/block. - ctx sdk.Context - // Context of the committed state (before transaction execution). - // Required for StateDB.CommitedState. Set in `BeginCachedContext`. - committedCtx sdk.Context + ctxStack ContextStack // chain ID number obtained from the context's chain id eip155ChainID *big.Int @@ -86,9 +84,19 @@ func NewKeeper( } } -// CommittedCtx returns the committed context -func (k Keeper) CommittedCtx() sdk.Context { - return k.committedCtx +// Ctx returns the current context from the context stack +func (k Keeper) Ctx() sdk.Context { + return k.ctxStack.CurrentContext() +} + +// CommitCachedContexts commit all the cache contexts created by `StateDB.Snapshot`. +func (k *Keeper) CommitCachedContexts() { + k.ctxStack.Commit() +} + +// CachedContextsEmpty returns true if there's no cache contexts. +func (k *Keeper) CachedContextsEmpty() bool { + return k.ctxStack.IsEmpty() } // Logger returns a module-specific logger. @@ -96,10 +104,9 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger { return ctx.Logger().With("module", types.ModuleName) } -// WithContext sets an updated SDK context to the keeper +// WithContext clears the context stack, and set the initial context. func (k *Keeper) WithContext(ctx sdk.Context) { - k.ctx = ctx - k.committedCtx = ctx + k.ctxStack.Reset(ctx) } // WithChainID sets the chain id to the local variable in the keeper @@ -147,8 +154,8 @@ func (k Keeper) SetBlockBloom(ctx sdk.Context, height int64, bloom ethtypes.Bloo // GetBlockBloomTransient returns bloom bytes for the current block height func (k Keeper) GetBlockBloomTransient() *big.Int { - store := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientBloom) - heightBz := sdk.Uint64ToBigEndian(uint64(k.ctx.BlockHeight())) + store := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientBloom) + heightBz := sdk.Uint64ToBigEndian(uint64(k.Ctx().BlockHeight())) bz := store.Get(heightBz) if len(bz) == 0 { return big.NewInt(0) @@ -160,8 +167,8 @@ func (k Keeper) GetBlockBloomTransient() *big.Int { // SetBlockBloomTransient sets the given bloom bytes to the transient store. This value is reset on // every block. func (k Keeper) SetBlockBloomTransient(bloom *big.Int) { - store := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientBloom) - heightBz := sdk.Uint64ToBigEndian(uint64(k.ctx.BlockHeight())) + store := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientBloom) + heightBz := sdk.Uint64ToBigEndian(uint64(k.Ctx().BlockHeight())) store.Set(heightBz, bloom.Bytes()) } @@ -171,7 +178,7 @@ func (k Keeper) SetBlockBloomTransient(bloom *big.Int) { // GetTxHashTransient returns the hash of current processing transaction func (k Keeper) GetTxHashTransient() common.Hash { - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) bz := store.Get(types.KeyPrefixTransientTxHash) if len(bz) == 0 { return common.Hash{} @@ -182,13 +189,13 @@ func (k Keeper) GetTxHashTransient() common.Hash { // SetTxHashTransient set the hash of processing transaction func (k Keeper) SetTxHashTransient(hash common.Hash) { - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) store.Set(types.KeyPrefixTransientTxHash, hash.Bytes()) } // GetTxIndexTransient returns EVM transaction index on the current block. func (k Keeper) GetTxIndexTransient() uint64 { - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) bz := store.Get(types.KeyPrefixTransientTxIndex) if len(bz) == 0 { return 0 @@ -201,7 +208,7 @@ func (k Keeper) GetTxIndexTransient() uint64 { // value by one and then sets the new index back to the transient store. func (k Keeper) IncreaseTxIndexTransient() { txIndex := k.GetTxIndexTransient() - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) store.Set(types.KeyPrefixTransientTxIndex, sdk.Uint64ToBigEndian(txIndex+1)) } @@ -235,7 +242,7 @@ func (k Keeper) GetAllTxLogs(ctx sdk.Context) []types.TransactionLogs { // GetLogs returns the current logs for a given transaction hash from the KVStore. // This function returns an empty, non-nil slice if no logs are found. func (k Keeper) GetTxLogs(txHash common.Hash) []*ethtypes.Log { - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.KeyPrefixLogs) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.KeyPrefixLogs) bz := store.Get(txHash.Bytes()) if len(bz) == 0 { @@ -250,7 +257,7 @@ func (k Keeper) GetTxLogs(txHash common.Hash) []*ethtypes.Log { // SetLogs sets the logs for a transaction in the KVStore. func (k Keeper) SetLogs(txHash common.Hash, logs []*ethtypes.Log) { - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.KeyPrefixLogs) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.KeyPrefixLogs) txLogs := types.NewTransactionLogsFromEth(txHash, logs) bz := k.cdc.MustMarshal(&txLogs) @@ -266,7 +273,7 @@ func (k Keeper) DeleteTxLogs(ctx sdk.Context, txHash common.Hash) { // GetLogSizeTransient returns EVM log index on the current block. func (k Keeper) GetLogSizeTransient() uint64 { - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) bz := store.Get(types.KeyPrefixTransientLogSize) if len(bz) == 0 { return 0 @@ -279,7 +286,7 @@ func (k Keeper) GetLogSizeTransient() uint64 { // value by one and then sets the new index back to the transient store. func (k Keeper) IncreaseLogSizeTransient() { logSize := k.GetLogSizeTransient() - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) store.Set(types.KeyPrefixTransientLogSize, sdk.Uint64ToBigEndian(logSize+1)) } @@ -308,7 +315,7 @@ func (k Keeper) GetAccountStorage(ctx sdk.Context, address common.Address) (type // ---------------------------------------------------------------------------- func (k Keeper) DeleteState(addr common.Address, key common.Hash) { - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.AddressStoragePrefix(addr)) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.AddressStoragePrefix(addr)) key = types.KeyAddressStorage(addr, key) store.Delete(key.Bytes()) } @@ -329,22 +336,22 @@ func (k Keeper) DeleteCode(addr common.Address) { return } - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.KeyPrefixCode) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.KeyPrefixCode) store.Delete(hash.Bytes()) } // ClearBalance subtracts the EVM all the balance denomination from the address // balance while also updating the total supply. func (k Keeper) ClearBalance(addr sdk.AccAddress) (prevBalance sdk.Coin, err error) { - params := k.GetParams(k.ctx) + params := k.GetParams(k.Ctx()) - prevBalance = k.bankKeeper.GetBalance(k.ctx, addr, params.EvmDenom) + prevBalance = k.bankKeeper.GetBalance(k.Ctx(), addr, params.EvmDenom) if prevBalance.IsPositive() { - if err := k.bankKeeper.SendCoinsFromAccountToModule(k.ctx, addr, types.ModuleName, sdk.Coins{prevBalance}); err != nil { + if err := k.bankKeeper.SendCoinsFromAccountToModule(k.Ctx(), addr, types.ModuleName, sdk.Coins{prevBalance}); err != nil { return sdk.Coin{}, stacktrace.Propagate(err, "failed to transfer to module account") } - if err := k.bankKeeper.BurnCoins(k.ctx, types.ModuleName, sdk.Coins{prevBalance}); err != nil { + if err := k.bankKeeper.BurnCoins(k.Ctx(), types.ModuleName, sdk.Coins{prevBalance}); err != nil { return sdk.Coin{}, stacktrace.Propagate(err, "failed to burn coins from evm module account") } } @@ -358,15 +365,3 @@ func (k Keeper) ResetAccount(addr common.Address) { k.DeleteCode(addr) k.DeleteAccountStorage(addr) } - -// BeginCachedContext create the cached context -func (k *Keeper) BeginCachedContext() (commit func()) { - k.committedCtx = k.ctx - k.ctx, commit = k.ctx.CacheContext() - return -} - -// EndCachedContext recover the committed context -func (k *Keeper) EndCachedContext() { - k.ctx = k.committedCtx -} diff --git a/x/evm/keeper/state_transition.go b/x/evm/keeper/state_transition.go index 0319e015..1cae6d2a 100644 --- a/x/evm/keeper/state_transition.go +++ b/x/evm/keeper/state_transition.go @@ -34,9 +34,9 @@ func (k *Keeper) NewEVM(msg core.Message, config *params.ChainConfig, params typ Transfer: core.Transfer, GetHash: k.GetHashFn(), Coinbase: coinbase, - GasLimit: ethermint.BlockGasLimit(k.ctx), - BlockNumber: big.NewInt(k.ctx.BlockHeight()), - Time: big.NewInt(k.ctx.BlockHeader().Time.Unix()), + GasLimit: ethermint.BlockGasLimit(k.Ctx()), + BlockNumber: big.NewInt(k.Ctx().BlockHeight()), + Time: big.NewInt(k.Ctx().BlockHeader().Time.Unix()), Difficulty: big.NewInt(0), // unused. Only required in PoW context } @@ -65,38 +65,38 @@ func (k Keeper) GetHashFn() vm.GetHashFunc { return func(height uint64) common.Hash { h := int64(height) switch { - case k.ctx.BlockHeight() == h: + case k.Ctx().BlockHeight() == h: // Case 1: The requested height matches the one from the context so we can retrieve the header // hash directly from the context. // Note: The headerHash is only set at begin block, it will be nil in case of a query context - headerHash := k.ctx.HeaderHash() + headerHash := k.Ctx().HeaderHash() if len(headerHash) != 0 { return common.BytesToHash(headerHash) } // only recompute the hash if not set (eg: checkTxState) - contextBlockHeader := k.ctx.BlockHeader() + contextBlockHeader := k.Ctx().BlockHeader() header, err := tmtypes.HeaderFromProto(&contextBlockHeader) if err != nil { - k.Logger(k.ctx).Error("failed to cast tendermint header from proto", "error", err) + k.Logger(k.Ctx()).Error("failed to cast tendermint header from proto", "error", err) return common.Hash{} } headerHash = header.Hash() return common.BytesToHash(headerHash) - case k.ctx.BlockHeight() > h: + case k.Ctx().BlockHeight() > h: // Case 2: if the chain is not the current height we need to retrieve the hash from the store for the // current chain epoch. This only applies if the current height is greater than the requested height. - histInfo, found := k.stakingKeeper.GetHistoricalInfo(k.ctx, h) + histInfo, found := k.stakingKeeper.GetHistoricalInfo(k.Ctx(), h) if !found { - k.Logger(k.ctx).Debug("historical info not found", "height", h) + k.Logger(k.Ctx()).Debug("historical info not found", "height", h) return common.Hash{} } header, err := tmtypes.HeaderFromProto(&histInfo.Header) if err != nil { - k.Logger(k.ctx).Error("failed to cast tendermint header from proto", "error", err) + k.Logger(k.Ctx()).Error("failed to cast tendermint header from proto", "error", err) return common.Hash{} } @@ -128,20 +128,17 @@ func (k Keeper) GetHashFn() vm.GetHashFunc { func (k *Keeper) ApplyTransaction(tx *ethtypes.Transaction) (*types.MsgEthereumTxResponse, error) { defer telemetry.ModuleMeasureSince(types.ModuleName, time.Now(), types.MetricKeyTransitionDB) - params := k.GetParams(k.ctx) + params := k.GetParams(k.Ctx()) ethCfg := params.ChainConfig.EthereumConfig(k.eip155ChainID) // get the latest signer according to the chain rules from the config - signer := ethtypes.MakeSigner(ethCfg, big.NewInt(k.ctx.BlockHeight())) + signer := ethtypes.MakeSigner(ethCfg, big.NewInt(k.Ctx().BlockHeight())) msg, err := tx.AsMessage(signer) if err != nil { return nil, stacktrace.Propagate(err, "failed to return ethereum transaction as core message") } - // we use a cached context to avoid modifying to state in case EVM msg is reverted - commit := k.BeginCachedContext() - // get the coinbase address from the block proposer coinbase, err := k.GetCoinbaseAddress() if err != nil { @@ -158,6 +155,10 @@ func (k *Keeper) ApplyTransaction(tx *ethtypes.Transaction) (*types.MsgEthereumT k.SetTxHashTransient(txHash) k.IncreaseTxIndexTransient() + if !k.ctxStack.IsEmpty() { + panic("context stack shouldn't be dirty before apply message") + } + // pass false to execute in real mode, which do actual gas refunding res, err := k.ApplyMessage(evm, msg, ethCfg, false) if err != nil { @@ -166,26 +167,18 @@ func (k *Keeper) ApplyTransaction(tx *ethtypes.Transaction) (*types.MsgEthereumT res.Hash = txHash.Hex() logs := k.GetTxLogs(txHash) - - // Commit and switch to committed context - if !res.Failed() { - // keep the cosmos events emitted in the cache context - k.committedCtx.EventManager().EmitEvents(k.ctx.EventManager().Events()) - commit() - } - - k.EndCachedContext() - - // Logs needs to be ignored when tx is reverted - // Set the log and bloom filter only when the tx is NOT REVERTED - if !res.Failed() { + if len(logs) > 0 { res.Logs = types.NewLogsFromEth(logs) - // Update block bloom filter in the original context because blockbloom is set in EndBlock + // Update transient block bloom filter bloom := k.GetBlockBloomTransient() bloom.Or(bloom, big.NewInt(0).SetBytes(ethtypes.LogsBloom(logs))) k.SetBlockBloomTransient(bloom) } + // Since we've implemented `RevertToSnapshot` api, so for the vm error cases, + // the state is reverted, so it's ok to call the commit here anyway. + k.CommitCachedContexts() + // update the gas used after refund k.resetGasMeterAndConsumeGas(res.GasUsed) return res, nil @@ -292,7 +285,7 @@ func (k *Keeper) ApplyMessage(evm *vm.EVM, msg core.Message, cfg *params.ChainCo // GetEthIntrinsicGas returns the intrinsic gas cost for the transaction func (k *Keeper) GetEthIntrinsicGas(msg core.Message, cfg *params.ChainConfig, isContractCreation bool) (uint64, error) { - height := big.NewInt(k.ctx.BlockHeight()) + height := big.NewInt(k.Ctx().BlockHeight()) homestead := cfg.IsHomestead(height) istanbul := cfg.IsIstanbul(height) @@ -347,12 +340,12 @@ func (k *Keeper) RefundGas(msg core.Message, leftoverGas uint64) (uint64, error) return leftoverGas, sdkerrors.Wrapf(types.ErrInvalidRefund, "refunded amount value cannot be negative %d", remaining.Int64()) case 1: // positive amount refund - params := k.GetParams(k.ctx) + params := k.GetParams(k.Ctx()) refundedCoins := sdk.Coins{sdk.NewCoin(params.EvmDenom, sdk.NewIntFromBigInt(remaining))} // refund to sender from the fee collector module account, which is the escrow account in charge of collecting tx fees - err := k.bankKeeper.SendCoinsFromModuleToAccount(k.ctx, authtypes.FeeCollectorName, msg.From().Bytes(), refundedCoins) + err := k.bankKeeper.SendCoinsFromModuleToAccount(k.Ctx(), authtypes.FeeCollectorName, msg.From().Bytes(), refundedCoins) if err != nil { err = sdkerrors.Wrapf(sdkerrors.ErrInsufficientFunds, "fee collector account failed to refund fees: %s", err.Error()) return leftoverGas, stacktrace.Propagate(err, "failed to refund %d leftover gas (%s)", leftoverGas, refundedCoins.String()) @@ -368,14 +361,14 @@ func (k *Keeper) RefundGas(msg core.Message, leftoverGas uint64) (uint64, error) // 'gasUsed' func (k *Keeper) resetGasMeterAndConsumeGas(gasUsed uint64) { // reset the gas count - k.ctx.GasMeter().RefundGas(k.ctx.GasMeter().GasConsumed(), "reset the gas count") - k.ctx.GasMeter().ConsumeGas(gasUsed, "apply evm transaction") + k.Ctx().GasMeter().RefundGas(k.Ctx().GasMeter().GasConsumed(), "reset the gas count") + k.Ctx().GasMeter().ConsumeGas(gasUsed, "apply evm transaction") } // GetCoinbaseAddress returns the block proposer's validator operator address. func (k Keeper) GetCoinbaseAddress() (common.Address, error) { - consAddr := sdk.ConsAddress(k.ctx.BlockHeader().ProposerAddress) - validator, found := k.stakingKeeper.GetValidatorByConsAddr(k.ctx, consAddr) + consAddr := sdk.ConsAddress(k.Ctx().BlockHeader().ProposerAddress) + validator, found := k.stakingKeeper.GetValidatorByConsAddr(k.Ctx(), consAddr) if !found { return common.Address{}, stacktrace.Propagate( sdkerrors.Wrap(stakingtypes.ErrNoValidatorFound, consAddr.String()), diff --git a/x/evm/keeper/statedb.go b/x/evm/keeper/statedb.go index 62a62e0b..21eaf912 100644 --- a/x/evm/keeper/statedb.go +++ b/x/evm/keeper/statedb.go @@ -30,7 +30,7 @@ var _ vm.StateDB = &Keeper{} func (k *Keeper) CreateAccount(addr common.Address) { cosmosAddr := sdk.AccAddress(addr.Bytes()) - account := k.accountKeeper.GetAccount(k.ctx, cosmosAddr) + account := k.accountKeeper.GetAccount(k.Ctx(), cosmosAddr) log := "" if account == nil { log = "account created" @@ -39,10 +39,10 @@ func (k *Keeper) CreateAccount(addr common.Address) { k.ResetAccount(addr) } - account = k.accountKeeper.NewAccountWithAddress(k.ctx, cosmosAddr) - k.accountKeeper.SetAccount(k.ctx, account) + account = k.accountKeeper.NewAccountWithAddress(k.Ctx(), cosmosAddr) + k.accountKeeper.SetAccount(k.Ctx(), account) - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( log, "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -58,7 +58,7 @@ func (k *Keeper) CreateAccount(addr common.Address) { // from the module parameters. func (k *Keeper) AddBalance(addr common.Address, amount *big.Int) { if amount.Sign() != 1 { - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "ignored non-positive amount addition", "ethereum-address", addr.Hex(), "amount", amount.Int64(), @@ -68,11 +68,11 @@ func (k *Keeper) AddBalance(addr common.Address, amount *big.Int) { cosmosAddr := sdk.AccAddress(addr.Bytes()) - params := k.GetParams(k.ctx) + params := k.GetParams(k.Ctx()) coins := sdk.Coins{sdk.NewCoin(params.EvmDenom, sdk.NewIntFromBigInt(amount))} - if err := k.bankKeeper.MintCoins(k.ctx, types.ModuleName, coins); err != nil { - k.Logger(k.ctx).Error( + if err := k.bankKeeper.MintCoins(k.Ctx(), types.ModuleName, coins); err != nil { + k.Logger(k.Ctx()).Error( "failed to mint coins when adding balance", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -81,8 +81,8 @@ func (k *Keeper) AddBalance(addr common.Address, amount *big.Int) { return } - if err := k.bankKeeper.SendCoinsFromModuleToAccount(k.ctx, types.ModuleName, cosmosAddr, coins); err != nil { - k.Logger(k.ctx).Error( + if err := k.bankKeeper.SendCoinsFromModuleToAccount(k.Ctx(), types.ModuleName, cosmosAddr, coins); err != nil { + k.Logger(k.Ctx()).Error( "failed to send from module to account when adding balance", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -91,7 +91,7 @@ func (k *Keeper) AddBalance(addr common.Address, amount *big.Int) { return } - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "balance addition", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -104,7 +104,7 @@ func (k *Keeper) AddBalance(addr common.Address, amount *big.Int) { // or the user doesn't have enough funds for the transfer. func (k *Keeper) SubBalance(addr common.Address, amount *big.Int) { if amount.Sign() != 1 { - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "ignored non-positive amount addition", "ethereum-address", addr.Hex(), "amount", amount.Int64(), @@ -114,11 +114,11 @@ func (k *Keeper) SubBalance(addr common.Address, amount *big.Int) { cosmosAddr := sdk.AccAddress(addr.Bytes()) - params := k.GetParams(k.ctx) + params := k.GetParams(k.Ctx()) coins := sdk.Coins{sdk.NewCoin(params.EvmDenom, sdk.NewIntFromBigInt(amount))} - if err := k.bankKeeper.SendCoinsFromAccountToModule(k.ctx, cosmosAddr, types.ModuleName, coins); err != nil { - k.Logger(k.ctx).Debug( + if err := k.bankKeeper.SendCoinsFromAccountToModule(k.Ctx(), cosmosAddr, types.ModuleName, coins); err != nil { + k.Logger(k.Ctx()).Debug( "failed to send from account to module when subtracting balance", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -128,8 +128,8 @@ func (k *Keeper) SubBalance(addr common.Address, amount *big.Int) { return } - if err := k.bankKeeper.BurnCoins(k.ctx, types.ModuleName, coins); err != nil { - k.Logger(k.ctx).Error( + if err := k.bankKeeper.BurnCoins(k.Ctx(), types.ModuleName, coins); err != nil { + k.Logger(k.Ctx()).Error( "failed to burn coins when subtracting balance", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -138,7 +138,7 @@ func (k *Keeper) SubBalance(addr common.Address, amount *big.Int) { return } - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "balance subtraction", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -149,8 +149,8 @@ func (k *Keeper) SubBalance(addr common.Address, amount *big.Int) { // denomination is obtained from the module parameters. func (k *Keeper) GetBalance(addr common.Address) *big.Int { cosmosAddr := sdk.AccAddress(addr.Bytes()) - params := k.GetParams(k.ctx) - balance := k.bankKeeper.GetBalance(k.ctx, cosmosAddr, params.EvmDenom) + params := k.GetParams(k.Ctx()) + balance := k.bankKeeper.GetBalance(k.Ctx(), cosmosAddr, params.EvmDenom) return balance.Amount.BigInt() } @@ -163,9 +163,9 @@ func (k *Keeper) GetBalance(addr common.Address) *big.Int { // sequence (i.e nonce). The function performs a no-op if the account is not found. func (k *Keeper) GetNonce(addr common.Address) uint64 { cosmosAddr := sdk.AccAddress(addr.Bytes()) - nonce, err := k.accountKeeper.GetSequence(k.ctx, cosmosAddr) + nonce, err := k.accountKeeper.GetSequence(k.Ctx(), cosmosAddr) if err != nil { - k.Logger(k.ctx).Error( + k.Logger(k.Ctx()).Error( "account not found", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -180,20 +180,20 @@ func (k *Keeper) GetNonce(addr common.Address) uint64 { // account doesn't exist, a new one will be created from the address. func (k *Keeper) SetNonce(addr common.Address, nonce uint64) { cosmosAddr := sdk.AccAddress(addr.Bytes()) - account := k.accountKeeper.GetAccount(k.ctx, cosmosAddr) + account := k.accountKeeper.GetAccount(k.Ctx(), cosmosAddr) if account == nil { - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "account not found", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), ) // create address if it doesn't exist - account = k.accountKeeper.NewAccountWithAddress(k.ctx, cosmosAddr) + account = k.accountKeeper.NewAccountWithAddress(k.Ctx(), cosmosAddr) } if err := account.SetSequence(nonce); err != nil { - k.Logger(k.ctx).Error( + k.Logger(k.Ctx()).Error( "failed to set nonce", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -204,9 +204,9 @@ func (k *Keeper) SetNonce(addr common.Address, nonce uint64) { return } - k.accountKeeper.SetAccount(k.ctx, account) + k.accountKeeper.SetAccount(k.Ctx(), account) - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "nonce set", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -222,7 +222,7 @@ func (k *Keeper) SetNonce(addr common.Address, nonce uint64) { // exist or is not an EthAccount type, GetCodeHash returns the empty code hash value. func (k *Keeper) GetCodeHash(addr common.Address) common.Hash { cosmosAddr := sdk.AccAddress(addr.Bytes()) - account := k.accountKeeper.GetAccount(k.ctx, cosmosAddr) + account := k.accountKeeper.GetAccount(k.Ctx(), cosmosAddr) if account == nil { return common.BytesToHash(types.EmptyCodeHash) } @@ -244,11 +244,11 @@ func (k *Keeper) GetCode(addr common.Address) []byte { return nil } - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.KeyPrefixCode) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.KeyPrefixCode) code := store.Get(hash.Bytes()) if len(code) == 0 { - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "code not found", "ethereum-address", addr.Hex(), "code-hash", hash.Hex(), @@ -262,20 +262,20 @@ func (k *Keeper) GetCode(addr common.Address) []byte { // code hash to the given account. The code is deleted from the store if it is empty. func (k *Keeper) SetCode(addr common.Address, code []byte) { if bytes.Equal(code, types.EmptyCodeHash) { - k.Logger(k.ctx).Debug("passed in EmptyCodeHash, but expected empty code") + k.Logger(k.Ctx()).Debug("passed in EmptyCodeHash, but expected empty code") } hash := crypto.Keccak256Hash(code) // update account code hash - account := k.accountKeeper.GetAccount(k.ctx, addr.Bytes()) + account := k.accountKeeper.GetAccount(k.Ctx(), addr.Bytes()) if account == nil { - account = k.accountKeeper.NewAccountWithAddress(k.ctx, addr.Bytes()) - k.accountKeeper.SetAccount(k.ctx, account) + account = k.accountKeeper.NewAccountWithAddress(k.Ctx(), addr.Bytes()) + k.accountKeeper.SetAccount(k.Ctx(), account) } ethAccount, isEthAccount := account.(*ethermint.EthAccount) if !isEthAccount { - k.Logger(k.ctx).Error( + k.Logger(k.Ctx()).Error( "invalid account type", "ethereum-address", addr.Hex(), "code-hash", hash.Hex(), @@ -284,9 +284,9 @@ func (k *Keeper) SetCode(addr common.Address, code []byte) { } ethAccount.CodeHash = hash.Hex() - k.accountKeeper.SetAccount(k.ctx, ethAccount) + k.accountKeeper.SetAccount(k.Ctx(), ethAccount) - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.KeyPrefixCode) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.KeyPrefixCode) action := "updated" @@ -298,7 +298,7 @@ func (k *Keeper) SetCode(addr common.Address, code []byte) { store.Set(hash.Bytes(), code) } - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( fmt.Sprintf("code %s", action), "ethereum-address", addr.Hex(), "code-hash", hash.Hex(), @@ -327,7 +327,7 @@ func (k *Keeper) AddRefund(gas uint64) { refund += gas - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) store.Set(types.KeyPrefixTransientRefund, sdk.Uint64ToBigEndian(refund)) } @@ -343,14 +343,14 @@ func (k *Keeper) SubRefund(gas uint64) { refund -= gas - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) store.Set(types.KeyPrefixTransientRefund, sdk.Uint64ToBigEndian(refund)) } // GetRefund returns the amount of gas available for return after the tx execution // finalizes. This value is reset to 0 on every transaction. func (k *Keeper) GetRefund() uint64 { - store := k.ctx.TransientStore(k.transientKey) + store := k.Ctx().TransientStore(k.transientKey) bz := store.Get(types.KeyPrefixTransientRefund) if len(bz) == 0 { @@ -379,19 +379,19 @@ func doGetState(ctx sdk.Context, storeKey sdk.StoreKey, addr common.Address, has // GetCommittedState returns the value set in store for the given key hash. If the key is not registered // this function returns the empty hash. func (k *Keeper) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { - return doGetState(k.committedCtx, k.storeKey, addr, hash) + return doGetState(k.ctxStack.initialCtx, k.storeKey, addr, hash) } // GetState returns the committed state for the given key hash, as all changes are committed directly // to the KVStore. func (k *Keeper) GetState(addr common.Address, hash common.Hash) common.Hash { - return doGetState(k.ctx, k.storeKey, addr, hash) + return doGetState(k.Ctx(), k.storeKey, addr, hash) } // SetState sets the given hashes (key, value) to the KVStore. If the value hash is empty, this // function deletes the key from the store. func (k *Keeper) SetState(addr common.Address, key, value common.Hash) { - store := prefix.NewStore(k.ctx.KVStore(k.storeKey), types.AddressStoragePrefix(addr)) + store := prefix.NewStore(k.Ctx().KVStore(k.storeKey), types.AddressStoragePrefix(addr)) key = types.KeyAddressStorage(addr, key) action := "updated" @@ -402,7 +402,7 @@ func (k *Keeper) SetState(addr common.Address, key, value common.Hash) { store.Set(key.Bytes(), value.Bytes()) } - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( fmt.Sprintf("state %s", action), "ethereum-address", addr.Hex(), "key", key.Hex(), @@ -425,7 +425,7 @@ func (k *Keeper) Suicide(addr common.Address) bool { _, err := k.ClearBalance(cosmosAddr) if err != nil { - k.Logger(k.ctx).Error( + k.Logger(k.Ctx()).Error( "failed to subtract balance on suicide", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -438,10 +438,10 @@ func (k *Keeper) Suicide(addr common.Address) bool { // TODO: (@fedekunze) do we also need to delete the storage state and the code? // Set a single byte to the transient store - store := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientSuicided) + store := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientSuicided) store.Set(addr.Bytes(), []byte{1}) - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "account suicided", "ethereum-address", addr.Hex(), "cosmos-address", cosmosAddr.String(), @@ -454,7 +454,7 @@ func (k *Keeper) Suicide(addr common.Address) bool { // current block. Accounts that are suicided will be returned as non-nil during queries and "cleared" // after the block has been committed. func (k *Keeper) HasSuicided(addr common.Address) bool { - store := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientSuicided) + store := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientSuicided) return store.Has(addr.Bytes()) } @@ -471,7 +471,7 @@ func (k *Keeper) Exist(addr common.Address) bool { } cosmosAddr := sdk.AccAddress(addr.Bytes()) - account := k.accountKeeper.GetAccount(k.ctx, cosmosAddr) + account := k.accountKeeper.GetAccount(k.Ctx(), cosmosAddr) return account != nil } @@ -486,7 +486,7 @@ func (k *Keeper) Empty(addr common.Address) bool { codeHash := types.EmptyCodeHash cosmosAddr := sdk.AccAddress(addr.Bytes()) - account := k.accountKeeper.GetAccount(k.ctx, cosmosAddr) + account := k.accountKeeper.GetAccount(k.Ctx(), cosmosAddr) if account != nil { nonce = account.GetSequence() @@ -537,7 +537,7 @@ func (k *Keeper) PrepareAccessList(sender common.Address, dest *common.Address, // AddressInAccessList returns true if the address is registered on the transient store. func (k *Keeper) AddressInAccessList(addr common.Address) bool { - ts := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientAccessListAddress) + ts := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientAccessListAddress) return ts.Has(addr.Bytes()) } @@ -550,7 +550,7 @@ func (k *Keeper) SlotInAccessList(addr common.Address, slot common.Hash) (addres // addressSlotInAccessList returns true if the address's slot is registered on the transient store. func (k *Keeper) addressSlotInAccessList(addr common.Address, slot common.Hash) bool { - ts := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientAccessListSlot) + ts := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientAccessListSlot) key := append(addr.Bytes(), slot.Bytes()...) return ts.Has(key) } @@ -562,7 +562,7 @@ func (k *Keeper) AddAddressToAccessList(addr common.Address) { return } - ts := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientAccessListAddress) + ts := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientAccessListAddress) ts.Set(addr.Bytes(), []byte{0x1}) } @@ -574,7 +574,7 @@ func (k *Keeper) AddSlotToAccessList(addr common.Address, slot common.Hash) { return } - ts := prefix.NewStore(k.ctx.TransientStore(k.transientKey), types.KeyPrefixTransientAccessListSlot) + ts := prefix.NewStore(k.Ctx().TransientStore(k.transientKey), types.KeyPrefixTransientAccessListSlot) key := append(addr.Bytes(), slot.Bytes()...) ts.Set(key, []byte{0x1}) } @@ -583,16 +583,15 @@ func (k *Keeper) AddSlotToAccessList(addr common.Address, slot common.Hash) { // Snapshotting // ---------------------------------------------------------------------------- -// Snapshot return zero as the state changes won't be committed if the state transition fails. So there -// is no need to snapshot before the VM execution. -// See Cosmos SDK docs for more info: https://docs.cosmos.network/master/core/baseapp.html#delivertx-state-updates +// Snapshot return the index in the cached context stack func (k *Keeper) Snapshot() int { - return 0 + return k.ctxStack.Snapshot() } -// RevertToSnapshot performs a no-op because when a transaction execution fails on the EVM, the state -// won't be persisted during ABCI DeliverTx. -func (k *Keeper) RevertToSnapshot(_ int) {} +// RevertToSnapshot pop all the cached contexts after(including) the snapshot +func (k *Keeper) RevertToSnapshot(target int) { + k.ctxStack.RevertToSnapshot(target) +} // ---------------------------------------------------------------------------- // Log @@ -602,7 +601,7 @@ func (k *Keeper) RevertToSnapshot(_ int) {} // context. This function also fills in the tx hash, block hash, tx index and log index fields before setting the log // to store. func (k *Keeper) AddLog(log *ethtypes.Log) { - log.BlockHash = common.BytesToHash(k.ctx.HeaderHash()) + log.BlockHash = common.BytesToHash(k.Ctx().HeaderHash()) log.TxIndex = uint(k.GetTxIndexTransient()) log.TxHash = k.GetTxHashTransient() @@ -614,7 +613,7 @@ func (k *Keeper) AddLog(log *ethtypes.Log) { k.SetLogs(log.TxHash, logs) - k.Logger(k.ctx).Debug( + k.Logger(k.Ctx()).Debug( "log added", "tx-hash-ethereum", log.TxHash.Hex(), "log-index", int(log.Index), @@ -637,7 +636,7 @@ func (k *Keeper) AddPreimage(_ common.Hash, _ []byte) {} // ForEachStorage uses the store iterator to iterate over all the state keys and perform a callback // function on each of them. func (k *Keeper) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { - store := k.ctx.KVStore(k.storeKey) + store := k.Ctx().KVStore(k.storeKey) prefix := types.AddressStoragePrefix(addr) iterator := sdk.KVStorePrefixIterator(store, prefix) diff --git a/x/evm/keeper/statedb_test.go b/x/evm/keeper/statedb_test.go index 64a6cffb..30defc08 100644 --- a/x/evm/keeper/statedb_test.go +++ b/x/evm/keeper/statedb_test.go @@ -392,7 +392,7 @@ func (suite *KeeperTestSuite) TestCommittedState() { suite.app.EvmKeeper.SetState(suite.address, key, value1) - commit := suite.app.EvmKeeper.BeginCachedContext() + suite.app.EvmKeeper.Snapshot() suite.app.EvmKeeper.SetState(suite.address, key, value2) tmp := suite.app.EvmKeeper.GetState(suite.address, key) @@ -400,8 +400,7 @@ func (suite *KeeperTestSuite) TestCommittedState() { tmp = suite.app.EvmKeeper.GetCommittedState(suite.address, key) suite.Require().Equal(value1, tmp) - commit() - suite.app.EvmKeeper.EndCachedContext() + suite.app.EvmKeeper.CommitCachedContexts() tmp = suite.app.EvmKeeper.GetCommittedState(suite.address, key) suite.Require().Equal(value2, tmp) @@ -476,9 +475,62 @@ func (suite *KeeperTestSuite) TestEmpty() { } func (suite *KeeperTestSuite) TestSnapshot() { - revision := suite.app.EvmKeeper.Snapshot() - suite.Require().Zero(revision) - suite.app.EvmKeeper.RevertToSnapshot(revision) // no-op + + var key = common.BytesToHash([]byte("key")) + var value1 = common.BytesToHash([]byte("value1")) + var value2 = common.BytesToHash([]byte("value2")) + + testCases := []struct { + name string + malleate func() + }{ + {"simple revert", func() { + revision := suite.app.EvmKeeper.Snapshot() + suite.Require().Zero(revision) + + suite.app.EvmKeeper.SetState(suite.address, key, value1) + suite.Require().Equal(value1, suite.app.EvmKeeper.GetState(suite.address, key)) + + suite.app.EvmKeeper.RevertToSnapshot(revision) + + // reverted + suite.Require().Equal(common.Hash{}, suite.app.EvmKeeper.GetState(suite.address, key)) + }}, + {"nested snapshot/revert", func() { + revision1 := suite.app.EvmKeeper.Snapshot() + suite.Require().Zero(revision1) + + suite.app.EvmKeeper.SetState(suite.address, key, value1) + + revision2 := suite.app.EvmKeeper.Snapshot() + + suite.app.EvmKeeper.SetState(suite.address, key, value2) + suite.Require().Equal(value2, suite.app.EvmKeeper.GetState(suite.address, key)) + + suite.app.EvmKeeper.RevertToSnapshot(revision2) + suite.Require().Equal(value1, suite.app.EvmKeeper.GetState(suite.address, key)) + + suite.app.EvmKeeper.RevertToSnapshot(revision1) + suite.Require().Equal(common.Hash{}, suite.app.EvmKeeper.GetState(suite.address, key)) + }}, + {"jump revert", func() { + revision1 := suite.app.EvmKeeper.Snapshot() + suite.app.EvmKeeper.SetState(suite.address, key, value1) + suite.app.EvmKeeper.Snapshot() + suite.app.EvmKeeper.SetState(suite.address, key, value2) + suite.app.EvmKeeper.RevertToSnapshot(revision1) + suite.Require().Equal(common.Hash{}, suite.app.EvmKeeper.GetState(suite.address, key)) + }}, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + suite.SetupTest() + tc.malleate() + // the test case should finish in clean state + suite.Require().True(suite.app.EvmKeeper.CachedContextsEmpty()) + }) + } } func (suite *KeeperTestSuite) CreateTestTx(msg *evmtypes.MsgEthereumTx, priv cryptotypes.PrivKey) authsigning.Tx {