core, core/types: refactored tx chain id checking

Refactored explicit chain id checking in to the Sender deriviation method
This commit is contained in:
Jeffrey Wilcke 2016-11-14 15:59:31 +01:00
parent ca73dea3b9
commit 6c9c1e6712
4 changed files with 48 additions and 12 deletions

View File

@ -1226,11 +1226,8 @@ func TestEIP155Transition(t *testing.T) {
block.AddTx(tx) block.AddTx(tx)
} }
}) })
errExp := "Invalid transaction chain id. Current chain id: 1 tx chain id: 2"
_, err := blockchain.InsertChain(blocks) _, err := blockchain.InsertChain(blocks)
if err == nil { if err != types.ErrInvalidChainId {
t.Error("expected transaction chain id error") t.Error("expected error:", types.ErrInvalidChainId)
} else if err.Error() != errExp {
t.Error("expected:", errExp, "got:", err)
} }
} }

View File

@ -17,7 +17,6 @@
package core package core
import ( import (
"fmt"
"math/big" "math/big"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
@ -73,10 +72,6 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg
} }
// Iterate over and process the individual transactions // Iterate over and process the individual transactions
for i, tx := range block.Transactions() { for i, tx := range block.Transactions() {
if tx.Protected() && tx.ChainId().Cmp(p.config.ChainId) != 0 {
return nil, nil, nil, fmt.Errorf("Invalid transaction chain id. Current chain id: %v tx chain id: %v", p.config.ChainId, tx.ChainId())
}
statedb.StartRecord(tx.Hash(), block.Hash(), i) statedb.StartRecord(tx.Hash(), block.Hash(), i)
receipt, logs, _, err := ApplyTransaction(p.config, p.bc, gp, statedb, header, tx, totalUsedGas, cfg) receipt, logs, _, err := ApplyTransaction(p.config, p.bc, gp, statedb, header, tx, totalUsedGas, cfg)
if err != nil { if err != nil {

View File

@ -21,13 +21,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"reflect"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
) )
var ErrInvalidChainId = errors.New("invalid chaid id for signer")
// sigCache is used to cache the derived sender and contains // sigCache is used to cache the derived sender and contains
// the signer used to derive it. // the signer used to derive it.
type sigCache struct { type sigCache struct {
@ -75,7 +76,7 @@ func Sender(signer Signer, tx *Transaction) (common.Address, error) {
// If the signer used to derive from in a previous // If the signer used to derive from in a previous
// call is not the same as used current, invalidate // call is not the same as used current, invalidate
// the cache. // the cache.
if reflect.TypeOf(sigCache.signer) == reflect.TypeOf(signer) { if sigCache.signer.Equal(signer) {
return sigCache.from, nil return sigCache.from, nil
} }
} }
@ -104,6 +105,8 @@ type Signer interface {
SignECDSA(tx *Transaction, prv *ecdsa.PrivateKey) (*Transaction, error) SignECDSA(tx *Transaction, prv *ecdsa.PrivateKey) (*Transaction, error)
// WithSignature returns a copy of the transaction with the given signature // WithSignature returns a copy of the transaction with the given signature
WithSignature(tx *Transaction, sig []byte) (*Transaction, error) WithSignature(tx *Transaction, sig []byte) (*Transaction, error)
// Checks for equality on the signers
Equal(Signer) bool
} }
// EIP155Transaction implements TransactionInterface using the // EIP155Transaction implements TransactionInterface using the
@ -121,6 +124,11 @@ func NewEIP155Signer(chainId *big.Int) EIP155Signer {
} }
} }
func (s EIP155Signer) Equal(s2 Signer) bool {
eip155, ok := s2.(EIP155Signer)
return ok && eip155.chainId.Cmp(s.chainId) == 0
}
func (s EIP155Signer) SignECDSA(tx *Transaction, prv *ecdsa.PrivateKey) (*Transaction, error) { func (s EIP155Signer) SignECDSA(tx *Transaction, prv *ecdsa.PrivateKey) (*Transaction, error) {
return SignECDSA(s, tx, prv) return SignECDSA(s, tx, prv)
} }
@ -131,6 +139,10 @@ func (s EIP155Signer) PublicKey(tx *Transaction) ([]byte, error) {
return (HomesteadSigner{}).PublicKey(tx) return (HomesteadSigner{}).PublicKey(tx)
} }
if tx.ChainId().Cmp(s.chainId) != 0 {
return nil, ErrInvalidChainId
}
V := normaliseV(s, tx.data.V) V := normaliseV(s, tx.data.V)
if !crypto.ValidateSignatureValues(V, tx.data.R, tx.data.S, true) { if !crypto.ValidateSignatureValues(V, tx.data.R, tx.data.S, true) {
return nil, ErrInvalidSig return nil, ErrInvalidSig
@ -200,6 +212,11 @@ func (s EIP155Signer) SigECDSA(tx *Transaction, prv *ecdsa.PrivateKey) (*Transac
// homestead rules. // homestead rules.
type HomesteadSigner struct{ FrontierSigner } type HomesteadSigner struct{ FrontierSigner }
func (s HomesteadSigner) Equal(s2 Signer) bool {
_, ok := s2.(HomesteadSigner)
return ok
}
// WithSignature returns a new transaction with the given snature. // WithSignature returns a new transaction with the given snature.
// This snature needs to be formatted as described in the yellow paper (v+27). // This snature needs to be formatted as described in the yellow paper (v+27).
func (hs HomesteadSigner) WithSignature(tx *Transaction, sig []byte) (*Transaction, error) { func (hs HomesteadSigner) WithSignature(tx *Transaction, sig []byte) (*Transaction, error) {
@ -251,6 +268,11 @@ func (hs HomesteadSigner) PublicKey(tx *Transaction) ([]byte, error) {
type FrontierSigner struct{} type FrontierSigner struct{}
func (s FrontierSigner) Equal(s2 Signer) bool {
_, ok := s2.(FrontierSigner)
return ok
}
// WithSignature returns a new transaction with the given snature. // WithSignature returns a new transaction with the given snature.
// This snature needs to be formatted as described in the yellow paper (v+27). // This snature needs to be formatted as described in the yellow paper (v+27).
func (fs FrontierSigner) WithSignature(tx *Transaction, sig []byte) (*Transaction, error) { func (fs FrontierSigner) WithSignature(tx *Transaction, sig []byte) (*Transaction, error) {

View File

@ -114,3 +114,25 @@ func TestEIP155SigningVitalik(t *testing.T) {
} }
} }
func TestChainId(t *testing.T) {
key, _ := defaultTestKey()
tx := NewTransaction(0, common.Address{}, new(big.Int), new(big.Int), new(big.Int), nil)
var err error
tx, err = tx.SignECDSA(NewEIP155Signer(big.NewInt(1)), key)
if err != nil {
t.Fatal(err)
}
_, err = Sender(NewEIP155Signer(big.NewInt(2)), tx)
if err != ErrInvalidChainId {
t.Error("expected error:", ErrInvalidChainId)
}
_, err = Sender(NewEIP155Signer(big.NewInt(1)), tx)
if err != nil {
t.Error("expected no error")
}
}