forked from cerc-io/plugeth
		
	common/math: optimized modexp (+ fuzzer) (#25525)
This adds a * core/vm, tests: optimized modexp + fuzzer * common/math: modexp optimizations * core/vm: special case base 1 in big modexp * core/vm: disable fastexp
This commit is contained in:
		
							parent
							
								
									a007ab786c
								
							
						
					
					
						commit
						bed3b10086
					
				
							
								
								
									
										82
									
								
								common/math/modexp.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								common/math/modexp.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,82 @@ | ||||
| // Copyright 2020 The Go Authors. All rights reserved.
 | ||||
| // Use of this source code is governed by a BSD-style
 | ||||
| // license that can be found in the LICENSE file.
 | ||||
| 
 | ||||
| package math | ||||
| 
 | ||||
| import ( | ||||
| 	"math/big" | ||||
| 	"math/bits" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/common" | ||||
| ) | ||||
| 
 | ||||
| // FastExp is semantically equivalent to x.Exp(x,y, m), but is faster for even
 | ||||
| // modulus.
 | ||||
| func FastExp(x, y, m *big.Int) *big.Int { | ||||
| 	// Split m = m1 × m2 where m1 = 2ⁿ
 | ||||
| 	n := m.TrailingZeroBits() | ||||
| 	m1 := new(big.Int).Lsh(common.Big1, n) | ||||
| 	mask := new(big.Int).Sub(m1, common.Big1) | ||||
| 	m2 := new(big.Int).Rsh(m, n) | ||||
| 
 | ||||
| 	// We want z = x**y mod m.
 | ||||
| 	// z1 = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
 | ||||
| 	// z2 = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
 | ||||
| 	z1 := fastExpPow2(x, y, mask) | ||||
| 	z2 := new(big.Int).Exp(x, y, m2) | ||||
| 
 | ||||
| 	// Reconstruct z from z1, z2 using CRT, using algorithm from paper,
 | ||||
| 	// which uses only a single modInverse.
 | ||||
| 	//	p = (z1 - z2) * m2⁻¹ (mod m1)
 | ||||
| 	//	z = z2 + p * m2
 | ||||
| 	z := new(big.Int).Set(z2) | ||||
| 
 | ||||
| 	// Compute (z1 - z2) mod m1 [m1 == 2**n] into z1.
 | ||||
| 	z1 = z1.And(z1, mask) | ||||
| 	z2 = z2.And(z2, mask) | ||||
| 	z1 = z1.Sub(z1, z2) | ||||
| 	if z1.Sign() < 0 { | ||||
| 		z1 = z1.Add(z1, m1) | ||||
| 	} | ||||
| 
 | ||||
| 	// Reuse z2 for p = z1 * m2inv.
 | ||||
| 	m2inv := new(big.Int).ModInverse(m2, m1) | ||||
| 	z2 = z2.Mul(z1, m2inv) | ||||
| 	z2 = z2.And(z2, mask) | ||||
| 
 | ||||
| 	// Reuse z1 for m2 * p.
 | ||||
| 	z = z.Add(z, z1.Mul(z2, m2)) | ||||
| 	z = z.Rem(z, m) | ||||
| 
 | ||||
| 	return z | ||||
| } | ||||
| 
 | ||||
| func fastExpPow2(x, y *big.Int, mask *big.Int) *big.Int { | ||||
| 	z := big.NewInt(1) | ||||
| 	if y.Sign() == 0 { | ||||
| 		return z | ||||
| 	} | ||||
| 	p := new(big.Int).Set(x) | ||||
| 	p = p.And(p, mask) | ||||
| 	if p.Cmp(z) <= 0 { // p <= 1
 | ||||
| 		return p | ||||
| 	} | ||||
| 	if y.Cmp(mask) > 0 { | ||||
| 		y = new(big.Int).And(y, mask) | ||||
| 	} | ||||
| 	t := new(big.Int) | ||||
| 
 | ||||
| 	for _, b := range y.Bits() { | ||||
| 		for i := 0; i < bits.UintSize; i++ { | ||||
| 			if b&1 != 0 { | ||||
| 				z, t = t.Mul(z, p), z | ||||
| 				z = z.And(z, mask) | ||||
| 			} | ||||
| 			p, t = t.Mul(p, p), p | ||||
| 			p = p.And(p, mask) | ||||
| 			b >>= 1 | ||||
| 		} | ||||
| 	} | ||||
| 	return z | ||||
| } | ||||
| @ -380,12 +380,23 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { | ||||
| 		base = new(big.Int).SetBytes(getData(input, 0, baseLen)) | ||||
| 		exp  = new(big.Int).SetBytes(getData(input, baseLen, expLen)) | ||||
| 		mod  = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) | ||||
| 		v    []byte | ||||
| 	) | ||||
| 	if mod.BitLen() == 0 { | ||||
| 	switch { | ||||
| 	case mod.BitLen() == 0: | ||||
| 		// Modulo 0 is undefined, return zero
 | ||||
| 		return common.LeftPadBytes([]byte{}, int(modLen)), nil | ||||
| 	case base.Cmp(common.Big1) == 0: | ||||
| 		//If base == 1, then we can just return base % mod (if mod >= 1, which it is)
 | ||||
| 		v = base.Mod(base, mod).Bytes() | ||||
| 	//case mod.Bit(0) == 0:
 | ||||
| 	//	// Modulo is even
 | ||||
| 	//	v = math.FastExp(base, exp, mod).Bytes()
 | ||||
| 	default: | ||||
| 		// Modulo is odd
 | ||||
| 		v = base.Exp(base, exp, mod).Bytes() | ||||
| 	} | ||||
| 	return common.LeftPadBytes(base.Exp(base, exp, mod).Bytes(), int(modLen)), nil | ||||
| 	return common.LeftPadBytes(v, int(modLen)), nil | ||||
| } | ||||
| 
 | ||||
| // newCurvePoint unmarshals a binary blob into a bn256 elliptic curve point,
 | ||||
|  | ||||
| @ -125,5 +125,7 @@ compile_fuzzer tests/fuzzers/snap  FuzzSRange fuzz_storage_range | ||||
| compile_fuzzer tests/fuzzers/snap  FuzzByteCodes fuzz_byte_codes | ||||
| compile_fuzzer tests/fuzzers/snap  FuzzTrieNodes fuzz_trie_nodes | ||||
| 
 | ||||
| compile_fuzzer tests/fuzzers/modexp  Fuzz fuzzModexp | ||||
| 
 | ||||
| #TODO: move this to tests/fuzzers, if possible | ||||
| compile_fuzzer crypto/blake2b  Fuzz      fuzzBlake2b | ||||
|  | ||||
							
								
								
									
										84
									
								
								tests/fuzzers/modexp/modexp-fuzzer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								tests/fuzzers/modexp/modexp-fuzzer.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,84 @@ | ||||
| // Copyright 2022 The go-ethereum Authors
 | ||||
| // This file is part of the go-ethereum library.
 | ||||
| //
 | ||||
| // The go-ethereum library is free software: you can redistribute it and/or modify
 | ||||
| // it under the terms of the GNU Lesser General Public License as published by
 | ||||
| // the Free Software Foundation, either version 3 of the License, or
 | ||||
| // (at your option) any later version.
 | ||||
| //
 | ||||
| // The go-ethereum library is distributed in the hope that it will be useful,
 | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of
 | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 | ||||
| // GNU Lesser General Public License for more details.
 | ||||
| //
 | ||||
| // You should have received a copy of the GNU Lesser General Public License
 | ||||
| // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | ||||
| 
 | ||||
| package modexp | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math/big" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/common" | ||||
| 	"github.com/ethereum/go-ethereum/common/math" | ||||
| 	"github.com/ethereum/go-ethereum/core/vm" | ||||
| ) | ||||
| 
 | ||||
| // The function must return
 | ||||
| // 1 if the fuzzer should increase priority of the
 | ||||
| //    given input during subsequent fuzzing (for example, the input is lexically
 | ||||
| //    correct and was parsed successfully);
 | ||||
| // -1 if the input must not be added to corpus even if gives new coverage; and
 | ||||
| // 0  otherwise
 | ||||
| // other values are reserved for future use.
 | ||||
| func Fuzz(input []byte) int { | ||||
| 	if len(input) <= 96 { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	// Abort on too expensive inputs
 | ||||
| 	precomp := vm.PrecompiledContractsBerlin[common.BytesToAddress([]byte{5})] | ||||
| 	if gas := precomp.RequiredGas(input); gas > 40_000_000 { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	var ( | ||||
| 		baseLen = new(big.Int).SetBytes(getData(input, 0, 32)).Uint64() | ||||
| 		expLen  = new(big.Int).SetBytes(getData(input, 32, 32)).Uint64() | ||||
| 		modLen  = new(big.Int).SetBytes(getData(input, 64, 32)).Uint64() | ||||
| 	) | ||||
| 	// Handle a special case when both the base and mod length is zero
 | ||||
| 	if baseLen == 0 && modLen == 0 { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	input = input[96:] | ||||
| 	// Retrieve the operands and execute the exponentiation
 | ||||
| 	var ( | ||||
| 		base = new(big.Int).SetBytes(getData(input, 0, baseLen)) | ||||
| 		exp  = new(big.Int).SetBytes(getData(input, baseLen, expLen)) | ||||
| 		mod  = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) | ||||
| 	) | ||||
| 	if mod.BitLen() == 0 { | ||||
| 		// Modulo 0 is undefined, return zero
 | ||||
| 		return -1 | ||||
| 	} | ||||
| 	var a = math.FastExp(new(big.Int).Set(base), new(big.Int).Set(exp), new(big.Int).Set(mod)) | ||||
| 	var b = base.Exp(base, exp, mod) | ||||
| 	if a.Cmp(b) != 0 { | ||||
| 		panic(fmt.Sprintf("Inequality %x != %x", a, b)) | ||||
| 	} | ||||
| 	return 1 | ||||
| } | ||||
| 
 | ||||
| // getData returns a slice from the data based on the start and size and pads
 | ||||
| // up to size with zero's. This function is overflow safe.
 | ||||
| func getData(data []byte, start uint64, size uint64) []byte { | ||||
| 	length := uint64(len(data)) | ||||
| 	if start > length { | ||||
| 		start = length | ||||
| 	} | ||||
| 	end := start + size | ||||
| 	if end > length { | ||||
| 		end = length | ||||
| 	} | ||||
| 	return common.RightPadBytes(data[start:end], int(size)) | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user