diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 8b034a0e7..8270300ba 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -25,21 +25,17 @@ import ( // memoryGasCost calculates the quadratic gas for memory expansion. It does so // only for the memory region that is expanded, not the total memory. func memoryGasCost(mem *Memory, newMemSize uint64) (uint64, error) { - if newMemSize == 0 { return 0, nil } - // The maximum that will fit in a uint64 is max_word_count - 1 - // anything above that will result in an overflow. - // Additionally, a newMemSize which results in a - // newMemSizeWords larger than 0xFFFFFFFF will cause the square operation - // to overflow. - // The constant 0x1FFFFFFFE0 is the highest number that can be used without - // overflowing the gas calculation + // The maximum that will fit in a uint64 is max_word_count - 1. Anything above + // that will result in an overflow. Additionally, a newMemSize which results in + // a newMemSizeWords larger than 0xFFFFFFFF will cause the square operation to + // overflow. The constant 0x1FFFFFFFE0 is the highest number that can be used + // without overflowing the gas calculation. if newMemSize > 0x1FFFFFFFE0 { return 0, errGasUintOverflow } - newMemSizeWords := toWordSize(newMemSize) newMemSize = newMemSizeWords * 32 diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go index 1b91aee56..2c1e11894 100644 --- a/core/vm/gas_table_test.go +++ b/core/vm/gas_table_test.go @@ -19,18 +19,21 @@ package vm import "testing" func TestMemoryGasCost(t *testing.T) { - //size := uint64(math.MaxUint64 - 64) - size := uint64(0xffffffffe0) - v, err := memoryGasCost(&Memory{}, size) - if err != nil { - t.Error("didn't expect error:", err) + tests := []struct { + size uint64 + cost uint64 + overflow bool + }{ + {0x1fffffffe0, 36028809887088637, false}, + {0x1fffffffe1, 0, true}, } - if v != 36028899963961341 { - t.Errorf("Expected: 36028899963961341, got %d", v) - } - - _, err = memoryGasCost(&Memory{}, size+1) - if err == nil { - t.Error("expected error") + for i, tt := range tests { + v, err := memoryGasCost(&Memory{}, tt.size) + if (err == errGasUintOverflow) != tt.overflow { + t.Errorf("test %d: overflow mismatch: have %v, want %v", i, err == errGasUintOverflow, tt.overflow) + } + if v != tt.cost { + t.Errorf("test %d: gas cost mismatch: have %v, want %v", i, v, tt.cost) + } } }