p2p/discover: add config option for discv5 protocol ID (#26041)

This option is occasionally useful for advanced uses of the discv5 protocol.

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
RichΛrd 2022-11-30 17:03:34 -04:00 committed by GitHub
parent 1b8a392153
commit c1aa1db69e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 23 deletions

View File

@ -86,7 +86,7 @@ func newConn(dest *enode.Node, log logger) *conn {
localNode: ln, localNode: ln,
remote: dest, remote: dest,
remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()}, remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()},
codec: v5wire.NewCodec(ln, key, mclock.System{}), codec: v5wire.NewCodec(ln, key, mclock.System{}, nil),
log: log, log: log,
} }
} }

View File

@ -35,16 +35,24 @@ type UDPConn interface {
LocalAddr() net.Addr LocalAddr() net.Addr
} }
type V5Config struct {
ProtocolID *[6]byte
}
// Config holds settings for the discovery listener. // Config holds settings for the discovery listener.
type Config struct { type Config struct {
// These settings are required and configure the UDP listener: // These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey PrivateKey *ecdsa.PrivateKey
// These settings are optional: // These settings are optional:
NetRestrict *netutil.Netlist // list of allowed IP networks NetRestrict *netutil.Netlist // list of allowed IP networks
Bootnodes []*enode.Node // list of bootstrap nodes Bootnodes []*enode.Node // list of bootstrap nodes
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
Log log.Logger // if set, log messages go here Log log.Logger // if set, log messages go here
// V5ProtocolID configures the discv5 protocol identifier.
V5ProtocolID *[6]byte
ValidSchemes enr.IdentityScheme // allowed identity schemes ValidSchemes enr.IdentityScheme // allowed identity schemes
Clock mclock.Clock Clock mclock.Clock
} }

View File

@ -154,7 +154,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
callDoneCh: make(chan *callV5), callDoneCh: make(chan *callV5),
respTimeoutCh: make(chan *callTimeout), respTimeoutCh: make(chan *callTimeout),
// state of dispatch // state of dispatch
codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock), codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock, cfg.V5ProtocolID),
activeCallByNode: make(map[enode.ID]*callV5), activeCallByNode: make(map[enode.ID]*callV5),
activeCallByAuth: make(map[v5wire.Nonce]*callV5), activeCallByAuth: make(map[v5wire.Nonce]*callV5),
callQueue: make(map[enode.ID][]*callV5), callQueue: make(map[enode.ID][]*callV5),

View File

@ -98,7 +98,7 @@ const (
randomPacketMsgSize = 20 randomPacketMsgSize = 20
) )
var protocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'} var DefaultProtocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'}
// Errors. // Errors.
var ( var (
@ -134,10 +134,11 @@ var (
// Codec encodes and decodes Discovery v5 packets. // Codec encodes and decodes Discovery v5 packets.
// This type is not safe for concurrent use. // This type is not safe for concurrent use.
type Codec struct { type Codec struct {
sha256 hash.Hash sha256 hash.Hash
localnode *enode.LocalNode localnode *enode.LocalNode
privkey *ecdsa.PrivateKey privkey *ecdsa.PrivateKey
sc *SessionCache sc *SessionCache
protocolID [6]byte
// encoder buffers // encoder buffers
buf bytes.Buffer // whole packet buf bytes.Buffer // whole packet
@ -150,12 +151,16 @@ type Codec struct {
} }
// NewCodec creates a wire codec. // NewCodec creates a wire codec.
func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *Codec { func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, protocolID *[6]byte) *Codec {
c := &Codec{ c := &Codec{
sha256: sha256.New(), sha256: sha256.New(),
localnode: ln, localnode: ln,
privkey: key, privkey: key,
sc: NewSessionCache(1024, clock), sc: NewSessionCache(1024, clock),
protocolID: DefaultProtocolID,
}
if protocolID != nil {
c.protocolID = *protocolID
} }
return c return c
} }
@ -255,7 +260,7 @@ func (c *Codec) makeHeader(toID enode.ID, flag byte, authsizeExtra int) Header {
} }
return Header{ return Header{
StaticHeader: StaticHeader{ StaticHeader: StaticHeader{
ProtocolID: protocolID, ProtocolID: c.protocolID,
Version: version, Version: version,
Flag: flag, Flag: flag,
AuthSize: uint16(authsize), AuthSize: uint16(authsize),
@ -434,7 +439,7 @@ func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node,
c.reader.Reset(staticHeader) c.reader.Reset(staticHeader)
binary.Read(&c.reader, binary.BigEndian, &head.StaticHeader) binary.Read(&c.reader, binary.BigEndian, &head.StaticHeader)
remainingInput := len(input) - sizeofStaticPacketData remainingInput := len(input) - sizeofStaticPacketData
if err := head.checkValid(remainingInput); err != nil { if err := head.checkValid(remainingInput, c.protocolID); err != nil {
return enode.ID{}, nil, nil, err return enode.ID{}, nil, nil, err
} }
@ -621,7 +626,7 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
// checkValid performs some basic validity checks on the header. // checkValid performs some basic validity checks on the header.
// The packetLen here is the length remaining after the static header. // The packetLen here is the length remaining after the static header.
func (h *StaticHeader) checkValid(packetLen int) error { func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {
if h.ProtocolID != protocolID { if h.ProtocolID != protocolID {
return errInvalidHeader return errInvalidHeader
} }

View File

@ -504,8 +504,8 @@ type handshakeTestNode struct {
func newHandshakeTest() *handshakeTest { func newHandshakeTest() *handshakeTest {
t := new(handshakeTest) t := new(handshakeTest)
t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock, DefaultProtocolID)
t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock, DefaultProtocolID)
return t return t
} }
@ -514,11 +514,11 @@ func (t *handshakeTest) close() {
t.nodeB.ln.Database().Close() t.nodeB.ln.Database().Close()
} }
func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock, protocolID [6]byte) {
db, _ := enode.OpenDB("") db, _ := enode.OpenDB("")
n.ln = enode.NewLocalNode(db, key) n.ln = enode.NewLocalNode(db, key)
n.ln.SetStaticIP(ip) n.ln.SetStaticIP(ip)
n.c = NewCodec(n.ln, key, clock) n.c = NewCodec(n.ln, key, clock, nil)
} }
func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p Packet) ([]byte, Nonce) { func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p Packet) ([]byte, Nonce) {