diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 8580c6121..99b0957ab 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -19,11 +19,14 @@ package main import ( "fmt" "net" + "os" "strings" "time" + "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" @@ -40,6 +43,7 @@ var ( discv4ResolveCommand, discv4ResolveJSONCommand, discv4CrawlCommand, + discv4TestCommand, }, } discv4PingCommand = cli.Command{ @@ -74,6 +78,12 @@ var ( Action: discv4Crawl, Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, } + discv4TestCommand = cli.Command{ + Name: "test", + Usage: "Runs tests against a node", + Action: discv4Test, + Flags: []cli.Flag{remoteEnodeFlag, testPatternFlag, testListen1Flag, testListen2Flag}, + } ) var ( @@ -98,6 +108,25 @@ var ( Usage: "Time limit for the crawl.", Value: 30 * time.Minute, } + remoteEnodeFlag = cli.StringFlag{ + Name: "remote", + Usage: "Enode of the remote node under test", + EnvVar: "REMOTE_ENODE", + } + testPatternFlag = cli.StringFlag{ + Name: "run", + Usage: "Pattern of test suite(s) to run", + } + testListen1Flag = cli.StringFlag{ + Name: "listen1", + Usage: "IP address of the first tester", + Value: v4test.Listen1, + } + testListen2Flag = cli.StringFlag{ + Name: "listen2", + Usage: "IP address of the second tester", + Value: v4test.Listen2, + } ) func discv4Ping(ctx *cli.Context) error { @@ -184,6 +213,28 @@ func discv4Crawl(ctx *cli.Context) error { return nil } +func discv4Test(ctx *cli.Context) error { + // Configure test package globals. + if !ctx.IsSet(remoteEnodeFlag.Name) { + return fmt.Errorf("Missing -%v", remoteEnodeFlag.Name) + } + v4test.Remote = ctx.String(remoteEnodeFlag.Name) + v4test.Listen1 = ctx.String(testListen1Flag.Name) + v4test.Listen2 = ctx.String(testListen2Flag.Name) + + // Filter and run test cases. + tests := v4test.AllTests + if ctx.IsSet(testPatternFlag.Name) { + tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name)) + } + results := utesting.RunTests(tests, os.Stdout) + if fails := utesting.CountFailures(results); fails > 0 { + return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests)) + } + fmt.Printf("%v/%v passed\n", len(tests), len(tests)) + return nil +} + // startV4 starts an ephemeral discovery V4 node. func startV4(ctx *cli.Context) *discover.UDPv4 { ln, config := makeDiscoveryConfig(ctx) diff --git a/cmd/devp2p/internal/v4test/discv4tests.go b/cmd/devp2p/internal/v4test/discv4tests.go new file mode 100644 index 000000000..140b96bfa --- /dev/null +++ b/cmd/devp2p/internal/v4test/discv4tests.go @@ -0,0 +1,467 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package v4test + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "reflect" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" +) + +const ( + expiration = 20 * time.Second + wrongPacket = 66 + macSize = 256 / 8 +) + +var ( + // Remote node under test + Remote string + // IP where the first tester is listening, port will be assigned + Listen1 string = "127.0.0.1" + // IP where the second tester is listening, port will be assigned + // Before running the test, you may have to `sudo ifconfig lo0 add 127.0.0.2` (on MacOS at least) + Listen2 string = "127.0.0.2" +) + +type pingWithJunk struct { + Version uint + From, To v4wire.Endpoint + Expiration uint64 + JunkData1 uint + JunkData2 []byte +} + +func (req *pingWithJunk) Name() string { return "PING/v4" } +func (req *pingWithJunk) Kind() byte { return v4wire.PingPacket } + +type pingWrongType struct { + Version uint + From, To v4wire.Endpoint + Expiration uint64 +} + +func (req *pingWrongType) Name() string { return "WRONG/v4" } +func (req *pingWrongType) Kind() byte { return wrongPacket } + +func futureExpiration() uint64 { + return uint64(time.Now().Add(expiration).Unix()) +} + +// This test just sends a PING packet and expects a response. +func BasicPing(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// checkPong verifies that reply is a valid PONG matching the given ping hash. +func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error { + if reply == nil || reply.Kind() != v4wire.PongPacket { + return fmt.Errorf("expected PONG reply, got %v", reply) + } + pong := reply.(*v4wire.Pong) + if !bytes.Equal(pong.ReplyTok, pingHash) { + return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash) + } + wantEndpoint := te.localEndpoint(te.l1) + if !reflect.DeepEqual(pong.To, wantEndpoint) { + return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint) + } + if v4wire.Expired(pong.Expiration) { + return fmt.Errorf("PONG is expired (%v)", pong.Expiration) + } + return nil +} + +// This test sends a PING packet with wrong 'to' field and expects a PONG response. +func PingWrongTo(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: wrongEndpoint, + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with wrong 'from' field and expects a PONG response. +func PingWrongFrom(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: wrongEndpoint, + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with additional data at the end and expects a PONG +// response. The remote node should respond because EIP-8 mandates ignoring additional +// trailing data. +func PingExtraData(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + pingHash := te.send(te.l1, &pingWithJunk{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + JunkData1: 42, + JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1}, + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with additional data and wrong 'from' field +// and expects a PONG response. +func PingExtraDataWrongFrom(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + req := pingWithJunk{ + Version: 4, + From: wrongEndpoint, + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + JunkData1: 42, + JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1}, + } + pingHash := te.send(te.l1, &req) + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with an expiration in the past. +// The remote node should not respond. +func PingPastExpiration(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: -futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if reply != nil { + t.Fatal("Expected no reply, got", reply) + } +} + +// This test sends an invalid packet. The remote node should not respond. +func WrongPacketType(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + te.send(te.l1, &pingWrongType{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if reply != nil { + t.Fatal("Expected no reply, got", reply) + } +} + +// This test verifies that the default behaviour of ignoring 'from' fields is unaffected by +// the bonding process. After bonding, it pings the target with a different from endpoint. +func BondThenPingWithWrongFrom(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: wrongEndpoint, + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test just sends FINDNODE. The remote node should not reply +// because the endpoint proof has not completed. +func FindnodeWithoutEndpointProof(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + req := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(req.Target[:]) + te.send(te.l1, &req) + + reply, _, _ := te.read(te.l1) + if reply != nil { + t.Fatal("Expected no response, got", reply) + } +} + +// BasicFindnode sends a FINDNODE request after performing the endpoint +// proof. The remote node should respond. +func BasicFindnode(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + findnode := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l1, &findnode) + + reply, _, err := te.read(te.l1) + if err != nil { + t.Fatal("read find nodes", err) + } + if reply.Kind() != v4wire.NeighborsPacket { + t.Fatal("Expected neighbors, got", reply.Name()) + } +} + +// This test sends an unsolicited NEIGHBORS packet after the endpoint proof, then sends +// FINDNODE to read the remote table. The remote node should not return the node contained +// in the unsolicited NEIGHBORS packet. +func UnsolicitedNeighbors(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + // Send unsolicited NEIGHBORS response. + fakeKey, _ := crypto.GenerateKey() + encFakeKey := v4wire.EncodePubkey(&fakeKey.PublicKey) + neighbors := v4wire.Neighbors{ + Expiration: futureExpiration(), + Nodes: []v4wire.Node{{ + ID: encFakeKey, + IP: net.IP{1, 2, 3, 4}, + UDP: 30303, + TCP: 30303, + }}, + } + te.send(te.l1, &neighbors) + + // Check if the remote node included the fake node. + te.send(te.l1, &v4wire.Findnode{ + Expiration: futureExpiration(), + Target: encFakeKey, + }) + + reply, _, err := te.read(te.l1) + if err != nil { + t.Fatal("read find nodes", err) + } + if reply.Kind() != v4wire.NeighborsPacket { + t.Fatal("Expected neighbors, got", reply.Name()) + } + nodes := reply.(*v4wire.Neighbors).Nodes + if contains(nodes, encFakeKey) { + t.Fatal("neighbors response contains node from earlier unsolicited neighbors response") + } +} + +// This test sends FINDNODE with an expiration timestamp in the past. +// The remote node should not respond. +func FindnodePastExpiration(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + findnode := v4wire.Findnode{Expiration: -futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l1, &findnode) + + for { + reply, _, _ := te.read(te.l1) + if reply == nil { + return + } else if reply.Kind() == v4wire.NeighborsPacket { + t.Fatal("Unexpected NEIGHBORS response for expired FINDNODE request") + } + } +} + +// bond performs the endpoint proof with the remote node. +func bond(t *utesting.T, te *testenv) { + te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + var gotPing, gotPong bool + for !gotPing || !gotPong { + req, hash, err := te.read(te.l1) + if err != nil { + t.Fatal(err) + } + switch req.(type) { + case *v4wire.Ping: + te.send(te.l1, &v4wire.Pong{ + To: te.remoteEndpoint(), + ReplyTok: hash, + Expiration: futureExpiration(), + }) + gotPing = true + case *v4wire.Pong: + // TODO: maybe verify pong data here + gotPong = true + } + } +} + +// This test attempts to perform a traffic amplification attack against a +// 'victim' endpoint using FINDNODE. In this attack scenario, the attacker +// attempts to complete the endpoint proof non-interactively by sending a PONG +// with mismatching reply token from the 'victim' endpoint. The attack works if +// the remote node does not verify the PONG reply token field correctly. The +// attacker could then perform traffic amplification by sending many FINDNODE +// requests to the discovery node, which would reply to the 'victim' address. +func FindnodeAmplificationInvalidPongHash(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + // Send PING to start endpoint verification. + te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + var gotPing, gotPong bool + for !gotPing || !gotPong { + req, _, err := te.read(te.l1) + if err != nil { + t.Fatal(err) + } + switch req.(type) { + case *v4wire.Ping: + // Send PONG from this node ID, but with invalid ReplyTok. + te.send(te.l1, &v4wire.Pong{ + To: te.remoteEndpoint(), + ReplyTok: make([]byte, macSize), + Expiration: futureExpiration(), + }) + gotPing = true + case *v4wire.Pong: + gotPong = true + } + } + + // Now send FINDNODE. The remote node should not respond because our + // PONG did not reference the PING hash. + findnode := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l1, &findnode) + + // If we receive a NEIGHBORS response, the attack worked and the test fails. + reply, _, _ := te.read(te.l1) + if reply != nil && reply.Kind() == v4wire.NeighborsPacket { + t.Error("Got neighbors") + } +} + +// This test attempts to perform a traffic amplification attack using FINDNODE. +// The attack works if the remote node does not verify the IP address of FINDNODE +// against the endpoint verification proof done by PING/PONG. +func FindnodeAmplificationWrongIP(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + // Do the endpoint proof from the l1 IP. + bond(t, te) + + // Now send FINDNODE from the same node ID, but different IP address. + // The remote node should not respond. + findnode := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l2, &findnode) + + // If we receive a NEIGHBORS response, the attack worked and the test fails. + reply, _, _ := te.read(te.l2) + if reply != nil { + t.Error("Got NEIGHORS response for FINDNODE from wrong IP") + } +} + +var AllTests = []utesting.Test{ + {Name: "Ping/Basic", Fn: BasicPing}, + {Name: "Ping/WrongTo", Fn: PingWrongTo}, + {Name: "Ping/WrongFrom", Fn: PingWrongFrom}, + {Name: "Ping/ExtraData", Fn: PingExtraData}, + {Name: "Ping/ExtraDataWrongFrom", Fn: PingExtraDataWrongFrom}, + {Name: "Ping/PastExpiration", Fn: PingPastExpiration}, + {Name: "Ping/WrongPacketType", Fn: WrongPacketType}, + {Name: "Ping/BondThenPingWithWrongFrom", Fn: BondThenPingWithWrongFrom}, + {Name: "Findnode/WithoutEndpointProof", Fn: FindnodeWithoutEndpointProof}, + {Name: "Findnode/BasicFindnode", Fn: BasicFindnode}, + {Name: "Findnode/UnsolicitedNeighbors", Fn: UnsolicitedNeighbors}, + {Name: "Findnode/PastExpiration", Fn: FindnodePastExpiration}, + {Name: "Amplification/InvalidPongHash", Fn: FindnodeAmplificationInvalidPongHash}, + {Name: "Amplification/WrongIP", Fn: FindnodeAmplificationWrongIP}, +} diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go new file mode 100644 index 000000000..928659418 --- /dev/null +++ b/cmd/devp2p/internal/v4test/framework.go @@ -0,0 +1,123 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package v4test + +import ( + "crypto/ecdsa" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const waitTime = 300 * time.Millisecond + +type testenv struct { + l1, l2 net.PacketConn + key *ecdsa.PrivateKey + remote *enode.Node + remoteAddr *net.UDPAddr +} + +func newTestEnv(remote string, listen1, listen2 string) *testenv { + l1, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen1)) + if err != nil { + panic(err) + } + l2, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen2)) + if err != nil { + panic(err) + } + key, err := crypto.GenerateKey() + if err != nil { + panic(err) + } + node, err := enode.Parse(enode.ValidSchemes, remote) + if err != nil { + panic(err) + } + if node.IP() == nil || node.UDP() == 0 { + var ip net.IP + var tcpPort, udpPort int + if ip = node.IP(); ip == nil { + ip = net.ParseIP("127.0.0.1") + } + if tcpPort = node.TCP(); tcpPort == 0 { + tcpPort = 30303 + } + if udpPort = node.TCP(); udpPort == 0 { + udpPort = 30303 + } + node = enode.NewV4(node.Pubkey(), ip, tcpPort, udpPort) + } + addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()} + return &testenv{l1, l2, key, node, addr} +} + +func (te *testenv) close() { + te.l1.Close() + te.l2.Close() +} + +func (te *testenv) send(c net.PacketConn, req v4wire.Packet) []byte { + packet, hash, err := v4wire.Encode(te.key, req) + if err != nil { + panic(fmt.Errorf("can't encode %v packet: %v", req.Name(), err)) + } + if _, err := c.WriteTo(packet, te.remoteAddr); err != nil { + panic(fmt.Errorf("can't send %v: %v", req.Name(), err)) + } + return hash +} + +func (te *testenv) read(c net.PacketConn) (v4wire.Packet, []byte, error) { + buf := make([]byte, 2048) + if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil { + return nil, nil, err + } + n, _, err := c.ReadFrom(buf) + if err != nil { + return nil, nil, err + } + p, _, hash, err := v4wire.Decode(buf[:n]) + return p, hash, err +} + +func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint { + addr := c.LocalAddr().(*net.UDPAddr) + return v4wire.Endpoint{ + IP: addr.IP.To4(), + UDP: uint16(addr.Port), + TCP: 0, + } +} + +func (te *testenv) remoteEndpoint() v4wire.Endpoint { + return v4wire.NewEndpoint(te.remoteAddr, 0) +} + +func contains(ns []v4wire.Node, key v4wire.Pubkey) bool { + for _, n := range ns { + if n.ID == key { + return true + } + } + return false +} diff --git a/go.sum b/go.sum index 0fc59d736..7ef5af8a5 100644 --- a/go.sum +++ b/go.sum @@ -207,6 +207,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/utesting/utesting.go b/internal/utesting/utesting.go new file mode 100644 index 000000000..23c748cae --- /dev/null +++ b/internal/utesting/utesting.go @@ -0,0 +1,190 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package utesting provides a standalone replacement for package testing. +// +// This package exists because package testing cannot easily be embedded into a +// standalone go program. It provides an API that mirrors the standard library +// testing API. +package utesting + +import ( + "bytes" + "fmt" + "io" + "regexp" + "runtime" + "sync" + "time" +) + +// Test represents a single test. +type Test struct { + Name string + Fn func(*T) +} + +// Result is the result of a test execution. +type Result struct { + Name string + Failed bool + Output string + Duration time.Duration +} + +// MatchTests returns the tests whose name matches a regular expression. +func MatchTests(tests []Test, expr string) []Test { + var results []Test + re, err := regexp.Compile(expr) + if err != nil { + return nil + } + for _, test := range tests { + if re.MatchString(test.Name) { + results = append(results, test) + } + } + return results +} + +// RunTests executes all given tests in order and returns their results. +// If the report writer is non-nil, a test report is written to it in real time. +func RunTests(tests []Test, report io.Writer) []Result { + results := make([]Result, len(tests)) + for i, test := range tests { + start := time.Now() + results[i].Name = test.Name + results[i].Failed, results[i].Output = Run(test) + results[i].Duration = time.Since(start) + if report != nil { + printResult(results[i], report) + } + } + return results +} + +func printResult(r Result, w io.Writer) { + pd := r.Duration.Truncate(100 * time.Microsecond) + if r.Failed { + fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd) + fmt.Fprintln(w, r.Output) + } else { + fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd) + } +} + +// CountFailures returns the number of failed tests in the result slice. +func CountFailures(rr []Result) int { + count := 0 + for _, r := range rr { + if r.Failed { + count++ + } + } + return count +} + +// Run executes a single test. +func Run(test Test) (bool, string) { + t := new(T) + done := make(chan struct{}) + go func() { + defer close(done) + defer func() { + if err := recover(); err != nil { + buf := make([]byte, 4096) + i := runtime.Stack(buf, false) + t.Logf("panic: %v\n\n%s", err, buf[:i]) + t.Fail() + } + }() + test.Fn(t) + }() + <-done + return t.failed, t.output.String() +} + +// T is the value given to the test function. The test can signal failures +// and log output by calling methods on this object. +type T struct { + mu sync.Mutex + failed bool + output bytes.Buffer +} + +// FailNow marks the test as having failed and stops its execution by calling +// runtime.Goexit (which then runs all deferred calls in the current goroutine). +func (t *T) FailNow() { + t.Fail() + runtime.Goexit() +} + +// Fail marks the test as having failed but continues execution. +func (t *T) Fail() { + t.mu.Lock() + defer t.mu.Unlock() + t.failed = true +} + +// Failed reports whether the test has failed. +func (t *T) Failed() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.failed +} + +// Log formats its arguments using default formatting, analogous to Println, and records +// the text in the error log. +func (t *T) Log(vs ...interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + fmt.Fprintln(&t.output, vs...) +} + +// Logf formats its arguments according to the format, analogous to Printf, and records +// the text in the error log. A final newline is added if not provided. +func (t *T) Logf(format string, vs ...interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + if len(format) == 0 || format[len(format)-1] != '\n' { + format += "\n" + } + fmt.Fprintf(&t.output, format, vs...) +} + +// Error is equivalent to Log followed by Fail. +func (t *T) Error(vs ...interface{}) { + t.Log(vs...) + t.Fail() +} + +// Errorf is equivalent to Logf followed by Fail. +func (t *T) Errorf(format string, vs ...interface{}) { + t.Logf(format, vs...) + t.Fail() +} + +// Fatal is equivalent to Log followed by FailNow. +func (t *T) Fatal(vs ...interface{}) { + t.Log(vs...) + t.FailNow() +} + +// Fatalf is equivalent to Logf followed by FailNow. +func (t *T) Fatalf(format string, vs ...interface{}) { + t.Logf(format, vs...) + t.FailNow() +} diff --git a/internal/utesting/utesting_test.go b/internal/utesting/utesting_test.go new file mode 100644 index 000000000..1403a5c8f --- /dev/null +++ b/internal/utesting/utesting_test.go @@ -0,0 +1,55 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package utesting + +import ( + "strings" + "testing" +) + +func TestTest(t *testing.T) { + tests := []Test{ + { + Name: "successful test", + Fn: func(t *T) {}, + }, + { + Name: "failing test", + Fn: func(t *T) { + t.Log("output") + t.Error("failed") + }, + }, + { + Name: "panicking test", + Fn: func(t *T) { + panic("oh no") + }, + }, + } + results := RunTests(tests, nil) + + if results[0].Failed || results[0].Output != "" { + t.Fatalf("wrong result for successful test: %#v", results[0]) + } + if !results[1].Failed || results[1].Output != "output\nfailed\n" { + t.Fatalf("wrong result for failing test: %#v", results[1]) + } + if !results[2].Failed || !strings.HasPrefix(results[2].Output, "panic: oh no\n") { + t.Fatalf("wrong result for panicking test: %#v", results[2]) + } +}