forked from cerc-io/plugeth
p2p: disable encryption handshake
The diff is a bit bigger than expected because the protocol handshake logic has moved out of Peer. This is necessary because the protocol handshake will have custom framing in the final protocol.
This commit is contained in:
parent
4322632c59
commit
73f94f3755
@ -1,21 +1,20 @@
|
||||
package p2p
|
||||
|
||||
import (
|
||||
// "binary"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
var clogger = ethlogger.NewLogger("CRYPTOID")
|
||||
|
||||
const (
|
||||
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||
sigLen = 65 // elliptic S256
|
||||
@ -30,26 +29,76 @@ const (
|
||||
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||
)
|
||||
|
||||
type hexkey []byte
|
||||
|
||||
func (self hexkey) String() string {
|
||||
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
|
||||
type conn struct {
|
||||
*frameRW
|
||||
*protoHandshake
|
||||
}
|
||||
|
||||
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
|
||||
remoteID discover.NodeID,
|
||||
sessionToken []byte,
|
||||
err error,
|
||||
) {
|
||||
func newConn(fd net.Conn, hs *protoHandshake) *conn {
|
||||
return &conn{newFrameRW(fd, msgWriteTimeout), hs}
|
||||
}
|
||||
|
||||
// encHandshake represents information about the remote end
|
||||
// of a connection that is negotiated during the encryption handshake.
|
||||
type encHandshake struct {
|
||||
ID discover.NodeID
|
||||
IngressMAC []byte
|
||||
EgressMAC []byte
|
||||
Token []byte
|
||||
}
|
||||
|
||||
// protoHandshake is the RLP structure of the protocol handshake.
|
||||
type protoHandshake struct {
|
||||
Version uint64
|
||||
Name string
|
||||
Caps []Cap
|
||||
ListenPort uint64
|
||||
ID discover.NodeID
|
||||
}
|
||||
|
||||
// setupConn starts a protocol session on the given connection.
|
||||
// It runs the encryption handshake and the protocol handshake.
|
||||
// If dial is non-nil, the connection the local node is the initiator.
|
||||
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||
if dial == nil {
|
||||
var remotePubkey []byte
|
||||
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
|
||||
copy(remoteID[:], remotePubkey)
|
||||
return setupInboundConn(fd, prv, our)
|
||||
} else {
|
||||
remoteID = dial.ID
|
||||
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
|
||||
return setupOutboundConn(fd, prv, our, dial)
|
||||
}
|
||||
return remoteID, sessionToken, err
|
||||
}
|
||||
|
||||
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
|
||||
// var remotePubkey []byte
|
||||
// sessionToken, remotePubkey, err = inboundEncHandshake(fd, prv, nil)
|
||||
// copy(remoteID[:], remotePubkey)
|
||||
|
||||
rw := newFrameRW(fd, msgWriteTimeout)
|
||||
rhs, err := readProtocolHandshake(rw, our)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||
}
|
||||
return &conn{rw, rhs}, nil
|
||||
}
|
||||
|
||||
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||
// remoteID = dial.ID
|
||||
// sessionToken, err = outboundEncHandshake(fd, prv, remoteID[:], nil)
|
||||
|
||||
rw := newFrameRW(fd, msgWriteTimeout)
|
||||
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||
}
|
||||
rhs, err := readProtocolHandshake(rw, our)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("protocol handshake read error: %v", err)
|
||||
}
|
||||
if rhs.ID != dial.ID {
|
||||
return nil, errors.New("dialed node id mismatch")
|
||||
}
|
||||
return &conn{rw, rhs}, nil
|
||||
}
|
||||
|
||||
// outboundEncHandshake negotiates a session token on conn.
|
||||
@ -66,18 +115,9 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sessionToken != nil {
|
||||
clogger.Debugf("session-token: %v", hexkey(sessionToken))
|
||||
}
|
||||
|
||||
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
|
||||
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
||||
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
|
||||
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
|
||||
if _, err = conn.Write(auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clogger.Debugf("initiator handshake: %v", hexkey(auth))
|
||||
|
||||
response := make([]byte, rHSLen)
|
||||
if _, err = io.ReadFull(conn, response); err != nil {
|
||||
@ -88,9 +128,6 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
||||
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
|
||||
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
|
||||
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||
}
|
||||
|
||||
@ -221,12 +258,9 @@ func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionTo
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
||||
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
||||
if _, err = conn.Write(response); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
|
||||
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||
return token, remotePubKey, err
|
||||
}
|
||||
@ -361,3 +395,40 @@ func xor(one, other []byte) (xor []byte) {
|
||||
}
|
||||
return xor
|
||||
}
|
||||
|
||||
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
|
||||
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
|
||||
}
|
||||
|
||||
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
||||
// read and handle remote handshake
|
||||
msg, err := r.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg.Code == discMsg {
|
||||
// disconnect before protocol handshake is valid according to the
|
||||
// spec and we send it ourself if Server.addPeer fails.
|
||||
var reason DiscReason
|
||||
rlp.Decode(msg.Payload, &reason)
|
||||
return nil, discRequestedError(reason)
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
|
||||
}
|
||||
var hs protoHandshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// validate handshake info
|
||||
if hs.Version != our.Version {
|
||||
return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
||||
}
|
||||
if (hs.ID == discover.NodeID{}) {
|
||||
return nil, newPeerError(errPubkeyInvalid, "missing")
|
||||
}
|
||||
return &hs, nil
|
||||
}
|
@ -5,10 +5,12 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
func TestPublicKeyEncoding(t *testing.T) {
|
||||
@ -91,14 +93,14 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
t.Logf("-> %v", hexkey(auth))
|
||||
// t.Logf("-> %v", hexkey(auth))
|
||||
|
||||
// receiver reads auth and responds with response
|
||||
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
t.Logf("<- %v\n", hexkey(response))
|
||||
// t.Logf("<- %v\n", hexkey(response))
|
||||
|
||||
// initiator reads receiver's response and the key exchange completes
|
||||
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
|
||||
@ -132,7 +134,7 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
func TestEncHandshake(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
prv0, _ := crypto.GenerateKey()
|
||||
@ -165,3 +167,58 @@ func TestHandshake(t *testing.T) {
|
||||
t.Error("session token mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupConn(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey()
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
node0 := &discover.Node{
|
||||
ID: discover.PubkeyID(&prv0.PublicKey),
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
TCPPort: 33,
|
||||
}
|
||||
node1 := &discover.Node{
|
||||
ID: discover.PubkeyID(&prv1.PublicKey),
|
||||
IP: net.IP{5, 6, 7, 8},
|
||||
TCPPort: 44,
|
||||
}
|
||||
hs0 := &protoHandshake{
|
||||
Version: baseProtocolVersion,
|
||||
ID: node0.ID,
|
||||
Caps: []Cap{{"a", 0}, {"b", 2}},
|
||||
}
|
||||
hs1 := &protoHandshake{
|
||||
Version: baseProtocolVersion,
|
||||
ID: node1.ID,
|
||||
Caps: []Cap{{"c", 1}, {"d", 3}},
|
||||
}
|
||||
fd0, fd1 := net.Pipe()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
conn0, err := setupConn(fd0, prv0, hs0, node1)
|
||||
if err != nil {
|
||||
t.Errorf("outbound side error: %v", err)
|
||||
return
|
||||
}
|
||||
if conn0.ID != node1.ID {
|
||||
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
|
||||
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
|
||||
}
|
||||
}()
|
||||
|
||||
conn1, err := setupConn(fd1, prv1, hs1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("inbound side error: %v", err)
|
||||
}
|
||||
if conn1.ID != node0.ID {
|
||||
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
|
||||
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
@ -197,7 +197,7 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
|
||||
return msg, err
|
||||
}
|
||||
if !bytes.HasPrefix(start, magicToken) {
|
||||
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
|
||||
return msg, fmt.Errorf("bad magic token %x", start[:4])
|
||||
}
|
||||
size := binary.BigEndian.Uint32(start[4:])
|
||||
|
||||
|
195
p2p/peer.go
195
p2p/peer.go
@ -33,37 +33,14 @@ const (
|
||||
peersMsg = 0x05
|
||||
)
|
||||
|
||||
// handshake is the RLP structure of the protocol handshake.
|
||||
type handshake struct {
|
||||
Version uint64
|
||||
Name string
|
||||
Caps []Cap
|
||||
ListenPort uint64
|
||||
NodeID discover.NodeID
|
||||
}
|
||||
|
||||
// Peer represents a connected remote node.
|
||||
type Peer struct {
|
||||
// Peers have all the log methods.
|
||||
// Use them to display messages related to the peer.
|
||||
*logger.Logger
|
||||
|
||||
infoMu sync.Mutex
|
||||
name string
|
||||
caps []Cap
|
||||
|
||||
ourID, remoteID *discover.NodeID
|
||||
ourName string
|
||||
|
||||
rw *frameRW
|
||||
|
||||
// These fields maintain the running protocols.
|
||||
protocols []Protocol
|
||||
runlock sync.RWMutex // protects running
|
||||
running map[string]*proto
|
||||
|
||||
// disables protocol handshake, for testing
|
||||
noHandshake bool
|
||||
rw *conn
|
||||
running map[string]*protoRW
|
||||
|
||||
protoWG sync.WaitGroup
|
||||
protoErr chan error
|
||||
@ -73,36 +50,27 @@ type Peer struct {
|
||||
|
||||
// NewPeer returns a peer for testing purposes.
|
||||
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||
conn, _ := net.Pipe()
|
||||
peer := newPeer(conn, nil, "", nil, &id)
|
||||
peer.setHandshakeInfo(name, caps)
|
||||
pipe, _ := net.Pipe()
|
||||
conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
|
||||
peer := newPeer(conn, nil)
|
||||
close(peer.closed) // ensures Disconnect doesn't block
|
||||
return peer
|
||||
}
|
||||
|
||||
// ID returns the node's public key.
|
||||
func (p *Peer) ID() discover.NodeID {
|
||||
return *p.remoteID
|
||||
return p.rw.ID
|
||||
}
|
||||
|
||||
// Name returns the node name that the remote node advertised.
|
||||
func (p *Peer) Name() string {
|
||||
// this needs a lock because the information is part of the
|
||||
// protocol handshake.
|
||||
p.infoMu.Lock()
|
||||
name := p.name
|
||||
p.infoMu.Unlock()
|
||||
return name
|
||||
return p.rw.Name
|
||||
}
|
||||
|
||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||
func (p *Peer) Caps() []Cap {
|
||||
// this needs a lock because the information is part of the
|
||||
// protocol handshake.
|
||||
p.infoMu.Lock()
|
||||
caps := p.caps
|
||||
p.infoMu.Unlock()
|
||||
return caps
|
||||
// TODO: maybe return copy
|
||||
return p.rw.Caps
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the network connection.
|
||||
@ -126,30 +94,20 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (p *Peer) String() string {
|
||||
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
||||
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
|
||||
}
|
||||
|
||||
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
||||
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
||||
return &Peer{
|
||||
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
||||
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
|
||||
p := &Peer{
|
||||
Logger: logger.NewLogger(logtag),
|
||||
rw: newFrameRW(conn, msgWriteTimeout),
|
||||
ourID: ourID,
|
||||
ourName: ourName,
|
||||
remoteID: remoteID,
|
||||
protocols: protocols,
|
||||
running: make(map[string]*proto),
|
||||
rw: conn,
|
||||
running: matchProtocols(protocols, conn.Caps, conn),
|
||||
disc: make(chan DiscReason),
|
||||
protoErr: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
|
||||
p.infoMu.Lock()
|
||||
p.name = name
|
||||
p.caps = caps
|
||||
p.infoMu.Unlock()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Peer) run() DiscReason {
|
||||
@ -157,16 +115,9 @@ func (p *Peer) run() DiscReason {
|
||||
defer p.closeProtocols()
|
||||
defer close(p.closed)
|
||||
|
||||
p.startProtocols()
|
||||
go func() { readErr <- p.readLoop() }()
|
||||
|
||||
if !p.noHandshake {
|
||||
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
||||
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
||||
p.rw.Close()
|
||||
return DiscProtocolError
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for an error or disconnect.
|
||||
var reason DiscReason
|
||||
select {
|
||||
@ -206,11 +157,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||
}
|
||||
|
||||
func (p *Peer) readLoop() error {
|
||||
if !p.noHandshake {
|
||||
if err := readProtocolHandshake(p, p.rw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for {
|
||||
msg, err := p.rw.ReadMsg()
|
||||
if err != nil {
|
||||
@ -249,88 +195,36 @@ func (p *Peer) handle(msg Msg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
||||
// read and handle remote handshake
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code == discMsg {
|
||||
// disconnect before protocol handshake is valid according to the
|
||||
// spec and we send it ourself if Server.addPeer fails.
|
||||
var reason DiscReason
|
||||
rlp.Decode(msg.Payload, &reason)
|
||||
return discRequestedError(reason)
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errInvalidMsg, "message too big")
|
||||
}
|
||||
var hs handshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return err
|
||||
}
|
||||
// validate handshake info
|
||||
if hs.Version != baseProtocolVersion {
|
||||
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
|
||||
baseProtocolVersion, hs.Version)
|
||||
}
|
||||
if hs.NodeID == *p.remoteID {
|
||||
return newPeerError(errPubkeyForbidden, "node ID mismatch")
|
||||
}
|
||||
// TODO: remove Caps with empty name
|
||||
p.setHandshakeInfo(hs.Name, hs.Caps)
|
||||
p.startSubprotocols(hs.Caps)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
|
||||
var caps []interface{}
|
||||
for _, proto := range ps {
|
||||
caps = append(caps, proto.cap())
|
||||
}
|
||||
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
|
||||
}
|
||||
|
||||
// startProtocols starts matching named subprotocols.
|
||||
func (p *Peer) startSubprotocols(caps []Cap) {
|
||||
// matchProtocols creates structures for matching named subprotocols.
|
||||
func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
|
||||
sort.Sort(capsByName(caps))
|
||||
p.runlock.Lock()
|
||||
defer p.runlock.Unlock()
|
||||
offset := baseProtocolLength
|
||||
result := make(map[string]*protoRW)
|
||||
outer:
|
||||
for _, cap := range caps {
|
||||
for _, proto := range p.protocols {
|
||||
if proto.Name == cap.Name &&
|
||||
proto.Version == cap.Version &&
|
||||
p.running[cap.Name] == nil {
|
||||
p.running[cap.Name] = p.startProto(offset, proto)
|
||||
for _, proto := range protocols {
|
||||
if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
|
||||
result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
|
||||
offset += proto.Length
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
|
||||
rw := &proto{
|
||||
name: impl.Name,
|
||||
in: make(chan Msg),
|
||||
offset: offset,
|
||||
maxcode: impl.Length,
|
||||
w: p.rw,
|
||||
}
|
||||
func (p *Peer) startProtocols() {
|
||||
for _, proto := range p.running {
|
||||
proto := proto
|
||||
p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
|
||||
p.protoWG.Add(1)
|
||||
go func() {
|
||||
err := impl.Run(p, rw)
|
||||
err := proto.Run(p, proto)
|
||||
if err == nil {
|
||||
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
||||
p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
|
||||
err = errors.New("protocol returned")
|
||||
} else {
|
||||
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
||||
p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
|
||||
}
|
||||
select {
|
||||
case p.protoErr <- err:
|
||||
@ -338,16 +232,14 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||
}
|
||||
p.protoWG.Done()
|
||||
}()
|
||||
return rw
|
||||
}
|
||||
}
|
||||
|
||||
// getProto finds the protocol responsible for handling
|
||||
// the given message code.
|
||||
func (p *Peer) getProto(code uint64) (*proto, error) {
|
||||
p.runlock.RLock()
|
||||
defer p.runlock.RUnlock()
|
||||
func (p *Peer) getProto(code uint64) (*protoRW, error) {
|
||||
for _, proto := range p.running {
|
||||
if code >= proto.offset && code < proto.offset+proto.maxcode {
|
||||
if code >= proto.offset && code < proto.offset+proto.Length {
|
||||
return proto, nil
|
||||
}
|
||||
}
|
||||
@ -355,46 +247,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
|
||||
}
|
||||
|
||||
func (p *Peer) closeProtocols() {
|
||||
p.runlock.RLock()
|
||||
for _, p := range p.running {
|
||||
close(p.in)
|
||||
}
|
||||
p.runlock.RUnlock()
|
||||
p.protoWG.Wait()
|
||||
}
|
||||
|
||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||
// this exists because of Server.Broadcast.
|
||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||
p.runlock.RLock()
|
||||
proto, ok := p.running[protoName]
|
||||
p.runlock.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
||||
}
|
||||
if msg.Code >= proto.maxcode {
|
||||
if msg.Code >= proto.Length {
|
||||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||
}
|
||||
msg.Code += proto.offset
|
||||
return p.rw.WriteMsg(msg)
|
||||
}
|
||||
|
||||
type proto struct {
|
||||
name string
|
||||
type protoRW struct {
|
||||
Protocol
|
||||
|
||||
in chan Msg
|
||||
maxcode, offset uint64
|
||||
offset uint64
|
||||
w MsgWriter
|
||||
}
|
||||
|
||||
func (rw *proto) WriteMsg(msg Msg) error {
|
||||
if msg.Code >= rw.maxcode {
|
||||
func (rw *protoRW) WriteMsg(msg Msg) error {
|
||||
if msg.Code >= rw.Length {
|
||||
return newPeerError(errInvalidMsgCode, "not handled")
|
||||
}
|
||||
msg.Code += rw.offset
|
||||
return rw.w.WriteMsg(msg)
|
||||
}
|
||||
|
||||
func (rw *proto) ReadMsg() (Msg, error) {
|
||||
func (rw *protoRW) ReadMsg() (Msg, error) {
|
||||
msg, ok := <-rw.in
|
||||
if !ok {
|
||||
return msg, io.EOF
|
||||
|
105
p2p/peer_test.go
105
p2p/peer_test.go
@ -6,11 +6,9 @@ import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
@ -23,6 +21,7 @@ var discard = Protocol{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("discarding %d\n", msg.Code)
|
||||
if err = msg.Discard(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -30,13 +29,20 @@ var discard = Protocol{
|
||||
},
|
||||
}
|
||||
|
||||
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
|
||||
conn1, conn2 := net.Pipe()
|
||||
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
|
||||
peer.noHandshake = noHandshake
|
||||
func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
|
||||
fd1, fd2 := net.Pipe()
|
||||
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||
for _, p := range protos {
|
||||
hs1.Caps = append(hs1.Caps, p.cap())
|
||||
hs2.Caps = append(hs2.Caps, p.cap())
|
||||
}
|
||||
|
||||
peer := newPeer(newConn(fd1, hs1), protos)
|
||||
errc := make(chan DiscReason, 1)
|
||||
go func() { errc <- peer.run() }()
|
||||
return newFrameRW(conn2, msgWriteTimeout), peer, errc
|
||||
|
||||
return newConn(fd2, hs2), peer, errc
|
||||
}
|
||||
|
||||
func TestPeerProtoReadMsg(t *testing.T) {
|
||||
@ -61,9 +67,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||
rw, _, errc := testPeer([]Protocol{proto})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{proto.cap()})
|
||||
|
||||
EncodeMsg(rw, baseProtocolLength+2, 1)
|
||||
EncodeMsg(rw, baseProtocolLength+3, 2)
|
||||
@ -100,9 +105,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||
rw, _, errc := testPeer([]Protocol{proto})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{proto.cap()})
|
||||
|
||||
EncodeMsg(rw, 18, make([]byte, msgsize))
|
||||
select {
|
||||
@ -130,9 +134,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
rw, peer, _ := testPeer(true, []Protocol{proto})
|
||||
rw, _, _ := testPeer([]Protocol{proto})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{proto.cap()})
|
||||
|
||||
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||
t.Error(err)
|
||||
@ -142,9 +145,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
rw, peer, peerErr := testPeer(true, []Protocol{discard})
|
||||
rw, peer, peerErr := testPeer([]Protocol{discard})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{discard.cap()})
|
||||
|
||||
// test write errors
|
||||
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
||||
@ -160,7 +162,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
read := make(chan struct{})
|
||||
go func() {
|
||||
if err := expectMsg(rw, 16, nil); err != nil {
|
||||
t.Error()
|
||||
t.Error(err)
|
||||
}
|
||||
close(read)
|
||||
}()
|
||||
@ -179,7 +181,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
func TestPeerPing(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
rw, _, _ := testPeer(true, nil)
|
||||
rw, _, _ := testPeer(nil)
|
||||
defer rw.Close()
|
||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -192,7 +194,7 @@ func TestPeerPing(t *testing.T) {
|
||||
func TestPeerDisconnect(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
rw, _, disc := testPeer(true, nil)
|
||||
rw, _, disc := testPeer(nil)
|
||||
defer rw.Close()
|
||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -206,73 +208,6 @@ func TestPeerDisconnect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerHandshake(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
// remote has two matching protocols: a and c
|
||||
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
|
||||
remoteID := randomID()
|
||||
remote.ourID = &remoteID
|
||||
remote.ourName = "remote peer"
|
||||
|
||||
start := make(chan string)
|
||||
stop := make(chan struct{})
|
||||
run := func(p *Peer, rw MsgReadWriter) error {
|
||||
name := rw.(*proto).name
|
||||
if name != "a" && name != "c" {
|
||||
t.Errorf("protocol %q should not be started", name)
|
||||
} else {
|
||||
start <- name
|
||||
}
|
||||
<-stop
|
||||
return nil
|
||||
}
|
||||
protocols := []Protocol{
|
||||
{Name: "a", Version: 1, Length: 1, Run: run},
|
||||
{Name: "b", Version: 2, Length: 1, Run: run},
|
||||
{Name: "c", Version: 3, Length: 1, Run: run},
|
||||
{Name: "d", Version: 4, Length: 1, Run: run},
|
||||
}
|
||||
rw, p, disc := testPeer(false, protocols)
|
||||
p.remoteID = remote.ourID
|
||||
defer rw.Close()
|
||||
|
||||
// run the handshake
|
||||
remoteProtocols := []Protocol{protocols[0], protocols[2]}
|
||||
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
|
||||
t.Fatalf("handshake write error: %v", err)
|
||||
}
|
||||
if err := readProtocolHandshake(remote, rw); err != nil {
|
||||
t.Fatalf("handshake read error: %v", err)
|
||||
}
|
||||
|
||||
// check that all protocols have been started
|
||||
var started []string
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case name := <-start:
|
||||
started = append(started, name)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
sort.Strings(started)
|
||||
if !reflect.DeepEqual(started, []string{"a", "c"}) {
|
||||
t.Errorf("wrong protocols started: %v", started)
|
||||
}
|
||||
|
||||
// check that metadata has been set
|
||||
if p.ID() != remoteID {
|
||||
t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
|
||||
}
|
||||
if p.Name() != remote.ourName {
|
||||
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
|
||||
}
|
||||
|
||||
close(stop)
|
||||
expectMsg(rw, discMsg, nil)
|
||||
t.Logf("disc reason: %v", <-disc)
|
||||
}
|
||||
|
||||
func TestNewPeer(t *testing.T) {
|
||||
name := "nodename"
|
||||
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
@ -83,9 +82,11 @@ type Server struct {
|
||||
|
||||
// Hooks for testing. These are useful because we can inhibit
|
||||
// the whole protocol stack.
|
||||
handshakeFunc
|
||||
setupFunc
|
||||
newPeerHook
|
||||
|
||||
ourHandshake *protoHandshake
|
||||
|
||||
lock sync.RWMutex
|
||||
running bool
|
||||
listener net.Listener
|
||||
@ -99,7 +100,7 @@ type Server struct {
|
||||
peerConnect chan *discover.Node
|
||||
}
|
||||
|
||||
type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
|
||||
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
|
||||
type newPeerHook func(*Peer)
|
||||
|
||||
// Peers returns all connected peers.
|
||||
@ -170,8 +171,8 @@ func (srv *Server) Start() (err error) {
|
||||
srv.peers = make(map[discover.NodeID]*Peer)
|
||||
srv.peerConnect = make(chan *discover.Node)
|
||||
|
||||
if srv.handshakeFunc == nil {
|
||||
srv.handshakeFunc = encHandshake
|
||||
if srv.setupFunc == nil {
|
||||
srv.setupFunc = setupConn
|
||||
}
|
||||
if srv.Blacklist == nil {
|
||||
srv.Blacklist = NewBlacklist()
|
||||
@ -183,11 +184,17 @@ func (srv *Server) Start() (err error) {
|
||||
}
|
||||
|
||||
// dial stuff
|
||||
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
|
||||
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.ntab = dt
|
||||
srv.ntab = ntab
|
||||
|
||||
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()}
|
||||
for _, p := range srv.Protocols {
|
||||
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
||||
}
|
||||
|
||||
if srv.Dialer == nil {
|
||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||
}
|
||||
@ -347,18 +354,17 @@ func (srv *Server) findPeers() {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
|
||||
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
||||
// TODO: handle/store session token
|
||||
conn.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
|
||||
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
|
||||
fd.Close()
|
||||
srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
ourID := srv.ntab.Self()
|
||||
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
|
||||
if ok, reason := srv.addPeer(remoteID, p); !ok {
|
||||
p := newPeer(conn, srv.Protocols)
|
||||
if ok, reason := srv.addPeer(conn.ID, p); !ok {
|
||||
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
||||
p.politeDisconnect(reason)
|
||||
return
|
||||
@ -394,7 +400,7 @@ func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
|
||||
|
||||
func (srv *Server) removePeer(p *Peer) {
|
||||
srv.lock.Lock()
|
||||
delete(srv.peers, *p.remoteID)
|
||||
delete(srv.peers, p.ID())
|
||||
srv.lock.Unlock()
|
||||
srv.peerWG.Done()
|
||||
}
|
||||
|
@ -21,8 +21,12 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
newPeerHook: pf,
|
||||
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
|
||||
return randomID(), nil, err
|
||||
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||
id := randomID()
|
||||
return &conn{
|
||||
frameRW: newFrameRW(fd, msgWriteTimeout),
|
||||
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
@ -116,9 +120,7 @@ func TestServerBroadcast(t *testing.T) {
|
||||
|
||||
var connected sync.WaitGroup
|
||||
srv := startTestServer(t, func(p *Peer) {
|
||||
p.protocols = []Protocol{discard}
|
||||
p.startSubprotocols([]Cap{discard.cap()})
|
||||
p.noHandshake = true
|
||||
p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
|
||||
connected.Done()
|
||||
})
|
||||
defer srv.Stop()
|
||||
|
Loading…
Reference in New Issue
Block a user