p2p: integrate p2p/discover

Overview of changes:

- ClientIdentity has been removed, use discover.NodeID
- Server now requires a private key to be set (instead of public key)
- Server performs the encryption handshake before launching Peer
- Dial logic takes peers from discover table
- Encryption handshake code has been cleaned up a bit
- baseProtocol is gone because we don't exchange peers anymore
- Some parts of baseProtocol have moved into Peer instead
This commit is contained in:
Felix Lange 2015-02-05 03:07:58 +01:00
parent 739066ec56
commit 5bdc115943
15 changed files with 1080 additions and 1683 deletions

View File

@ -1,68 +0,0 @@
package p2p
import (
"fmt"
"runtime"
)
// ClientIdentity represents the identity of a peer.
type ClientIdentity interface {
String() string // human readable identity
Pubkey() []byte // 512-bit public key
}
type SimpleClientIdentity struct {
clientIdentifier string
version string
customIdentifier string
os string
implementation string
privkey []byte
pubkey []byte
}
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
clientIdentity := &SimpleClientIdentity{
clientIdentifier: clientIdentifier,
version: version,
customIdentifier: customIdentifier,
os: runtime.GOOS,
implementation: runtime.Version(),
pubkey: pubkey,
}
return clientIdentity
}
func (c *SimpleClientIdentity) init() {
}
func (c *SimpleClientIdentity) String() string {
var id string
if len(c.customIdentifier) > 0 {
id = "/" + c.customIdentifier
}
return fmt.Sprintf("%s/v%s%s/%s/%s",
c.clientIdentifier,
c.version,
id,
c.os,
c.implementation)
}
func (c *SimpleClientIdentity) Privkey() []byte {
return c.privkey
}
func (c *SimpleClientIdentity) Pubkey() []byte {
return c.pubkey
}
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
c.customIdentifier = customIdentifier
}
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
return c.customIdentifier
}

View File

@ -1,35 +0,0 @@
package p2p
import (
"bytes"
"fmt"
"runtime"
"testing"
)
func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
key := clientIdentity.Pubkey()
if !bytes.Equal(key, []byte("pubkey")) {
t.Errorf("Expected Pubkey to be %x, got %x", key, []byte("pubkey"))
}
clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
customIdentifier := clientIdentity.GetCustomIdentifier()
if customIdentifier != "test" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
}
clientIdentity.SetCustomIdentifier("test2")
customIdentifier = clientIdentity.GetCustomIdentifier()
if customIdentifier != "test2" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
}
clientString = clientIdentity.String()
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
}

View File

@ -10,28 +10,25 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/obscuren/ecies"
)
var clogger = ethlogger.NewLogger("CRYPTOID")
const (
sskLen int = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen int = 65 // elliptic S256
pubLen int = 64 // 512 bit pubkey in uncompressed representation without format byte
shaLen int = 32 // hash length (for nonce etc)
msgLen int = 194 // sigLen + shaLen + pubLen + shaLen + 1 = 194
resLen int = 97 // pubLen + shaLen + 1
iHSLen int = 307 // size of the final ECIES payload sent as initiator's handshake
rHSLen int = 210 // size of the final ECIES payload sent as receiver's handshake
)
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen = 65 // elliptic S256
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
shaLen = 32 // hash length (for nonce etc)
// secretRW implements a message read writer with encryption and authentication
// it is initialised by cryptoId.Run() after a successful crypto handshake
// aesSecret, macSecret, egressMac, ingress
type secretRW struct {
aesSecret, macSecret, egressMac, ingressMac []byte
}
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
authRespLen = pubLen + shaLen + 1
eciesBytes = 65 + 16 + 32
iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
)
type hexkey []byte
@ -39,150 +36,73 @@ func (self hexkey) String() string {
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
}
var nonceF = func(b []byte) (n int, err error) {
return rand.Read(b)
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
remoteID discover.NodeID,
sessionToken []byte,
err error,
) {
if dial == nil {
var remotePubkey []byte
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
copy(remoteID[:], remotePubkey)
} else {
remoteID = dial.ID
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
}
return remoteID, sessionToken, err
}
var step = 0
var detnonceF = func(b []byte) (n int, err error) {
step++
copy(b, crypto.Sha3([]byte("privacy"+string(step))))
fmt.Printf("detkey %v: %v\n", step, hexkey(b))
return
}
var keyF = func() (priv *ecdsa.PrivateKey, err error) {
priv, err = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
// outboundEncHandshake negotiates a session token on conn.
// it should be called on the dialing side of the connection.
//
// privateKey is the local client's private key
// remotePublicKey is the remote peer's node ID
// sessionToken is the token from a previous session with this node.
func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (
newSessionToken []byte,
err error,
) {
auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken)
if err != nil {
return
}
return
}
var detkeyF = func() (priv *ecdsa.PrivateKey, err error) {
s := make([]byte, 32)
detnonceF(s)
priv = crypto.ToECDSA(s)
return
}
/*
NewSecureSession(connection, privateKey, remotePublicKey, sessionToken, initiator) is called when the peer connection starts to set up a secure session by performing a crypto handshake.
connection is (a buffered) network connection.
privateKey is the local client's private key (*ecdsa.PrivateKey)
remotePublicKey is the remote peer's node Id ([]byte)
sessionToken is the token from the previous session with this same peer. Nil if no token is found.
initiator is a boolean flag. True if the node is the initiator of the connection (ie., remote is an outbound peer reached by dialing out). False if the connection was established by accepting a call from the remote peer via a listener.
It returns a secretRW which implements the MsgReadWriter interface.
*/
func NewSecureSession(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePubKeyS []byte, sessionToken []byte, initiator bool) (token []byte, rw *secretRW, err error) {
var auth, initNonce, recNonce []byte
var read int
var randomPrivKey *ecdsa.PrivateKey
var remoteRandomPubKey *ecdsa.PublicKey
clogger.Debugf("attempting session with %v", hexkey(remotePubKeyS))
if initiator {
if auth, initNonce, randomPrivKey, _, err = startHandshake(prvKey, remotePubKeyS, sessionToken); err != nil {
return
return nil, err
}
if sessionToken != nil {
clogger.Debugf("session-token: %v", hexkey(sessionToken))
}
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
randomPublicKeyS, _ := ExportPublicKey(&randomPrivKey.PublicKey)
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
if _, err = conn.Write(auth); err != nil {
return
return nil, err
}
clogger.Debugf("initiator handshake (sent to %v):\n%v", hexkey(remotePubKeyS), hexkey(auth))
var response []byte = make([]byte, rHSLen)
if read, err = conn.Read(response); err != nil || read == 0 {
return
clogger.Debugf("initiator handshake: %v", hexkey(auth))
response := make([]byte, rHSLen)
if _, err = io.ReadFull(conn, response); err != nil {
return nil, err
}
if read != rHSLen {
err = fmt.Errorf("remote receiver's handshake has invalid length. expect %v, got %v", rHSLen, read)
return
}
// write out auth message
// wait for response, then call complete
if recNonce, remoteRandomPubKey, _, err = completeHandshake(response, prvKey); err != nil {
return
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey)
if err != nil {
return nil, err
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
remoteRandomPubKeyS, _ := ExportPublicKey(remoteRandomPubKey)
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
} else {
auth = make([]byte, iHSLen)
clogger.Debugf("waiting for initiator handshake (from %v)", hexkey(remotePubKeyS))
if read, err = conn.Read(auth); err != nil {
return
}
if read != iHSLen {
err = fmt.Errorf("remote initiator's handshake has invalid length. expect %v, got %v", iHSLen, read)
return
}
clogger.Debugf("received initiator handshake (from %v):\n%v", hexkey(remotePubKeyS), hexkey(auth))
// we are listening connection. we are responders in the handshake.
// Extract info from the authentication. The initiator starts by sending us a handshake that we need to respond to.
// so we read auth message first, then respond
var response []byte
if response, recNonce, initNonce, randomPrivKey, remoteRandomPubKey, err = respondToHandshake(auth, prvKey, remotePubKeyS, sessionToken); err != nil {
return
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
if _, err = conn.Write(response); err != nil {
return
}
clogger.Debugf("receiver handshake (sent to %v):\n%v", hexkey(remotePubKeyS), hexkey(response))
}
return newSession(initiator, initNonce, recNonce, auth, randomPrivKey, remoteRandomPubKey)
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
}
/*
ImportPublicKey creates a 512 bit *ecsda.PublicKey from a byte slice. It accepts the simple 64 byte uncompressed format or the 65 byte format given by calling elliptic.Marshal on the EC point represented by the key. Any other length will result in an invalid public key error.
*/
func ImportPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
var pubKey65 []byte
switch len(pubKey) {
case 64:
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
return crypto.ToECDSAPub(pubKey65), nil
}
/*
ExportPublicKey exports a *ecdsa.PublicKey into a byte slice using a simple 64-byte format. and is used for simple serialisation in network communication
*/
func ExportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
if pubKeyEC == nil {
return nil, fmt.Errorf("no ECDSA public key given")
}
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
}
/* startHandshake is called by if the node is the initiator of the connection.
The caller provides the public key of the peer as conjuctured from lookup based on IP:port, given as user input or proven by signatures. The caller must have access to persistant information about the peers, and pass the previous session token as an argument to cryptoId.
The first return value is the auth message that is to be sent out to the remote receiver.
*/
func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (auth []byte, initNonce []byte, randomPrvKey *ecdsa.PrivateKey, remotePubKey *ecdsa.PublicKey, err error) {
// authMsg creates the initiator handshake.
func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (
auth, initNonce []byte,
randomPrvKey *ecdsa.PrivateKey,
err error,
) {
// session init, common to both parties
if remotePubKey, err = ImportPublicKey(remotePubKeyS); err != nil {
remotePubKey, err := importPublicKey(remotePubKeyS)
if err != nil {
return
}
@ -203,20 +123,18 @@ func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte
//E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0)
// E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1)
// allocate msgLen long message,
var msg []byte = make([]byte, msgLen)
initNonce = msg[msgLen-shaLen-1 : msgLen-1]
fmt.Printf("init-nonce: ")
if _, err = nonceF(initNonce); err != nil {
var msg []byte = make([]byte, authMsgLen)
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
if _, err = rand.Read(initNonce); err != nil {
return
}
// create known message
// ecdh-shared-secret^nonce for new peers
// token^nonce for old peers
var sharedSecret = Xor(sessionToken, initNonce)
var sharedSecret = xor(sessionToken, initNonce)
// generate random keypair to use for signing
fmt.Printf("init-random-ecdhe-private-key: ")
if randomPrvKey, err = keyF(); err != nil {
if randomPrvKey, err = crypto.GenerateKey(); err != nil {
return
}
// sign shared secret (message known to both parties): shared-secret
@ -232,11 +150,11 @@ func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte
copy(msg, signature) // copy signed-shared-secret
// H(ecdhe-random-pubk)
var randomPubKey64 []byte
if randomPubKey64, err = ExportPublicKey(&randomPrvKey.PublicKey); err != nil {
if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil {
return
}
var pubKey64 []byte
if pubKey64, err = ExportPublicKey(&prvKey.PublicKey); err != nil {
if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil {
return
}
copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64))
@ -244,36 +162,98 @@ func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte
copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64)
// nonce is already in the slice
// stick tokenFlag byte to the end
msg[msgLen-1] = tokenFlag
msg[authMsgLen-1] = tokenFlag
// encrypt using remote-pubk
// auth = eciesEncrypt(remote-pubk, msg)
if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil {
return
}
return
}
/*
respondToHandshake is called by peer if it accepted (but not initiated) the connection from the remote. It is passed the initiator handshake received, the public key and session token belonging to the remote initiator.
The first return value is the authentication response (aka receiver handshake) that is to be sent to the remote initiator.
*/
func respondToHandshake(auth []byte, prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (authResp []byte, respNonce []byte, initNonce []byte, randomPrivKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey, err error) {
// completeHandshake is called when the initiator receives an
// authentication response (aka receiver handshake). It completes the
// handshake by reading off parameters the remote peer provides needed
// to set up the secure session.
func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (
respNonce []byte,
remoteRandomPubKey *ecdsa.PublicKey,
tokenFlag bool,
err error,
) {
var msg []byte
var remotePubKey *ecdsa.PublicKey
if remotePubKey, err = ImportPublicKey(remotePubKeyS); err != nil {
return
}
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
return
}
respNonce = msg[pubLen : pubLen+shaLen]
var remoteRandomPubKeyS = msg[:pubLen]
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
return
}
if msg[authRespLen-1] == 0x01 {
tokenFlag = true
}
return
}
// inboundEncHandshake negotiates a session token on conn.
// it should be called on the listening side of the connection.
//
// privateKey is the local client's private key
// sessionToken is the token from a previous session with this node.
func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (
token, remotePubKey []byte,
err error,
) {
// we are listening connection. we are responders in the
// handshake. Extract info from the authentication. The initiator
// starts by sending us a handshake that we need to respond to. so
// we read auth message first, then respond.
auth := make([]byte, iHSLen)
if _, err := io.ReadFull(conn, auth); err != nil {
return nil, nil, err
}
response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey)
if err != nil {
return nil, nil, err
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
if _, err = conn.Write(response); err != nil {
return nil, nil, err
}
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
return token, remotePubKey, err
}
// authResp is called by peer if it accepted (but not
// initiated) the connection from the remote. It is passed the initiator
// handshake received and the session token belonging to the
// remote initiator.
//
// The first return value is the authentication response (aka receiver
// handshake) that is to be sent to the remote initiator.
func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) (
authResp, respNonce, initNonce, remotePubKeyS []byte,
randomPrivKey *ecdsa.PrivateKey,
remoteRandomPubKey *ecdsa.PublicKey,
err error,
) {
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
msg, err := crypto.Decrypt(prvKey, auth)
if err != nil {
return
}
remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen]
remotePubKey, _ := importPublicKey(remotePubKeyS)
var tokenFlag byte
if sessionToken == nil {
// no session token found means we need to generate shared secret.
@ -289,42 +269,42 @@ func respondToHandshake(auth []byte, prvKey *ecdsa.PrivateKey, remotePubKeyS, se
}
// the initiator nonce is read off the end of the message
initNonce = msg[msgLen-shaLen-1 : msgLen-1]
// I prove that i own prv key (to derive shared secret, and read nonce off encrypted msg) and that I own shared secret
// they prove they own the private key belonging to ecdhe-random-pubk
// we can now reconstruct the signed message and recover the peers pubkey
var signedMsg = Xor(sessionToken, initNonce)
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
// I prove that i own prv key (to derive shared secret, and read
// nonce off encrypted msg) and that I own shared secret they
// prove they own the private key belonging to ecdhe-random-pubk
// we can now reconstruct the signed message and recover the peers
// pubkey
var signedMsg = xor(sessionToken, initNonce)
var remoteRandomPubKeyS []byte
if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil {
return
}
// convert to ECDSA standard
if remoteRandomPubKey, err = ImportPublicKey(remoteRandomPubKeyS); err != nil {
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
return
}
// now we find ourselves a long task too, fill it random
var resp = make([]byte, resLen)
var resp = make([]byte, authRespLen)
// generate shaLen long nonce
respNonce = resp[pubLen : pubLen+shaLen]
fmt.Printf("rec-nonce: ")
if _, err = nonceF(respNonce); err != nil {
if _, err = rand.Read(respNonce); err != nil {
return
}
// generate random keypair for session
fmt.Printf("rec-random-ecdhe-private-key: ")
if randomPrivKey, err = keyF(); err != nil {
if randomPrivKey, err = crypto.GenerateKey(); err != nil {
return
}
// responder auth message
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
var randomPubKeyS []byte
if randomPubKeyS, err = ExportPublicKey(&randomPrivKey.PublicKey); err != nil {
if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil {
return
}
copy(resp[:pubLen], randomPubKeyS)
// nonce is already in the slice
resp[resLen-1] = tokenFlag
resp[authRespLen-1] = tokenFlag
// encrypt using remote-pubk
// auth = eciesEncrypt(remote-pubk, msg)
@ -335,70 +315,49 @@ func respondToHandshake(auth []byte, prvKey *ecdsa.PrivateKey, remotePubKeyS, se
return
}
/*
completeHandshake is called when the initiator receives an authentication response (aka receiver handshake). It completes the handshake by reading off parameters the remote peer provides needed to set up the secure session
*/
func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (respNonce []byte, remoteRandomPubKey *ecdsa.PublicKey, tokenFlag bool, err error) {
var msg []byte
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
return
}
respNonce = msg[pubLen : pubLen+shaLen]
var remoteRandomPubKeyS = msg[:pubLen]
if remoteRandomPubKey, err = ImportPublicKey(remoteRandomPubKeyS); err != nil {
return
}
if msg[resLen-1] == 0x01 {
tokenFlag = true
}
return
}
/*
newSession is called after the handshake is completed. The arguments are values negotiated in the handshake and the return value is a new session : a new session Token to be remembered for the next time we connect with this peer. And a MsgReadWriter that implements an encrypted and authenticated connection with key material obtained from the crypto handshake key exchange
*/
func newSession(initiator bool, initNonce, respNonce, auth []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) (sessionToken []byte, rw *secretRW, err error) {
// newSession is called after the handshake is completed. The
// arguments are values negotiated in the handshake. The return value
// is a new session Token to be remembered for the next time we
// connect with this peer.
func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) {
// 3) Now we can trust ecdhe-random-pubk to derive new keys
//ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk)
var dhSharedSecret []byte
pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey)
if dhSharedSecret, err = ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen); err != nil {
return
dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen)
if err != nil {
return nil, err
}
var sharedSecret = crypto.Sha3(append(dhSharedSecret, crypto.Sha3(append(respNonce, initNonce...))...))
sessionToken = crypto.Sha3(sharedSecret)
var aesSecret = crypto.Sha3(append(dhSharedSecret, sharedSecret...))
var macSecret = crypto.Sha3(append(dhSharedSecret, aesSecret...))
var egressMac, ingressMac []byte
if initiator {
egressMac = Xor(macSecret, respNonce)
ingressMac = Xor(macSecret, initNonce)
} else {
egressMac = Xor(macSecret, initNonce)
ingressMac = Xor(macSecret, respNonce)
}
rw = &secretRW{
aesSecret: aesSecret,
macSecret: macSecret,
egressMac: egressMac,
ingressMac: ingressMac,
}
clogger.Debugf("aes-secret: %v", hexkey(aesSecret))
clogger.Debugf("mac-secret: %v", hexkey(macSecret))
clogger.Debugf("egress-mac: %v", hexkey(egressMac))
clogger.Debugf("ingress-mac: %v", hexkey(ingressMac))
return
sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce))
sessionToken := crypto.Sha3(sharedSecret)
return sessionToken, nil
}
// TODO: optimisation
// should use cipher.xorBytes from crypto/cipher/xor.go for fast xor
func Xor(one, other []byte) (xor []byte) {
// importPublicKey unmarshals 512 bit public keys.
func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
var pubKey65 []byte
switch len(pubKey) {
case 64:
// add 'uncompressed key' flag
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
return crypto.ToECDSAPub(pubKey65), nil
}
func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
if pubKeyEC == nil {
return nil, fmt.Errorf("no ECDSA public key given")
}
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
}
func xor(one, other []byte) (xor []byte) {
xor = make([]byte, len(one))
for i := 0; i < len(one); i++ {
xor[i] = one[i] ^ other[i]
}
return
return xor
}

View File

@ -3,10 +3,9 @@ package p2p
import (
"bytes"
"crypto/ecdsa"
"fmt"
"crypto/rand"
"net"
"testing"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/obscuren/ecies"
@ -16,7 +15,7 @@ func TestPublicKeyEncoding(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
pub0s := crypto.FromECDSAPub(pub0)
pub1, err := ImportPublicKey(pub0s)
pub1, err := importPublicKey(pub0s)
if err != nil {
t.Errorf("%v", err)
}
@ -24,18 +23,18 @@ func TestPublicKeyEncoding(t *testing.T) {
if eciesPub1 == nil {
t.Errorf("invalid ecdsa public key")
}
pub1s, err := ExportPublicKey(pub1)
pub1s, err := exportPublicKey(pub1)
if err != nil {
t.Errorf("%v", err)
}
if len(pub1s) != 64 {
t.Errorf("wrong length expect 64, got", len(pub1s))
}
pub2, err := ImportPublicKey(pub1s)
pub2, err := importPublicKey(pub1s)
if err != nil {
t.Errorf("%v", err)
}
pub2s, err := ExportPublicKey(pub2)
pub2s, err := exportPublicKey(pub2)
if err != nil {
t.Errorf("%v", err)
}
@ -69,95 +68,53 @@ func TestSharedSecret(t *testing.T) {
}
func TestCryptoHandshake(t *testing.T) {
testCryptoHandshakeWithGen(false, t)
testCryptoHandshake(newkey(), newkey(), nil, t)
}
func TestTokenCryptoHandshake(t *testing.T) {
testCryptoHandshakeWithGen(true, t)
}
func TestDetCryptoHandshake(t *testing.T) {
defer testlog(t).detach()
tmpkeyF := keyF
keyF = detkeyF
tmpnonceF := nonceF
nonceF = detnonceF
testCryptoHandshakeWithGen(false, t)
keyF = tmpkeyF
nonceF = tmpnonceF
}
func TestDetTokenCryptoHandshake(t *testing.T) {
defer testlog(t).detach()
tmpkeyF := keyF
keyF = detkeyF
tmpnonceF := nonceF
nonceF = detnonceF
testCryptoHandshakeWithGen(true, t)
keyF = tmpkeyF
nonceF = tmpnonceF
}
func testCryptoHandshakeWithGen(token bool, t *testing.T) {
fmt.Printf("init-private-key: ")
prv0, err := keyF()
if err != nil {
t.Errorf("%v", err)
return
}
fmt.Printf("rec-private-key: ")
prv1, err := keyF()
if err != nil {
t.Errorf("%v", err)
return
}
var nonce []byte
if token {
fmt.Printf("session-token: ")
nonce = make([]byte, shaLen)
nonceF(nonce)
}
testCryptoHandshake(prv0, prv1, nonce, t)
func TestCryptoHandshakeWithToken(t *testing.T) {
sessionToken := make([]byte, shaLen)
rand.Read(sessionToken)
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
}
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
var err error
pub0 := &prv0.PublicKey
// pub0 := &prv0.PublicKey
pub1 := &prv1.PublicKey
pub0s := crypto.FromECDSAPub(pub0)
// pub0s := crypto.FromECDSAPub(pub0)
pub1s := crypto.FromECDSAPub(pub1)
// simulate handshake by feeding output to input
// initiator sends handshake 'auth'
auth, initNonce, randomPrivKey, _, err := startHandshake(prv0, pub1s, sessionToken)
auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
if err != nil {
t.Errorf("%v", err)
}
fmt.Printf("-> %v\n", hexkey(auth))
t.Logf("-> %v", hexkey(auth))
// receiver reads auth and responds with response
response, remoteRecNonce, remoteInitNonce, remoteRandomPrivKey, remoteInitRandomPubKey, err := respondToHandshake(auth, prv1, pub0s, sessionToken)
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
if err != nil {
t.Errorf("%v", err)
}
fmt.Printf("<- %v\n", hexkey(response))
t.Logf("<- %v\n", hexkey(response))
// initiator reads receiver's response and the key exchange completes
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
if err != nil {
t.Errorf("%v", err)
t.Errorf("completeHandshake error: %v", err)
}
// now both parties should have the same session parameters
initSessionToken, initSecretRW, err := newSession(true, initNonce, recNonce, auth, randomPrivKey, remoteRandomPubKey)
initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
if err != nil {
t.Errorf("%v", err)
t.Errorf("newSession error: %v", err)
}
recSessionToken, recSecretRW, err := newSession(false, remoteInitNonce, remoteRecNonce, auth, remoteRandomPrivKey, remoteInitRandomPubKey)
recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
if err != nil {
t.Errorf("%v", err)
t.Errorf("newSession error: %v", err)
}
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
@ -173,76 +130,38 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
if !bytes.Equal(initSessionToken, recSessionToken) {
t.Errorf("session tokens do not match")
}
// aesSecret, macSecret, egressMac, ingressMac
if !bytes.Equal(initSecretRW.aesSecret, recSecretRW.aesSecret) {
t.Errorf("AES secrets do not match")
}
if !bytes.Equal(initSecretRW.macSecret, recSecretRW.macSecret) {
t.Errorf("macSecrets do not match")
}
if !bytes.Equal(initSecretRW.egressMac, recSecretRW.ingressMac) {
t.Errorf("initiator's egressMac do not match receiver's ingressMac")
}
if !bytes.Equal(initSecretRW.ingressMac, recSecretRW.egressMac) {
t.Errorf("initiator's inressMac do not match receiver's egressMac")
}
}
func TestPeersHandshake(t *testing.T) {
func TestHandshake(t *testing.T) {
defer testlog(t).detach()
var err error
// var sessionToken []byte
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv0, _ := crypto.GenerateKey()
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
pub0s, _ := exportPublicKey(&prv0.PublicKey)
pub1s, _ := exportPublicKey(&prv1.PublicKey)
rw0, rw1 := net.Pipe()
tokens := make(chan []byte)
prv0s := crypto.FromECDSA(prv0)
pub0s := crypto.FromECDSAPub(pub0)
prv1s := crypto.FromECDSA(prv1)
pub1s := crypto.FromECDSAPub(pub1)
conn1, conn2 := net.Pipe()
initiator := newPeer(conn1, []Protocol{}, nil)
receiver := newPeer(conn2, []Protocol{}, nil)
initiator.dialAddr = &peerAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: pub1s[1:]}
initiator.privateKey = prv0s
// this is cheating. identity of initiator/dialler not available to listener/receiver
// its public key should be looked up based on IP address
receiver.identity = &peerId{nil, pub0s}
receiver.privateKey = prv1s
initiator.pubkeyHook = func(*peerAddr) error { return nil }
receiver.pubkeyHook = func(*peerAddr) error { return nil }
initiator.cryptoHandshake = true
receiver.cryptoHandshake = true
errc0 := make(chan error, 1)
errc1 := make(chan error, 1)
go func() {
_, err := initiator.loop()
errc0 <- err
token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
if err != nil {
t.Errorf("outbound side error: %v", err)
}
tokens <- token
}()
go func() {
_, err := receiver.loop()
errc1 <- err
token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
if err != nil {
t.Errorf("inbound side error: %v", err)
}
if !bytes.Equal(remotePubkey, pub0s) {
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
}
tokens <- token
}()
ready := make(chan bool)
go func() {
<-initiator.cryptoReady
<-receiver.cryptoReady
close(ready)
}()
timeout := time.After(10 * time.Second)
select {
case <-ready:
case <-timeout:
t.Errorf("crypto handshake hanging for too long")
case err = <-errc0:
t.Errorf("peer 0 quit with error: %v", err)
case err = <-errc1:
t.Errorf("peer 1 quit with error: %v", err)
t1, t2 := <-tokens, <-tokens
if !bytes.Equal(t1, t2) {
t.Error("session token mismatch")
}
}

View File

@ -1,6 +1,7 @@
package p2p
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
@ -8,7 +9,10 @@ import (
"io"
"io/ioutil"
"math/big"
"net"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp"
@ -74,11 +78,14 @@ type MsgWriter interface {
// WriteMsg sends a message. It will block until the message's
// Payload has been consumed by the other end.
//
// Note that messages can be sent only once.
// Note that messages can be sent only once because their
// payload reader is drained.
WriteMsg(Msg) error
}
// MsgReadWriter provides reading and writing of encoded messages.
// Implementations should ensure that ReadMsg and WriteMsg can be
// called simultaneously from multiple goroutines.
type MsgReadWriter interface {
MsgReader
MsgWriter
@ -90,8 +97,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...))
}
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
// As required by the interface, ReadMsg and WriteMsg can be called from
// multiple goroutines.
type frameRW struct {
net.Conn // make Conn methods available. be careful.
bufconn *bufio.ReadWriter
// this channel is used to 'lend' bufconn to a caller of ReadMsg
// until the message payload has been consumed. the channel
// receives a value when EOF is reached on the payload, unblocking
// a pending call to ReadMsg.
rsync chan struct{}
// this mutex guards writes to bufconn.
writeMu sync.Mutex
}
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
rsync := make(chan struct{}, 1)
rsync <- struct{}{}
return &frameRW{
Conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
rsync: rsync,
}
}
var magicToken = []byte{34, 64, 8, 145}
func (rw *frameRW) WriteMsg(msg Msg) error {
rw.writeMu.Lock()
defer rw.writeMu.Unlock()
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
if err := writeMsg(rw.bufconn, msg); err != nil {
return err
}
return rw.bufconn.Flush()
}
func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code))
@ -120,12 +164,16 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...)
}
// readMsg reads a message header from r.
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
func (rw *frameRW) ReadMsg() (msg Msg, err error) {
<-rw.rsync // wait until bufconn is ours
// this read timeout applies also to the payload.
// TODO: proper read timeout
rw.SetReadDeadline(time.Now().Add(msgReadTimeout))
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
if _, err = io.ReadFull(rw.bufconn, start); err != nil {
return msg, newPeerError(errRead, "%v", err)
}
if !bytes.HasPrefix(start, magicToken) {
@ -134,17 +182,33 @@ func readMsg(r rlp.ByteReader) (msg Msg, err error) {
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
posr := &postrack{r, 0}
posr := &postrack{rw.bufconn, 0}
s := rlp.NewStream(posr)
if _, err := s.List(); err != nil {
return msg, err
}
code, err := s.Uint()
msg.Code, err = s.Uint()
if err != nil {
return msg, err
}
payloadsize := size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
msg.Size = size - posr.p
if msg.Size <= wholePayloadSize {
// msg is small, read all of it and move on to the next message.
pbuf := make([]byte, msg.Size)
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
return msg, err
}
rw.rsync <- struct{}{} // bufconn is available again
msg.Payload = bytes.NewReader(pbuf)
} else {
// lend bufconn to the caller until it has
// consumed the payload. eofSignal will send a value
// on rw.rsync when EOF is reached.
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
msg.Payload = pr
}
return msg, nil
}
// postrack wraps an rlp.ByteReader with a position counter.
@ -167,6 +231,39 @@ func (r *postrack) ReadByte() (byte, error) {
return b, err
}
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
type eofSignal struct {
wrapped io.Reader
count uint32 // number of bytes left
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
if r.count == 0 {
if r.eof != nil {
r.eof <- struct{}{}
r.eof = nil
}
return 0, io.EOF
}
max := len(buf)
if int(r.count) < len(buf) {
max = int(r.count)
}
n, err := r.wrapped.Read(buf[:max])
r.count -= uint32(n)
if (err != nil || r.count == 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}
// MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter.
@ -198,7 +295,7 @@ type MsgPipeRW struct {
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1)
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
select {
case p.w <- msg:
if msg.Size > 0 {

View File

@ -3,12 +3,11 @@ package p2p
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"runtime"
"testing"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
func TestNewMsg(t *testing.T) {
@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
}
}
func TestEncodeDecodeMsg(t *testing.T) {
msg := NewMsg(3, 1, "000")
buf := new(bytes.Buffer)
if err := writeMsg(buf, msg); err != nil {
t.Fatalf("encodeMsg error: %v", err)
}
// t.Logf("encoded: %x", buf.Bytes())
// func TestEncodeDecodeMsg(t *testing.T) {
// msg := NewMsg(3, 1, "000")
// buf := new(bytes.Buffer)
// if err := writeMsg(buf, msg); err != nil {
// t.Fatalf("encodeMsg error: %v", err)
// }
// // t.Logf("encoded: %x", buf.Bytes())
decmsg, err := readMsg(buf)
if err != nil {
t.Fatalf("readMsg error: %v", err)
}
if decmsg.Code != 3 {
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
}
if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
}
// decmsg, err := readMsg(buf)
// if err != nil {
// t.Fatalf("readMsg error: %v", err)
// }
// if decmsg.Code != 3 {
// t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
// }
// if decmsg.Size != 5 {
// t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
// }
var data struct {
I uint
S string
}
if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err)
}
if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
}
if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
}
}
// var data struct {
// I uint
// S string
// }
// if err := decmsg.Decode(&data); err != nil {
// t.Fatalf("Decode error: %v", err)
// }
// if data.I != 1 {
// t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
// }
// if data.S != "000" {
// t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
// }
// }
func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// func TestDecodeRealMsg(t *testing.T) {
// data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
// msg, err := readMsg(bytes.NewReader(data))
// if err != nil {
// t.Fatalf("unexpected error: %v", err)
// }
if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
}
}
// if msg.Code != 0 {
// t.Errorf("incorrect code %d, want %d", msg.Code, 0)
// }
// }
func ExampleMsgPipe() {
rw1, rw2 := MsgPipe()
@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
go rw1.Close()
}
}
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}

View File

@ -1,10 +1,6 @@
package p2p
import (
"bufio"
"bytes"
"crypto/ecdsa"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
@ -13,179 +9,118 @@ import (
"sync"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
)
// peerAddr is the structure of a peer list element.
// It is also a valid net.Addr.
type peerAddr struct {
IP net.IP
Port uint64
Pubkey []byte // optional
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a message
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
disconnectGracePeriod = 2 * time.Second
)
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the RLP structure of the protocol handshake.
type handshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
NodeID discover.NodeID
}
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
n := addr.Network()
if n != "tcp" && n != "tcp4" && n != "tcp6" {
// for testing with non-TCP
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
}
ta := addr.(*net.TCPAddr)
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
}
func (d peerAddr) Network() string {
if d.IP.To4() != nil {
return "tcp4"
} else {
return "tcp6"
}
}
func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port)
}
func (d *peerAddr) RlpData() interface{} {
return []interface{}{string(d.IP), d.Port, d.Pubkey}
}
// Peer represents a remote peer.
// Peer represents a connected remote node.
type Peer struct {
// Peers have all the log methods.
// Use them to display messages related to the peer.
*logger.Logger
infolock sync.Mutex
identity ClientIdentity
infoMu sync.Mutex
name string
caps []Cap
listenAddr *peerAddr // what remote peer is listening on
dialAddr *peerAddr // non-nil if dialing
// The mutex protects the connection
// so only one protocol can write at a time.
writeMu sync.Mutex
conn net.Conn
bufconn *bufio.ReadWriter
ourID, remoteID *discover.NodeID
ourName string
rw *frameRW
// These fields maintain the running protocols.
protocols []Protocol
runBaseProtocol bool // for testing
cryptoHandshake bool // for testing
cryptoReady chan struct{}
privateKey []byte
runlock sync.RWMutex // protects running
running map[string]*proto
protocolHandshakeEnabled bool
protoWG sync.WaitGroup
protoErr chan error
closed chan struct{}
disc chan DiscReason
activity event.TypeMux // for activity events
slot int // index into Server peer list
// These fields are kept so base protocol can access them.
// TODO: this should be one or more interfaces
ourID ClientIdentity // client id of the Server
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
newPeerAddr chan<- *peerAddr // tell server about received peers
otherPeers func() []*Peer // should return the list of all peers
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
}
// NewPeer returns a peer for testing purposes.
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
conn, _ := net.Pipe()
peer := newPeer(conn, nil, nil)
peer.setHandshakeInfo(id, nil, caps)
close(peer.closed)
peer := newPeer(conn, nil, "", nil, &id)
peer.setHandshakeInfo(name, caps)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
p := newPeer(conn, server.Protocols, dialAddr)
p.ourID = server.Identity
p.newPeerAddr = server.peerConnect
p.otherPeers = server.Peers
p.pubkeyHook = server.verifyPeer
p.runBaseProtocol = true
// laddr can be updated concurrently by NAT traversal.
// newServerPeer must be called with the server lock held.
if server.laddr != nil {
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
}
return p
// ID returns the node's public key.
func (p *Peer) ID() discover.NodeID {
return *p.remoteID
}
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
p := &Peer{
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
conn: conn,
dialAddr: dialAddr,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
cryptoReady: make(chan struct{}),
}
return p
}
// Identity returns the client identity of the remote peer. The
// identity can be nil if the peer has not yet completed the
// handshake.
func (p *Peer) Identity() ClientIdentity {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.identity
}
func (self *Peer) Pubkey() (pubkey []byte) {
self.infolock.Lock()
defer self.infolock.Unlock()
switch {
case self.identity != nil:
pubkey = self.identity.Pubkey()[1:]
case self.dialAddr != nil:
pubkey = self.dialAddr.Pubkey
case self.listenAddr != nil:
pubkey = self.listenAddr.Pubkey
}
return
// Name returns the node name that the remote node advertised.
func (p *Peer) Name() string {
// this needs a lock because the information is part of the
// protocol handshake.
p.infoMu.Lock()
name := p.name
p.infoMu.Unlock()
return name
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.caps
}
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
p.infolock.Lock()
p.identity = id
p.listenAddr = laddr
p.caps = caps
p.infolock.Unlock()
// this needs a lock because the information is part of the
// protocol handshake.
p.infoMu.Lock()
caps := p.caps
p.infoMu.Unlock()
return caps
}
// RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr()
return p.rw.RemoteAddr()
}
// LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr()
return p.rw.LocalAddr()
}
// Disconnect terminates the peer connection with the given reason.
@ -199,201 +134,167 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
kind := "inbound"
p.infolock.Lock()
if p.dialAddr != nil {
kind = "outbound"
}
p.infolock.Unlock()
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
}
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a message
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
return &Peer{
Logger: logger.NewLogger(logtag),
rw: newFrameRW(conn, msgWriteTimeout),
ourID: ourID,
ourName: ourName,
remoteID: remoteID,
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
}
var (
inactivityTimeout = 2 * time.Second
disconnectGracePeriod = 2 * time.Second
)
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
p.infoMu.Lock()
p.name = name
p.caps = caps
p.infoMu.Unlock()
}
func (p *Peer) loop() (reason DiscReason, err error) {
defer p.activity.Stop()
func (p *Peer) run() DiscReason {
var readErr = make(chan error, 1)
defer p.closeProtocols()
defer close(p.closed)
defer p.conn.Close()
defer p.rw.Close()
var readLoop func(chan<- Msg, chan<- error, <-chan bool)
if p.cryptoHandshake {
if readLoop, err = p.handleCryptoHandshake(); err != nil {
// from here on everything can be encrypted, authenticated
return DiscProtocolError, err // no graceful disconnect
// start the read loop
go func() { readErr <- p.readLoop() }()
if p.protocolHandshakeEnabled {
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
p.DebugDetailf("Protocol handshake error: %v\n", err)
return DiscProtocolError
}
} else {
readLoop = p.readLoop
}
// read loop
readMsg := make(chan Msg)
readErr := make(chan error)
readNext := make(chan bool, 1)
protoDone := make(chan struct{}, 1)
go readLoop(readMsg, readErr, readNext)
readNext <- true
close(p.cryptoReady)
if p.runBaseProtocol {
p.startBaseProtocol()
}
loop:
for {
// wait for an error or disconnect
var reason DiscReason
select {
case msg := <-readMsg:
// a new message has arrived.
var wait bool
if wait, err = p.dispatch(msg, protoDone); err != nil {
p.Errorf("msg dispatch error: %v\n", err)
reason = discReasonForError(err)
break loop
}
if !wait {
// Msg has already been read completely, continue with next message.
readNext <- true
}
p.activity.Post(time.Now())
case <-protoDone:
// protocol has consumed the message payload,
// we can continue reading from the socket.
readNext <- true
case err := <-readErr:
// read failed. there is no need to run the
// polite disconnect sequence because the connection
// is probably dead anyway.
// TODO: handle write errors as well
return DiscNetworkError, err
case err = <-p.protoErr:
// We rely on protocols to abort if there is a write error. It
// might be more robust to handle them here as well.
p.DebugDetailf("Read error: %v\n", err)
reason = DiscNetworkError
case err := <-p.protoErr:
reason = discReasonForError(err)
break loop
case reason = <-p.disc:
break loop
}
if reason != DiscNetworkError {
p.politeDisconnect(reason)
}
p.Debugf("Disconnected: %v\n", reason)
return reason
}
// wait for read loop to return.
close(readNext)
<-readErr
// tell the remote end to disconnect
func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{})
go func() {
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
io.Copy(ioutil.Discard, p.conn)
// send reason
EncodeMsg(p.rw, discMsg, uint(reason))
// discard any data that might arrive
io.Copy(ioutil.Discard, p.rw)
close(done)
}()
select {
case <-done:
case <-time.After(disconnectGracePeriod):
}
return reason, err
}
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
for _ = range unblock {
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
if msg, err := readMsg(p.bufconn); err != nil {
errc <- err
} else {
msgc <- msg
func (p *Peer) readLoop() error {
if p.protocolHandshakeEnabled {
if err := readProtocolHandshake(p, p.rw); err != nil {
return err
}
}
close(errc)
for {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
if err = p.handle(msg); err != nil {
return err
}
}
return nil
}
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
func (p *Peer) handle(msg Msg) error {
switch {
case msg.Code == pingMsg:
msg.Discard()
go EncodeMsg(p.rw, pongMsg)
case msg.Code == discMsg:
var reason DiscReason
// no need to discard or for error checking, we'll close the
// connection after this.
rlp.Decode(msg.Payload, &reason)
p.Disconnect(DiscRequested)
return discRequestedError(reason)
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
return msg.Discard()
default:
// it's a subprotocol message
proto, err := p.getProto(msg.Code)
if err != nil {
return false, err
return fmt.Errorf("msg code out of range: %v", msg.Code)
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
proto.in <- msg
}
return nil
}
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
// read and handle remote handshake
msg, err := rw.ReadMsg()
if err != nil {
return false, err
return err
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
wait = true
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
msg.Payload = pr
proto.in <- msg
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
}
return wait, nil
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if hs.NodeID == *p.remoteID {
return newPeerError(errPubkeyForbidden, "node ID mismatch")
}
// TODO: remove Caps with empty name
p.setHandshakeInfo(hs.Name, hs.Caps)
p.startSubprotocols(hs.Caps)
return nil
}
type readLoop func(chan<- Msg, chan<- error, <-chan bool)
func (p *Peer) PrivateKey() (prv *ecdsa.PrivateKey, err error) {
if prv = crypto.ToECDSA(p.privateKey); prv == nil {
err = fmt.Errorf("invalid private key")
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
var caps []interface{}
for _, proto := range ps {
caps = append(caps, proto.cap())
}
return
}
func (p *Peer) handleCryptoHandshake() (loop readLoop, err error) {
// cryptoId is just created for the lifecycle of the handshake
// it is survived by an encrypted readwriter
var initiator bool
var sessionToken []byte
sessionToken = make([]byte, shaLen)
if _, err = rand.Read(sessionToken); err != nil {
return
}
if p.dialAddr != nil { // this should have its own method Outgoing() bool
initiator = true
}
// run on peer
// this bit handles the handshake and creates a secure communications channel with
// var rw *secretRW
var prvKey *ecdsa.PrivateKey
if prvKey, err = p.PrivateKey(); err != nil {
err = fmt.Errorf("unable to access private key for client: %v", err)
return
}
// initialise a new secure session
if sessionToken, _, err = NewSecureSession(p.conn, prvKey, p.Pubkey(), sessionToken, initiator); err != nil {
p.Debugf("unable to setup secure session: %v", err)
return
}
loop = func(msg chan<- Msg, err chan<- error, next <-chan bool) {
// this is the readloop :)
}
return
}
func (p *Peer) startBaseProtocol() {
p.runlock.Lock()
defer p.runlock.Unlock()
p.running[""] = p.startProto(0, Protocol{
Length: baseProtocolLength,
Run: runBaseProtocol,
})
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
}
// startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps))
p.runlock.Lock()
defer p.runlock.Unlock()
offset := baseProtocolLength
@ -412,20 +313,22 @@ outer:
}
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
rw := &proto{
name: impl.Name,
in: make(chan Msg),
offset: offset,
maxcode: impl.Length,
peer: p,
w: p.rw,
}
p.protoWG.Add(1)
go func() {
err := impl.Run(p, rw)
if err == nil {
p.Infof("protocol %q returned", impl.Name)
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
err = newPeerError(errMisc, "protocol returned")
} else {
p.Errorf("protocol %q error: %v\n", impl.Name, err)
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
}
select {
case p.protoErr <- err:
@ -459,6 +362,7 @@ func (p *Peer) closeProtocols() {
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
// this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock()
proto, ok := p.running[protoName]
@ -470,25 +374,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return p.writeMsg(msg, msgWriteTimeout)
}
// writeMsg writes a message to the connection.
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
p.conn.SetWriteDeadline(time.Now().Add(timeout))
if err := writeMsg(p.bufconn, msg); err != nil {
return newPeerError(errWrite, "%v", err)
}
return p.bufconn.Flush()
return p.rw.WriteMsg(msg)
}
type proto struct {
name string
in chan Msg
maxcode, offset uint64
peer *Peer
w MsgWriter
}
func (rw *proto) WriteMsg(msg Msg) error {
@ -496,11 +389,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.peer.writeMsg(msg, msgWriteTimeout)
}
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data...))
return rw.w.WriteMsg(msg)
}
func (rw *proto) ReadMsg() (Msg, error) {
@ -511,26 +400,3 @@ func (rw *proto) ReadMsg() (Msg, error) {
msg.Code -= rw.offset
return msg, nil
}
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
//
type eofSignal struct {
wrapped io.Reader
count int64
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
r.count -= int64(n)
if (err != nil || r.count <= 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}

View File

@ -12,7 +12,6 @@ const (
errInvalidMsgCode
errInvalidMsg
errP2PVersionMismatch
errPubkeyMissing
errPubkeyInvalid
errPubkeyForbidden
errProtocolBreach
@ -22,20 +21,19 @@ const (
)
var errorToString = map[int]string{
errMagicTokenMismatch: "Magic token mismatch",
errRead: "Read error",
errWrite: "Write error",
errMisc: "Misc error",
errInvalidMsgCode: "Invalid message code",
errInvalidMsg: "Invalid message",
errMagicTokenMismatch: "magic token mismatch",
errRead: "read error",
errWrite: "write error",
errMisc: "misc error",
errInvalidMsgCode: "invalid message code",
errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyMissing: "Public key missing",
errPubkeyInvalid: "Public key invalid",
errPubkeyForbidden: "Public key forbidden",
errProtocolBreach: "Protocol Breach",
errPingTimeout: "Ping timeout",
errInvalidNetworkId: "Invalid network id",
errInvalidProtocolVersion: "Invalid protocol version",
errPubkeyInvalid: "public key invalid",
errPubkeyForbidden: "public key forbidden",
errProtocolBreach: "protocol Breach",
errPingTimeout: "ping timeout",
errInvalidNetworkId: "invalid network id",
errInvalidProtocolVersion: "invalid protocol version",
}
type peerError struct {
@ -62,22 +60,22 @@ func (self *peerError) Error() string {
type DiscReason byte
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
DiscRequested DiscReason = iota
DiscNetworkError
DiscProtocolError
DiscUselessPeer
DiscTooManyPeers
DiscAlreadyConnected
DiscIncompatibleVersion
DiscInvalidIdentity
DiscQuitting
DiscUnexpectedIdentity
DiscSelf
DiscReadTimeout
DiscSubprotocolError
)
var discReasonToString = [DiscSubprotocolError + 1]string{
var discReasonToString = [...]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
case errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer

View File

@ -1,15 +1,17 @@
package p2p
import (
"bufio"
"bytes"
"encoding/hex"
"io"
"fmt"
"io/ioutil"
"net"
"reflect"
"sort"
"testing"
"time"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
)
var discard = Protocol{
@ -28,17 +30,13 @@ var discard = Protocol{
},
}
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
conn1, conn2 := net.Pipe()
peer := newPeer(conn1, protos, nil)
peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1)
go func() {
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
peer.protocolHandshakeEnabled = handshake
errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }()
return newFrameRW(conn2, msgWriteTimeout), peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) {
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
if err := expectMsg(rw, 2, []uint{1}); err != nil {
t.Error(err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
if err := expectMsg(rw, 3, []uint{2}); err != nil {
t.Error(err)
}
data, err := ioutil.ReadAll(msg.Payload)
if err != nil {
t.Errorf("payload read error: %v", err)
}
expdata, _ := hex.DecodeString("0183303030")
if !bytes.Equal(expdata, data) {
t.Errorf("incorrect msg data %x", data)
if err := expectMsg(rw, 4, []uint{3}); err != nil {
t.Error(err)
}
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
rw, peer, errc := testPeer(false, []Protocol{proto})
defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, 1, "000"))
EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2)
EncodeMsg(rw, baseProtocolLength+4, 3)
select {
case <-done:
case err := <-errc:
@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
rw, peer, errc := testPeer(false, []Protocol{proto})
defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
EncodeMsg(rw, 18, make([]byte, msgsize))
select {
case <-done:
case err := <-errc:
@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
rw, peer, _ := testPeer(false, []Protocol{proto})
defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
t.Error(err)
}
}
func TestPeerWrite(t *testing.T) {
func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
rw, peer, peerErr := testPeer(false, []Protocol{discard})
defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) {
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
if err := expectMsg(rw, 16, nil); err != nil {
t.Error()
}
msg.Discard()
close(read)
}()
// test succcessful write
// test successful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
@ -198,104 +176,152 @@ func TestPeerWrite(t *testing.T) {
}
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
func TestPeerPing(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
rw, _, _ := testPeer(false, nil)
defer rw.Close()
if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, pongMsg, nil); err != nil {
t.Error(err)
}
}
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
func TestPeerDisconnect(t *testing.T) {
defer testlog(t).detach()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
rw, _, disc := testPeer(false, nil)
defer rw.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
t.Error(err)
}
rw.Close() // make test end faster
if reason := <-disc; reason != DiscRequested {
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
}
}
func TestPeerHandshake(t *testing.T) {
defer testlog(t).detach()
// remote has two matching protocols: a and c
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
remoteID := randomID()
remote.ourID = &remoteID
remote.ourName = "remote peer"
start := make(chan string)
stop := make(chan struct{})
run := func(p *Peer, rw MsgReadWriter) error {
name := rw.(*proto).name
if name != "a" && name != "c" {
t.Errorf("protocol %q should not be started", name)
} else {
start <- name
}
<-stop
return nil
}
protocols := []Protocol{
{Name: "a", Version: 1, Length: 1, Run: run},
{Name: "b", Version: 2, Length: 1, Run: run},
{Name: "c", Version: 3, Length: 1, Run: run},
{Name: "d", Version: 4, Length: 1, Run: run},
}
rw, p, disc := testPeer(true, protocols)
p.remoteID = remote.ourID
defer rw.Close()
// run the handshake
remoteProtocols := []Protocol{protocols[0], protocols[2]}
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
t.Fatalf("handshake write error: %v", err)
}
if err := readProtocolHandshake(remote, rw); err != nil {
t.Fatalf("handshake read error: %v", err)
}
// check that all protocols have been started
var started []string
for i := 0; i < 2; i++ {
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
case name := <-start:
started = append(started, name)
case <-time.After(100 * time.Millisecond):
}
}
sort.Strings(started)
if !reflect.DeepEqual(started, []string{"a", "c"}) {
t.Errorf("wrong protocols started: %v", started)
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
// check that metadata has been set
if p.ID() != remoteID {
t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
}
if p.Name() != remote.ourName {
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
}
close(stop)
t.Logf("disc reason: %v", <-disc)
}
func TestNewPeer(t *testing.T) {
name := "nodename"
caps := []Cap{{"foo", 2}, {"bar", 3}}
id := &peerId{}
p := NewPeer(id, caps)
id := randomID()
p := NewPeer(id, name, caps)
if p.ID() != id {
t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
}
if p.Name() != name {
t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
}
if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
}
if p.Identity() != id {
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
}
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
p.Disconnect(DiscAlreadyConnected) // Should not hang
}
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// expectMsg reads a message from r and verifies that its
// code and encoded RLP content match the provided values.
// If content is nil, the payload is discarded and not verified.
func expectMsg(r MsgReader, code uint64, content interface{}) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
}
if content == nil {
return msg.Discard()
} else {
contentEnc, err := rlp.EncodeToBytes(content)
if err != nil {
panic("content encode error: " + err.Error())
}
// skip over list header in encoded value. this is temporary.
contentEncR := bytes.NewReader(contentEnc)
if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
panic("content must encode as RLP list")
}
contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 8 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
if !bytes.Equal(actualContent, contentEnc) {
return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
}
}
return nil
}

View File

@ -1,10 +1,5 @@
package p2p
import (
"bytes"
"time"
)
// Protocol represents a P2P subprotocol implementation.
type Protocol struct {
// Name should contain the official protocol name,
@ -32,42 +27,6 @@ func (p Protocol) cap() Cap {
return Cap{p.Name, p.Version}
}
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the structure of a handshake list.
type handshake struct {
Version uint64
ID string
Caps []Cap
ListenPort uint64
NodeID []byte
}
func (h *handshake) String() string {
return h.ID
}
func (h *handshake) Pubkey() []byte {
return h.NodeID
}
func (h *handshake) PrivKey() []byte {
return nil
}
// Cap is the structure of a peer capability.
type Cap struct {
Name string
@ -83,210 +42,3 @@ type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
}
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer}
errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err
}
// run main loop
go func() {
for {
if err := bp.handle(rw); err != nil {
errc <- err
break
}
}
}()
return bp.loop(errc)
}
var pingTimeout = 2 * time.Second
func (bp *baseProtocol) loop(quit <-chan error) error {
ping := time.NewTimer(pingTimeout)
activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := EncodeMsg(bp.rw, getPeersMsg)
for err == nil {
select {
case err = <-quit:
return err
case <-getPeersTick.C:
err = EncodeMsg(bp.rw, getPeersMsg)
case event := <-activity.Chan():
ping.Reset(pingTimeout)
lastActive = event.(time.Time)
case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = EncodeMsg(bp.rw, pingMsg)
}
}
}
return err
}
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
// make sure that the payload has been fully consumed
defer msg.Discard()
switch msg.Code {
case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
var reason [1]DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
return discRequestedError(reason[0])
case pingMsg:
return EncodeMsg(bp.rw, pongMsg)
case pongMsg:
case getPeersMsg:
peers := bp.peerList()
// this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available.
// this means that would need to send a response later when
// new peers become available.
//
// TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 {
return EncodeMsg(bp.rw, peersMsg, peers...)
}
case peersMsg:
var peers []*peerAddr
if err := msg.Decode(&peers); err != nil {
return err
}
for _, addr := range peers {
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
}
default:
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
}
return nil
}
func (bp *baseProtocol) readHandshake() error {
// read and handle remote handshake
msg, err := bp.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if len(hs.NodeID) == 0 {
return newPeerError(errPubkeyMissing, "")
}
if len(hs.NodeID) != 64 {
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
}
if da := bp.peer.dialAddr; da != nil {
// verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
// TODO: remove Caps with empty name
var addr *peerAddr
if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
}
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
func (bp *baseProtocol) peerList() []interface{} {
peers := bp.peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -1,167 +0,0 @@
package p2p
import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"github.com/ethereum/go-ethereum/crypto"
)
type peerId struct {
privKey, pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func (self *peerId) PrivKey() (privKey []byte) {
privKey = self.privKey
if len(privKey) == 0 {
privKey = crypto.GenerateNewKeyPair().PublicKey
self.privKey = privKey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
peerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
defer rw1.Close()
wg := new(sync.WaitGroup)
// run matcher, close pipe when addresses have arrived
numPeers := len(peerList) + 1
addrChan := make(chan *peerAddr)
wg.Add(1)
go func() {
i := 0
for got := range addrChan {
var want *peerAddr
switch {
case i < len(peerList):
want = peerList[i]
case i == len(peerList):
want = listenAddr // listenAddr should be the last thing sent
}
t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %+v, want %+v", got, want)
}
i++
if i == numPeers {
break
}
}
if i != numPeers {
t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
}
rw1.Close()
wg.Done()
}()
// run first peer (in background)
peer1 := newTestPeer()
peer1.ourListenAddr = listenAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(peerList))
for i, addr := range peerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
wg.Add(1)
go func() {
runBaseProtocol(peer1, rw1)
wg.Done()
}()
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
// terminate matcher
close(addrChan)
wg.Wait()
}
func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(&peerId{}, nil)
peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := expectMsg(rw2, handshakeMsg); err != nil {
t.Error(err)
}
err := EncodeMsg(rw2, handshakeMsg,
baseProtocolVersion,
"",
[]interface{}{},
0,
make([]byte, 64),
)
if err != nil {
t.Error(err)
}
if err := expectMsg(rw2, getPeersMsg); err != nil {
t.Error(err)
}
if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
t.Error(err)
}
close(done)
}()
if err := runBaseProtocol(peer, rw1); err == nil {
t.Errorf("base protocol returned without error")
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
t.Errorf("base protocol returned wrong error: %v", err)
}
<-done
}
func expectMsg(r MsgReader, code uint64) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if err := msg.Discard(); err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
}
return nil
}

View File

@ -2,37 +2,56 @@ package p2p
import (
"bytes"
"crypto/ecdsa"
"errors"
"fmt"
"io"
"net"
"runtime"
"sync"
"time"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
)
const (
outboundAddressPoolSize = 500
defaultDialTimeout = 10 * time.Second
refreshPeersInterval = 30 * time.Second
portMappingUpdateInterval = 15 * time.Minute
portMappingTimeout = 20 * time.Minute
)
var srvlog = logger.NewLogger("P2P Server")
// MakeName creates a node name that follows the ethereum convention
// for such names. It adds the operation system name and Go runtime version
// the name.
func MakeName(name, version string) string {
return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version())
}
// Server manages all peer connections.
//
// The fields of Server are used as configuration parameters.
// You should set them before starting the Server. Fields may not be
// modified while the server is running.
type Server struct {
// This field must be set to a valid client identity.
Identity ClientIdentity
// This field must be set to a valid secp256k1 private key.
PrivateKey *ecdsa.PrivateKey
// MaxPeers is the maximum number of peers that can be
// connected. It must be greater than zero.
MaxPeers int
// Name sets the node name of this server.
// Use MakeName to create a name that follows existing conventions.
Name string
// Bootstrap nodes are used to establish connectivity
// with the rest of the network.
BootstrapNodes []discover.Node
// Protocols should contain the protocols supported
// by the server. Matching protocols are launched for
// each peer.
@ -62,22 +81,23 @@ type Server struct {
// If NoDial is true, the server will not dial any peers.
NoDial bool
// Hook for testing. This is useful because we can inhibit
// Hooks for testing. These are useful because we can inhibit
// the whole protocol stack.
newPeerFunc peerFunc
handshakeFunc
newPeerHook
lock sync.RWMutex
running bool
listener net.Listener
laddr *net.TCPAddr // real listen addr
peers []*Peer
peerSlots chan int
peerCount int
peers map[discover.NodeID]*Peer
ntab *discover.Table
quit chan struct{}
wg sync.WaitGroup
peerConnect chan *peerAddr
peerDisconnect chan *Peer
loopWG sync.WaitGroup // {dial,listen,nat}Loop
peerWG sync.WaitGroup // active peer goroutines
peerConnect chan *discover.Node
}
// NAT is implemented by NAT traversal methods.
@ -90,7 +110,8 @@ type NAT interface {
String() string
}
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
type newPeerHook func(*Peer)
// Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) {
@ -107,18 +128,15 @@ func (srv *Server) Peers() (peers []*Peer) {
// PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int {
srv.lock.RLock()
defer srv.lock.RUnlock()
return srv.peerCount
n := len(srv.peers)
srv.lock.RUnlock()
return n
}
// SuggestPeer injects an address into the outbound address pool.
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
addr := &peerAddr{ip, uint64(port), nodeID}
select {
case srv.peerConnect <- addr:
default: // don't block
srvlog.Warnf("peer suggestion %v ignored", addr)
}
// SuggestPeer creates a connection to the given Node if it
// is not already connected.
func (srv *Server) SuggestPeer(ip net.IP, port int, id discover.NodeID) {
srv.peerConnect <- &discover.Node{ID: id, Addr: &net.UDPAddr{IP: ip, Port: port}}
}
// Broadcast sends an RLP-encoded message to all connected peers.
@ -152,47 +170,47 @@ func (srv *Server) Start() (err error) {
}
srvlog.Infoln("Starting Server")
// initialize fields
if srv.Identity == nil {
return fmt.Errorf("Server.Identity must be set to a non-nil identity")
// initialize all the fields
if srv.PrivateKey == nil {
return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
}
if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0")
}
srv.quit = make(chan struct{})
srv.peers = make([]*Peer, srv.MaxPeers)
srv.peerSlots = make(chan int, srv.MaxPeers)
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
srv.peerDisconnect = make(chan *Peer)
if srv.newPeerFunc == nil {
srv.newPeerFunc = newServerPeer
srv.peers = make(map[discover.NodeID]*Peer)
srv.peerConnect = make(chan *discover.Node)
if srv.handshakeFunc == nil {
srv.handshakeFunc = encHandshake
}
if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist()
}
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
}
}
// dial stuff
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr)
if err != nil {
return err
}
srv.ntab = dt
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if !srv.NoDial {
srv.wg.Add(1)
srv.loopWG.Add(1)
go srv.dialLoop()
}
if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
}
// make all slots available
for i := range srv.peers {
srv.peerSlots <- i
}
// note: discLoop is not part of WaitGroup
go srv.discLoop()
srv.running = true
return nil
}
@ -205,10 +223,10 @@ func (srv *Server) startListening() error {
srv.ListenAddr = listener.Addr().String()
srv.laddr = listener.Addr().(*net.TCPAddr)
srv.listener = listener
srv.wg.Add(1)
srv.loopWG.Add(1)
go srv.listenLoop()
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
srv.wg.Add(1)
srv.loopWG.Add(1)
go srv.natLoop(srv.laddr.Port)
}
return nil
@ -225,57 +243,41 @@ func (srv *Server) Stop() {
srv.running = false
srv.lock.Unlock()
srvlog.Infoln("Stopping server")
srvlog.Infoln("Stopping Server")
srv.ntab.Close()
if srv.listener != nil {
// this unblocks listener Accept
srv.listener.Close()
}
close(srv.quit)
for _, peer := range srv.Peers() {
srv.loopWG.Wait()
// No new peers can be added at this point because dialLoop and
// listenLoop are down. It is safe to call peerWG.Wait because
// peerWG.Add is not called outside of those loops.
for _, peer := range srv.peers {
peer.Disconnect(DiscQuitting)
}
srv.wg.Wait()
// wait till they actually disconnect
// this is checked by claiming all peerSlots.
// slots become available as the peers disconnect.
for i := 0; i < cap(srv.peerSlots); i++ {
<-srv.peerSlots
}
// terminate discLoop
close(srv.peerDisconnect)
}
func (srv *Server) discLoop() {
for peer := range srv.peerDisconnect {
srv.removePeer(peer)
}
srv.peerWG.Wait()
}
// main loop for adding connections via listening
func (srv *Server) listenLoop() {
defer srv.wg.Done()
defer srv.loopWG.Done()
srvlog.Infoln("Listening on", srv.listener.Addr())
for {
select {
case slot := <-srv.peerSlots:
srvlog.Debugf("grabbed slot %v for listening", slot)
conn, err := srv.listener.Accept()
if err != nil {
srv.peerSlots <- slot
return
}
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
srv.addPeer(conn, nil, slot)
case <-srv.quit:
return
}
srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
srv.peerWG.Add(1)
go srv.startPeer(conn, nil)
}
}
func (srv *Server) natLoop(port int) {
defer srv.wg.Done()
defer srv.loopWG.Done()
for {
srv.updatePortMapping(port)
select {
@ -314,108 +316,131 @@ func (srv *Server) removePortMapping(port int) {
}
func (srv *Server) dialLoop() {
defer srv.wg.Done()
var (
suggest chan *peerAddr
slot *int
slots = srv.peerSlots
)
defer srv.loopWG.Done()
refresh := time.NewTicker(refreshPeersInterval)
defer refresh.Stop()
srv.ntab.Bootstrap(srv.BootstrapNodes)
go srv.findPeers()
dialed := make(chan *discover.Node)
dialing := make(map[discover.NodeID]bool)
// TODO: limit number of active dials
// TODO: ensure only one findPeers goroutine is running
// TODO: pause findPeers when we're at capacity
for {
select {
case i := <-slots:
// we need a peer in slot i, slot reserved
slot = &i
// now we can watch for candidate peers in the next loop
suggest = srv.peerConnect
// do not consume more until candidate peer is found
slots = nil
case <-refresh.C:
case desc := <-suggest:
// candidate peer found, will dial out asyncronously
// if connection fails slot will be released
srvlog.DebugDetailf("dial %v (%v)", desc, *slot)
go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop
slots = srv.peerSlots
// until then we dont care about candidate peers
suggest = nil
go srv.findPeers()
case dest := <-srv.peerConnect:
srv.lock.Lock()
_, isconnected := srv.peers[dest.ID]
srv.lock.Unlock()
if isconnected || dialing[dest.ID] {
continue
}
dialing[dest.ID] = true
srv.peerWG.Add(1)
go func() {
srv.dialNode(dest)
// at this point, the peer has been added
// or discarded. either way, we're not dialing it anymore.
dialed <- dest
}()
case dest := <-dialed:
delete(dialing, dest.ID)
case <-srv.quit:
// give back the currently reserved slot
if slot != nil {
srv.peerSlots <- *slot
}
// TODO: maybe wait for active dials
return
}
}
}
// connect to peer via dial out
func (srv *Server) dialPeer(desc *peerAddr, slot int) {
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
func (srv *Server) dialNode(dest *discover.Node) {
srvlog.Debugf("Dialing %v\n", dest.Addr)
conn, err := srv.Dialer.Dial("tcp", dest.Addr.String())
if err != nil {
srvlog.DebugDetailf("dial error: %v", err)
srv.peerSlots <- slot
return
}
go srv.addPeer(conn, desc, slot)
srv.startPeer(conn, dest)
}
// creates the new peer object and inserts it into its slot
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
func (srv *Server) findPeers() {
far := srv.ntab.Self()
for i := range far {
far[i] = ^far[i]
}
closeToSelf := srv.ntab.Lookup(srv.ntab.Self())
farFromSelf := srv.ntab.Lookup(far)
for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
if i < len(closeToSelf) {
srv.peerConnect <- closeToSelf[i]
}
if i < len(farFromSelf) {
srv.peerConnect <- farFromSelf[i]
}
}
}
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
// TODO: I/O timeout, handle/store session token
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
if err != nil {
conn.Close()
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
return
}
ourID := srv.ntab.Self()
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
if ok, reason := srv.addPeer(remoteID, p); !ok {
p.Disconnect(reason)
return
}
srv.newPeerHook(p)
p.run()
srv.removePeer(p)
}
func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
srv.lock.Lock()
defer srv.lock.Unlock()
if !srv.running {
conn.Close()
srv.peerSlots <- slot // release slot
return nil
switch {
case !srv.running:
return false, DiscQuitting
case len(srv.peers) >= srv.MaxPeers:
return false, DiscTooManyPeers
case srv.peers[id] != nil:
return false, DiscAlreadyConnected
case srv.Blacklist.Exists(id[:]):
return false, DiscUselessPeer
case id == srv.ntab.Self():
return false, DiscSelf
}
peer := srv.newPeerFunc(srv, conn, desc)
peer.slot = slot
srv.peers[slot] = peer
srv.peerCount++
go func() { peer.loop(); srv.peerDisconnect <- peer }()
return peer
srvlog.Debugf("Adding %v\n", p)
srv.peers[id] = p
return true, 0
}
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
func (srv *Server) removePeer(peer *Peer) {
func (srv *Server) removePeer(p *Peer) {
srvlog.Debugf("Removing %v\n", p)
srv.lock.Lock()
defer srv.lock.Unlock()
srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot)
if srv.peers[peer.slot] != peer {
srvlog.Warnln("Invalid peer to remove:", peer)
return
}
// remove from list and index
srv.peerCount--
srv.peers[peer.slot] = nil
// release slot to signal need for a new peer, last!
srv.peerSlots <- peer.slot
delete(srv.peers, *p.remoteID)
srv.lock.Unlock()
srv.peerWG.Done()
}
func (srv *Server) verifyPeer(addr *peerAddr) error {
if srv.Blacklist.Exists(addr.Pubkey) {
return errors.New("blacklisted")
}
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
id := peer.Identity()
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
return errors.New("already connected")
}
}
}
return nil
}
// TODO replace with "Set"
type Blacklist interface {
Get([]byte) (bool, error)
Put([]byte) error

View File

@ -2,19 +2,28 @@ package p2p
import (
"bytes"
"crypto/ecdsa"
"io"
"math/rand"
"net"
"sync"
"testing"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover"
)
func startTestServer(t *testing.T, pf peerFunc) *Server {
func startTestServer(t *testing.T, pf newPeerHook) *Server {
server := &Server{
Identity: &peerId{},
Name: "test",
MaxPeers: 10,
ListenAddr: "127.0.0.1:0",
newPeerFunc: pf,
PrivateKey: newkey(),
newPeerHook: pf,
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
return randomID(), nil, err
},
}
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) {
// start the test server
connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
if conn == nil {
srv := startTestServer(t, func(p *Peer) {
if p == nil {
t.Error("peer func called with nil conn")
}
if dialAddr != nil {
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
connected <- p
})
defer close(connected)
defer srv.Stop()
@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) {
select {
case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
if peer.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr())
peer.LocalAddr(), conn.RemoteAddr())
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) {
func TestServerDial(t *testing.T) {
defer testlog(t).detach()
// run a fake TCP server to handle the connection.
// run a one-shot TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("could not setup listener: %v")
@ -72,41 +76,33 @@ func TestServerDial(t *testing.T) {
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error("acccept error:", err)
t.Error("accept error:", err)
return
}
conn.Close()
accepted <- conn
}()
// start the test server
// start the server
connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
if conn == nil {
t.Error("peer func called with nil conn")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
srv := startTestServer(t, func(p *Peer) { connected <- p })
defer close(connected)
defer srv.Stop()
// tell the server to connect.
connAddr := newPeerAddr(listener.Addr(), nil)
// tell the server to connect
tcpAddr := listener.Addr().(*net.TCPAddr)
connAddr := &discover.Node{Addr: &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}}
srv.peerConnect <- connAddr
select {
case conn := <-accepted:
select {
case peer := <-connected:
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr())
}
if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
peer.dialAddr, connAddr)
peer.RemoteAddr(), conn.LocalAddr())
}
// TODO: validate more fields
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
}
@ -118,16 +114,16 @@ func TestServerDial(t *testing.T) {
func TestServerBroadcast(t *testing.T) {
defer testlog(t).detach()
var connected sync.WaitGroup
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
peer := newPeer(c, []Protocol{discard}, dialAddr)
peer.startSubprotocols([]Cap{discard.cap()})
srv := startTestServer(t, func(p *Peer) {
p.protocols = []Protocol{discard}
p.startSubprotocols([]Cap{discard.cap()})
connected.Done()
return peer
})
defer srv.Stop()
// dial a bunch of conns
// create a few peers
var conns = make([]net.Conn, 8)
connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second)
@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) {
}
}
}
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}
func randomID() (id discover.NodeID) {
for i := range id {
id[i] = byte(rand.Intn(255))
}
return id
}

View File

@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {

View File

@ -1,40 +0,0 @@
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}