p2p: disable encryption handshake

The diff is a bit bigger than expected because the protocol handshake
logic has moved out of Peer. This is necessary because the protocol
handshake will have custom framing in the final protocol.
This commit is contained in:
Felix Lange 2015-02-19 01:52:03 +01:00
parent 4322632c59
commit 73f94f3755
7 changed files with 274 additions and 314 deletions

View File

@ -1,21 +1,20 @@
package p2p package p2p
import ( import (
// "binary"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/crypto/secp256k1"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
) )
var clogger = ethlogger.NewLogger("CRYPTOID")
const ( const (
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen = 65 // elliptic S256 sigLen = 65 // elliptic S256
@ -30,26 +29,76 @@ const (
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
) )
type hexkey []byte type conn struct {
*frameRW
func (self hexkey) String() string { *protoHandshake
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
} }
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) ( func newConn(fd net.Conn, hs *protoHandshake) *conn {
remoteID discover.NodeID, return &conn{newFrameRW(fd, msgWriteTimeout), hs}
sessionToken []byte, }
err error,
) { // encHandshake represents information about the remote end
// of a connection that is negotiated during the encryption handshake.
type encHandshake struct {
ID discover.NodeID
IngressMAC []byte
EgressMAC []byte
Token []byte
}
// protoHandshake is the RLP structure of the protocol handshake.
type protoHandshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
ID discover.NodeID
}
// setupConn starts a protocol session on the given connection.
// It runs the encryption handshake and the protocol handshake.
// If dial is non-nil, the connection the local node is the initiator.
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
if dial == nil { if dial == nil {
var remotePubkey []byte return setupInboundConn(fd, prv, our)
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
copy(remoteID[:], remotePubkey)
} else { } else {
remoteID = dial.ID return setupOutboundConn(fd, prv, our, dial)
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
} }
return remoteID, sessionToken, err }
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
// var remotePubkey []byte
// sessionToken, remotePubkey, err = inboundEncHandshake(fd, prv, nil)
// copy(remoteID[:], remotePubkey)
rw := newFrameRW(fd, msgWriteTimeout)
rhs, err := readProtocolHandshake(rw, our)
if err != nil {
return nil, err
}
if err := writeProtocolHandshake(rw, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err)
}
return &conn{rw, rhs}, nil
}
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
// remoteID = dial.ID
// sessionToken, err = outboundEncHandshake(fd, prv, remoteID[:], nil)
rw := newFrameRW(fd, msgWriteTimeout)
if err := writeProtocolHandshake(rw, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err)
}
rhs, err := readProtocolHandshake(rw, our)
if err != nil {
return nil, fmt.Errorf("protocol handshake read error: %v", err)
}
if rhs.ID != dial.ID {
return nil, errors.New("dialed node id mismatch")
}
return &conn{rw, rhs}, nil
} }
// outboundEncHandshake negotiates a session token on conn. // outboundEncHandshake negotiates a session token on conn.
@ -66,18 +115,9 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
if err != nil { if err != nil {
return nil, err return nil, err
} }
if sessionToken != nil {
clogger.Debugf("session-token: %v", hexkey(sessionToken))
}
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
if _, err = conn.Write(auth); err != nil { if _, err = conn.Write(auth); err != nil {
return nil, err return nil, err
} }
clogger.Debugf("initiator handshake: %v", hexkey(auth))
response := make([]byte, rHSLen) response := make([]byte, rHSLen)
if _, err = io.ReadFull(conn, response); err != nil { if _, err = io.ReadFull(conn, response); err != nil {
@ -88,9 +128,6 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
return nil, err return nil, err
} }
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
} }
@ -221,12 +258,9 @@ func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionTo
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
if _, err = conn.Write(response); err != nil { if _, err = conn.Write(response); err != nil {
return nil, nil, err return nil, nil, err
} }
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey) token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
return token, remotePubKey, err return token, remotePubKey, err
} }
@ -361,3 +395,40 @@ func xor(one, other []byte) (xor []byte) {
} }
return xor return xor
} }
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
}
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
// read and handle remote handshake
msg, err := r.ReadMsg()
if err != nil {
return nil, err
}
if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
var reason DiscReason
rlp.Decode(msg.Payload, &reason)
return nil, discRequestedError(reason)
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
}
var hs protoHandshake
if err := msg.Decode(&hs); err != nil {
return nil, err
}
// validate handshake info
if hs.Version != our.Version {
return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
}
if (hs.ID == discover.NodeID{}) {
return nil, newPeerError(errPubkeyInvalid, "missing")
}
return &hs, nil
}

View File

@ -5,10 +5,12 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"net" "net"
"reflect"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/p2p/discover"
) )
func TestPublicKeyEncoding(t *testing.T) { func TestPublicKeyEncoding(t *testing.T) {
@ -91,14 +93,14 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
t.Logf("-> %v", hexkey(auth)) // t.Logf("-> %v", hexkey(auth))
// receiver reads auth and responds with response // receiver reads auth and responds with response
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1) response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
t.Logf("<- %v\n", hexkey(response)) // t.Logf("<- %v\n", hexkey(response))
// initiator reads receiver's response and the key exchange completes // initiator reads receiver's response and the key exchange completes
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0) recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
@ -132,7 +134,7 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
} }
} }
func TestHandshake(t *testing.T) { func TestEncHandshake(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
prv0, _ := crypto.GenerateKey() prv0, _ := crypto.GenerateKey()
@ -165,3 +167,58 @@ func TestHandshake(t *testing.T) {
t.Error("session token mismatch") t.Error("session token mismatch")
} }
} }
func TestSetupConn(t *testing.T) {
prv0, _ := crypto.GenerateKey()
prv1, _ := crypto.GenerateKey()
node0 := &discover.Node{
ID: discover.PubkeyID(&prv0.PublicKey),
IP: net.IP{1, 2, 3, 4},
TCPPort: 33,
}
node1 := &discover.Node{
ID: discover.PubkeyID(&prv1.PublicKey),
IP: net.IP{5, 6, 7, 8},
TCPPort: 44,
}
hs0 := &protoHandshake{
Version: baseProtocolVersion,
ID: node0.ID,
Caps: []Cap{{"a", 0}, {"b", 2}},
}
hs1 := &protoHandshake{
Version: baseProtocolVersion,
ID: node1.ID,
Caps: []Cap{{"c", 1}, {"d", 3}},
}
fd0, fd1 := net.Pipe()
done := make(chan struct{})
go func() {
defer close(done)
conn0, err := setupConn(fd0, prv0, hs0, node1)
if err != nil {
t.Errorf("outbound side error: %v", err)
return
}
if conn0.ID != node1.ID {
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
}
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
}
}()
conn1, err := setupConn(fd1, prv1, hs1, nil)
if err != nil {
t.Fatalf("inbound side error: %v", err)
}
if conn1.ID != node0.ID {
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
}
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
}
<-done
}

View File

@ -197,7 +197,7 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
return msg, err return msg, err
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken) return msg, fmt.Errorf("bad magic token %x", start[:4])
} }
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])

View File

@ -33,37 +33,14 @@ const (
peersMsg = 0x05 peersMsg = 0x05
) )
// handshake is the RLP structure of the protocol handshake.
type handshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
NodeID discover.NodeID
}
// Peer represents a connected remote node. // Peer represents a connected remote node.
type Peer struct { type Peer struct {
// Peers have all the log methods. // Peers have all the log methods.
// Use them to display messages related to the peer. // Use them to display messages related to the peer.
*logger.Logger *logger.Logger
infoMu sync.Mutex rw *conn
name string running map[string]*protoRW
caps []Cap
ourID, remoteID *discover.NodeID
ourName string
rw *frameRW
// These fields maintain the running protocols.
protocols []Protocol
runlock sync.RWMutex // protects running
running map[string]*proto
// disables protocol handshake, for testing
noHandshake bool
protoWG sync.WaitGroup protoWG sync.WaitGroup
protoErr chan error protoErr chan error
@ -73,36 +50,27 @@ type Peer struct {
// NewPeer returns a peer for testing purposes. // NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
conn, _ := net.Pipe() pipe, _ := net.Pipe()
peer := newPeer(conn, nil, "", nil, &id) conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
peer.setHandshakeInfo(name, caps) peer := newPeer(conn, nil)
close(peer.closed) // ensures Disconnect doesn't block close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
// ID returns the node's public key. // ID returns the node's public key.
func (p *Peer) ID() discover.NodeID { func (p *Peer) ID() discover.NodeID {
return *p.remoteID return p.rw.ID
} }
// Name returns the node name that the remote node advertised. // Name returns the node name that the remote node advertised.
func (p *Peer) Name() string { func (p *Peer) Name() string {
// this needs a lock because the information is part of the return p.rw.Name
// protocol handshake.
p.infoMu.Lock()
name := p.name
p.infoMu.Unlock()
return name
} }
// Caps returns the capabilities (supported subprotocols) of the remote peer. // Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap { func (p *Peer) Caps() []Cap {
// this needs a lock because the information is part of the // TODO: maybe return copy
// protocol handshake. return p.rw.Caps
p.infoMu.Lock()
caps := p.caps
p.infoMu.Unlock()
return caps
} }
// RemoteAddr returns the remote address of the network connection. // RemoteAddr returns the remote address of the network connection.
@ -126,30 +94,20 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (p *Peer) String() string { func (p *Peer) String() string {
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
} }
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { func newPeer(conn *conn, protocols []Protocol) *Peer {
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
return &Peer{ p := &Peer{
Logger: logger.NewLogger(logtag), Logger: logger.NewLogger(logtag),
rw: newFrameRW(conn, msgWriteTimeout), rw: conn,
ourID: ourID, running: matchProtocols(protocols, conn.Caps, conn),
ourName: ourName,
remoteID: remoteID,
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason), disc: make(chan DiscReason),
protoErr: make(chan error), protoErr: make(chan error),
closed: make(chan struct{}), closed: make(chan struct{}),
} }
} return p
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
p.infoMu.Lock()
p.name = name
p.caps = caps
p.infoMu.Unlock()
} }
func (p *Peer) run() DiscReason { func (p *Peer) run() DiscReason {
@ -157,16 +115,9 @@ func (p *Peer) run() DiscReason {
defer p.closeProtocols() defer p.closeProtocols()
defer close(p.closed) defer close(p.closed)
p.startProtocols()
go func() { readErr <- p.readLoop() }() go func() { readErr <- p.readLoop() }()
if !p.noHandshake {
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
p.DebugDetailf("Protocol handshake error: %v\n", err)
p.rw.Close()
return DiscProtocolError
}
}
// Wait for an error or disconnect. // Wait for an error or disconnect.
var reason DiscReason var reason DiscReason
select { select {
@ -206,11 +157,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
} }
func (p *Peer) readLoop() error { func (p *Peer) readLoop() error {
if !p.noHandshake {
if err := readProtocolHandshake(p, p.rw); err != nil {
return err
}
}
for { for {
msg, err := p.rw.ReadMsg() msg, err := p.rw.ReadMsg()
if err != nil { if err != nil {
@ -249,88 +195,36 @@ func (p *Peer) handle(msg Msg) error {
return nil return nil
} }
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { // matchProtocols creates structures for matching named subprotocols.
// read and handle remote handshake func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
var reason DiscReason
rlp.Decode(msg.Payload, &reason)
return discRequestedError(reason)
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errInvalidMsg, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if hs.NodeID == *p.remoteID {
return newPeerError(errPubkeyForbidden, "node ID mismatch")
}
// TODO: remove Caps with empty name
p.setHandshakeInfo(hs.Name, hs.Caps)
p.startSubprotocols(hs.Caps)
return nil
}
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
var caps []interface{}
for _, proto := range ps {
caps = append(caps, proto.cap())
}
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
}
// startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps)) sort.Sort(capsByName(caps))
p.runlock.Lock()
defer p.runlock.Unlock()
offset := baseProtocolLength offset := baseProtocolLength
result := make(map[string]*protoRW)
outer: outer:
for _, cap := range caps { for _, cap := range caps {
for _, proto := range p.protocols { for _, proto := range protocols {
if proto.Name == cap.Name && if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
proto.Version == cap.Version && result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
p.running[cap.Name] == nil {
p.running[cap.Name] = p.startProto(offset, proto)
offset += proto.Length offset += proto.Length
continue outer continue outer
} }
} }
} }
return result
} }
func (p *Peer) startProto(offset uint64, impl Protocol) *proto { func (p *Peer) startProtocols() {
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version) for _, proto := range p.running {
rw := &proto{ proto := proto
name: impl.Name, p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
in: make(chan Msg),
offset: offset,
maxcode: impl.Length,
w: p.rw,
}
p.protoWG.Add(1) p.protoWG.Add(1)
go func() { go func() {
err := impl.Run(p, rw) err := proto.Run(p, proto)
if err == nil { if err == nil {
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
err = errors.New("protocol returned") err = errors.New("protocol returned")
} else { } else {
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
} }
select { select {
case p.protoErr <- err: case p.protoErr <- err:
@ -338,16 +232,14 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
} }
p.protoWG.Done() p.protoWG.Done()
}() }()
return rw }
} }
// getProto finds the protocol responsible for handling // getProto finds the protocol responsible for handling
// the given message code. // the given message code.
func (p *Peer) getProto(code uint64) (*proto, error) { func (p *Peer) getProto(code uint64) (*protoRW, error) {
p.runlock.RLock()
defer p.runlock.RUnlock()
for _, proto := range p.running { for _, proto := range p.running {
if code >= proto.offset && code < proto.offset+proto.maxcode { if code >= proto.offset && code < proto.offset+proto.Length {
return proto, nil return proto, nil
} }
} }
@ -355,46 +247,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
} }
func (p *Peer) closeProtocols() { func (p *Peer) closeProtocols() {
p.runlock.RLock()
for _, p := range p.running { for _, p := range p.running {
close(p.in) close(p.in)
} }
p.runlock.RUnlock()
p.protoWG.Wait() p.protoWG.Wait()
} }
// writeProtoMsg sends the given message on behalf of the given named protocol. // writeProtoMsg sends the given message on behalf of the given named protocol.
// this exists because of Server.Broadcast. // this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock()
proto, ok := p.running[protoName] proto, ok := p.running[protoName]
p.runlock.RUnlock()
if !ok { if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName) return fmt.Errorf("protocol %s not handled by peer", protoName)
} }
if msg.Code >= proto.maxcode { if msg.Code >= proto.Length {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
} }
msg.Code += proto.offset msg.Code += proto.offset
return p.rw.WriteMsg(msg) return p.rw.WriteMsg(msg)
} }
type proto struct { type protoRW struct {
name string Protocol
in chan Msg in chan Msg
maxcode, offset uint64 offset uint64
w MsgWriter w MsgWriter
} }
func (rw *proto) WriteMsg(msg Msg) error { func (rw *protoRW) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode { if msg.Code >= rw.Length {
return newPeerError(errInvalidMsgCode, "not handled") return newPeerError(errInvalidMsgCode, "not handled")
} }
msg.Code += rw.offset msg.Code += rw.offset
return rw.w.WriteMsg(msg) return rw.w.WriteMsg(msg)
} }
func (rw *proto) ReadMsg() (Msg, error) { func (rw *protoRW) ReadMsg() (Msg, error) {
msg, ok := <-rw.in msg, ok := <-rw.in
if !ok { if !ok {
return msg, io.EOF return msg, io.EOF

View File

@ -6,11 +6,9 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -23,6 +21,7 @@ var discard = Protocol{
if err != nil { if err != nil {
return err return err
} }
fmt.Printf("discarding %d\n", msg.Code)
if err = msg.Discard(); err != nil { if err = msg.Discard(); err != nil {
return err return err
} }
@ -30,13 +29,20 @@ var discard = Protocol{
}, },
} }
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
conn1, conn2 := net.Pipe() fd1, fd2 := net.Pipe()
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
peer.noHandshake = noHandshake hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
for _, p := range protos {
hs1.Caps = append(hs1.Caps, p.cap())
hs2.Caps = append(hs2.Caps, p.cap())
}
peer := newPeer(newConn(fd1, hs1), protos)
errc := make(chan DiscReason, 1) errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }() go func() { errc <- peer.run() }()
return newFrameRW(conn2, msgWriteTimeout), peer, errc
return newConn(fd2, hs2), peer, errc
} }
func TestPeerProtoReadMsg(t *testing.T) { func TestPeerProtoReadMsg(t *testing.T) {
@ -61,9 +67,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
}, },
} }
rw, peer, errc := testPeer(true, []Protocol{proto}) rw, _, errc := testPeer([]Protocol{proto})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()})
EncodeMsg(rw, baseProtocolLength+2, 1) EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2) EncodeMsg(rw, baseProtocolLength+3, 2)
@ -100,9 +105,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
}, },
} }
rw, peer, errc := testPeer(true, []Protocol{proto}) rw, _, errc := testPeer([]Protocol{proto})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()})
EncodeMsg(rw, 18, make([]byte, msgsize)) EncodeMsg(rw, 18, make([]byte, msgsize))
select { select {
@ -130,9 +134,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil return nil
}, },
} }
rw, peer, _ := testPeer(true, []Protocol{proto}) rw, _, _ := testPeer([]Protocol{proto})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()})
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
t.Error(err) t.Error(err)
@ -142,9 +145,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
func TestPeerWriteForBroadcast(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, peer, peerErr := testPeer(true, []Protocol{discard}) rw, peer, peerErr := testPeer([]Protocol{discard})
defer rw.Close() defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors // test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
@ -160,7 +162,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
read := make(chan struct{}) read := make(chan struct{})
go func() { go func() {
if err := expectMsg(rw, 16, nil); err != nil { if err := expectMsg(rw, 16, nil); err != nil {
t.Error() t.Error(err)
} }
close(read) close(read)
}() }()
@ -179,7 +181,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
func TestPeerPing(t *testing.T) { func TestPeerPing(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, _, _ := testPeer(true, nil) rw, _, _ := testPeer(nil)
defer rw.Close() defer rw.Close()
if err := EncodeMsg(rw, pingMsg); err != nil { if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err) t.Fatal(err)
@ -192,7 +194,7 @@ func TestPeerPing(t *testing.T) {
func TestPeerDisconnect(t *testing.T) { func TestPeerDisconnect(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, _, disc := testPeer(true, nil) rw, _, disc := testPeer(nil)
defer rw.Close() defer rw.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err) t.Fatal(err)
@ -206,73 +208,6 @@ func TestPeerDisconnect(t *testing.T) {
} }
} }
func TestPeerHandshake(t *testing.T) {
defer testlog(t).detach()
// remote has two matching protocols: a and c
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
remoteID := randomID()
remote.ourID = &remoteID
remote.ourName = "remote peer"
start := make(chan string)
stop := make(chan struct{})
run := func(p *Peer, rw MsgReadWriter) error {
name := rw.(*proto).name
if name != "a" && name != "c" {
t.Errorf("protocol %q should not be started", name)
} else {
start <- name
}
<-stop
return nil
}
protocols := []Protocol{
{Name: "a", Version: 1, Length: 1, Run: run},
{Name: "b", Version: 2, Length: 1, Run: run},
{Name: "c", Version: 3, Length: 1, Run: run},
{Name: "d", Version: 4, Length: 1, Run: run},
}
rw, p, disc := testPeer(false, protocols)
p.remoteID = remote.ourID
defer rw.Close()
// run the handshake
remoteProtocols := []Protocol{protocols[0], protocols[2]}
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
t.Fatalf("handshake write error: %v", err)
}
if err := readProtocolHandshake(remote, rw); err != nil {
t.Fatalf("handshake read error: %v", err)
}
// check that all protocols have been started
var started []string
for i := 0; i < 2; i++ {
select {
case name := <-start:
started = append(started, name)
case <-time.After(100 * time.Millisecond):
}
}
sort.Strings(started)
if !reflect.DeepEqual(started, []string{"a", "c"}) {
t.Errorf("wrong protocols started: %v", started)
}
// check that metadata has been set
if p.ID() != remoteID {
t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
}
if p.Name() != remote.ourName {
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
}
close(stop)
expectMsg(rw, discMsg, nil)
t.Logf("disc reason: %v", <-disc)
}
func TestNewPeer(t *testing.T) { func TestNewPeer(t *testing.T) {
name := "nodename" name := "nodename"
caps := []Cap{{"foo", 2}, {"bar", 3}} caps := []Cap{{"foo", 2}, {"bar", 3}}

View File

@ -5,7 +5,6 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"runtime" "runtime"
"sync" "sync"
@ -83,9 +82,11 @@ type Server struct {
// Hooks for testing. These are useful because we can inhibit // Hooks for testing. These are useful because we can inhibit
// the whole protocol stack. // the whole protocol stack.
handshakeFunc setupFunc
newPeerHook newPeerHook
ourHandshake *protoHandshake
lock sync.RWMutex lock sync.RWMutex
running bool running bool
listener net.Listener listener net.Listener
@ -99,7 +100,7 @@ type Server struct {
peerConnect chan *discover.Node peerConnect chan *discover.Node
} }
type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error) type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
type newPeerHook func(*Peer) type newPeerHook func(*Peer)
// Peers returns all connected peers. // Peers returns all connected peers.
@ -170,8 +171,8 @@ func (srv *Server) Start() (err error) {
srv.peers = make(map[discover.NodeID]*Peer) srv.peers = make(map[discover.NodeID]*Peer)
srv.peerConnect = make(chan *discover.Node) srv.peerConnect = make(chan *discover.Node)
if srv.handshakeFunc == nil { if srv.setupFunc == nil {
srv.handshakeFunc = encHandshake srv.setupFunc = setupConn
} }
if srv.Blacklist == nil { if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist() srv.Blacklist = NewBlacklist()
@ -183,11 +184,17 @@ func (srv *Server) Start() (err error) {
} }
// dial stuff // dial stuff
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT) ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
if err != nil { if err != nil {
return err return err
} }
srv.ntab = dt srv.ntab = ntab
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()}
for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
if srv.Dialer == nil { if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
} }
@ -347,18 +354,17 @@ func (srv *Server) findPeers() {
} }
} }
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) { func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
// TODO: handle/store session token // TODO: handle/store session token
conn.SetDeadline(time.Now().Add(handshakeTimeout)) fd.SetDeadline(time.Now().Add(handshakeTimeout))
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest) conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
if err != nil { if err != nil {
conn.Close() fd.Close()
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err) srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
return return
} }
ourID := srv.ntab.Self() p := newPeer(conn, srv.Protocols)
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID) if ok, reason := srv.addPeer(conn.ID, p); !ok {
if ok, reason := srv.addPeer(remoteID, p); !ok {
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason) srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
p.politeDisconnect(reason) p.politeDisconnect(reason)
return return
@ -394,7 +400,7 @@ func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
func (srv *Server) removePeer(p *Peer) { func (srv *Server) removePeer(p *Peer) {
srv.lock.Lock() srv.lock.Lock()
delete(srv.peers, *p.remoteID) delete(srv.peers, p.ID())
srv.lock.Unlock() srv.lock.Unlock()
srv.peerWG.Done() srv.peerWG.Done()
} }

View File

@ -21,8 +21,12 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(), PrivateKey: newkey(),
newPeerHook: pf, newPeerHook: pf,
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) { setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
return randomID(), nil, err id := randomID()
return &conn{
frameRW: newFrameRW(fd, msgWriteTimeout),
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
}, nil
}, },
} }
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
@ -116,9 +120,7 @@ func TestServerBroadcast(t *testing.T) {
var connected sync.WaitGroup var connected sync.WaitGroup
srv := startTestServer(t, func(p *Peer) { srv := startTestServer(t, func(p *Peer) {
p.protocols = []Protocol{discard} p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
p.startSubprotocols([]Cap{discard.cap()})
p.noHandshake = true
connected.Done() connected.Done()
}) })
defer srv.Stop() defer srv.Stop()