core/state, core/vm: cleanup refunds

This commit is contained in:
obscuren 2015-06-17 10:20:33 +02:00
parent dfd18d245a
commit 5721fcf668
3 changed files with 13 additions and 21 deletions

View File

@ -21,7 +21,7 @@ type StateDB struct {
stateObjects map[string]*StateObject stateObjects map[string]*StateObject
refund map[string]*big.Int refund *big.Int
thash, bhash common.Hash thash, bhash common.Hash
txIndex int txIndex int
@ -31,7 +31,7 @@ type StateDB struct {
// Create a new state from a given trie // Create a new state from a given trie
func New(root common.Hash, db common.Database) *StateDB { func New(root common.Hash, db common.Database) *StateDB {
trie := trie.NewSecure(root[:], db) trie := trie.NewSecure(root[:], db)
return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: make(map[string]*big.Int), logs: make(map[common.Hash]Logs)} return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)}
} }
func (self *StateDB) PrintRoot() { func (self *StateDB) PrintRoot() {
@ -63,12 +63,8 @@ func (self *StateDB) Logs() Logs {
return logs return logs
} }
func (self *StateDB) Refund(address common.Address, gas *big.Int) { func (self *StateDB) Refund(gas *big.Int) {
addr := address.Str() self.refund.Add(self.refund, gas)
if self.refund[addr] == nil {
self.refund[addr] = new(big.Int)
}
self.refund[addr].Add(self.refund[addr], gas)
} }
/* /*
@ -268,9 +264,7 @@ func (self *StateDB) Copy() *StateDB {
state.stateObjects[k] = stateObject.Copy() state.stateObjects[k] = stateObject.Copy()
} }
for addr, refund := range self.refund { state.refund.Set(self.refund)
state.refund[addr] = new(big.Int).Set(refund)
}
for hash, logs := range self.logs { for hash, logs := range self.logs {
state.logs[hash] = make(Logs, len(logs)) state.logs[hash] = make(Logs, len(logs))
@ -330,15 +324,15 @@ func (s *StateDB) Sync() {
func (self *StateDB) Empty() { func (self *StateDB) Empty() {
self.stateObjects = make(map[string]*StateObject) self.stateObjects = make(map[string]*StateObject)
self.refund = make(map[string]*big.Int) self.refund = new(big.Int)
} }
func (self *StateDB) Refunds() map[string]*big.Int { func (self *StateDB) Refunds() *big.Int {
return self.refund return self.refund
} }
func (self *StateDB) Update() { func (self *StateDB) Update() {
self.refund = make(map[string]*big.Int) self.refund = new(big.Int)
for _, stateObject := range self.stateObjects { for _, stateObject := range self.stateObjects {
if stateObject.dirty { if stateObject.dirty {

View File

@ -241,11 +241,9 @@ func (self *StateTransition) refundGas() {
sender.AddBalance(remaining) sender.AddBalance(remaining)
uhalf := new(big.Int).Div(self.gasUsed(), common.Big2) uhalf := new(big.Int).Div(self.gasUsed(), common.Big2)
for addr, ref := range self.state.Refunds() { refund := common.BigMin(uhalf, self.state.Refunds())
refund := common.BigMin(uhalf, ref)
self.gas.Add(self.gas, refund) self.gas.Add(self.gas, refund)
self.state.AddBalance(common.StringToAddress(addr), refund.Mul(refund, self.msg.GasPrice())) self.state.AddBalance(sender.Address(), refund.Mul(refund, self.msg.GasPrice()))
}
coinbase.RefundGas(self.gas, self.msg.GasPrice()) coinbase.RefundGas(self.gas, self.msg.GasPrice())
} }

View File

@ -690,7 +690,7 @@ func (self *Vm) calculateGasAndSize(context *Context, caller ContextRef, op OpCo
// 0 => non 0 // 0 => non 0
g = params.SstoreSetGas g = params.SstoreSetGas
} else if len(val) > 0 && len(y.Bytes()) == 0 { } else if len(val) > 0 && len(y.Bytes()) == 0 {
statedb.Refund(self.env.Origin(), params.SstoreRefundGas) statedb.Refund(params.SstoreRefundGas)
g = params.SstoreClearGas g = params.SstoreClearGas
} else { } else {
@ -700,7 +700,7 @@ func (self *Vm) calculateGasAndSize(context *Context, caller ContextRef, op OpCo
gas.Set(g) gas.Set(g)
case SUICIDE: case SUICIDE:
if !statedb.IsDeleted(context.Address()) { if !statedb.IsDeleted(context.Address()) {
statedb.Refund(self.env.Origin(), params.SuicideRefundGas) statedb.Refund(params.SuicideRefundGas)
} }
case MLOAD: case MLOAD:
newMemSize = calcMemSize(stack.peek(), u256(32)) newMemSize = calcMemSize(stack.peek(), u256(32))