diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go
index 8580c6121..99b0957ab 100644
--- a/cmd/devp2p/discv4cmd.go
+++ b/cmd/devp2p/discv4cmd.go
@@ -19,11 +19,14 @@ package main
import (
"fmt"
"net"
+ "os"
"strings"
"time"
+ "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
@@ -40,6 +43,7 @@ var (
discv4ResolveCommand,
discv4ResolveJSONCommand,
discv4CrawlCommand,
+ discv4TestCommand,
},
}
discv4PingCommand = cli.Command{
@@ -74,6 +78,12 @@ var (
Action: discv4Crawl,
Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag},
}
+ discv4TestCommand = cli.Command{
+ Name: "test",
+ Usage: "Runs tests against a node",
+ Action: discv4Test,
+ Flags: []cli.Flag{remoteEnodeFlag, testPatternFlag, testListen1Flag, testListen2Flag},
+ }
)
var (
@@ -98,6 +108,25 @@ var (
Usage: "Time limit for the crawl.",
Value: 30 * time.Minute,
}
+ remoteEnodeFlag = cli.StringFlag{
+ Name: "remote",
+ Usage: "Enode of the remote node under test",
+ EnvVar: "REMOTE_ENODE",
+ }
+ testPatternFlag = cli.StringFlag{
+ Name: "run",
+ Usage: "Pattern of test suite(s) to run",
+ }
+ testListen1Flag = cli.StringFlag{
+ Name: "listen1",
+ Usage: "IP address of the first tester",
+ Value: v4test.Listen1,
+ }
+ testListen2Flag = cli.StringFlag{
+ Name: "listen2",
+ Usage: "IP address of the second tester",
+ Value: v4test.Listen2,
+ }
)
func discv4Ping(ctx *cli.Context) error {
@@ -184,6 +213,28 @@ func discv4Crawl(ctx *cli.Context) error {
return nil
}
+func discv4Test(ctx *cli.Context) error {
+ // Configure test package globals.
+ if !ctx.IsSet(remoteEnodeFlag.Name) {
+ return fmt.Errorf("Missing -%v", remoteEnodeFlag.Name)
+ }
+ v4test.Remote = ctx.String(remoteEnodeFlag.Name)
+ v4test.Listen1 = ctx.String(testListen1Flag.Name)
+ v4test.Listen2 = ctx.String(testListen2Flag.Name)
+
+ // Filter and run test cases.
+ tests := v4test.AllTests
+ if ctx.IsSet(testPatternFlag.Name) {
+ tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name))
+ }
+ results := utesting.RunTests(tests, os.Stdout)
+ if fails := utesting.CountFailures(results); fails > 0 {
+ return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests))
+ }
+ fmt.Printf("%v/%v passed\n", len(tests), len(tests))
+ return nil
+}
+
// startV4 starts an ephemeral discovery V4 node.
func startV4(ctx *cli.Context) *discover.UDPv4 {
ln, config := makeDiscoveryConfig(ctx)
diff --git a/cmd/devp2p/internal/v4test/discv4tests.go b/cmd/devp2p/internal/v4test/discv4tests.go
new file mode 100644
index 000000000..140b96bfa
--- /dev/null
+++ b/cmd/devp2p/internal/v4test/discv4tests.go
@@ -0,0 +1,467 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+
+package v4test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "reflect"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/internal/utesting"
+ "github.com/ethereum/go-ethereum/p2p/discover/v4wire"
+)
+
+const (
+ expiration = 20 * time.Second
+ wrongPacket = 66
+ macSize = 256 / 8
+)
+
+var (
+ // Remote node under test
+ Remote string
+ // IP where the first tester is listening, port will be assigned
+ Listen1 string = "127.0.0.1"
+ // IP where the second tester is listening, port will be assigned
+ // Before running the test, you may have to `sudo ifconfig lo0 add 127.0.0.2` (on MacOS at least)
+ Listen2 string = "127.0.0.2"
+)
+
+type pingWithJunk struct {
+ Version uint
+ From, To v4wire.Endpoint
+ Expiration uint64
+ JunkData1 uint
+ JunkData2 []byte
+}
+
+func (req *pingWithJunk) Name() string { return "PING/v4" }
+func (req *pingWithJunk) Kind() byte { return v4wire.PingPacket }
+
+type pingWrongType struct {
+ Version uint
+ From, To v4wire.Endpoint
+ Expiration uint64
+}
+
+func (req *pingWrongType) Name() string { return "WRONG/v4" }
+func (req *pingWrongType) Kind() byte { return wrongPacket }
+
+func futureExpiration() uint64 {
+ return uint64(time.Now().Add(expiration).Unix())
+}
+
+// This test just sends a PING packet and expects a response.
+func BasicPing(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// checkPong verifies that reply is a valid PONG matching the given ping hash.
+func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error {
+ if reply == nil || reply.Kind() != v4wire.PongPacket {
+ return fmt.Errorf("expected PONG reply, got %v", reply)
+ }
+ pong := reply.(*v4wire.Pong)
+ if !bytes.Equal(pong.ReplyTok, pingHash) {
+ return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash)
+ }
+ wantEndpoint := te.localEndpoint(te.l1)
+ if !reflect.DeepEqual(pong.To, wantEndpoint) {
+ return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint)
+ }
+ if v4wire.Expired(pong.Expiration) {
+ return fmt.Errorf("PONG is expired (%v)", pong.Expiration)
+ }
+ return nil
+}
+
+// This test sends a PING packet with wrong 'to' field and expects a PONG response.
+func PingWrongTo(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: wrongEndpoint,
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with wrong 'from' field and expects a PONG response.
+func PingWrongFrom(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: wrongEndpoint,
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with additional data at the end and expects a PONG
+// response. The remote node should respond because EIP-8 mandates ignoring additional
+// trailing data.
+func PingExtraData(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ pingHash := te.send(te.l1, &pingWithJunk{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ JunkData1: 42,
+ JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1},
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with additional data and wrong 'from' field
+// and expects a PONG response.
+func PingExtraDataWrongFrom(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ req := pingWithJunk{
+ Version: 4,
+ From: wrongEndpoint,
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ JunkData1: 42,
+ JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1},
+ }
+ pingHash := te.send(te.l1, &req)
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with an expiration in the past.
+// The remote node should not respond.
+func PingPastExpiration(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: -futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if reply != nil {
+ t.Fatal("Expected no reply, got", reply)
+ }
+}
+
+// This test sends an invalid packet. The remote node should not respond.
+func WrongPacketType(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ te.send(te.l1, &pingWrongType{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if reply != nil {
+ t.Fatal("Expected no reply, got", reply)
+ }
+}
+
+// This test verifies that the default behaviour of ignoring 'from' fields is unaffected by
+// the bonding process. After bonding, it pings the target with a different from endpoint.
+func BondThenPingWithWrongFrom(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: wrongEndpoint,
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test just sends FINDNODE. The remote node should not reply
+// because the endpoint proof has not completed.
+func FindnodeWithoutEndpointProof(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ req := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(req.Target[:])
+ te.send(te.l1, &req)
+
+ reply, _, _ := te.read(te.l1)
+ if reply != nil {
+ t.Fatal("Expected no response, got", reply)
+ }
+}
+
+// BasicFindnode sends a FINDNODE request after performing the endpoint
+// proof. The remote node should respond.
+func BasicFindnode(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ findnode := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l1, &findnode)
+
+ reply, _, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal("read find nodes", err)
+ }
+ if reply.Kind() != v4wire.NeighborsPacket {
+ t.Fatal("Expected neighbors, got", reply.Name())
+ }
+}
+
+// This test sends an unsolicited NEIGHBORS packet after the endpoint proof, then sends
+// FINDNODE to read the remote table. The remote node should not return the node contained
+// in the unsolicited NEIGHBORS packet.
+func UnsolicitedNeighbors(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ // Send unsolicited NEIGHBORS response.
+ fakeKey, _ := crypto.GenerateKey()
+ encFakeKey := v4wire.EncodePubkey(&fakeKey.PublicKey)
+ neighbors := v4wire.Neighbors{
+ Expiration: futureExpiration(),
+ Nodes: []v4wire.Node{{
+ ID: encFakeKey,
+ IP: net.IP{1, 2, 3, 4},
+ UDP: 30303,
+ TCP: 30303,
+ }},
+ }
+ te.send(te.l1, &neighbors)
+
+ // Check if the remote node included the fake node.
+ te.send(te.l1, &v4wire.Findnode{
+ Expiration: futureExpiration(),
+ Target: encFakeKey,
+ })
+
+ reply, _, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal("read find nodes", err)
+ }
+ if reply.Kind() != v4wire.NeighborsPacket {
+ t.Fatal("Expected neighbors, got", reply.Name())
+ }
+ nodes := reply.(*v4wire.Neighbors).Nodes
+ if contains(nodes, encFakeKey) {
+ t.Fatal("neighbors response contains node from earlier unsolicited neighbors response")
+ }
+}
+
+// This test sends FINDNODE with an expiration timestamp in the past.
+// The remote node should not respond.
+func FindnodePastExpiration(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ findnode := v4wire.Findnode{Expiration: -futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l1, &findnode)
+
+ for {
+ reply, _, _ := te.read(te.l1)
+ if reply == nil {
+ return
+ } else if reply.Kind() == v4wire.NeighborsPacket {
+ t.Fatal("Unexpected NEIGHBORS response for expired FINDNODE request")
+ }
+ }
+}
+
+// bond performs the endpoint proof with the remote node.
+func bond(t *utesting.T, te *testenv) {
+ te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ var gotPing, gotPong bool
+ for !gotPing || !gotPong {
+ req, hash, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch req.(type) {
+ case *v4wire.Ping:
+ te.send(te.l1, &v4wire.Pong{
+ To: te.remoteEndpoint(),
+ ReplyTok: hash,
+ Expiration: futureExpiration(),
+ })
+ gotPing = true
+ case *v4wire.Pong:
+ // TODO: maybe verify pong data here
+ gotPong = true
+ }
+ }
+}
+
+// This test attempts to perform a traffic amplification attack against a
+// 'victim' endpoint using FINDNODE. In this attack scenario, the attacker
+// attempts to complete the endpoint proof non-interactively by sending a PONG
+// with mismatching reply token from the 'victim' endpoint. The attack works if
+// the remote node does not verify the PONG reply token field correctly. The
+// attacker could then perform traffic amplification by sending many FINDNODE
+// requests to the discovery node, which would reply to the 'victim' address.
+func FindnodeAmplificationInvalidPongHash(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ // Send PING to start endpoint verification.
+ te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ var gotPing, gotPong bool
+ for !gotPing || !gotPong {
+ req, _, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch req.(type) {
+ case *v4wire.Ping:
+ // Send PONG from this node ID, but with invalid ReplyTok.
+ te.send(te.l1, &v4wire.Pong{
+ To: te.remoteEndpoint(),
+ ReplyTok: make([]byte, macSize),
+ Expiration: futureExpiration(),
+ })
+ gotPing = true
+ case *v4wire.Pong:
+ gotPong = true
+ }
+ }
+
+ // Now send FINDNODE. The remote node should not respond because our
+ // PONG did not reference the PING hash.
+ findnode := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l1, &findnode)
+
+ // If we receive a NEIGHBORS response, the attack worked and the test fails.
+ reply, _, _ := te.read(te.l1)
+ if reply != nil && reply.Kind() == v4wire.NeighborsPacket {
+ t.Error("Got neighbors")
+ }
+}
+
+// This test attempts to perform a traffic amplification attack using FINDNODE.
+// The attack works if the remote node does not verify the IP address of FINDNODE
+// against the endpoint verification proof done by PING/PONG.
+func FindnodeAmplificationWrongIP(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ // Do the endpoint proof from the l1 IP.
+ bond(t, te)
+
+ // Now send FINDNODE from the same node ID, but different IP address.
+ // The remote node should not respond.
+ findnode := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l2, &findnode)
+
+ // If we receive a NEIGHBORS response, the attack worked and the test fails.
+ reply, _, _ := te.read(te.l2)
+ if reply != nil {
+ t.Error("Got NEIGHORS response for FINDNODE from wrong IP")
+ }
+}
+
+var AllTests = []utesting.Test{
+ {Name: "Ping/Basic", Fn: BasicPing},
+ {Name: "Ping/WrongTo", Fn: PingWrongTo},
+ {Name: "Ping/WrongFrom", Fn: PingWrongFrom},
+ {Name: "Ping/ExtraData", Fn: PingExtraData},
+ {Name: "Ping/ExtraDataWrongFrom", Fn: PingExtraDataWrongFrom},
+ {Name: "Ping/PastExpiration", Fn: PingPastExpiration},
+ {Name: "Ping/WrongPacketType", Fn: WrongPacketType},
+ {Name: "Ping/BondThenPingWithWrongFrom", Fn: BondThenPingWithWrongFrom},
+ {Name: "Findnode/WithoutEndpointProof", Fn: FindnodeWithoutEndpointProof},
+ {Name: "Findnode/BasicFindnode", Fn: BasicFindnode},
+ {Name: "Findnode/UnsolicitedNeighbors", Fn: UnsolicitedNeighbors},
+ {Name: "Findnode/PastExpiration", Fn: FindnodePastExpiration},
+ {Name: "Amplification/InvalidPongHash", Fn: FindnodeAmplificationInvalidPongHash},
+ {Name: "Amplification/WrongIP", Fn: FindnodeAmplificationWrongIP},
+}
diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go
new file mode 100644
index 000000000..928659418
--- /dev/null
+++ b/cmd/devp2p/internal/v4test/framework.go
@@ -0,0 +1,123 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+
+package v4test
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/p2p/discover/v4wire"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+const waitTime = 300 * time.Millisecond
+
+type testenv struct {
+ l1, l2 net.PacketConn
+ key *ecdsa.PrivateKey
+ remote *enode.Node
+ remoteAddr *net.UDPAddr
+}
+
+func newTestEnv(remote string, listen1, listen2 string) *testenv {
+ l1, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen1))
+ if err != nil {
+ panic(err)
+ }
+ l2, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen2))
+ if err != nil {
+ panic(err)
+ }
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic(err)
+ }
+ node, err := enode.Parse(enode.ValidSchemes, remote)
+ if err != nil {
+ panic(err)
+ }
+ if node.IP() == nil || node.UDP() == 0 {
+ var ip net.IP
+ var tcpPort, udpPort int
+ if ip = node.IP(); ip == nil {
+ ip = net.ParseIP("127.0.0.1")
+ }
+ if tcpPort = node.TCP(); tcpPort == 0 {
+ tcpPort = 30303
+ }
+ if udpPort = node.TCP(); udpPort == 0 {
+ udpPort = 30303
+ }
+ node = enode.NewV4(node.Pubkey(), ip, tcpPort, udpPort)
+ }
+ addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()}
+ return &testenv{l1, l2, key, node, addr}
+}
+
+func (te *testenv) close() {
+ te.l1.Close()
+ te.l2.Close()
+}
+
+func (te *testenv) send(c net.PacketConn, req v4wire.Packet) []byte {
+ packet, hash, err := v4wire.Encode(te.key, req)
+ if err != nil {
+ panic(fmt.Errorf("can't encode %v packet: %v", req.Name(), err))
+ }
+ if _, err := c.WriteTo(packet, te.remoteAddr); err != nil {
+ panic(fmt.Errorf("can't send %v: %v", req.Name(), err))
+ }
+ return hash
+}
+
+func (te *testenv) read(c net.PacketConn) (v4wire.Packet, []byte, error) {
+ buf := make([]byte, 2048)
+ if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
+ return nil, nil, err
+ }
+ n, _, err := c.ReadFrom(buf)
+ if err != nil {
+ return nil, nil, err
+ }
+ p, _, hash, err := v4wire.Decode(buf[:n])
+ return p, hash, err
+}
+
+func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint {
+ addr := c.LocalAddr().(*net.UDPAddr)
+ return v4wire.Endpoint{
+ IP: addr.IP.To4(),
+ UDP: uint16(addr.Port),
+ TCP: 0,
+ }
+}
+
+func (te *testenv) remoteEndpoint() v4wire.Endpoint {
+ return v4wire.NewEndpoint(te.remoteAddr, 0)
+}
+
+func contains(ns []v4wire.Node, key v4wire.Pubkey) bool {
+ for _, n := range ns {
+ if n.ID == key {
+ return true
+ }
+ }
+ return false
+}
diff --git a/go.sum b/go.sum
index 0fc59d736..7ef5af8a5 100644
--- a/go.sum
+++ b/go.sum
@@ -207,6 +207,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
diff --git a/internal/utesting/utesting.go b/internal/utesting/utesting.go
new file mode 100644
index 000000000..23c748cae
--- /dev/null
+++ b/internal/utesting/utesting.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package utesting provides a standalone replacement for package testing.
+//
+// This package exists because package testing cannot easily be embedded into a
+// standalone go program. It provides an API that mirrors the standard library
+// testing API.
+package utesting
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "regexp"
+ "runtime"
+ "sync"
+ "time"
+)
+
+// Test represents a single test.
+type Test struct {
+ Name string
+ Fn func(*T)
+}
+
+// Result is the result of a test execution.
+type Result struct {
+ Name string
+ Failed bool
+ Output string
+ Duration time.Duration
+}
+
+// MatchTests returns the tests whose name matches a regular expression.
+func MatchTests(tests []Test, expr string) []Test {
+ var results []Test
+ re, err := regexp.Compile(expr)
+ if err != nil {
+ return nil
+ }
+ for _, test := range tests {
+ if re.MatchString(test.Name) {
+ results = append(results, test)
+ }
+ }
+ return results
+}
+
+// RunTests executes all given tests in order and returns their results.
+// If the report writer is non-nil, a test report is written to it in real time.
+func RunTests(tests []Test, report io.Writer) []Result {
+ results := make([]Result, len(tests))
+ for i, test := range tests {
+ start := time.Now()
+ results[i].Name = test.Name
+ results[i].Failed, results[i].Output = Run(test)
+ results[i].Duration = time.Since(start)
+ if report != nil {
+ printResult(results[i], report)
+ }
+ }
+ return results
+}
+
+func printResult(r Result, w io.Writer) {
+ pd := r.Duration.Truncate(100 * time.Microsecond)
+ if r.Failed {
+ fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd)
+ fmt.Fprintln(w, r.Output)
+ } else {
+ fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd)
+ }
+}
+
+// CountFailures returns the number of failed tests in the result slice.
+func CountFailures(rr []Result) int {
+ count := 0
+ for _, r := range rr {
+ if r.Failed {
+ count++
+ }
+ }
+ return count
+}
+
+// Run executes a single test.
+func Run(test Test) (bool, string) {
+ t := new(T)
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ defer func() {
+ if err := recover(); err != nil {
+ buf := make([]byte, 4096)
+ i := runtime.Stack(buf, false)
+ t.Logf("panic: %v\n\n%s", err, buf[:i])
+ t.Fail()
+ }
+ }()
+ test.Fn(t)
+ }()
+ <-done
+ return t.failed, t.output.String()
+}
+
+// T is the value given to the test function. The test can signal failures
+// and log output by calling methods on this object.
+type T struct {
+ mu sync.Mutex
+ failed bool
+ output bytes.Buffer
+}
+
+// FailNow marks the test as having failed and stops its execution by calling
+// runtime.Goexit (which then runs all deferred calls in the current goroutine).
+func (t *T) FailNow() {
+ t.Fail()
+ runtime.Goexit()
+}
+
+// Fail marks the test as having failed but continues execution.
+func (t *T) Fail() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.failed = true
+}
+
+// Failed reports whether the test has failed.
+func (t *T) Failed() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.failed
+}
+
+// Log formats its arguments using default formatting, analogous to Println, and records
+// the text in the error log.
+func (t *T) Log(vs ...interface{}) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ fmt.Fprintln(&t.output, vs...)
+}
+
+// Logf formats its arguments according to the format, analogous to Printf, and records
+// the text in the error log. A final newline is added if not provided.
+func (t *T) Logf(format string, vs ...interface{}) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if len(format) == 0 || format[len(format)-1] != '\n' {
+ format += "\n"
+ }
+ fmt.Fprintf(&t.output, format, vs...)
+}
+
+// Error is equivalent to Log followed by Fail.
+func (t *T) Error(vs ...interface{}) {
+ t.Log(vs...)
+ t.Fail()
+}
+
+// Errorf is equivalent to Logf followed by Fail.
+func (t *T) Errorf(format string, vs ...interface{}) {
+ t.Logf(format, vs...)
+ t.Fail()
+}
+
+// Fatal is equivalent to Log followed by FailNow.
+func (t *T) Fatal(vs ...interface{}) {
+ t.Log(vs...)
+ t.FailNow()
+}
+
+// Fatalf is equivalent to Logf followed by FailNow.
+func (t *T) Fatalf(format string, vs ...interface{}) {
+ t.Logf(format, vs...)
+ t.FailNow()
+}
diff --git a/internal/utesting/utesting_test.go b/internal/utesting/utesting_test.go
new file mode 100644
index 000000000..1403a5c8f
--- /dev/null
+++ b/internal/utesting/utesting_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package utesting
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestTest(t *testing.T) {
+ tests := []Test{
+ {
+ Name: "successful test",
+ Fn: func(t *T) {},
+ },
+ {
+ Name: "failing test",
+ Fn: func(t *T) {
+ t.Log("output")
+ t.Error("failed")
+ },
+ },
+ {
+ Name: "panicking test",
+ Fn: func(t *T) {
+ panic("oh no")
+ },
+ },
+ }
+ results := RunTests(tests, nil)
+
+ if results[0].Failed || results[0].Output != "" {
+ t.Fatalf("wrong result for successful test: %#v", results[0])
+ }
+ if !results[1].Failed || results[1].Output != "output\nfailed\n" {
+ t.Fatalf("wrong result for failing test: %#v", results[1])
+ }
+ if !results[2].Failed || !strings.HasPrefix(results[2].Output, "panic: oh no\n") {
+ t.Fatalf("wrong result for panicking test: %#v", results[2])
+ }
+}