diff --git a/cmd/devp2p/internal/v5test/framework.go b/cmd/devp2p/internal/v5test/framework.go index 6ccbbd075..f31677e51 100644 --- a/cmd/devp2p/internal/v5test/framework.go +++ b/cmd/devp2p/internal/v5test/framework.go @@ -86,7 +86,7 @@ func newConn(dest *enode.Node, log logger) *conn { localNode: ln, remote: dest, remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()}, - codec: v5wire.NewCodec(ln, key, mclock.System{}), + codec: v5wire.NewCodec(ln, key, mclock.System{}, nil), log: log, } } diff --git a/p2p/discover/common.go b/p2p/discover/common.go index e389821fd..c36e8dcc3 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -35,16 +35,24 @@ type UDPConn interface { LocalAddr() net.Addr } +type V5Config struct { + ProtocolID *[6]byte +} + // Config holds settings for the discovery listener. type Config struct { // These settings are required and configure the UDP listener: PrivateKey *ecdsa.PrivateKey // These settings are optional: - NetRestrict *netutil.Netlist // list of allowed IP networks - Bootnodes []*enode.Node // list of bootstrap nodes - Unhandled chan<- ReadPacket // unhandled packets are sent on this channel - Log log.Logger // if set, log messages go here + NetRestrict *netutil.Netlist // list of allowed IP networks + Bootnodes []*enode.Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel + Log log.Logger // if set, log messages go here + + // V5ProtocolID configures the discv5 protocol identifier. + V5ProtocolID *[6]byte + ValidSchemes enr.IdentityScheme // allowed identity schemes Clock mclock.Clock } diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 321c5bd2a..57d624498 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -154,7 +154,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { callDoneCh: make(chan *callV5), respTimeoutCh: make(chan *callTimeout), // state of dispatch - codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock), + codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock, cfg.V5ProtocolID), activeCallByNode: make(map[enode.ID]*callV5), activeCallByAuth: make(map[v5wire.Nonce]*callV5), callQueue: make(map[enode.ID][]*callV5), diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index e41d7f4c4..d979ab0f9 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -98,7 +98,7 @@ const ( randomPacketMsgSize = 20 ) -var protocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'} +var DefaultProtocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'} // Errors. var ( @@ -134,10 +134,11 @@ var ( // Codec encodes and decodes Discovery v5 packets. // This type is not safe for concurrent use. type Codec struct { - sha256 hash.Hash - localnode *enode.LocalNode - privkey *ecdsa.PrivateKey - sc *SessionCache + sha256 hash.Hash + localnode *enode.LocalNode + privkey *ecdsa.PrivateKey + sc *SessionCache + protocolID [6]byte // encoder buffers buf bytes.Buffer // whole packet @@ -150,12 +151,16 @@ type Codec struct { } // NewCodec creates a wire codec. -func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *Codec { +func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, protocolID *[6]byte) *Codec { c := &Codec{ - sha256: sha256.New(), - localnode: ln, - privkey: key, - sc: NewSessionCache(1024, clock), + sha256: sha256.New(), + localnode: ln, + privkey: key, + sc: NewSessionCache(1024, clock), + protocolID: DefaultProtocolID, + } + if protocolID != nil { + c.protocolID = *protocolID } return c } @@ -255,7 +260,7 @@ func (c *Codec) makeHeader(toID enode.ID, flag byte, authsizeExtra int) Header { } return Header{ StaticHeader: StaticHeader{ - ProtocolID: protocolID, + ProtocolID: c.protocolID, Version: version, Flag: flag, AuthSize: uint16(authsize), @@ -434,7 +439,7 @@ func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node, c.reader.Reset(staticHeader) binary.Read(&c.reader, binary.BigEndian, &head.StaticHeader) remainingInput := len(input) - sizeofStaticPacketData - if err := head.checkValid(remainingInput); err != nil { + if err := head.checkValid(remainingInput, c.protocolID); err != nil { return enode.ID{}, nil, nil, err } @@ -621,7 +626,7 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet // checkValid performs some basic validity checks on the header. // The packetLen here is the length remaining after the static header. -func (h *StaticHeader) checkValid(packetLen int) error { +func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error { if h.ProtocolID != protocolID { return errInvalidHeader } diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go index a08cffa2a..25df73283 100644 --- a/p2p/discover/v5wire/encoding_test.go +++ b/p2p/discover/v5wire/encoding_test.go @@ -504,8 +504,8 @@ type handshakeTestNode struct { func newHandshakeTest() *handshakeTest { t := new(handshakeTest) - t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) - t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) + t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock, DefaultProtocolID) + t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock, DefaultProtocolID) return t } @@ -514,11 +514,11 @@ func (t *handshakeTest) close() { t.nodeB.ln.Database().Close() } -func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { +func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock, protocolID [6]byte) { db, _ := enode.OpenDB("") n.ln = enode.NewLocalNode(db, key) n.ln.SetStaticIP(ip) - n.c = NewCodec(n.ln, key, clock) + n.c = NewCodec(n.ln, key, clock, nil) } func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p Packet) ([]byte, Nonce) {