diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 53a1c6f76..38f5b3b65 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -83,6 +83,7 @@ type UDPv5 struct { callCh chan *callV5 callDoneCh chan *callV5 respTimeoutCh chan *callTimeout + unhandled chan<- ReadPacket // state of dispatch codec codecV5 @@ -156,6 +157,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { callCh: make(chan *callV5), callDoneCh: make(chan *callV5), respTimeoutCh: make(chan *callTimeout), + unhandled: cfg.Unhandled, // state of dispatch codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock, cfg.V5ProtocolID), activeCallByNode: make(map[enode.ID]*callV5), @@ -657,6 +659,14 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { addr := fromAddr.String() fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr) if err != nil { + if t.unhandled != nil && v5wire.IsInvalidHeader(err) { + // The packet seems unrelated to discv5, send it to the next protocol. + // t.log.Trace("Unhandled discv5 packet", "id", fromID, "addr", addr, "err", err) + up := ReadPacket{Data: make([]byte, len(rawpacket)), Addr: fromAddr} + copy(up.Data, rawpacket) + t.unhandled <- up + return nil + } t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err) return err } diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index d979ab0f9..510891062 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -94,6 +94,8 @@ const ( // Should reject packets smaller than minPacketSize. minPacketSize = 63 + maxPacketSize = 1280 + minMessageSize = 48 // this refers to data after static headers randomPacketMsgSize = 20 ) @@ -122,6 +124,13 @@ var ( ErrInvalidReqID = errors.New("request ID larger than 8 bytes") ) +// IsInvalidHeader reports whether 'err' is related to an invalid packet header. When it +// returns false, it is pretty certain that the packet causing the error does not belong +// to discv5. +func IsInvalidHeader(err error) bool { + return err == errTooShort || err == errInvalidHeader || err == errMsgTooShort +} + // Packet sizes. var ( sizeofStaticHeader = binary.Size(StaticHeader{}) @@ -147,6 +156,7 @@ type Codec struct { msgctbuf []byte // message data ciphertext // decoder buffer + decbuf []byte reader bytes.Reader } @@ -158,6 +168,7 @@ func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, pr privkey: key, sc: NewSessionCache(1024, clock), protocolID: DefaultProtocolID, + decbuf: make([]byte, maxPacketSize), } if protocolID != nil { c.protocolID = *protocolID @@ -424,10 +435,13 @@ func (c *Codec) encryptMessage(s *session, p Packet, head *Header, headerData [] } // Decode decodes a discovery packet. -func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node, p Packet, err error) { - if len(input) < minPacketSize { +func (c *Codec) Decode(inputData []byte, addr string) (src enode.ID, n *enode.Node, p Packet, err error) { + if len(inputData) < minPacketSize { return enode.ID{}, nil, nil, errTooShort } + // Copy the packet to a tmp buffer to avoid modifying it. + c.decbuf = append(c.decbuf[:0], inputData...) + input := c.decbuf // Unmask the static header. var head Header copy(head.IV[:], input[:sizeofMaskingIV])