diff --git a/core/vm/common.go b/core/vm/common.go index 779cee006..17de38dec 100644 --- a/core/vm/common.go +++ b/core/vm/common.go @@ -34,7 +34,21 @@ func calcMemSize(off, l *big.Int) *big.Int { // getData returns a slice from the data based on the start and size and pads // up to size with zero's. This function is overflow safe. -func getData(data []byte, start, size *big.Int) []byte { +func getData(data []byte, start uint64, size uint64) []byte { + length := uint64(len(data)) + if start > length { + start = length + } + end := start + size + if end > length { + end = length + } + return common.RightPadBytes(data[start:end], int(size)) +} + +// getDataBig returns a slice from the data based on the start and size and pads +// up to size with zero's. This function is overflow safe. +func getDataBig(data []byte, start *big.Int, size *big.Int) []byte { dlen := big.NewInt(int64(len(data))) s := math.BigMin(start, dlen) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index c59779dac..b885d42bb 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -149,30 +149,41 @@ func (c *dataCopy) Run(in []byte) ([]byte, error) { // bigModExp implements a native big integer exponential modular operation. type bigModExp struct{} +var ( + big1 = big.NewInt(1) + big4 = big.NewInt(4) + big8 = big.NewInt(8) + big16 = big.NewInt(16) + big32 = big.NewInt(32) + big64 = big.NewInt(64) + big96 = big.NewInt(96) + big480 = big.NewInt(480) + big1024 = big.NewInt(1024) + big3072 = big.NewInt(3072) + big199680 = big.NewInt(199680) +) + // RequiredGas returns the gas required to execute the pre-compiled contract. func (c *bigModExp) RequiredGas(input []byte) uint64 { - // Pad the input with zeroes to the minimum size to read the field lengths - input = common.RightPadBytes(input, 96) - var ( - baseLen = new(big.Int).SetBytes(input[:32]) - expLen = new(big.Int).SetBytes(input[32:64]) - modLen = new(big.Int).SetBytes(input[64:96]) + baseLen = new(big.Int).SetBytes(getData(input, 0, 32)) + expLen = new(big.Int).SetBytes(getData(input, 32, 32)) + modLen = new(big.Int).SetBytes(getData(input, 64, 32)) ) - input = input[96:] - + if len(input) > 96 { + input = input[96:] + } else { + input = input[:0] + } // Retrieve the head 32 bytes of exp for the adjusted exponent length var expHead *big.Int if big.NewInt(int64(len(input))).Cmp(baseLen) <= 0 { expHead = new(big.Int) } else { - offset := int(baseLen.Uint64()) - - input = common.RightPadBytes(input, offset+32) - if expLen.Cmp(big.NewInt(32)) > 0 { - expHead = new(big.Int).SetBytes(input[offset : offset+32]) + if expLen.Cmp(big32) > 0 { + expHead = new(big.Int).SetBytes(getData(input, baseLen.Uint64(), 32)) } else { - expHead = new(big.Int).SetBytes(input[offset : offset+int(expLen.Uint64())]) + expHead = new(big.Int).SetBytes(getData(input, baseLen.Uint64(), expLen.Uint64())) } } // Calculate the adjusted exponent length @@ -181,29 +192,29 @@ func (c *bigModExp) RequiredGas(input []byte) uint64 { msb = bitlen - 1 } adjExpLen := new(big.Int) - if expLen.Cmp(big.NewInt(32)) > 0 { - adjExpLen.Sub(expLen, big.NewInt(32)) - adjExpLen.Mul(big.NewInt(8), adjExpLen) + if expLen.Cmp(big32) > 0 { + adjExpLen.Sub(expLen, big32) + adjExpLen.Mul(big8, adjExpLen) } adjExpLen.Add(adjExpLen, big.NewInt(int64(msb))) // Calculate the gas cost of the operation gas := new(big.Int).Set(math.BigMax(modLen, baseLen)) switch { - case gas.Cmp(big.NewInt(64)) <= 0: + case gas.Cmp(big64) <= 0: gas.Mul(gas, gas) - case gas.Cmp(big.NewInt(1024)) <= 0: + case gas.Cmp(big1024) <= 0: gas = new(big.Int).Add( - new(big.Int).Div(new(big.Int).Mul(gas, gas), big.NewInt(4)), - new(big.Int).Sub(new(big.Int).Mul(big.NewInt(96), gas), big.NewInt(3072)), + new(big.Int).Div(new(big.Int).Mul(gas, gas), big4), + new(big.Int).Sub(new(big.Int).Mul(big96, gas), big3072), ) default: gas = new(big.Int).Add( - new(big.Int).Div(new(big.Int).Mul(gas, gas), big.NewInt(16)), - new(big.Int).Sub(new(big.Int).Mul(big.NewInt(480), gas), big.NewInt(199680)), + new(big.Int).Div(new(big.Int).Mul(gas, gas), big16), + new(big.Int).Sub(new(big.Int).Mul(big480, gas), big199680), ) } - gas.Mul(gas, math.BigMax(adjExpLen, big.NewInt(1))) + gas.Mul(gas, math.BigMax(adjExpLen, big1)) gas.Div(gas, new(big.Int).SetUint64(params.ModExpQuadCoeffDiv)) if gas.BitLen() > 64 { @@ -213,23 +224,25 @@ func (c *bigModExp) RequiredGas(input []byte) uint64 { } func (c *bigModExp) Run(input []byte) ([]byte, error) { - // Pad the input with zeroes to the minimum size to read the field lengths - input = common.RightPadBytes(input, 96) - var ( - baseLen = new(big.Int).SetBytes(input[:32]).Uint64() - expLen = new(big.Int).SetBytes(input[32:64]).Uint64() - modLen = new(big.Int).SetBytes(input[64:96]).Uint64() + baseLen = new(big.Int).SetBytes(getData(input, 0, 32)).Uint64() + expLen = new(big.Int).SetBytes(getData(input, 32, 32)).Uint64() + modLen = new(big.Int).SetBytes(getData(input, 64, 32)).Uint64() ) - input = input[96:] - - // Pad the input with zeroes to the minimum size to read the field contents - input = common.RightPadBytes(input, int(baseLen+expLen+modLen)) - + if len(input) > 96 { + input = input[96:] + } else { + input = input[:0] + } + // Handle a special case when both the base and mod length is zero + if baseLen == 0 && modLen == 0 { + return []byte{}, nil + } + // Retrieve the operands and execute the exponentiation var ( - base = new(big.Int).SetBytes(input[:baseLen]) - exp = new(big.Int).SetBytes(input[baseLen : baseLen+expLen]) - mod = new(big.Int).SetBytes(input[baseLen+expLen : baseLen+expLen+modLen]) + base = new(big.Int).SetBytes(getData(input, 0, baseLen)) + exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) + mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) ) if mod.BitLen() == 0 { // Modulo 0 is undefined, return zero @@ -286,14 +299,11 @@ func (c *bn256Add) RequiredGas(input []byte) uint64 { } func (c *bn256Add) Run(input []byte) ([]byte, error) { - // Ensure we have enough data to operate on - input = common.RightPadBytes(input, 128) - - x, err := newCurvePoint(input[:64]) + x, err := newCurvePoint(getData(input, 0, 64)) if err != nil { return nil, err } - y, err := newCurvePoint(input[64:128]) + y, err := newCurvePoint(getData(input, 64, 64)) if err != nil { return nil, err } @@ -310,14 +320,11 @@ func (c *bn256ScalarMul) RequiredGas(input []byte) uint64 { } func (c *bn256ScalarMul) Run(input []byte) ([]byte, error) { - // Ensure we have enough data to operate on - input = common.RightPadBytes(input, 96) - - p, err := newCurvePoint(input[:64]) + p, err := newCurvePoint(getData(input, 0, 64)) if err != nil { return nil, err } - p.ScalarMult(p, new(big.Int).SetBytes(input[64:96])) + p.ScalarMult(p, new(big.Int).SetBytes(getData(input, 64, 32))) return p.Marshal(), nil } diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 4a2ed8101..9bc860b5f 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -325,7 +325,7 @@ func TestPrecompiledBn256Pairing(t *testing.T) { } // Behcnmarks the sample inputs from the elliptic curve pairing check EIP 197. -func BenchmarkPrecompiledPairing(bench *testing.B) { +func BenchmarkPrecompiledBn256Pairing(bench *testing.B) { for _, test := range bn256PairingTests { benchmarkPrecompiled("08", test, bench) } diff --git a/core/vm/instructions.go b/core/vm/instructions.go index f5164fcdd..4f9e45ffe 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -337,7 +337,7 @@ func opCallValue(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack } func opCalldataLoad(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { - stack.push(new(big.Int).SetBytes(getData(contract.Input, stack.pop(), common.Big32))) + stack.push(new(big.Int).SetBytes(getDataBig(contract.Input, stack.pop(), big32))) return nil, nil } @@ -352,7 +352,7 @@ func opCalldataCopy(pc *uint64, evm *EVM, contract *Contract, memory *Memory, st cOff = stack.pop() l = stack.pop() ) - memory.Set(mOff.Uint64(), l.Uint64(), getData(contract.Input, cOff, l)) + memory.Set(mOff.Uint64(), l.Uint64(), getDataBig(contract.Input, cOff, l)) evm.interpreter.intPool.put(mOff, cOff, l) return nil, nil @@ -380,7 +380,7 @@ func opCodeCopy(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack cOff = stack.pop() l = stack.pop() ) - codeCopy := getData(contract.Code, cOff, l) + codeCopy := getDataBig(contract.Code, cOff, l) memory.Set(mOff.Uint64(), l.Uint64(), codeCopy) @@ -395,7 +395,7 @@ func opExtCodeCopy(pc *uint64, evm *EVM, contract *Contract, memory *Memory, sta cOff = stack.pop() l = stack.pop() ) - codeCopy := getData(evm.StateDB.GetCode(addr), cOff, l) + codeCopy := getDataBig(evm.StateDB.GetCode(addr), cOff, l) memory.Set(mOff.Uint64(), l.Uint64(), codeCopy)