From a05ea5fcc95579db2eaa2c79e44cf59172cfca72 Mon Sep 17 00:00:00 2001 From: Ethan Buchman Date: Sun, 4 Mar 2018 03:15:26 -0500 Subject: [PATCH] implement replay protection --- examples/basecoin/app/app_test.go | 19 +++++++++++++++++++ types/errors.go | 16 +++++----------- types/tx_msg.go | 14 ++++++++++++++ x/auth/ante.go | 16 ++++++++++------ 4 files changed, 48 insertions(+), 17 deletions(-) diff --git a/examples/basecoin/app/app_test.go b/examples/basecoin/app/app_test.go index 1349b67679..a187058a8c 100644 --- a/examples/basecoin/app/app_test.go +++ b/examples/basecoin/app/app_test.go @@ -169,6 +169,9 @@ func TestSendMsgWithAccounts(t *testing.T) { assert.Equal(t, acc1, res1) // Sign the tx + chainID := "" // TODO: InitChain should get the ChainID + sequence := int64(0) + sig := priv1.Sign(sdk.StdSignBytes(chainID, sequence, sendMsg)) tx := sdk.NewStdTx(sendMsg, []sdk.StdSignature{{ PubKey: priv1.PubKey(), Signature: priv1.Sign(sendMsg.GetSignBytes()), @@ -189,6 +192,22 @@ func TestSendMsgWithAccounts(t *testing.T) { res3 := bapp.accountMapper.GetAccount(ctxDeliver, addr2) assert.Equal(t, fmt.Sprintf("%v", res2.GetCoins()), "67foocoin") assert.Equal(t, fmt.Sprintf("%v", res3.GetCoins()), "10foocoin") + + // Delivering again should cause replay error + res = bapp.Deliver(tx) + assert.Equal(t, sdk.CodeInvalidSequence, res.Code, res.Log) + + // bumping the txnonce number without resigning should be an auth error + sequence += 1 + tx.Signatures[0].Sequence = sequence + res = bapp.Deliver(tx) + assert.Equal(t, sdk.CodeUnauthorized, res.Code, res.Log) + + // resigning the tx with the bumped sequence should work + sig = priv1.Sign(sdk.StdSignBytes(chainID, sequence, tx.Msg)) + tx.Signatures[0].Signature = sig + res = bapp.Deliver(tx) + assert.Equal(t, sdk.CodeOK, res.Code, res.Log) } //func TestWhatCoolMsg(t *testing.T) { diff --git a/types/errors.go b/types/errors.go index e52fe625f1..9d8175e309 100644 --- a/types/errors.go +++ b/types/errors.go @@ -22,12 +22,11 @@ const ( CodeOK CodeType = 0 CodeInternal CodeType = 1 CodeTxParse CodeType = 2 - CodeBadNonce CodeType = 3 + CodeInvalidSequence CodeType = 3 CodeUnauthorized CodeType = 4 CodeInsufficientFunds CodeType = 5 CodeUnknownRequest CodeType = 6 CodeUnrecognizedAddress CodeType = 7 - CodeInvalidSequence CodeType = 8 CodeGenesisParse CodeType = 0xdead // TODO: remove ? ) @@ -41,8 +40,8 @@ func CodeToDefaultMsg(code CodeType) string { return "Tx parse error" case CodeGenesisParse: return "Genesis parse error" - case CodeBadNonce: - return "Bad nonce" + case CodeInvalidSequence: + return "Invalid sequence" case CodeUnauthorized: return "Unauthorized" case CodeInsufficientFunds: @@ -51,8 +50,6 @@ func CodeToDefaultMsg(code CodeType) string { return "Unknown request" case CodeUnrecognizedAddress: return "Unrecognized address" - case CodeInvalidSequence: - return "Invalid sequence" default: return fmt.Sprintf("Unknown code %d", code) } @@ -72,8 +69,8 @@ func ErrTxParse(msg string) Error { func ErrGenesisParse(msg string) Error { return newError(CodeGenesisParse, msg) } -func ErrBadNonce(msg string) Error { - return newError(CodeBadNonce, msg) +func ErrInvalidSequence(msg string) Error { + return newError(CodeInvalidSequence, msg) } func ErrUnauthorized(msg string) Error { return newError(CodeUnauthorized, msg) @@ -87,9 +84,6 @@ func ErrUnknownRequest(msg string) Error { func ErrUnrecognizedAddress(addr Address) Error { return newError(CodeUnrecognizedAddress, addr.String()) } -func ErrInvalidSequence(msg string) Error { - return newError(CodeInvalidSequence, msg) -} //---------------------------------------- // Error & sdkError diff --git a/types/tx_msg.go b/types/tx_msg.go index e97fc331fe..c88107c518 100644 --- a/types/tx_msg.go +++ b/types/tx_msg.go @@ -1,5 +1,7 @@ package types +import "encoding/json" + // Transactions messages must fulfill the Msg type Msg interface { @@ -77,6 +79,18 @@ type StdSignDoc struct { AltBytes []byte `json:"alt_bytes"` // TODO: do we really want this ? } +func StdSignBytes(chainID string, sequence int64, msg Msg) []byte { + bz, err := json.Marshal(StdSignDoc{ + ChainID: chainID, + Sequence: sequence, + MsgBytes: msg.GetSignBytes(), + }) + if err != nil { + panic(err) + } + return bz +} + //------------------------------------- // Application function variable used to unmarshal transaction bytes diff --git a/x/auth/ante.go b/x/auth/ante.go index 17e9ccdd94..7ab677e88d 100644 --- a/x/auth/ante.go +++ b/x/auth/ante.go @@ -31,13 +31,16 @@ func NewAnteHandler(accountMapper sdk.AccountMapper) sdk.AnteHandler { // Collect accounts to set in the context var signerAccs = make([]sdk.Account, len(signerAddrs)) - signBytes := msg.GetSignBytes() - // First sig is the fee payer. - // Check sig and nonce, deduct fee. + // signBytes uses the sequence of the fee payer + // (ie. the first account) + payerAddr, payerSig := signerAddrs[0], sigs[0] + signBytes := sdk.StdSignBytes(ctx.ChainID(), payerSig.Sequence, msg) + + // Check fee payer sig and nonce, and deduct fee. // This is done first because it only // requires fetching 1 account. - payerAcc, res := processSig(ctx, accountMapper, signerAddrs[0], sigs[0], signBytes) + payerAcc, res := processSig(ctx, accountMapper, payerAddr, payerSig, signBytes) if !res.IsOK() { return ctx, res, true } @@ -47,7 +50,8 @@ func NewAnteHandler(accountMapper sdk.AccountMapper) sdk.AnteHandler { // Check sig and nonce for the rest. for i := 1; i < len(sigs); i++ { - signerAcc, res := processSig(ctx, accountMapper, signerAddrs[i], sigs[i], signBytes) + signerAddr, sig := signerAddrs[i], sigs[i] + signerAcc, res := processSig(ctx, accountMapper, signerAddr, sig, signBytes) if !res.IsOK() { return ctx, res, true } @@ -90,7 +94,7 @@ func processSig(ctx sdk.Context, am sdk.AccountMapper, addr sdk.Address, sig sdk // Check sig. if !sig.PubKey.VerifyBytes(signBytes, sig.Signature) { - return nil, sdk.ErrUnauthorized("").Result() + return nil, sdk.ErrUnauthorized("signature verification failed").Result() } // Save the account.