forked from cerc-io/plugeth
p2p: use RLPx frames for messaging
This commit is contained in:
parent
51e01cceca
commit
736e632215
@ -32,14 +32,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
*frameRW
|
MsgReadWriter
|
||||||
*protoHandshake
|
*protoHandshake
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(fd net.Conn, hs *protoHandshake) *conn {
|
|
||||||
return &conn{newFrameRW(fd, msgWriteTimeout), hs}
|
|
||||||
}
|
|
||||||
|
|
||||||
// encHandshake contains the state of the encryption handshake.
|
// encHandshake contains the state of the encryption handshake.
|
||||||
type encHandshake struct {
|
type encHandshake struct {
|
||||||
remoteID discover.NodeID
|
remoteID discover.NodeID
|
||||||
@ -115,17 +111,16 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
|
|||||||
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
// Run the protocol handshake using authenticated messages.
|
||||||
// TODO: move buffering setup here (out of newFrameRW)
|
// TODO: move buffering setup here (out of newFrameRW)
|
||||||
phsrw := newRlpxFrameRW(fd, secrets)
|
rw := newRlpxFrameRW(fd, secrets)
|
||||||
rhs, err := readProtocolHandshake(phsrw, our)
|
rhs, err := readProtocolHandshake(rw, our)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := writeProtocolHandshake(phsrw, our); err != nil {
|
// TODO: validate that handshake node ID matches
|
||||||
|
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||||
}
|
}
|
||||||
|
return &conn{&lockedRW{wrapped: rw}, rhs}, nil
|
||||||
rw := newFrameRW(fd, msgWriteTimeout)
|
|
||||||
return &conn{rw, rhs}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||||
@ -136,20 +131,18 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
|
|||||||
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
// Run the protocol handshake using authenticated messages.
|
||||||
// TODO: move buffering setup here (out of newFrameRW)
|
// TODO: move buffering setup here (out of newFrameRW)
|
||||||
phsrw := newRlpxFrameRW(fd, secrets)
|
rw := newRlpxFrameRW(fd, secrets)
|
||||||
if err := writeProtocolHandshake(phsrw, our); err != nil {
|
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||||
}
|
}
|
||||||
rhs, err := readProtocolHandshake(phsrw, our)
|
rhs, err := readProtocolHandshake(rw, our)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("protocol handshake read error: %v", err)
|
return nil, fmt.Errorf("protocol handshake read error: %v", err)
|
||||||
}
|
}
|
||||||
if rhs.ID != dial.ID {
|
if rhs.ID != dial.ID {
|
||||||
return nil, errors.New("dialed node id mismatch")
|
return nil, errors.New("dialed node id mismatch")
|
||||||
}
|
}
|
||||||
|
return &conn{&lockedRW{wrapped: rw}, rhs}, nil
|
||||||
rw := newFrameRW(fd, msgWriteTimeout)
|
|
||||||
return &conn{rw, rhs}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// outboundEncHandshake negotiates a session token on conn.
|
// outboundEncHandshake negotiates a session token on conn.
|
||||||
|
@ -119,6 +119,25 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
|
|||||||
return w.WriteMsg(NewMsg(code, data...))
|
return w.WriteMsg(NewMsg(code, data...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// lockedRW wraps a MsgReadWriter with locks around
|
||||||
|
// ReadMsg and WriteMsg.
|
||||||
|
type lockedRW struct {
|
||||||
|
rmu, wmu sync.Mutex
|
||||||
|
wrapped MsgReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *lockedRW) ReadMsg() (Msg, error) {
|
||||||
|
rw.rmu.Lock()
|
||||||
|
defer rw.rmu.Unlock()
|
||||||
|
return rw.wrapped.ReadMsg()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *lockedRW) WriteMsg(msg Msg) error {
|
||||||
|
rw.wmu.Lock()
|
||||||
|
defer rw.wmu.Unlock()
|
||||||
|
return rw.wrapped.WriteMsg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
|
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
|
||||||
// As required by the interface, ReadMsg and WriteMsg can be called from
|
// As required by the interface, ReadMsg and WriteMsg can be called from
|
||||||
// multiple goroutines.
|
// multiple goroutines.
|
||||||
|
21
p2p/peer.go
21
p2p/peer.go
@ -40,6 +40,7 @@ type Peer struct {
|
|||||||
// Use them to display messages related to the peer.
|
// Use them to display messages related to the peer.
|
||||||
*logger.Logger
|
*logger.Logger
|
||||||
|
|
||||||
|
conn net.Conn
|
||||||
rw *conn
|
rw *conn
|
||||||
running map[string]*protoRW
|
running map[string]*protoRW
|
||||||
|
|
||||||
@ -52,8 +53,9 @@ type Peer struct {
|
|||||||
// NewPeer returns a peer for testing purposes.
|
// NewPeer returns a peer for testing purposes.
|
||||||
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||||
pipe, _ := net.Pipe()
|
pipe, _ := net.Pipe()
|
||||||
conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
|
msgpipe, _ := MsgPipe()
|
||||||
peer := newPeer(conn, nil)
|
conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
|
||||||
|
peer := newPeer(pipe, conn, nil)
|
||||||
close(peer.closed) // ensures Disconnect doesn't block
|
close(peer.closed) // ensures Disconnect doesn't block
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
@ -76,12 +78,12 @@ func (p *Peer) Caps() []Cap {
|
|||||||
|
|
||||||
// RemoteAddr returns the remote address of the network connection.
|
// RemoteAddr returns the remote address of the network connection.
|
||||||
func (p *Peer) RemoteAddr() net.Addr {
|
func (p *Peer) RemoteAddr() net.Addr {
|
||||||
return p.rw.RemoteAddr()
|
return p.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the local address of the network connection.
|
// LocalAddr returns the local address of the network connection.
|
||||||
func (p *Peer) LocalAddr() net.Addr {
|
func (p *Peer) LocalAddr() net.Addr {
|
||||||
return p.rw.LocalAddr()
|
return p.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect terminates the peer connection with the given reason.
|
// Disconnect terminates the peer connection with the given reason.
|
||||||
@ -98,10 +100,11 @@ func (p *Peer) String() string {
|
|||||||
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
|
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
|
||||||
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
|
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr())
|
||||||
p := &Peer{
|
p := &Peer{
|
||||||
Logger: logger.NewLogger(logtag),
|
Logger: logger.NewLogger(logtag),
|
||||||
|
conn: fd,
|
||||||
rw: conn,
|
rw: conn,
|
||||||
running: matchProtocols(protocols, conn.Caps, conn),
|
running: matchProtocols(protocols, conn.Caps, conn),
|
||||||
disc: make(chan DiscReason),
|
disc: make(chan DiscReason),
|
||||||
@ -138,7 +141,7 @@ loop:
|
|||||||
// We rely on protocols to abort if there is a write error. It
|
// We rely on protocols to abort if there is a write error. It
|
||||||
// might be more robust to handle them here as well.
|
// might be more robust to handle them here as well.
|
||||||
p.DebugDetailf("Read error: %v\n", err)
|
p.DebugDetailf("Read error: %v\n", err)
|
||||||
p.rw.Close()
|
p.conn.Close()
|
||||||
return DiscNetworkError
|
return DiscNetworkError
|
||||||
case err := <-p.protoErr:
|
case err := <-p.protoErr:
|
||||||
reason = discReasonForError(err)
|
reason = discReasonForError(err)
|
||||||
@ -161,14 +164,14 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
|
|||||||
EncodeMsg(p.rw, discMsg, uint(reason))
|
EncodeMsg(p.rw, discMsg, uint(reason))
|
||||||
// Wait for the other side to close the connection.
|
// Wait for the other side to close the connection.
|
||||||
// Discard any data that they send until then.
|
// Discard any data that they send until then.
|
||||||
io.Copy(ioutil.Discard, p.rw)
|
io.Copy(ioutil.Discard, p.conn)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(disconnectGracePeriod):
|
case <-time.After(disconnectGracePeriod):
|
||||||
}
|
}
|
||||||
p.rw.Close()
|
p.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) readLoop() error {
|
func (p *Peer) readLoop() error {
|
||||||
|
@ -3,6 +3,7 @@ package p2p
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -29,8 +30,8 @@ var discard = Protocol{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
|
func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) {
|
||||||
fd1, fd2 := net.Pipe()
|
fd1, _ := net.Pipe()
|
||||||
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||||
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||||
for _, p := range protos {
|
for _, p := range protos {
|
||||||
@ -38,11 +39,12 @@ func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
|
|||||||
hs2.Caps = append(hs2.Caps, p.cap())
|
hs2.Caps = append(hs2.Caps, p.cap())
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := newPeer(newConn(fd1, hs1), protos)
|
p1, p2 := MsgPipe()
|
||||||
|
peer := newPeer(fd1, &conn{p1, hs1}, protos)
|
||||||
errc := make(chan DiscReason, 1)
|
errc := make(chan DiscReason, 1)
|
||||||
go func() { errc <- peer.run() }()
|
go func() { errc <- peer.run() }()
|
||||||
|
|
||||||
return newConn(fd2, hs2), peer, errc
|
return p1, &conn{p2, hs2}, peer, errc
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerProtoReadMsg(t *testing.T) {
|
func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
@ -67,8 +69,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rw, _, errc := testPeer([]Protocol{proto})
|
closer, rw, _, errc := testPeer([]Protocol{proto})
|
||||||
defer rw.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
EncodeMsg(rw, baseProtocolLength+2, 1)
|
EncodeMsg(rw, baseProtocolLength+2, 1)
|
||||||
EncodeMsg(rw, baseProtocolLength+3, 2)
|
EncodeMsg(rw, baseProtocolLength+3, 2)
|
||||||
@ -105,8 +107,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rw, _, errc := testPeer([]Protocol{proto})
|
closer, rw, _, errc := testPeer([]Protocol{proto})
|
||||||
defer rw.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
EncodeMsg(rw, 18, make([]byte, msgsize))
|
EncodeMsg(rw, 18, make([]byte, msgsize))
|
||||||
select {
|
select {
|
||||||
@ -134,8 +136,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rw, _, _ := testPeer([]Protocol{proto})
|
closer, rw, _, _ := testPeer([]Protocol{proto})
|
||||||
defer rw.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@ -145,8 +147,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
|||||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, peer, peerErr := testPeer([]Protocol{discard})
|
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
|
||||||
defer rw.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
// test write errors
|
// test write errors
|
||||||
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
||||||
@ -181,8 +183,8 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
|||||||
func TestPeerPing(t *testing.T) {
|
func TestPeerPing(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, _, _ := testPeer(nil)
|
closer, rw, _, _ := testPeer(nil)
|
||||||
defer rw.Close()
|
defer closer.Close()
|
||||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -194,15 +196,15 @@ func TestPeerPing(t *testing.T) {
|
|||||||
func TestPeerDisconnect(t *testing.T) {
|
func TestPeerDisconnect(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, _, disc := testPeer(nil)
|
closer, rw, _, disc := testPeer(nil)
|
||||||
defer rw.Close()
|
defer closer.Close()
|
||||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
rw.Close() // make test end faster
|
closer.Close() // make test end faster
|
||||||
if reason := <-disc; reason != DiscRequested {
|
if reason := <-disc; reason != DiscRequested {
|
||||||
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
|
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
|
||||||
}
|
}
|
||||||
|
@ -358,14 +358,15 @@ func (srv *Server) findPeers() {
|
|||||||
|
|
||||||
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
||||||
// TODO: handle/store session token
|
// TODO: handle/store session token
|
||||||
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
// TODO: reenable deadlines
|
||||||
|
// fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||||
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
|
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fd.Close()
|
fd.Close()
|
||||||
srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p := newPeer(conn, srv.Protocols)
|
p := newPeer(fd, conn, srv.Protocols)
|
||||||
if ok, reason := srv.addPeer(conn.ID, p); !ok {
|
if ok, reason := srv.addPeer(conn.ID, p); !ok {
|
||||||
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
||||||
p.politeDisconnect(reason)
|
p.politeDisconnect(reason)
|
||||||
@ -375,7 +376,7 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
|||||||
srvlog.Debugf("Added %v\n", p)
|
srvlog.Debugf("Added %v\n", p)
|
||||||
srvjslog.LogJson(&logger.P2PConnected{
|
srvjslog.LogJson(&logger.P2PConnected{
|
||||||
RemoteId: fmt.Sprintf("%x", conn.ID[:]),
|
RemoteId: fmt.Sprintf("%x", conn.ID[:]),
|
||||||
RemoteAddress: conn.RemoteAddr().String(),
|
RemoteAddress: fd.RemoteAddr().String(),
|
||||||
RemoteVersionString: conn.Name,
|
RemoteVersionString: conn.Name,
|
||||||
NumConnections: srv.PeerCount(),
|
NumConnections: srv.PeerCount(),
|
||||||
})
|
})
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,8 +24,14 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
|||||||
newPeerHook: pf,
|
newPeerHook: pf,
|
||||||
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||||
id := randomID()
|
id := randomID()
|
||||||
|
rw := newRlpxFrameRW(fd, secrets{
|
||||||
|
MAC: zero16,
|
||||||
|
AES: zero16,
|
||||||
|
IngressMAC: sha3.NewKeccak256(),
|
||||||
|
EgressMAC: sha3.NewKeccak256(),
|
||||||
|
})
|
||||||
return &conn{
|
return &conn{
|
||||||
frameRW: newFrameRW(fd, msgWriteTimeout),
|
MsgReadWriter: rw,
|
||||||
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
|
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
@ -143,9 +150,7 @@ func TestServerBroadcast(t *testing.T) {
|
|||||||
|
|
||||||
// broadcast one message
|
// broadcast one message
|
||||||
srv.Broadcast("discard", 0, "foo")
|
srv.Broadcast("discard", 0, "foo")
|
||||||
goldbuf := new(bytes.Buffer)
|
golden := unhex("66e94e166f0a2c3b884cfa59ca34")
|
||||||
writeMsg(goldbuf, NewMsg(16, "foo"))
|
|
||||||
golden := goldbuf.Bytes()
|
|
||||||
|
|
||||||
// check that the message has been written everywhere
|
// check that the message has been written everywhere
|
||||||
for i, conn := range conns {
|
for i, conn := range conns {
|
||||||
|
Loading…
Reference in New Issue
Block a user