cmd/devp2p: add discv4 test suite (#21163)

This adds a test suite for discovery v4. The test suite is a port of the Hive suite for
discovery, and will replace the current suite on Hive soon-ish. The tests can be
run locally with this command:

    devp2p discv4 test -remote enode//...

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Adam Schmideg 2020-07-07 14:37:33 +02:00 committed by GitHub
parent e5871b928f
commit 6a48ae37b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 887 additions and 0 deletions

View File

@ -19,11 +19,14 @@ package main
import (
"fmt"
"net"
"os"
"strings"
"time"
"github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
@ -40,6 +43,7 @@ var (
discv4ResolveCommand,
discv4ResolveJSONCommand,
discv4CrawlCommand,
discv4TestCommand,
},
}
discv4PingCommand = cli.Command{
@ -74,6 +78,12 @@ var (
Action: discv4Crawl,
Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag},
}
discv4TestCommand = cli.Command{
Name: "test",
Usage: "Runs tests against a node",
Action: discv4Test,
Flags: []cli.Flag{remoteEnodeFlag, testPatternFlag, testListen1Flag, testListen2Flag},
}
)
var (
@ -98,6 +108,25 @@ var (
Usage: "Time limit for the crawl.",
Value: 30 * time.Minute,
}
remoteEnodeFlag = cli.StringFlag{
Name: "remote",
Usage: "Enode of the remote node under test",
EnvVar: "REMOTE_ENODE",
}
testPatternFlag = cli.StringFlag{
Name: "run",
Usage: "Pattern of test suite(s) to run",
}
testListen1Flag = cli.StringFlag{
Name: "listen1",
Usage: "IP address of the first tester",
Value: v4test.Listen1,
}
testListen2Flag = cli.StringFlag{
Name: "listen2",
Usage: "IP address of the second tester",
Value: v4test.Listen2,
}
)
func discv4Ping(ctx *cli.Context) error {
@ -184,6 +213,28 @@ func discv4Crawl(ctx *cli.Context) error {
return nil
}
func discv4Test(ctx *cli.Context) error {
// Configure test package globals.
if !ctx.IsSet(remoteEnodeFlag.Name) {
return fmt.Errorf("Missing -%v", remoteEnodeFlag.Name)
}
v4test.Remote = ctx.String(remoteEnodeFlag.Name)
v4test.Listen1 = ctx.String(testListen1Flag.Name)
v4test.Listen2 = ctx.String(testListen2Flag.Name)
// Filter and run test cases.
tests := v4test.AllTests
if ctx.IsSet(testPatternFlag.Name) {
tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name))
}
results := utesting.RunTests(tests, os.Stdout)
if fails := utesting.CountFailures(results); fails > 0 {
return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests))
}
fmt.Printf("%v/%v passed\n", len(tests), len(tests))
return nil
}
// startV4 starts an ephemeral discovery V4 node.
func startV4(ctx *cli.Context) *discover.UDPv4 {
ln, config := makeDiscoveryConfig(ctx)

View File

@ -0,0 +1,467 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
package v4test
import (
"bytes"
"crypto/rand"
"fmt"
"net"
"reflect"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p/discover/v4wire"
)
const (
expiration = 20 * time.Second
wrongPacket = 66
macSize = 256 / 8
)
var (
// Remote node under test
Remote string
// IP where the first tester is listening, port will be assigned
Listen1 string = "127.0.0.1"
// IP where the second tester is listening, port will be assigned
// Before running the test, you may have to `sudo ifconfig lo0 add 127.0.0.2` (on MacOS at least)
Listen2 string = "127.0.0.2"
)
type pingWithJunk struct {
Version uint
From, To v4wire.Endpoint
Expiration uint64
JunkData1 uint
JunkData2 []byte
}
func (req *pingWithJunk) Name() string { return "PING/v4" }
func (req *pingWithJunk) Kind() byte { return v4wire.PingPacket }
type pingWrongType struct {
Version uint
From, To v4wire.Endpoint
Expiration uint64
}
func (req *pingWrongType) Name() string { return "WRONG/v4" }
func (req *pingWrongType) Kind() byte { return wrongPacket }
func futureExpiration() uint64 {
return uint64(time.Now().Add(expiration).Unix())
}
// This test just sends a PING packet and expects a response.
func BasicPing(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
pingHash := te.send(te.l1, &v4wire.Ping{
Version: 4,
From: te.localEndpoint(te.l1),
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
})
reply, _, _ := te.read(te.l1)
if err := te.checkPong(reply, pingHash); err != nil {
t.Fatal(err)
}
}
// checkPong verifies that reply is a valid PONG matching the given ping hash.
func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error {
if reply == nil || reply.Kind() != v4wire.PongPacket {
return fmt.Errorf("expected PONG reply, got %v", reply)
}
pong := reply.(*v4wire.Pong)
if !bytes.Equal(pong.ReplyTok, pingHash) {
return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash)
}
wantEndpoint := te.localEndpoint(te.l1)
if !reflect.DeepEqual(pong.To, wantEndpoint) {
return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint)
}
if v4wire.Expired(pong.Expiration) {
return fmt.Errorf("PONG is expired (%v)", pong.Expiration)
}
return nil
}
// This test sends a PING packet with wrong 'to' field and expects a PONG response.
func PingWrongTo(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
pingHash := te.send(te.l1, &v4wire.Ping{
Version: 4,
From: te.localEndpoint(te.l1),
To: wrongEndpoint,
Expiration: futureExpiration(),
})
reply, _, _ := te.read(te.l1)
if err := te.checkPong(reply, pingHash); err != nil {
t.Fatal(err)
}
}
// This test sends a PING packet with wrong 'from' field and expects a PONG response.
func PingWrongFrom(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
pingHash := te.send(te.l1, &v4wire.Ping{
Version: 4,
From: wrongEndpoint,
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
})
reply, _, _ := te.read(te.l1)
if err := te.checkPong(reply, pingHash); err != nil {
t.Fatal(err)
}
}
// This test sends a PING packet with additional data at the end and expects a PONG
// response. The remote node should respond because EIP-8 mandates ignoring additional
// trailing data.
func PingExtraData(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
pingHash := te.send(te.l1, &pingWithJunk{
Version: 4,
From: te.localEndpoint(te.l1),
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
JunkData1: 42,
JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1},
})
reply, _, _ := te.read(te.l1)
if err := te.checkPong(reply, pingHash); err != nil {
t.Fatal(err)
}
}
// This test sends a PING packet with additional data and wrong 'from' field
// and expects a PONG response.
func PingExtraDataWrongFrom(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
req := pingWithJunk{
Version: 4,
From: wrongEndpoint,
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
JunkData1: 42,
JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1},
}
pingHash := te.send(te.l1, &req)
reply, _, _ := te.read(te.l1)
if err := te.checkPong(reply, pingHash); err != nil {
t.Fatal(err)
}
}
// This test sends a PING packet with an expiration in the past.
// The remote node should not respond.
func PingPastExpiration(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
te.send(te.l1, &v4wire.Ping{
Version: 4,
From: te.localEndpoint(te.l1),
To: te.remoteEndpoint(),
Expiration: -futureExpiration(),
})
reply, _, _ := te.read(te.l1)
if reply != nil {
t.Fatal("Expected no reply, got", reply)
}
}
// This test sends an invalid packet. The remote node should not respond.
func WrongPacketType(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
te.send(te.l1, &pingWrongType{
Version: 4,
From: te.localEndpoint(te.l1),
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
})
reply, _, _ := te.read(te.l1)
if reply != nil {
t.Fatal("Expected no reply, got", reply)
}
}
// This test verifies that the default behaviour of ignoring 'from' fields is unaffected by
// the bonding process. After bonding, it pings the target with a different from endpoint.
func BondThenPingWithWrongFrom(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
bond(t, te)
wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
pingHash := te.send(te.l1, &v4wire.Ping{
Version: 4,
From: wrongEndpoint,
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
})
reply, _, _ := te.read(te.l1)
if err := te.checkPong(reply, pingHash); err != nil {
t.Fatal(err)
}
}
// This test just sends FINDNODE. The remote node should not reply
// because the endpoint proof has not completed.
func FindnodeWithoutEndpointProof(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
req := v4wire.Findnode{Expiration: futureExpiration()}
rand.Read(req.Target[:])
te.send(te.l1, &req)
reply, _, _ := te.read(te.l1)
if reply != nil {
t.Fatal("Expected no response, got", reply)
}
}
// BasicFindnode sends a FINDNODE request after performing the endpoint
// proof. The remote node should respond.
func BasicFindnode(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
bond(t, te)
findnode := v4wire.Findnode{Expiration: futureExpiration()}
rand.Read(findnode.Target[:])
te.send(te.l1, &findnode)
reply, _, err := te.read(te.l1)
if err != nil {
t.Fatal("read find nodes", err)
}
if reply.Kind() != v4wire.NeighborsPacket {
t.Fatal("Expected neighbors, got", reply.Name())
}
}
// This test sends an unsolicited NEIGHBORS packet after the endpoint proof, then sends
// FINDNODE to read the remote table. The remote node should not return the node contained
// in the unsolicited NEIGHBORS packet.
func UnsolicitedNeighbors(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
bond(t, te)
// Send unsolicited NEIGHBORS response.
fakeKey, _ := crypto.GenerateKey()
encFakeKey := v4wire.EncodePubkey(&fakeKey.PublicKey)
neighbors := v4wire.Neighbors{
Expiration: futureExpiration(),
Nodes: []v4wire.Node{{
ID: encFakeKey,
IP: net.IP{1, 2, 3, 4},
UDP: 30303,
TCP: 30303,
}},
}
te.send(te.l1, &neighbors)
// Check if the remote node included the fake node.
te.send(te.l1, &v4wire.Findnode{
Expiration: futureExpiration(),
Target: encFakeKey,
})
reply, _, err := te.read(te.l1)
if err != nil {
t.Fatal("read find nodes", err)
}
if reply.Kind() != v4wire.NeighborsPacket {
t.Fatal("Expected neighbors, got", reply.Name())
}
nodes := reply.(*v4wire.Neighbors).Nodes
if contains(nodes, encFakeKey) {
t.Fatal("neighbors response contains node from earlier unsolicited neighbors response")
}
}
// This test sends FINDNODE with an expiration timestamp in the past.
// The remote node should not respond.
func FindnodePastExpiration(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
bond(t, te)
findnode := v4wire.Findnode{Expiration: -futureExpiration()}
rand.Read(findnode.Target[:])
te.send(te.l1, &findnode)
for {
reply, _, _ := te.read(te.l1)
if reply == nil {
return
} else if reply.Kind() == v4wire.NeighborsPacket {
t.Fatal("Unexpected NEIGHBORS response for expired FINDNODE request")
}
}
}
// bond performs the endpoint proof with the remote node.
func bond(t *utesting.T, te *testenv) {
te.send(te.l1, &v4wire.Ping{
Version: 4,
From: te.localEndpoint(te.l1),
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
})
var gotPing, gotPong bool
for !gotPing || !gotPong {
req, hash, err := te.read(te.l1)
if err != nil {
t.Fatal(err)
}
switch req.(type) {
case *v4wire.Ping:
te.send(te.l1, &v4wire.Pong{
To: te.remoteEndpoint(),
ReplyTok: hash,
Expiration: futureExpiration(),
})
gotPing = true
case *v4wire.Pong:
// TODO: maybe verify pong data here
gotPong = true
}
}
}
// This test attempts to perform a traffic amplification attack against a
// 'victim' endpoint using FINDNODE. In this attack scenario, the attacker
// attempts to complete the endpoint proof non-interactively by sending a PONG
// with mismatching reply token from the 'victim' endpoint. The attack works if
// the remote node does not verify the PONG reply token field correctly. The
// attacker could then perform traffic amplification by sending many FINDNODE
// requests to the discovery node, which would reply to the 'victim' address.
func FindnodeAmplificationInvalidPongHash(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
// Send PING to start endpoint verification.
te.send(te.l1, &v4wire.Ping{
Version: 4,
From: te.localEndpoint(te.l1),
To: te.remoteEndpoint(),
Expiration: futureExpiration(),
})
var gotPing, gotPong bool
for !gotPing || !gotPong {
req, _, err := te.read(te.l1)
if err != nil {
t.Fatal(err)
}
switch req.(type) {
case *v4wire.Ping:
// Send PONG from this node ID, but with invalid ReplyTok.
te.send(te.l1, &v4wire.Pong{
To: te.remoteEndpoint(),
ReplyTok: make([]byte, macSize),
Expiration: futureExpiration(),
})
gotPing = true
case *v4wire.Pong:
gotPong = true
}
}
// Now send FINDNODE. The remote node should not respond because our
// PONG did not reference the PING hash.
findnode := v4wire.Findnode{Expiration: futureExpiration()}
rand.Read(findnode.Target[:])
te.send(te.l1, &findnode)
// If we receive a NEIGHBORS response, the attack worked and the test fails.
reply, _, _ := te.read(te.l1)
if reply != nil && reply.Kind() == v4wire.NeighborsPacket {
t.Error("Got neighbors")
}
}
// This test attempts to perform a traffic amplification attack using FINDNODE.
// The attack works if the remote node does not verify the IP address of FINDNODE
// against the endpoint verification proof done by PING/PONG.
func FindnodeAmplificationWrongIP(t *utesting.T) {
te := newTestEnv(Remote, Listen1, Listen2)
defer te.close()
// Do the endpoint proof from the l1 IP.
bond(t, te)
// Now send FINDNODE from the same node ID, but different IP address.
// The remote node should not respond.
findnode := v4wire.Findnode{Expiration: futureExpiration()}
rand.Read(findnode.Target[:])
te.send(te.l2, &findnode)
// If we receive a NEIGHBORS response, the attack worked and the test fails.
reply, _, _ := te.read(te.l2)
if reply != nil {
t.Error("Got NEIGHORS response for FINDNODE from wrong IP")
}
}
var AllTests = []utesting.Test{
{Name: "Ping/Basic", Fn: BasicPing},
{Name: "Ping/WrongTo", Fn: PingWrongTo},
{Name: "Ping/WrongFrom", Fn: PingWrongFrom},
{Name: "Ping/ExtraData", Fn: PingExtraData},
{Name: "Ping/ExtraDataWrongFrom", Fn: PingExtraDataWrongFrom},
{Name: "Ping/PastExpiration", Fn: PingPastExpiration},
{Name: "Ping/WrongPacketType", Fn: WrongPacketType},
{Name: "Ping/BondThenPingWithWrongFrom", Fn: BondThenPingWithWrongFrom},
{Name: "Findnode/WithoutEndpointProof", Fn: FindnodeWithoutEndpointProof},
{Name: "Findnode/BasicFindnode", Fn: BasicFindnode},
{Name: "Findnode/UnsolicitedNeighbors", Fn: UnsolicitedNeighbors},
{Name: "Findnode/PastExpiration", Fn: FindnodePastExpiration},
{Name: "Amplification/InvalidPongHash", Fn: FindnodeAmplificationInvalidPongHash},
{Name: "Amplification/WrongIP", Fn: FindnodeAmplificationWrongIP},
}

View File

@ -0,0 +1,123 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
package v4test
import (
"crypto/ecdsa"
"fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover/v4wire"
"github.com/ethereum/go-ethereum/p2p/enode"
)
const waitTime = 300 * time.Millisecond
type testenv struct {
l1, l2 net.PacketConn
key *ecdsa.PrivateKey
remote *enode.Node
remoteAddr *net.UDPAddr
}
func newTestEnv(remote string, listen1, listen2 string) *testenv {
l1, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen1))
if err != nil {
panic(err)
}
l2, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen2))
if err != nil {
panic(err)
}
key, err := crypto.GenerateKey()
if err != nil {
panic(err)
}
node, err := enode.Parse(enode.ValidSchemes, remote)
if err != nil {
panic(err)
}
if node.IP() == nil || node.UDP() == 0 {
var ip net.IP
var tcpPort, udpPort int
if ip = node.IP(); ip == nil {
ip = net.ParseIP("127.0.0.1")
}
if tcpPort = node.TCP(); tcpPort == 0 {
tcpPort = 30303
}
if udpPort = node.TCP(); udpPort == 0 {
udpPort = 30303
}
node = enode.NewV4(node.Pubkey(), ip, tcpPort, udpPort)
}
addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()}
return &testenv{l1, l2, key, node, addr}
}
func (te *testenv) close() {
te.l1.Close()
te.l2.Close()
}
func (te *testenv) send(c net.PacketConn, req v4wire.Packet) []byte {
packet, hash, err := v4wire.Encode(te.key, req)
if err != nil {
panic(fmt.Errorf("can't encode %v packet: %v", req.Name(), err))
}
if _, err := c.WriteTo(packet, te.remoteAddr); err != nil {
panic(fmt.Errorf("can't send %v: %v", req.Name(), err))
}
return hash
}
func (te *testenv) read(c net.PacketConn) (v4wire.Packet, []byte, error) {
buf := make([]byte, 2048)
if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
return nil, nil, err
}
n, _, err := c.ReadFrom(buf)
if err != nil {
return nil, nil, err
}
p, _, hash, err := v4wire.Decode(buf[:n])
return p, hash, err
}
func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint {
addr := c.LocalAddr().(*net.UDPAddr)
return v4wire.Endpoint{
IP: addr.IP.To4(),
UDP: uint16(addr.Port),
TCP: 0,
}
}
func (te *testenv) remoteEndpoint() v4wire.Endpoint {
return v4wire.NewEndpoint(te.remoteAddr, 0)
}
func contains(ns []v4wire.Node, key v4wire.Pubkey) bool {
for _, n := range ns {
if n.ID == key {
return true
}
}
return false
}

1
go.sum
View File

@ -207,6 +207,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -0,0 +1,190 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package utesting provides a standalone replacement for package testing.
//
// This package exists because package testing cannot easily be embedded into a
// standalone go program. It provides an API that mirrors the standard library
// testing API.
package utesting
import (
"bytes"
"fmt"
"io"
"regexp"
"runtime"
"sync"
"time"
)
// Test represents a single test.
type Test struct {
Name string
Fn func(*T)
}
// Result is the result of a test execution.
type Result struct {
Name string
Failed bool
Output string
Duration time.Duration
}
// MatchTests returns the tests whose name matches a regular expression.
func MatchTests(tests []Test, expr string) []Test {
var results []Test
re, err := regexp.Compile(expr)
if err != nil {
return nil
}
for _, test := range tests {
if re.MatchString(test.Name) {
results = append(results, test)
}
}
return results
}
// RunTests executes all given tests in order and returns their results.
// If the report writer is non-nil, a test report is written to it in real time.
func RunTests(tests []Test, report io.Writer) []Result {
results := make([]Result, len(tests))
for i, test := range tests {
start := time.Now()
results[i].Name = test.Name
results[i].Failed, results[i].Output = Run(test)
results[i].Duration = time.Since(start)
if report != nil {
printResult(results[i], report)
}
}
return results
}
func printResult(r Result, w io.Writer) {
pd := r.Duration.Truncate(100 * time.Microsecond)
if r.Failed {
fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd)
fmt.Fprintln(w, r.Output)
} else {
fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd)
}
}
// CountFailures returns the number of failed tests in the result slice.
func CountFailures(rr []Result) int {
count := 0
for _, r := range rr {
if r.Failed {
count++
}
}
return count
}
// Run executes a single test.
func Run(test Test) (bool, string) {
t := new(T)
done := make(chan struct{})
go func() {
defer close(done)
defer func() {
if err := recover(); err != nil {
buf := make([]byte, 4096)
i := runtime.Stack(buf, false)
t.Logf("panic: %v\n\n%s", err, buf[:i])
t.Fail()
}
}()
test.Fn(t)
}()
<-done
return t.failed, t.output.String()
}
// T is the value given to the test function. The test can signal failures
// and log output by calling methods on this object.
type T struct {
mu sync.Mutex
failed bool
output bytes.Buffer
}
// FailNow marks the test as having failed and stops its execution by calling
// runtime.Goexit (which then runs all deferred calls in the current goroutine).
func (t *T) FailNow() {
t.Fail()
runtime.Goexit()
}
// Fail marks the test as having failed but continues execution.
func (t *T) Fail() {
t.mu.Lock()
defer t.mu.Unlock()
t.failed = true
}
// Failed reports whether the test has failed.
func (t *T) Failed() bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.failed
}
// Log formats its arguments using default formatting, analogous to Println, and records
// the text in the error log.
func (t *T) Log(vs ...interface{}) {
t.mu.Lock()
defer t.mu.Unlock()
fmt.Fprintln(&t.output, vs...)
}
// Logf formats its arguments according to the format, analogous to Printf, and records
// the text in the error log. A final newline is added if not provided.
func (t *T) Logf(format string, vs ...interface{}) {
t.mu.Lock()
defer t.mu.Unlock()
if len(format) == 0 || format[len(format)-1] != '\n' {
format += "\n"
}
fmt.Fprintf(&t.output, format, vs...)
}
// Error is equivalent to Log followed by Fail.
func (t *T) Error(vs ...interface{}) {
t.Log(vs...)
t.Fail()
}
// Errorf is equivalent to Logf followed by Fail.
func (t *T) Errorf(format string, vs ...interface{}) {
t.Logf(format, vs...)
t.Fail()
}
// Fatal is equivalent to Log followed by FailNow.
func (t *T) Fatal(vs ...interface{}) {
t.Log(vs...)
t.FailNow()
}
// Fatalf is equivalent to Logf followed by FailNow.
func (t *T) Fatalf(format string, vs ...interface{}) {
t.Logf(format, vs...)
t.FailNow()
}

View File

@ -0,0 +1,55 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package utesting
import (
"strings"
"testing"
)
func TestTest(t *testing.T) {
tests := []Test{
{
Name: "successful test",
Fn: func(t *T) {},
},
{
Name: "failing test",
Fn: func(t *T) {
t.Log("output")
t.Error("failed")
},
},
{
Name: "panicking test",
Fn: func(t *T) {
panic("oh no")
},
},
}
results := RunTests(tests, nil)
if results[0].Failed || results[0].Output != "" {
t.Fatalf("wrong result for successful test: %#v", results[0])
}
if !results[1].Failed || results[1].Output != "output\nfailed\n" {
t.Fatalf("wrong result for failing test: %#v", results[1])
}
if !results[2].Failed || !strings.HasPrefix(results[2].Output, "panic: oh no\n") {
t.Fatalf("wrong result for panicking test: %#v", results[2])
}
}