package decode import ( "fmt" "google.golang.org/protobuf/encoding/protowire" ) // rejectNonADR027TxRaw rejects txBytes that do not follow ADR-027. This is NOT // a generic ADR-027 checker, it only applies decoding TxRaw. Specifically, it // only checks that: // // - Field numbers are in ascending order (1, 2, and potentially multiple 3s) // - Varints are as short as possible // // All other ADR-027 edge cases (e.g. default values) are not applicable with // TxRaw. func rejectNonADR027TxRaw(txBytes []byte) error { // Make sure all fields are ordered in ascending order with this variable. prevTagNum := protowire.Number(0) for len(txBytes) > 0 { tagNum, wireType, m := protowire.ConsumeTag(txBytes) if m < 0 { return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) } // Paranoia from possible varint decoding which can trivially // be wrong due to the precarious nature of the format being tricked: // https://cyber.orijtech.com/advisory/varint-decode-limitless if m > len(txBytes) { return fmt.Errorf("invalid length from decoding (%d) > len(txBytes) (%d)", m, len(txBytes)) } // TxRaw only has bytes fields. if wireType != protowire.BytesType { return fmt.Errorf("expected %d wire type, got %d", protowire.BytesType, wireType) } // Make sure fields are ordered in ascending order. if tagNum < prevTagNum { return fmt.Errorf("txRaw must follow ADR-027, got tagNum %d after tagNum %d", tagNum, prevTagNum) } prevTagNum = tagNum // All 3 fields of TxRaw have wireType == 2, so their next component // is a varint, so we can safely call ConsumeVarint here. // Byte structure: // Inner fields are verified in `DefaultTxDecoder` lengthPrefix, m := protowire.ConsumeVarint(txBytes[m:]) if m < 0 { return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) } // We make sure that this varint is as short as possible. n := varintMinLength(lengthPrefix) if n != m { return fmt.Errorf("length prefix varint for tagNum %d is not as short as possible, read %d, only need %d", tagNum, m, n) } // Skip over the bytes that store fieldNumber and wireType bytes. _, _, m = protowire.ConsumeField(txBytes) if m < 0 { return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) } txBytes = txBytes[m:] } return nil } // varintMinLength returns the minimum number of bytes necessary to encode an // uint using varint encoding. func varintMinLength(n uint64) int { switch { // Note: 1<