diff --git a/restricted/types/transaction.go b/restricted/types/transaction.go index bb70b74..1094898 100644 --- a/restricted/types/transaction.go +++ b/restricted/types/transaction.go @@ -36,6 +36,7 @@ var ( ErrInvalidTxType = errors.New("transaction type not valid in this context") ErrTxTypeNotSupported = errors.New("transaction type not supported") ErrGasFeeCapTooLow = errors.New("fee cap less than base fee") + errShortTypedTx = errors.New("typed transaction too short") errEmptyTypedTx = errors.New("empty typed transaction bytes") ) @@ -103,6 +104,7 @@ type TxData interface { effectiveGasPrice(dst *big.Int, baseFee *big.Int) *big.Int encode(b *bytes.Buffer) error + decode([]byte) error } // EncodeRLP implements rlp.Encoder @@ -195,26 +197,24 @@ func (tx *Transaction) UnmarshalBinary(b []byte) error { // decodeTyped decodes a typed transaction from the canonical format. func (tx *Transaction) decodeTyped(b []byte) (TxData, error) { if len(b) <= 1 { - return nil, errEmptyTypedTx + return nil, errShortTypedTx } + var inner TxData switch b[0] { case AccessListTxType: - var inner AccessListTx - err := rlp.DecodeBytes(b[1:], &inner) - return &inner, err + inner = new(AccessListTx) case DynamicFeeTxType: - var inner DynamicFeeTx - err := rlp.DecodeBytes(b[1:], &inner) - return &inner, err + inner = new(DynamicFeeTx) case BlobTxType: - var inner BlobTx - err := rlp.DecodeBytes(b[1:], &inner) - return &inner, err + inner = new(BlobTx) default: return nil, ErrTxTypeNotSupported } + err := inner.decode(b[1:]) + return inner, err } + // setDecoded sets the inner transaction and size after decoding. func (tx *Transaction) setDecoded(inner TxData, size uint64) { tx.inner = inner diff --git a/restricted/types/transaction_test.go b/restricted/types/transaction_test.go index 397e2b8..df19478 100644 --- a/restricted/types/transaction_test.go +++ b/restricted/types/transaction_test.go @@ -82,7 +82,7 @@ func TestDecodeEmptyTypedTx(t *testing.T) { input := []byte{0x80} var tx Transaction err := rlp.DecodeBytes(input, &tx) - if err != errEmptyTypedTx { + if err != errShortTypedTx { t.Fatal("wrong error:", err) } }