diff --git a/p2p/protocols/accounting.go b/p2p/protocols/accounting.go
new file mode 100644
index 000000000..06a1a5845
--- /dev/null
+++ b/p2p/protocols/accounting.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package protocols
+
+import "github.com/ethereum/go-ethereum/metrics"
+
+//define some metrics
+var (
+ //NOTE: these metrics just define the interfaces and are currently *NOT persisted* over sessions
+ //All metrics are cumulative
+
+ //total amount of units credited
+ mBalanceCredit = metrics.NewRegisteredCounterForced("account.balance.credit", nil)
+ //total amount of units debited
+ mBalanceDebit = metrics.NewRegisteredCounterForced("account.balance.debit", nil)
+ //total amount of bytes credited
+ mBytesCredit = metrics.NewRegisteredCounterForced("account.bytes.credit", nil)
+ //total amount of bytes debited
+ mBytesDebit = metrics.NewRegisteredCounterForced("account.bytes.debit", nil)
+ //total amount of credited messages
+ mMsgCredit = metrics.NewRegisteredCounterForced("account.msg.credit", nil)
+ //total amount of debited messages
+ mMsgDebit = metrics.NewRegisteredCounterForced("account.msg.debit", nil)
+ //how many times local node had to drop remote peers
+ mPeerDrops = metrics.NewRegisteredCounterForced("account.peerdrops", nil)
+ //how many times local node overdrafted and dropped
+ mSelfDrops = metrics.NewRegisteredCounterForced("account.selfdrops", nil)
+)
+
+//Prices defines how prices are being passed on to the accounting instance
+type Prices interface {
+ //Return the Price for a message
+ Price(interface{}) *Price
+}
+
+type Payer bool
+
+const (
+ Sender = Payer(true)
+ Receiver = Payer(false)
+)
+
+//Price represents the costs of a message
+type Price struct {
+ Value uint64 //
+ PerByte bool //True if the price is per byte or for unit
+ Payer Payer
+}
+
+//For gives back the price for a message
+//A protocol provides the message price in absolute value
+//This method then returns the correct signed amount,
+//depending on who pays, which is identified by the `payer` argument:
+//`Send` will pass a `Sender` payer, `Receive` will pass the `Receiver` argument.
+//Thus: If Sending and sender pays, amount positive, otherwise negative
+//If Receiving, and receiver pays, amount positive, otherwise negative
+func (p *Price) For(payer Payer, size uint32) int64 {
+ price := p.Value
+ if p.PerByte {
+ price *= uint64(size)
+ }
+ if p.Payer == payer {
+ return 0 - int64(price)
+ }
+ return int64(price)
+}
+
+//Balance is the actual accounting instance
+//Balance defines the operations needed for accounting
+//Implementations internally maintain the balance for every peer
+type Balance interface {
+ //Adds amount to the local balance with remote node `peer`;
+ //positive amount = credit local node
+ //negative amount = debit local node
+ Add(amount int64, peer *Peer) error
+}
+
+//Accounting implements the Hook interface
+//It interfaces to the balances through the Balance interface,
+//while interfacing with protocols and its prices through the Prices interface
+type Accounting struct {
+ Balance //interface to accounting logic
+ Prices //interface to prices logic
+}
+
+func NewAccounting(balance Balance, po Prices) *Accounting {
+ ah := &Accounting{
+ Prices: po,
+ Balance: balance,
+ }
+ return ah
+}
+
+//Implement Hook.Send
+// Send takes a peer, a size and a msg and
+// - calculates the cost for the local node sending a msg of size to peer using the Prices interface
+// - credits/debits local node using balance interface
+func (ah *Accounting) Send(peer *Peer, size uint32, msg interface{}) error {
+ //get the price for a message (through the protocol spec)
+ price := ah.Price(msg)
+ //this message doesn't need accounting
+ if price == nil {
+ return nil
+ }
+ //evaluate the price for sending messages
+ costToLocalNode := price.For(Sender, size)
+ //do the accounting
+ err := ah.Add(costToLocalNode, peer)
+ //record metrics: just increase counters for user-facing metrics
+ ah.doMetrics(costToLocalNode, size, err)
+ return err
+}
+
+//Implement Hook.Receive
+// Receive takes a peer, a size and a msg and
+// - calculates the cost for the local node receiving a msg of size from peer using the Prices interface
+// - credits/debits local node using balance interface
+func (ah *Accounting) Receive(peer *Peer, size uint32, msg interface{}) error {
+ //get the price for a message (through the protocol spec)
+ price := ah.Price(msg)
+ //this message doesn't need accounting
+ if price == nil {
+ return nil
+ }
+ //evaluate the price for receiving messages
+ costToLocalNode := price.For(Receiver, size)
+ //do the accounting
+ err := ah.Add(costToLocalNode, peer)
+ //record metrics: just increase counters for user-facing metrics
+ ah.doMetrics(costToLocalNode, size, err)
+ return err
+}
+
+//record some metrics
+//this is not an error handling. `err` is returned by both `Send` and `Receive`
+//`err` will only be non-nil if a limit has been violated (overdraft), in which case the peer has been dropped.
+//if the limit has been violated and `err` is thus not nil:
+// * if the price is positive, local node has been credited; thus `err` implicitly signals the REMOTE has been dropped
+// * if the price is negative, local node has been debited, thus `err` implicitly signals LOCAL node "overdraft"
+func (ah *Accounting) doMetrics(price int64, size uint32, err error) {
+ if price > 0 {
+ mBalanceCredit.Inc(price)
+ mBytesCredit.Inc(int64(size))
+ mMsgCredit.Inc(1)
+ if err != nil {
+ //increase the number of times a remote node has been dropped due to "overdraft"
+ mPeerDrops.Inc(1)
+ }
+ } else {
+ mBalanceDebit.Inc(price)
+ mBytesDebit.Inc(int64(size))
+ mMsgDebit.Inc(1)
+ if err != nil {
+ //increase the number of times the local node has done an "overdraft" in respect to other nodes
+ mSelfDrops.Inc(1)
+ }
+ }
+}
diff --git a/p2p/protocols/accounting_simulation_test.go b/p2p/protocols/accounting_simulation_test.go
new file mode 100644
index 000000000..65b737abe
--- /dev/null
+++ b/p2p/protocols/accounting_simulation_test.go
@@ -0,0 +1,310 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package protocols
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/mattn/go-colorable"
+
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/rpc"
+
+ "github.com/ethereum/go-ethereum/node"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/simulations"
+ "github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+)
+
+const (
+ content = "123456789"
+)
+
+var (
+ nodes = flag.Int("nodes", 30, "number of nodes to create (default 30)")
+ msgs = flag.Int("msgs", 100, "number of messages sent by node (default 100)")
+ loglevel = flag.Int("loglevel", 0, "verbosity of logs")
+ rawlog = flag.Bool("rawlog", false, "remove terminal formatting from logs")
+)
+
+func init() {
+ flag.Parse()
+ log.PrintOrigins(true)
+ log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(!*rawlog))))
+}
+
+//TestAccountingSimulation runs a p2p/simulations simulation
+//It creates a *nodes number of nodes, connects each one with each other,
+//then sends out a random selection of messages up to *msgs amount of messages
+//from the test protocol spec.
+//The spec has some accounted messages defined through the Prices interface.
+//The test does accounting for all the message exchanged, and then checks
+//that every node has the same balance with a peer, but with opposite signs.
+//Balance(AwithB) = 0 - Balance(BwithA) or Abs|Balance(AwithB)| == Abs|Balance(BwithA)|
+func TestAccountingSimulation(t *testing.T) {
+ //setup the balances objects for every node
+ bal := newBalances(*nodes)
+ //define the node.Service for this test
+ services := adapters.Services{
+ "accounting": func(ctx *adapters.ServiceContext) (node.Service, error) {
+ return bal.newNode(), nil
+ },
+ }
+ //setup the simulation
+ adapter := adapters.NewSimAdapter(services)
+ net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{DefaultService: "accounting"})
+ defer net.Shutdown()
+
+ // we send msgs messages per node, wait for all messages to arrive
+ bal.wg.Add(*nodes * *msgs)
+ trigger := make(chan enode.ID)
+ go func() {
+ // wait for all of them to arrive
+ bal.wg.Wait()
+ // then trigger a check
+ // the selected node for the trigger is irrelevant,
+ // we just want to trigger the end of the simulation
+ trigger <- net.Nodes[0].ID()
+ }()
+
+ // create nodes and start them
+ for i := 0; i < *nodes; i++ {
+ conf := adapters.RandomNodeConfig()
+ bal.id2n[conf.ID] = i
+ if _, err := net.NewNodeWithConfig(conf); err != nil {
+ t.Fatal(err)
+ }
+ if err := net.Start(conf.ID); err != nil {
+ t.Fatal(err)
+ }
+ }
+ // fully connect nodes
+ for i, n := range net.Nodes {
+ for _, m := range net.Nodes[i+1:] {
+ if err := net.Connect(n.ID(), m.ID()); err != nil {
+ t.Fatal(err)
+ }
+ }
+ }
+
+ // empty action
+ action := func(ctx context.Context) error {
+ return nil
+ }
+ // check always checks out
+ check := func(ctx context.Context, id enode.ID) (bool, error) {
+ return true, nil
+ }
+
+ // run simulation
+ timeout := 30 * time.Second
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ result := simulations.NewSimulation(net).Run(ctx, &simulations.Step{
+ Action: action,
+ Trigger: trigger,
+ Expect: &simulations.Expectation{
+ Nodes: []enode.ID{net.Nodes[0].ID()},
+ Check: check,
+ },
+ })
+
+ if result.Error != nil {
+ t.Fatal(result.Error)
+ }
+
+ // check if balance matrix is symmetric
+ if err := bal.symmetric(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// matrix is a matrix of nodes and its balances
+// matrix is in fact a linear array of size n*n,
+// so the balance for any node A with B is at index
+// A*n + B, while the balance of node B with A is at
+// B*n + A
+// (n entries in the array will not be filled -
+// the balance of a node with itself)
+type matrix struct {
+ n int //number of nodes
+ m []int64 //array of balances
+}
+
+// create a new matrix
+func newMatrix(n int) *matrix {
+ return &matrix{
+ n: n,
+ m: make([]int64, n*n),
+ }
+}
+
+// called from the testBalance's Add accounting function: register balance change
+func (m *matrix) add(i, j int, v int64) error {
+ // index for the balance of local node i with remote nodde j is
+ // i * number of nodes + remote node
+ mi := i*m.n + j
+ // register that balance
+ m.m[mi] += v
+ return nil
+}
+
+// check that the balances are symmetric:
+// balance of node i with node j is the same as j with i but with inverted signs
+func (m *matrix) symmetric() error {
+ //iterate all nodes
+ for i := 0; i < m.n; i++ {
+ //iterate starting +1
+ for j := i + 1; j < m.n; j++ {
+ log.Debug("bal", "1", i, "2", j, "i,j", m.m[i*m.n+j], "j,i", m.m[j*m.n+i])
+ if m.m[i*m.n+j] != -m.m[j*m.n+i] {
+ return fmt.Errorf("value mismatch. m[%v, %v] = %v; m[%v, %v] = %v", i, j, m.m[i*m.n+j], j, i, m.m[j*m.n+i])
+ }
+ }
+ }
+ return nil
+}
+
+// all the balances
+type balances struct {
+ i int
+ *matrix
+ id2n map[enode.ID]int
+ wg *sync.WaitGroup
+}
+
+func newBalances(n int) *balances {
+ return &balances{
+ matrix: newMatrix(n),
+ id2n: make(map[enode.ID]int),
+ wg: &sync.WaitGroup{},
+ }
+}
+
+// create a new testNode for every node created as part of the service
+func (b *balances) newNode() *testNode {
+ defer func() { b.i++ }()
+ return &testNode{
+ bal: b,
+ i: b.i,
+ peers: make([]*testPeer, b.n), //a node will be connected to n-1 peers
+ }
+}
+
+type testNode struct {
+ bal *balances
+ i int
+ lock sync.Mutex
+ peers []*testPeer
+ peerCount int
+}
+
+// do the accounting for the peer's test protocol
+// testNode implements protocols.Balance
+func (t *testNode) Add(a int64, p *Peer) error {
+ //get the index for the remote peer
+ remote := t.bal.id2n[p.ID()]
+ log.Debug("add", "local", t.i, "remote", remote, "amount", a)
+ return t.bal.add(t.i, remote, a)
+}
+
+//run the p2p protocol
+//for every node, represented by testNode, create a remote testPeer
+func (t *testNode) run(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+ spec := createTestSpec()
+ //create accounting hook
+ spec.Hook = NewAccounting(t, &dummyPrices{})
+
+ //create a peer for this node
+ tp := &testPeer{NewPeer(p, rw, spec), t.i, t.bal.id2n[p.ID()], t.bal.wg}
+ t.lock.Lock()
+ t.peers[t.bal.id2n[p.ID()]] = tp
+ t.peerCount++
+ if t.peerCount == t.bal.n-1 {
+ //when all peer connections are established, start sending messages from this peer
+ go t.send()
+ }
+ t.lock.Unlock()
+ return tp.Run(tp.handle)
+}
+
+// p2p message receive handler function
+func (tp *testPeer) handle(ctx context.Context, msg interface{}) error {
+ tp.wg.Done()
+ log.Debug("receive", "from", tp.remote, "to", tp.local, "type", reflect.TypeOf(msg), "msg", msg)
+ return nil
+}
+
+type testPeer struct {
+ *Peer
+ local, remote int
+ wg *sync.WaitGroup
+}
+
+func (t *testNode) send() {
+ log.Debug("start sending")
+ for i := 0; i < *msgs; i++ {
+ //determine randomly to which peer to send
+ whom := rand.Intn(t.bal.n - 1)
+ if whom >= t.i {
+ whom++
+ }
+ t.lock.Lock()
+ p := t.peers[whom]
+ t.lock.Unlock()
+
+ //determine a random message from the spec's messages to be sent
+ which := rand.Intn(len(p.spec.Messages))
+ msg := p.spec.Messages[which]
+ switch msg.(type) {
+ case *perBytesMsgReceiverPays:
+ msg = &perBytesMsgReceiverPays{Content: content[:rand.Intn(len(content))]}
+ case *perBytesMsgSenderPays:
+ msg = &perBytesMsgSenderPays{Content: content[:rand.Intn(len(content))]}
+ }
+ log.Debug("send", "from", t.i, "to", whom, "type", reflect.TypeOf(msg), "msg", msg)
+ p.Send(context.TODO(), msg)
+ }
+}
+
+// define the protocol
+func (t *testNode) Protocols() []p2p.Protocol {
+ return []p2p.Protocol{{
+ Length: 100,
+ Run: t.run,
+ }}
+}
+
+func (t *testNode) APIs() []rpc.API {
+ return nil
+}
+
+func (t *testNode) Start(server *p2p.Server) error {
+ return nil
+}
+
+func (t *testNode) Stop() error {
+ return nil
+}
diff --git a/p2p/protocols/accounting_test.go b/p2p/protocols/accounting_test.go
new file mode 100644
index 000000000..3810ae2c9
--- /dev/null
+++ b/p2p/protocols/accounting_test.go
@@ -0,0 +1,223 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package protocols
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+//dummy Balance implementation
+type dummyBalance struct {
+ amount int64
+ peer *Peer
+}
+
+//dummy Prices implementation
+type dummyPrices struct{}
+
+//a dummy message which needs size based accounting
+//sender pays
+type perBytesMsgSenderPays struct {
+ Content string
+}
+
+//a dummy message which needs size based accounting
+//receiver pays
+type perBytesMsgReceiverPays struct {
+ Content string
+}
+
+//a dummy message which is paid for per unit
+//sender pays
+type perUnitMsgSenderPays struct{}
+
+//receiver pays
+type perUnitMsgReceiverPays struct{}
+
+//a dummy message which has zero as its price
+type zeroPriceMsg struct{}
+
+//a dummy message which has no accounting
+type nilPriceMsg struct{}
+
+//return the price for the defined messages
+func (d *dummyPrices) Price(msg interface{}) *Price {
+ switch msg.(type) {
+ //size based message cost, receiver pays
+ case *perBytesMsgReceiverPays:
+ return &Price{
+ PerByte: true,
+ Value: uint64(100),
+ Payer: Receiver,
+ }
+ //size based message cost, sender pays
+ case *perBytesMsgSenderPays:
+ return &Price{
+ PerByte: true,
+ Value: uint64(100),
+ Payer: Sender,
+ }
+ //unitary cost, receiver pays
+ case *perUnitMsgReceiverPays:
+ return &Price{
+ PerByte: false,
+ Value: uint64(99),
+ Payer: Receiver,
+ }
+ //unitary cost, sender pays
+ case *perUnitMsgSenderPays:
+ return &Price{
+ PerByte: false,
+ Value: uint64(99),
+ Payer: Sender,
+ }
+ case *zeroPriceMsg:
+ return &Price{
+ PerByte: false,
+ Value: uint64(0),
+ Payer: Sender,
+ }
+ case *nilPriceMsg:
+ return nil
+ }
+ return nil
+}
+
+//dummy accounting implementation, only stores values for later check
+func (d *dummyBalance) Add(amount int64, peer *Peer) error {
+ d.amount = amount
+ d.peer = peer
+ return nil
+}
+
+type testCase struct {
+ msg interface{}
+ size uint32
+ sendResult int64
+ recvResult int64
+}
+
+//lowest level unit test
+func TestBalance(t *testing.T) {
+ //create instances
+ balance := &dummyBalance{}
+ prices := &dummyPrices{}
+ //create the spec
+ spec := createTestSpec()
+ //create the accounting hook for the spec
+ acc := NewAccounting(balance, prices)
+ //create a peer
+ id := adapters.RandomNodeConfig().ID
+ p := p2p.NewPeer(id, "testPeer", nil)
+ peer := NewPeer(p, &dummyRW{}, spec)
+ //price depends on size, receiver pays
+ msg := &perBytesMsgReceiverPays{Content: "testBalance"}
+ size, _ := rlp.EncodeToBytes(msg)
+
+ testCases := []testCase{
+ {
+ msg,
+ uint32(len(size)),
+ int64(len(size) * 100),
+ int64(len(size) * -100),
+ },
+ {
+ &perBytesMsgSenderPays{Content: "testBalance"},
+ uint32(len(size)),
+ int64(len(size) * -100),
+ int64(len(size) * 100),
+ },
+ {
+ &perUnitMsgSenderPays{},
+ 0,
+ int64(-99),
+ int64(99),
+ },
+ {
+ &perUnitMsgReceiverPays{},
+ 0,
+ int64(99),
+ int64(-99),
+ },
+ {
+ &zeroPriceMsg{},
+ 0,
+ int64(0),
+ int64(0),
+ },
+ {
+ &nilPriceMsg{},
+ 0,
+ int64(0),
+ int64(0),
+ },
+ }
+ checkAccountingTestCases(t, testCases, acc, peer, balance, true)
+ checkAccountingTestCases(t, testCases, acc, peer, balance, false)
+}
+
+func checkAccountingTestCases(t *testing.T, cases []testCase, acc *Accounting, peer *Peer, balance *dummyBalance, send bool) {
+ for _, c := range cases {
+ var err error
+ var expectedResult int64
+ //reset balance before every check
+ balance.amount = 0
+ if send {
+ err = acc.Send(peer, c.size, c.msg)
+ expectedResult = c.sendResult
+ } else {
+ err = acc.Receive(peer, c.size, c.msg)
+ expectedResult = c.recvResult
+ }
+
+ checkResults(t, err, balance, peer, expectedResult)
+ }
+}
+
+func checkResults(t *testing.T, err error, balance *dummyBalance, peer *Peer, result int64) {
+ if err != nil {
+ t.Fatal(err)
+ }
+ if balance.peer != peer {
+ t.Fatalf("expected Add to be called with peer %v, got %v", peer, balance.peer)
+ }
+ if balance.amount != result {
+ t.Fatalf("Expected balance to be %d but is %d", result, balance.amount)
+ }
+}
+
+//create a test spec
+func createTestSpec() *Spec {
+ spec := &Spec{
+ Name: "test",
+ Version: 42,
+ MaxMsgSize: 10 * 1024,
+ Messages: []interface{}{
+ &perBytesMsgReceiverPays{},
+ &perBytesMsgSenderPays{},
+ &perUnitMsgReceiverPays{},
+ &perUnitMsgSenderPays{},
+ &zeroPriceMsg{},
+ &nilPriceMsg{},
+ },
+ }
+ return spec
+}
diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go
index 615f74b56..7dddd852f 100644
--- a/p2p/protocols/protocol.go
+++ b/p2p/protocols/protocol.go
@@ -122,6 +122,16 @@ type WrappedMsg struct {
Payload []byte
}
+//For accounting, the design is to allow the Spec to describe which and how its messages are priced
+//To access this functionality, we provide a Hook interface which will call accounting methods
+//NOTE: there could be more such (horizontal) hooks in the future
+type Hook interface {
+ //A hook for sending messages
+ Send(peer *Peer, size uint32, msg interface{}) error
+ //A hook for receiving messages
+ Receive(peer *Peer, size uint32, msg interface{}) error
+}
+
// Spec is a protocol specification including its name and version as well as
// the types of messages which are exchanged
type Spec struct {
@@ -141,6 +151,9 @@ type Spec struct {
// each message must have a single unique data type
Messages []interface{}
+ //hook for accounting (could be extended to multiple hooks in the future)
+ Hook Hook
+
initOnce sync.Once
codes map[reflect.Type]uint64
types map[uint64]reflect.Type
@@ -274,6 +287,15 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error {
Payload: r,
}
+ //if the accounting hook is set, call it
+ if p.spec.Hook != nil {
+ err := p.spec.Hook.Send(p, wmsg.Size, msg)
+ if err != nil {
+ p.Drop(err)
+ return err
+ }
+ }
+
code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
@@ -336,6 +358,14 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{})
return errorf(ErrDecode, "<= %v: %v", msg, err)
}
+ //if the accounting hook is set, call it
+ if p.spec.Hook != nil {
+ err := p.spec.Hook.Receive(p, wmsg.Size, val)
+ if err != nil {
+ return err
+ }
+ }
+
// call the registered handler callbacks
// a registered callback take the decoded message as argument as an interface
// which the handler is supposed to cast to the appropriate type
diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go
index 4755db3e6..2874af48d 100644
--- a/p2p/protocols/protocol_test.go
+++ b/p2p/protocols/protocol_test.go
@@ -17,12 +17,15 @@
package protocols
import (
+ "bytes"
"context"
"errors"
"fmt"
"testing"
"time"
+ "github.com/ethereum/go-ethereum/rlp"
+
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
@@ -185,6 +188,169 @@ func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) {
}
}
+type dummyHook struct {
+ peer *Peer
+ size uint32
+ msg interface{}
+ send bool
+ err error
+ waitC chan struct{}
+}
+
+type dummyMsg struct {
+ Content string
+}
+
+func (d *dummyHook) Send(peer *Peer, size uint32, msg interface{}) error {
+ d.peer = peer
+ d.size = size
+ d.msg = msg
+ d.send = true
+ return d.err
+}
+
+func (d *dummyHook) Receive(peer *Peer, size uint32, msg interface{}) error {
+ d.peer = peer
+ d.size = size
+ d.msg = msg
+ d.send = false
+ d.waitC <- struct{}{}
+ return d.err
+}
+
+func TestProtocolHook(t *testing.T) {
+ testHook := &dummyHook{
+ waitC: make(chan struct{}, 1),
+ }
+ spec := &Spec{
+ Name: "test",
+ Version: 42,
+ MaxMsgSize: 10 * 1024,
+ Messages: []interface{}{
+ dummyMsg{},
+ },
+ Hook: testHook,
+ }
+
+ runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+ peer := NewPeer(p, rw, spec)
+ ctx := context.TODO()
+ err := peer.Send(ctx, &dummyMsg{
+ Content: "handshake"})
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ handle := func(ctx context.Context, msg interface{}) error {
+ return nil
+ }
+
+ return peer.Run(handle)
+ }
+
+ conf := adapters.RandomNodeConfig()
+ tester := p2ptest.NewProtocolTester(t, conf.ID, 2, runFunc)
+ err := tester.TestExchanges(p2ptest.Exchange{
+ Expects: []p2ptest.Expect{
+ {
+ Code: 0,
+ Msg: &dummyMsg{Content: "handshake"},
+ Peer: tester.Nodes[0].ID(),
+ },
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" {
+ t.Fatal("Expected msg to be set, but it is not")
+ }
+ if !testHook.send {
+ t.Fatal("Expected a send message, but it is not")
+ }
+ if testHook.peer == nil || testHook.peer.ID() != tester.Nodes[0].ID() {
+ t.Fatal("Expected peer ID to be set correctly, but it is not")
+ }
+ if testHook.size != 11 { //11 is the length of the encoded message
+ t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
+ }
+
+ err = tester.TestExchanges(p2ptest.Exchange{
+ Triggers: []p2ptest.Trigger{
+ {
+ Code: 0,
+ Msg: &dummyMsg{Content: "response"},
+ Peer: tester.Nodes[1].ID(),
+ },
+ },
+ })
+
+ <-testHook.waitC
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" {
+ t.Fatal("Expected msg to be set, but it is not")
+ }
+ if testHook.send {
+ t.Fatal("Expected a send message, but it is not")
+ }
+ if testHook.peer == nil || testHook.peer.ID() != tester.Nodes[1].ID() {
+ t.Fatal("Expected peer ID to be set correctly, but it is not")
+ }
+ if testHook.size != 10 { //11 is the length of the encoded message
+ t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
+ }
+
+ testHook.err = fmt.Errorf("dummy error")
+ err = tester.TestExchanges(p2ptest.Exchange{
+ Triggers: []p2ptest.Trigger{
+ {
+ Code: 0,
+ Msg: &dummyMsg{Content: "response"},
+ Peer: tester.Nodes[1].ID(),
+ },
+ },
+ })
+
+ <-testHook.waitC
+
+ time.Sleep(100 * time.Millisecond)
+ err = tester.TestDisconnected(&p2ptest.Disconnect{tester.Nodes[1].ID(), testHook.err})
+ if err != nil {
+ t.Fatalf("Expected a specific disconnect error, but got different one: %v", err)
+ }
+
+}
+
+//We need to test that if the hook is not defined, then message infrastructure
+//(send,receive) still works
+func TestNoHook(t *testing.T) {
+ //create a test spec
+ spec := createTestSpec()
+ //a random node
+ id := adapters.RandomNodeConfig().ID
+ //a peer
+ p := p2p.NewPeer(id, "testPeer", nil)
+ rw := &dummyRW{}
+ peer := NewPeer(p, rw, spec)
+ ctx := context.TODO()
+ msg := &perBytesMsgSenderPays{Content: "testBalance"}
+ //send a message
+ err := peer.Send(ctx, msg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ //simulate receiving a message
+ rw.msg = msg
+ peer.handleIncoming(func(ctx context.Context, msg interface{}) error {
+ return nil
+ })
+ //all should just work and not result in any error
+}
+
func TestProtoHandshakeVersionMismatch(t *testing.T) {
runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error()))
}
@@ -386,3 +552,39 @@ func XTestMultiplePeersDropOther(t *testing.T) {
fmt.Errorf("subprotocol error"),
)
}
+
+//dummy implementation of a MsgReadWriter
+//this allows for quick and easy unit tests without
+//having to build up the complete protocol
+type dummyRW struct {
+ msg interface{}
+ size uint32
+ code uint64
+}
+
+func (d *dummyRW) WriteMsg(msg p2p.Msg) error {
+ return nil
+}
+
+func (d *dummyRW) ReadMsg() (p2p.Msg, error) {
+ enc := bytes.NewReader(d.getDummyMsg())
+ return p2p.Msg{
+ Code: d.code,
+ Size: d.size,
+ Payload: enc,
+ ReceivedAt: time.Now(),
+ }, nil
+}
+
+func (d *dummyRW) getDummyMsg() []byte {
+ r, _ := rlp.EncodeToBytes(d.msg)
+ var b bytes.Buffer
+ wmsg := WrappedMsg{
+ Context: b.Bytes(),
+ Size: uint32(len(r)),
+ Payload: r,
+ }
+ rr, _ := rlp.EncodeToBytes(wmsg)
+ d.size = uint32(len(rr))
+ return rr
+}