Merge branch 'poc8' into develop
This commit is contained in:
		
						commit
						f06543fd06
					
				| @ -5,10 +5,10 @@ import ( | ||||
| 	"runtime" | ||||
| ) | ||||
| 
 | ||||
| // should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
 | ||||
| // ClientIdentity represents the identity of a peer.
 | ||||
| type ClientIdentity interface { | ||||
| 	String() string | ||||
| 	Pubkey() []byte | ||||
| 	String() string // human readable identity
 | ||||
| 	Pubkey() []byte // 512-bit public key
 | ||||
| } | ||||
| 
 | ||||
| type SimpleClientIdentity struct { | ||||
|  | ||||
| @ -1,275 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	// "fmt"
 | ||||
| 	"net" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/ethutil" | ||||
| ) | ||||
| 
 | ||||
| type Connection struct { | ||||
| 	conn net.Conn | ||||
| 	// conn       NetworkConnection
 | ||||
| 	timeout    time.Duration | ||||
| 	in         chan []byte | ||||
| 	out        chan []byte | ||||
| 	err        chan *PeerError | ||||
| 	closingIn  chan chan bool | ||||
| 	closingOut chan chan bool | ||||
| } | ||||
| 
 | ||||
| // const readBufferLength = 2 //for testing
 | ||||
| 
 | ||||
| const readBufferLength = 1440 | ||||
| const partialsQueueSize = 10 | ||||
| const maxPendingQueueSize = 1 | ||||
| const defaultTimeout = 500 | ||||
| 
 | ||||
| var magicToken = []byte{34, 64, 8, 145} | ||||
| 
 | ||||
| func (self *Connection) Open() { | ||||
| 	go self.startRead() | ||||
| 	go self.startWrite() | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) Close() { | ||||
| 	self.closeIn() | ||||
| 	self.closeOut() | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) closeIn() { | ||||
| 	errc := make(chan bool) | ||||
| 	self.closingIn <- errc | ||||
| 	<-errc | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) closeOut() { | ||||
| 	errc := make(chan bool) | ||||
| 	self.closingOut <- errc | ||||
| 	<-errc | ||||
| } | ||||
| 
 | ||||
| func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { | ||||
| 	return &Connection{ | ||||
| 		conn:       conn, | ||||
| 		timeout:    defaultTimeout, | ||||
| 		in:         make(chan []byte), | ||||
| 		out:        make(chan []byte), | ||||
| 		err:        errchan, | ||||
| 		closingIn:  make(chan chan bool, 1), | ||||
| 		closingOut: make(chan chan bool, 1), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) Read() <-chan []byte { | ||||
| 	return self.in | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) Write() chan<- []byte { | ||||
| 	return self.out | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) Error() <-chan *PeerError { | ||||
| 	return self.err | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) startRead() { | ||||
| 	payloads := make(chan []byte) | ||||
| 	done := make(chan *PeerError) | ||||
| 	pending := [][]byte{} | ||||
| 	var head []byte | ||||
| 	var wait time.Duration // initally 0 (no delay)
 | ||||
| 	read := time.After(wait * time.Millisecond) | ||||
| 
 | ||||
| 	for { | ||||
| 		// if pending empty, nil channel blocks
 | ||||
| 		var in chan []byte | ||||
| 		if len(pending) > 0 { | ||||
| 			in = self.in // enable send case
 | ||||
| 			head = pending[0] | ||||
| 		} else { | ||||
| 			in = nil | ||||
| 		} | ||||
| 
 | ||||
| 		select { | ||||
| 		case <-read: | ||||
| 			go self.read(payloads, done) | ||||
| 		case err := <-done: | ||||
| 			if err == nil { // no error but nothing to read
 | ||||
| 				if len(pending) < maxPendingQueueSize { | ||||
| 					wait = 100 | ||||
| 				} else if wait == 0 { | ||||
| 					wait = 100 | ||||
| 				} else { | ||||
| 					wait = 2 * wait | ||||
| 				} | ||||
| 			} else { | ||||
| 				self.err <- err // report error
 | ||||
| 				wait = 100 | ||||
| 			} | ||||
| 			read = time.After(wait * time.Millisecond) | ||||
| 		case payload := <-payloads: | ||||
| 			pending = append(pending, payload) | ||||
| 			if len(pending) < maxPendingQueueSize { | ||||
| 				wait = 0 | ||||
| 			} else { | ||||
| 				wait = 100 | ||||
| 			} | ||||
| 			read = time.After(wait * time.Millisecond) | ||||
| 		case in <- head: | ||||
| 			pending = pending[1:] | ||||
| 		case errc := <-self.closingIn: | ||||
| 			errc <- true | ||||
| 			close(self.in) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) startWrite() { | ||||
| 	pending := [][]byte{} | ||||
| 	done := make(chan *PeerError) | ||||
| 	writing := false | ||||
| 	for { | ||||
| 		if len(pending) > 0 && !writing { | ||||
| 			writing = true | ||||
| 			go self.write(pending[0], done) | ||||
| 		} | ||||
| 		select { | ||||
| 		case payload := <-self.out: | ||||
| 			pending = append(pending, payload) | ||||
| 		case err := <-done: | ||||
| 			if err == nil { | ||||
| 				pending = pending[1:] | ||||
| 				writing = false | ||||
| 			} else { | ||||
| 				self.err <- err // report error
 | ||||
| 			} | ||||
| 		case errc := <-self.closingOut: | ||||
| 			errc <- true | ||||
| 			close(self.out) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func pack(payload []byte) (packet []byte) { | ||||
| 	length := ethutil.NumberToBytes(uint32(len(payload)), 32) | ||||
| 	// return error if too long?
 | ||||
| 	// Write magic token and payload length (first 8 bytes)
 | ||||
| 	packet = append(magicToken, length...) | ||||
| 	packet = append(packet, payload...) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func avoidPanic(done chan *PeerError) { | ||||
| 	if rec := recover(); rec != nil { | ||||
| 		err := NewPeerError(MiscError, " %v", rec) | ||||
| 		logger.Debugln(err) | ||||
| 		done <- err | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) write(payload []byte, done chan *PeerError) { | ||||
| 	defer avoidPanic(done) | ||||
| 	var err *PeerError | ||||
| 	_, ok := self.conn.Write(pack(payload)) | ||||
| 	if ok != nil { | ||||
| 		err = NewPeerError(WriteError, " %v", ok) | ||||
| 		logger.Debugln(err) | ||||
| 	} | ||||
| 	done <- err | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) read(payloads chan []byte, done chan *PeerError) { | ||||
| 	//defer avoidPanic(done)
 | ||||
| 
 | ||||
| 	partials := make(chan []byte, partialsQueueSize) | ||||
| 	errc := make(chan *PeerError) | ||||
| 	go self.readPartials(partials, errc) | ||||
| 
 | ||||
| 	packet := []byte{} | ||||
| 	length := 8 | ||||
| 	start := true | ||||
| 	var err *PeerError | ||||
| out: | ||||
| 	for { | ||||
| 		// appends partials read via connection until packet is
 | ||||
| 		// - either parseable (>=8bytes)
 | ||||
| 		// - or complete (payload fully consumed)
 | ||||
| 		for len(packet) < length { | ||||
| 			partial, ok := <-partials | ||||
| 			if !ok { // partials channel is closed
 | ||||
| 				err = <-errc | ||||
| 				if err == nil && len(packet) > 0 { | ||||
| 					if start { | ||||
| 						err = NewPeerError(PacketTooShort, "%v", packet) | ||||
| 					} else { | ||||
| 						err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) | ||||
| 					} | ||||
| 				} | ||||
| 				break out | ||||
| 			} | ||||
| 			packet = append(packet, partial...) | ||||
| 		} | ||||
| 		if start { | ||||
| 			// at least 8 bytes read, can validate packet
 | ||||
| 			if bytes.Compare(magicToken, packet[:4]) != 0 { | ||||
| 				err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) | ||||
| 				break | ||||
| 			} | ||||
| 			length = int(ethutil.BytesToNumber(packet[4:8])) | ||||
| 			packet = packet[8:] | ||||
| 
 | ||||
| 			if length > 0 { | ||||
| 				start = false // now consuming payload
 | ||||
| 			} else { //penalize peer but read on
 | ||||
| 				self.err <- NewPeerError(EmptyPayload, "") | ||||
| 				length = 8 | ||||
| 			} | ||||
| 		} else { | ||||
| 			// packet complete (payload fully consumed)
 | ||||
| 			payloads <- packet[:length] | ||||
| 			packet = packet[length:] // resclice packet
 | ||||
| 			start = true | ||||
| 			length = 8 | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// this stops partials read via the connection, should we?
 | ||||
| 	//if err != nil {
 | ||||
| 	//  select {
 | ||||
| 	//    case errc <- err
 | ||||
| 	//  default:
 | ||||
| 	//}
 | ||||
| 	done <- err | ||||
| } | ||||
| 
 | ||||
| func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { | ||||
| 	defer close(partials) | ||||
| 	for { | ||||
| 		// Give buffering some time
 | ||||
| 		self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) | ||||
| 		buffer := make([]byte, readBufferLength) | ||||
| 		// read partial from connection
 | ||||
| 		bytesRead, err := self.conn.Read(buffer) | ||||
| 		if err == nil || err.Error() == "EOF" { | ||||
| 			if bytesRead > 0 { | ||||
| 				partials <- buffer[:bytesRead] | ||||
| 			} | ||||
| 			if err != nil && err.Error() == "EOF" { | ||||
| 				break | ||||
| 			} | ||||
| 		} else { | ||||
| 			// unexpected error, report to errc
 | ||||
| 			err := NewPeerError(ReadError, " %v", err) | ||||
| 			logger.Debugln(err) | ||||
| 			errc <- err | ||||
| 			return // will close partials channel
 | ||||
| 		} | ||||
| 	} | ||||
| 	close(errc) | ||||
| } | ||||
| @ -1,222 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type TestNetworkConnection struct { | ||||
| 	in      chan []byte | ||||
| 	current []byte | ||||
| 	Out     [][]byte | ||||
| 	addr    net.Addr | ||||
| } | ||||
| 
 | ||||
| func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { | ||||
| 	return &TestNetworkConnection{ | ||||
| 		in:      make(chan []byte), | ||||
| 		current: []byte{}, | ||||
| 		Out:     [][]byte{}, | ||||
| 		addr:    addr, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { | ||||
| 	time.Sleep(latency) | ||||
| 	for _, s := range packets { | ||||
| 		self.in <- s | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { | ||||
| 	if len(self.current) == 0 { | ||||
| 		select { | ||||
| 		case self.current = <-self.in: | ||||
| 		default: | ||||
| 			return 0, io.EOF | ||||
| 		} | ||||
| 	} | ||||
| 	length := len(self.current) | ||||
| 	if length > len(buff) { | ||||
| 		copy(buff[:], self.current[:len(buff)]) | ||||
| 		self.current = self.current[len(buff):] | ||||
| 		return len(buff), nil | ||||
| 	} else { | ||||
| 		copy(buff[:length], self.current[:]) | ||||
| 		self.current = []byte{} | ||||
| 		return length, io.EOF | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { | ||||
| 	self.Out = append(self.Out, buff) | ||||
| 	fmt.Printf("net write %v\n%v\n", len(self.Out), buff) | ||||
| 	return len(buff), nil | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) Close() (err error) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { | ||||
| 	return self.addr | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func setupConnection() (*Connection, *TestNetworkConnection) { | ||||
| 	addr := &TestAddr{"test:30303"} | ||||
| 	net := NewTestNetworkConnection(addr) | ||||
| 	conn := NewConnection(net, NewPeerErrorChannel()) | ||||
| 	conn.Open() | ||||
| 	return conn, net | ||||
| } | ||||
| 
 | ||||
| func TestReadingNilPacket(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{}) | ||||
| 	// time.Sleep(10 * time.Millisecond)
 | ||||
| 	select { | ||||
| 	case packet := <-conn.Read(): | ||||
| 		t.Errorf("read %v", packet) | ||||
| 	case err := <-conn.Error(): | ||||
| 		t.Errorf("incorrect error %v", err) | ||||
| 	default: | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestReadingShortPacket(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{0}) | ||||
| 	select { | ||||
| 	case packet := <-conn.Read(): | ||||
| 		t.Errorf("read %v", packet) | ||||
| 	case err := <-conn.Error(): | ||||
| 		if err.Code != PacketTooShort { | ||||
| 			t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) | ||||
| 		} | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestReadingInvalidPacket(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) | ||||
| 	select { | ||||
| 	case packet := <-conn.Read(): | ||||
| 		t.Errorf("read %v", packet) | ||||
| 	case err := <-conn.Error(): | ||||
| 		if err.Code != MagicTokenMismatch { | ||||
| 			t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) | ||||
| 		} | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestReadingInvalidPayload(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) | ||||
| 	select { | ||||
| 	case packet := <-conn.Read(): | ||||
| 		t.Errorf("read %v", packet) | ||||
| 	case err := <-conn.Error(): | ||||
| 		if err.Code != PayloadTooShort { | ||||
| 			t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) | ||||
| 		} | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestReadingEmptyPayload(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	select { | ||||
| 	case packet := <-conn.Read(): | ||||
| 		t.Errorf("read %v", packet) | ||||
| 	default: | ||||
| 	} | ||||
| 	select { | ||||
| 	case err := <-conn.Error(): | ||||
| 		code := err.Code | ||||
| 		if code != EmptyPayload { | ||||
| 			t.Errorf("incorrect error, expected EmptyPayload, got %v", code) | ||||
| 		} | ||||
| 	default: | ||||
| 		t.Errorf("no error, expected EmptyPayload") | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestReadingCompletePacket(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	select { | ||||
| 	case packet := <-conn.Read(): | ||||
| 		if bytes.Compare(packet, []byte{1}) != 0 { | ||||
| 			t.Errorf("incorrect payload read") | ||||
| 		} | ||||
| 	case err := <-conn.Error(): | ||||
| 		t.Errorf("incorrect error %v", err) | ||||
| 	default: | ||||
| 		t.Errorf("nothing read") | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestReadingTwoCompletePackets(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) | ||||
| 
 | ||||
| 	for i := 0; i < 2; i++ { | ||||
| 		time.Sleep(10 * time.Millisecond) | ||||
| 		select { | ||||
| 		case packet := <-conn.Read(): | ||||
| 			if bytes.Compare(packet, []byte{byte(i)}) != 0 { | ||||
| 				t.Errorf("incorrect payload read") | ||||
| 			} | ||||
| 		case err := <-conn.Error(): | ||||
| 			t.Errorf("incorrect error %v", err) | ||||
| 		default: | ||||
| 			t.Errorf("nothing read") | ||||
| 		} | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| func TestWriting(t *testing.T) { | ||||
| 	conn, net := setupConnection() | ||||
| 	conn.Write() <- []byte{0} | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	if len(net.Out) == 0 { | ||||
| 		t.Errorf("no output") | ||||
| 	} else { | ||||
| 		out := net.Out[0] | ||||
| 		if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { | ||||
| 			t.Errorf("incorrect packet %v", out) | ||||
| 		} | ||||
| 	} | ||||
| 	conn.Close() | ||||
| } | ||||
| 
 | ||||
| // hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
 | ||||
							
								
								
									
										202
									
								
								p2p/message.go
									
									
									
									
									
								
							
							
						
						
									
										202
									
								
								p2p/message.go
									
									
									
									
									
								
							| @ -1,75 +1,155 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	// "fmt"
 | ||||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"math/big" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/ethutil" | ||||
| 	"github.com/ethereum/go-ethereum/rlp" | ||||
| ) | ||||
| 
 | ||||
| type MsgCode uint8 | ||||
| 
 | ||||
| // Msg defines the structure of a p2p message.
 | ||||
| //
 | ||||
| // Note that a Msg can only be sent once since the Payload reader is
 | ||||
| // consumed during sending. It is not possible to create a Msg and
 | ||||
| // send it any number of times. If you want to reuse an encoded
 | ||||
| // structure, encode the payload into a byte array and create a
 | ||||
| // separate Msg with a bytes.Reader as Payload for each send.
 | ||||
| type Msg struct { | ||||
| 	code    MsgCode // this is the raw code as per adaptive msg code scheme
 | ||||
| 	data    *ethutil.Value | ||||
| 	encoded []byte | ||||
| 	Code    uint64 | ||||
| 	Size    uint32 // size of the paylod
 | ||||
| 	Payload io.Reader | ||||
| } | ||||
| 
 | ||||
| func (self *Msg) Code() MsgCode { | ||||
| 	return self.code | ||||
| } | ||||
| 
 | ||||
| func (self *Msg) Data() *ethutil.Value { | ||||
| 	return self.data | ||||
| } | ||||
| 
 | ||||
| func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { | ||||
| 
 | ||||
| 	// // data := [][]interface{}{}
 | ||||
| 	// data := []interface{}{}
 | ||||
| 	// for _, value := range params {
 | ||||
| 	// 	if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
 | ||||
| 	// 		data = append(data, encodable.RlpValue())
 | ||||
| 	// 	} else if raw, ok := value.([]interface{}); ok {
 | ||||
| 	// 		data = append(data, raw)
 | ||||
| 	// 	} else {
 | ||||
| 	// 		// data = append(data, interface{}(raw))
 | ||||
| 	// 		err = fmt.Errorf("Unable to encode object of type %T", value)
 | ||||
| 	// 		return
 | ||||
| 	// 	}
 | ||||
| 	// }
 | ||||
| 	return &Msg{ | ||||
| 		code: code, | ||||
| 		data: ethutil.NewValue(interface{}(params)), | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { | ||||
| 	value := ethutil.NewValueFromBytes(encoded) | ||||
| 	// Type of message
 | ||||
| 	code := value.Get(0).Uint() | ||||
| 	// Actual data
 | ||||
| 	data := value.SliceFrom(1) | ||||
| 
 | ||||
| 	msg = &Msg{ | ||||
| 		code: MsgCode(code), | ||||
| 		data: data, | ||||
| 		// data:    ethutil.NewValue(data),
 | ||||
| 		encoded: encoded, | ||||
| // NewMsg creates an RLP-encoded message with the given code.
 | ||||
| func NewMsg(code uint64, params ...interface{}) Msg { | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	for _, p := range params { | ||||
| 		buf.Write(ethutil.Encode(p)) | ||||
| 	} | ||||
| 	return | ||||
| 	return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} | ||||
| } | ||||
| 
 | ||||
| func (self *Msg) Decode(offset MsgCode) { | ||||
| 	self.code = self.code - offset | ||||
| } | ||||
| 
 | ||||
| // encode takes an offset argument to implement adaptive message coding
 | ||||
| // the encoded message is memoized to make msgs relayed to several peers more efficient
 | ||||
| func (self *Msg) Encode(offset MsgCode) (res []byte) { | ||||
| 	if len(self.encoded) == 0 { | ||||
| 		res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() | ||||
| 		self.encoded = res | ||||
| 	} else { | ||||
| 		res = self.encoded | ||||
| func encodePayload(params ...interface{}) []byte { | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	for _, p := range params { | ||||
| 		buf.Write(ethutil.Encode(p)) | ||||
| 	} | ||||
| 	return | ||||
| 	return buf.Bytes() | ||||
| } | ||||
| 
 | ||||
| // Decode parse the RLP content of a message into
 | ||||
| // the given value, which must be a pointer.
 | ||||
| //
 | ||||
| // For the decoding rules, please see package rlp.
 | ||||
| func (msg Msg) Decode(val interface{}) error { | ||||
| 	s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) | ||||
| 	return s.Decode(val) | ||||
| } | ||||
| 
 | ||||
| // Discard reads any remaining payload data into a black hole.
 | ||||
| func (msg Msg) Discard() error { | ||||
| 	_, err := io.Copy(ioutil.Discard, msg.Payload) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| type MsgReader interface { | ||||
| 	ReadMsg() (Msg, error) | ||||
| } | ||||
| 
 | ||||
| type MsgWriter interface { | ||||
| 	// WriteMsg sends an existing message.
 | ||||
| 	// The Payload reader of the message is consumed.
 | ||||
| 	// Note that messages can be sent only once.
 | ||||
| 	WriteMsg(Msg) error | ||||
| 
 | ||||
| 	// EncodeMsg writes an RLP-encoded message with the given
 | ||||
| 	// code and data elements.
 | ||||
| 	EncodeMsg(code uint64, data ...interface{}) error | ||||
| } | ||||
| 
 | ||||
| // MsgReadWriter provides reading and writing of encoded messages.
 | ||||
| type MsgReadWriter interface { | ||||
| 	MsgReader | ||||
| 	MsgWriter | ||||
| } | ||||
| 
 | ||||
| var magicToken = []byte{34, 64, 8, 145} | ||||
| 
 | ||||
| func writeMsg(w io.Writer, msg Msg) error { | ||||
| 	// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
 | ||||
| 	code := ethutil.Encode(uint32(msg.Code)) | ||||
| 	listhdr := makeListHeader(msg.Size + uint32(len(code))) | ||||
| 	payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size | ||||
| 
 | ||||
| 	start := make([]byte, 8) | ||||
| 	copy(start, magicToken) | ||||
| 	binary.BigEndian.PutUint32(start[4:], payloadLen) | ||||
| 
 | ||||
| 	for _, b := range [][]byte{start, listhdr, code} { | ||||
| 		if _, err := w.Write(b); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	_, err := io.CopyN(w, msg.Payload, int64(msg.Size)) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func makeListHeader(length uint32) []byte { | ||||
| 	if length < 56 { | ||||
| 		return []byte{byte(length + 0xc0)} | ||||
| 	} | ||||
| 	enc := big.NewInt(int64(length)).Bytes() | ||||
| 	lenb := byte(len(enc)) + 0xf7 | ||||
| 	return append([]byte{lenb}, enc...) | ||||
| } | ||||
| 
 | ||||
| // readMsg reads a message header from r.
 | ||||
| // It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
 | ||||
| func readMsg(r rlp.ByteReader) (msg Msg, err error) { | ||||
| 	// read magic and payload size
 | ||||
| 	start := make([]byte, 8) | ||||
| 	if _, err = io.ReadFull(r, start); err != nil { | ||||
| 		return msg, newPeerError(errRead, "%v", err) | ||||
| 	} | ||||
| 	if !bytes.HasPrefix(start, magicToken) { | ||||
| 		return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) | ||||
| 	} | ||||
| 	size := binary.BigEndian.Uint32(start[4:]) | ||||
| 
 | ||||
| 	// decode start of RLP message to get the message code
 | ||||
| 	posr := &postrack{r, 0} | ||||
| 	s := rlp.NewStream(posr) | ||||
| 	if _, err := s.List(); err != nil { | ||||
| 		return msg, err | ||||
| 	} | ||||
| 	code, err := s.Uint() | ||||
| 	if err != nil { | ||||
| 		return msg, err | ||||
| 	} | ||||
| 	payloadsize := size - posr.p | ||||
| 	return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil | ||||
| } | ||||
| 
 | ||||
| // postrack wraps an rlp.ByteReader with a position counter.
 | ||||
| type postrack struct { | ||||
| 	r rlp.ByteReader | ||||
| 	p uint32 | ||||
| } | ||||
| 
 | ||||
| func (r *postrack) Read(buf []byte) (int, error) { | ||||
| 	n, err := r.r.Read(buf) | ||||
| 	r.p += uint32(n) | ||||
| 	return n, err | ||||
| } | ||||
| 
 | ||||
| func (r *postrack) ReadByte() (byte, error) { | ||||
| 	b, err := r.r.ReadByte() | ||||
| 	if err == nil { | ||||
| 		r.p++ | ||||
| 	} | ||||
| 	return b, err | ||||
| } | ||||
|  | ||||
| @ -1,38 +1,70 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"io/ioutil" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/ethutil" | ||||
| ) | ||||
| 
 | ||||
| func TestNewMsg(t *testing.T) { | ||||
| 	msg, _ := NewMsg(3, 1, "000") | ||||
| 	if msg.Code() != 3 { | ||||
| 		t.Errorf("incorrect code %v", msg.Code()) | ||||
| 	msg := NewMsg(3, 1, "000") | ||||
| 	if msg.Code != 3 { | ||||
| 		t.Errorf("incorrect code %d, want %d", msg.Code) | ||||
| 	} | ||||
| 	data0 := msg.Data().Get(0).Uint() | ||||
| 	data1 := string(msg.Data().Get(1).Bytes()) | ||||
| 	if data0 != 1 { | ||||
| 		t.Errorf("incorrect data %v", data0) | ||||
| 	if msg.Size != 5 { | ||||
| 		t.Errorf("incorrect size %d, want %d", msg.Size, 5) | ||||
| 	} | ||||
| 	if data1 != "000" { | ||||
| 		t.Errorf("incorrect data %v", data1) | ||||
| 	pl, _ := ioutil.ReadAll(msg.Payload) | ||||
| 	expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} | ||||
| 	if !bytes.Equal(pl, expect) { | ||||
| 		t.Errorf("incorrect payload content, got %x, want %x", pl, expect) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEncodeDecodeMsg(t *testing.T) { | ||||
| 	msg, _ := NewMsg(3, 1, "000") | ||||
| 	encoded := msg.Encode(3) | ||||
| 	msg, _ = NewMsgFromBytes(encoded) | ||||
| 	msg.Decode(3) | ||||
| 	if msg.Code() != 3 { | ||||
| 		t.Errorf("incorrect code %v", msg.Code()) | ||||
| 	msg := NewMsg(3, 1, "000") | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	if err := writeMsg(buf, msg); err != nil { | ||||
| 		t.Fatalf("encodeMsg error: %v", err) | ||||
| 	} | ||||
| 	data0 := msg.Data().Get(0).Uint() | ||||
| 	data1 := msg.Data().Get(1).Str() | ||||
| 	if data0 != 1 { | ||||
| 		t.Errorf("incorrect data %v", data0) | ||||
| 	// t.Logf("encoded: %x", buf.Bytes())
 | ||||
| 
 | ||||
| 	decmsg, err := readMsg(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("readMsg error: %v", err) | ||||
| 	} | ||||
| 	if data1 != "000" { | ||||
| 		t.Errorf("incorrect data %v", data1) | ||||
| 	if decmsg.Code != 3 { | ||||
| 		t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) | ||||
| 	} | ||||
| 	if decmsg.Size != 5 { | ||||
| 		t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) | ||||
| 	} | ||||
| 
 | ||||
| 	var data struct { | ||||
| 		I int | ||||
| 		S string | ||||
| 	} | ||||
| 	if err := decmsg.Decode(&data); err != nil { | ||||
| 		t.Fatalf("Decode error: %v", err) | ||||
| 	} | ||||
| 	if data.I != 1 { | ||||
| 		t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) | ||||
| 	} | ||||
| 	if data.S != "000" { | ||||
| 		t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDecodeRealMsg(t *testing.T) { | ||||
| 	data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") | ||||
| 	msg, err := readMsg(bytes.NewReader(data)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("unexpected error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if msg.Code != 0 { | ||||
| 		t.Errorf("incorrect code %d, want %d", msg.Code, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										220
									
								
								p2p/messenger.go
									
									
									
									
									
								
							
							
						
						
									
										220
									
								
								p2p/messenger.go
									
									
									
									
									
								
							| @ -1,220 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	handlerTimeout = 1000 | ||||
| ) | ||||
| 
 | ||||
| type Handlers map[string](func(p *Peer) Protocol) | ||||
| 
 | ||||
| type Messenger struct { | ||||
| 	conn          *Connection | ||||
| 	peer          *Peer | ||||
| 	handlers      Handlers | ||||
| 	protocolLock  sync.RWMutex | ||||
| 	protocols     []Protocol | ||||
| 	offsets       []MsgCode // offsets for adaptive message idss
 | ||||
| 	protocolTable map[string]int | ||||
| 	quit          chan chan bool | ||||
| 	err           chan *PeerError | ||||
| 	pulse         chan bool | ||||
| } | ||||
| 
 | ||||
| func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { | ||||
| 	baseProtocol := NewBaseProtocol(peer) | ||||
| 	return &Messenger{ | ||||
| 		conn:          conn, | ||||
| 		peer:          peer, | ||||
| 		offsets:       []MsgCode{baseProtocol.Offset()}, | ||||
| 		handlers:      handlers, | ||||
| 		protocols:     []Protocol{baseProtocol}, | ||||
| 		protocolTable: make(map[string]int), | ||||
| 		err:           errchan, | ||||
| 		pulse:         make(chan bool, 1), | ||||
| 		quit:          make(chan chan bool, 1), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Messenger) Start() { | ||||
| 	self.conn.Open() | ||||
| 	go self.messenger() | ||||
| 	self.protocolLock.RLock() | ||||
| 	defer self.protocolLock.RUnlock() | ||||
| 	self.protocols[0].Start() | ||||
| } | ||||
| 
 | ||||
| func (self *Messenger) Stop() { | ||||
| 	// close pulse to stop ping pong monitoring
 | ||||
| 	close(self.pulse) | ||||
| 	self.protocolLock.RLock() | ||||
| 	defer self.protocolLock.RUnlock() | ||||
| 	for _, protocol := range self.protocols { | ||||
| 		protocol.Stop() // could be parallel
 | ||||
| 	} | ||||
| 	q := make(chan bool) | ||||
| 	self.quit <- q | ||||
| 	<-q | ||||
| 	self.conn.Close() | ||||
| } | ||||
| 
 | ||||
| func (self *Messenger) messenger() { | ||||
| 	in := self.conn.Read() | ||||
| 	for { | ||||
| 		select { | ||||
| 		case payload, ok := <-in: | ||||
| 			//dispatches message to the protocol asynchronously
 | ||||
| 			if ok { | ||||
| 				go self.handle(payload) | ||||
| 			} else { | ||||
| 				return | ||||
| 			} | ||||
| 		case q := <-self.quit: | ||||
| 			q <- true | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // handles each message by dispatching to the appropriate protocol
 | ||||
| // using adaptive message codes
 | ||||
| // this function is started as a separate go routine for each message
 | ||||
| // it waits for the protocol response
 | ||||
| // then encodes and sends outgoing messages to the connection's write channel
 | ||||
| func (self *Messenger) handle(payload []byte) { | ||||
| 	// send ping to heartbeat channel signalling time of last message
 | ||||
| 	// select {
 | ||||
| 	// case self.pulse <- true:
 | ||||
| 	// default:
 | ||||
| 	// }
 | ||||
| 	self.pulse <- true | ||||
| 	// initialise message from payload
 | ||||
| 	msg, err := NewMsgFromBytes(payload) | ||||
| 	if err != nil { | ||||
| 		self.err <- NewPeerError(MiscError, " %v", err) | ||||
| 		return | ||||
| 	} | ||||
| 	// retrieves protocol based on message Code
 | ||||
| 	protocol, offset, peerErr := self.getProtocol(msg.Code()) | ||||
| 	if err != nil { | ||||
| 		self.err <- peerErr | ||||
| 		return | ||||
| 	} | ||||
| 	// reset message code based on adaptive offset
 | ||||
| 	msg.Decode(offset) | ||||
| 	// dispatches
 | ||||
| 	response := make(chan *Msg) | ||||
| 	go protocol.HandleIn(msg, response) | ||||
| 	// protocol reponse timeout to prevent leaks
 | ||||
| 	timer := time.After(handlerTimeout * time.Millisecond) | ||||
| 	for { | ||||
| 		select { | ||||
| 		case outgoing, ok := <-response: | ||||
| 			// we check if response channel is not closed
 | ||||
| 			if ok { | ||||
| 				self.conn.Write() <- outgoing.Encode(offset) | ||||
| 			} else { | ||||
| 				return | ||||
| 			} | ||||
| 		case <-timer: | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // negotiated protocols
 | ||||
| // stores offsets needed for adaptive message id scheme
 | ||||
| 
 | ||||
| // based on offsets set at handshake
 | ||||
| // get the right protocol to handle the message
 | ||||
| func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { | ||||
| 	self.protocolLock.RLock() | ||||
| 	defer self.protocolLock.RUnlock() | ||||
| 	base := MsgCode(0) | ||||
| 	for index, offset := range self.offsets { | ||||
| 		if code < offset { | ||||
| 			return self.protocols[index], base, nil | ||||
| 		} | ||||
| 		base = offset | ||||
| 	} | ||||
| 	return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) | ||||
| } | ||||
| 
 | ||||
| func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { | ||||
| 	fmt.Printf("pingpong keepalive started at %v", time.Now()) | ||||
| 
 | ||||
| 	timer := time.After(timeout) | ||||
| 	pinged := false | ||||
| 	for { | ||||
| 		select { | ||||
| 		case _, ok := <-self.pulse: | ||||
| 			if ok { | ||||
| 				pinged = false | ||||
| 				timer = time.After(timeout) | ||||
| 			} else { | ||||
| 				// pulse is closed, stop monitoring
 | ||||
| 				return | ||||
| 			} | ||||
| 		case <-timer: | ||||
| 			if pinged { | ||||
| 				fmt.Printf("timeout at %v", time.Now()) | ||||
| 				timeoutCallback() | ||||
| 				return | ||||
| 			} else { | ||||
| 				fmt.Printf("pinged at %v", time.Now()) | ||||
| 				pingCallback() | ||||
| 				timer = time.After(gracePeriod) | ||||
| 				pinged = true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Messenger) AddProtocols(protocols []string) { | ||||
| 	self.protocolLock.Lock() | ||||
| 	defer self.protocolLock.Unlock() | ||||
| 	i := len(self.offsets) | ||||
| 	offset := self.offsets[i-1] | ||||
| 	for _, name := range protocols { | ||||
| 		protocolFunc, ok := self.handlers[name] | ||||
| 		if ok { | ||||
| 			protocol := protocolFunc(self.peer) | ||||
| 			self.protocolTable[name] = i | ||||
| 			i++ | ||||
| 			offset += protocol.Offset() | ||||
| 			fmt.Println("offset ", name, offset) | ||||
| 
 | ||||
| 			self.offsets = append(self.offsets, offset) | ||||
| 			self.protocols = append(self.protocols, protocol) | ||||
| 			protocol.Start() | ||||
| 		} else { | ||||
| 			fmt.Println("no ", name) | ||||
| 			// protocol not handled
 | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Messenger) Write(protocol string, msg *Msg) error { | ||||
| 	self.protocolLock.RLock() | ||||
| 	defer self.protocolLock.RUnlock() | ||||
| 	i := 0 | ||||
| 	offset := MsgCode(0) | ||||
| 	if len(protocol) > 0 { | ||||
| 		var ok bool | ||||
| 		i, ok = self.protocolTable[protocol] | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("protocol %v not handled by peer", protocol) | ||||
| 		} | ||||
| 		offset = self.offsets[i-1] | ||||
| 	} | ||||
| 	handler := self.protocols[i] | ||||
| 	// checking if protocol status/caps allows the message to be sent out
 | ||||
| 	if handler.HandleOut(msg) { | ||||
| 		self.conn.Write() <- msg.Encode(offset) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @ -1,147 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	// "fmt"
 | ||||
| 	"bytes" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/ethutil" | ||||
| ) | ||||
| 
 | ||||
| func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { | ||||
| 	errchan := NewPeerErrorChannel() | ||||
| 	addr := &TestAddr{"test:30303"} | ||||
| 	net := NewTestNetworkConnection(addr) | ||||
| 	conn := NewConnection(net, errchan) | ||||
| 	mess := NewMessenger(nil, conn, errchan, handlers) | ||||
| 	mess.Start() | ||||
| 	return net, errchan, mess | ||||
| } | ||||
| 
 | ||||
| type TestProtocol struct { | ||||
| 	Msgs []*Msg | ||||
| } | ||||
| 
 | ||||
| func (self *TestProtocol) Start() { | ||||
| } | ||||
| 
 | ||||
| func (self *TestProtocol) Stop() { | ||||
| } | ||||
| 
 | ||||
| func (self *TestProtocol) Offset() MsgCode { | ||||
| 	return MsgCode(5) | ||||
| } | ||||
| 
 | ||||
| func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { | ||||
| 	self.Msgs = append(self.Msgs, msg) | ||||
| 	close(response) | ||||
| } | ||||
| 
 | ||||
| func (self *TestProtocol) HandleOut(msg *Msg) bool { | ||||
| 	if msg.Code() > 3 { | ||||
| 		return false | ||||
| 	} else { | ||||
| 		return true | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TestProtocol) Name() string { | ||||
| 	return "a" | ||||
| } | ||||
| 
 | ||||
| func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { | ||||
| 	msg, _ := NewMsg(code, params...) | ||||
| 	encoded := msg.Encode(offset) | ||||
| 	packet := []byte{34, 64, 8, 145} | ||||
| 	packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) | ||||
| 	return append(packet, encoded...) | ||||
| } | ||||
| 
 | ||||
| func TestRead(t *testing.T) { | ||||
| 	handlers := make(Handlers) | ||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} | ||||
| 	handlers["a"] = func(p *Peer) Protocol { return testProtocol } | ||||
| 	net, _, mess := setupMessenger(handlers) | ||||
| 	mess.AddProtocols([]string{"a"}) | ||||
| 	defer mess.Stop() | ||||
| 	wait := 1 * time.Millisecond | ||||
| 	packet := Packet(16, 1, uint32(1), "000") | ||||
| 	go net.In(0, packet) | ||||
| 	time.Sleep(wait) | ||||
| 	if len(testProtocol.Msgs) != 1 { | ||||
| 		t.Errorf("msg not relayed to correct protocol") | ||||
| 	} else { | ||||
| 		if testProtocol.Msgs[0].Code() != 1 { | ||||
| 			t.Errorf("incorrect msg code relayed to protocol") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestWrite(t *testing.T) { | ||||
| 	handlers := make(Handlers) | ||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} | ||||
| 	handlers["a"] = func(p *Peer) Protocol { return testProtocol } | ||||
| 	net, _, mess := setupMessenger(handlers) | ||||
| 	mess.AddProtocols([]string{"a"}) | ||||
| 	defer mess.Stop() | ||||
| 	wait := 1 * time.Millisecond | ||||
| 	msg, _ := NewMsg(3, uint32(1), "000") | ||||
| 	err := mess.Write("b", msg) | ||||
| 	if err == nil { | ||||
| 		t.Errorf("expect error for unknown protocol") | ||||
| 	} | ||||
| 	err = mess.Write("a", msg) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expect no error for known protocol: %v", err) | ||||
| 	} else { | ||||
| 		time.Sleep(wait) | ||||
| 		if len(net.Out) != 1 { | ||||
| 			t.Errorf("msg not written") | ||||
| 		} else { | ||||
| 			out := net.Out[0] | ||||
| 			packet := Packet(16, 3, uint32(1), "000") | ||||
| 			if bytes.Compare(out, packet) != 0 { | ||||
| 				t.Errorf("incorrect packet %v", out) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPulse(t *testing.T) { | ||||
| 	net, _, mess := setupMessenger(make(Handlers)) | ||||
| 	defer mess.Stop() | ||||
| 	ping := false | ||||
| 	timeout := false | ||||
| 	pingTimeout := 10 * time.Millisecond | ||||
| 	gracePeriod := 200 * time.Millisecond | ||||
| 	go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) | ||||
| 	net.In(0, Packet(0, 1)) | ||||
| 	if ping { | ||||
| 		t.Errorf("ping sent too early") | ||||
| 	} | ||||
| 	time.Sleep(pingTimeout + 100*time.Millisecond) | ||||
| 	if !ping { | ||||
| 		t.Errorf("no ping sent after timeout") | ||||
| 	} | ||||
| 	if timeout { | ||||
| 		t.Errorf("timeout too early") | ||||
| 	} | ||||
| 	ping = false | ||||
| 	net.In(0, Packet(0, 1)) | ||||
| 	time.Sleep(pingTimeout + 100*time.Millisecond) | ||||
| 	if !ping { | ||||
| 		t.Errorf("no ping sent after timeout") | ||||
| 	} | ||||
| 	if timeout { | ||||
| 		t.Errorf("timeout too early") | ||||
| 	} | ||||
| 	ping = false | ||||
| 	time.Sleep(gracePeriod) | ||||
| 	if ping { | ||||
| 		t.Errorf("ping called twice") | ||||
| 	} | ||||
| 	if !timeout { | ||||
| 		t.Errorf("no timeout after grace period") | ||||
| 	} | ||||
| } | ||||
| @ -3,6 +3,7 @@ package p2p | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"time" | ||||
| 
 | ||||
| 	natpmp "github.com/jackpal/go-nat-pmp" | ||||
| ) | ||||
| @ -13,38 +14,37 @@ import ( | ||||
| //  + Register for changes to the external address.
 | ||||
| //  + Re-register port mapping when router reboots.
 | ||||
| //  + A mechanism for keeping a port mapping registered.
 | ||||
| //  + Discover gateway address automatically.
 | ||||
| 
 | ||||
| type natPMPClient struct { | ||||
| 	client *natpmp.Client | ||||
| } | ||||
| 
 | ||||
| func NewNatPMP(gateway net.IP) (nat NAT) { | ||||
| // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
 | ||||
| // address should be the IP of your router.
 | ||||
| func PMP(gateway net.IP) (nat NAT) { | ||||
| 	return &natPMPClient{natpmp.NewClient(gateway)} | ||||
| } | ||||
| 
 | ||||
| func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { | ||||
| 	response, err := n.client.GetExternalAddress() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	ip := response.ExternalIPAddress | ||||
| 	addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) | ||||
| 	return | ||||
| func (*natPMPClient) String() string { | ||||
| 	return "NAT-PMP" | ||||
| } | ||||
| 
 | ||||
| func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, | ||||
| 	description string, timeout int) (mappedExternalPort int, err error) { | ||||
| 	if timeout <= 0 { | ||||
| 		err = fmt.Errorf("timeout must not be <= 0") | ||||
| 		return | ||||
| func (n *natPMPClient) GetExternalAddress() (net.IP, error) { | ||||
| 	response, err := n.client.GetExternalAddress() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return response.ExternalIPAddress[:], nil | ||||
| } | ||||
| 
 | ||||
| func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { | ||||
| 	if lifetime <= 0 { | ||||
| 		return fmt.Errorf("lifetime must not be <= 0") | ||||
| 	} | ||||
| 	// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
 | ||||
| 	response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	mappedExternalPort = int(response.MappedExternalPort) | ||||
| 	return | ||||
| 	_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { | ||||
|  | ||||
							
								
								
									
										198
									
								
								p2p/natupnp.go
									
									
									
									
									
								
							
							
						
						
									
										198
									
								
								p2p/natupnp.go
									
									
									
									
									
								
							| @ -7,6 +7,7 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/xml" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| @ -15,28 +16,46 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	upnpDiscoverAttempts = 3 | ||||
| 	upnpDiscoverTimeout  = 5 * time.Second | ||||
| ) | ||||
| 
 | ||||
| // UPNP returns a NAT port mapper that uses UPnP. It will attempt to
 | ||||
| // discover the address of your router using UDP broadcasts.
 | ||||
| func UPNP() NAT { | ||||
| 	return &upnpNAT{} | ||||
| } | ||||
| 
 | ||||
| type upnpNAT struct { | ||||
| 	serviceURL string | ||||
| 	ourIP      string | ||||
| } | ||||
| 
 | ||||
| func upnpDiscover(attempts int) (nat NAT, err error) { | ||||
| func (n *upnpNAT) String() string { | ||||
| 	return "UPNP" | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) discover() error { | ||||
| 	if n.serviceURL != "" { | ||||
| 		// already discovered
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 	// TODO: try on all network interfaces simultaneously.
 | ||||
| 	// Broadcasting on 0.0.0.0 could select a random interface
 | ||||
| 	// to send on (platform specific).
 | ||||
| 	conn, err := net.ListenPacket("udp4", ":0") | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	socket := conn.(*net.UDPConn) | ||||
| 	defer socket.Close() | ||||
| 
 | ||||
| 	err = socket.SetDeadline(time.Now().Add(10 * time.Second)) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 
 | ||||
| 	conn.SetDeadline(time.Now().Add(10 * time.Second)) | ||||
| 	st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" | ||||
| 	buf := bytes.NewBufferString( | ||||
| 		"M-SEARCH * HTTP/1.1\r\n" + | ||||
| @ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { | ||||
| 			"MX: 2\r\n\r\n") | ||||
| 	message := buf.Bytes() | ||||
| 	answerBytes := make([]byte, 1024) | ||||
| 	for i := 0; i < attempts; i++ { | ||||
| 		_, err = socket.WriteToUDP(message, ssdp) | ||||
| 	for i := 0; i < upnpDiscoverAttempts; i++ { | ||||
| 		_, err = conn.WriteTo(message, ssdp) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 			return err | ||||
| 		} | ||||
| 		var n int | ||||
| 		n, _, err = socket.ReadFromUDP(answerBytes) | ||||
| 		nn, _, err := conn.ReadFrom(answerBytes) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 			// socket.Close()
 | ||||
| 			// return
 | ||||
| 		} | ||||
| 		answer := string(answerBytes[0:n]) | ||||
| 		answer := string(answerBytes[0:nn]) | ||||
| 		if strings.Index(answer, "\r\n"+st) < 0 { | ||||
| 			continue | ||||
| 		} | ||||
| @ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { | ||||
| 		var serviceURL string | ||||
| 		serviceURL, err = getServiceURL(locURL) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 			return err | ||||
| 		} | ||||
| 		var ourIP string | ||||
| 		ourIP, err = getOurIP() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 			return err | ||||
| 		} | ||||
| 		nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} | ||||
| 		n.serviceURL = serviceURL | ||||
| 		n.ourIP = ourIP | ||||
| 		return nil | ||||
| 	} | ||||
| 	return errors.New("UPnP port discovery failed.") | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { | ||||
| 	if err := n.discover(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	info, err := n.getStatusInfo() | ||||
| 	return net.ParseIP(info.externalIpAddress), err | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { | ||||
| 	if err := n.discover(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// A single concatenation would break ARM compilation.
 | ||||
| 	message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||
| 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport) | ||||
| 	message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" | ||||
| 	message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" + | ||||
| 		"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + | ||||
| 		"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" | ||||
| 	message += description + | ||||
| 		"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) + | ||||
| 		"</NewLeaseDuration></u:AddPortMapping>" | ||||
| 
 | ||||
| 	// TODO: check response to see if the port was forwarded
 | ||||
| 	_, err := soapRequest(n.serviceURL, "AddPortMapping", message) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { | ||||
| 	if err := n.discover(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||
| 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + | ||||
| 		"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + | ||||
| 		"</u:DeletePortMapping>" | ||||
| 
 | ||||
| 	// TODO: check response to see if the port was deleted
 | ||||
| 	_, err := soapRequest(n.serviceURL, "DeletePortMapping", message) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| type statusInfo struct { | ||||
| 	externalIpAddress string | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { | ||||
| 	message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||
| 		"</u:GetStatusInfo>" | ||||
| 
 | ||||
| 	var response *http.Response | ||||
| 	response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	err = errors.New("UPnP port discovery failed.") | ||||
| 
 | ||||
| 	// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
 | ||||
| 
 | ||||
| 	response.Body.Close() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type statusInfo struct { | ||||
| 	externalIpAddress string | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { | ||||
| 
 | ||||
| 	message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||
| 		"</u:GetStatusInfo>" | ||||
| 
 | ||||
| 	var response *http.Response | ||||
| 	response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
 | ||||
| 
 | ||||
| 	response.Body.Close() | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { | ||||
| 	info, err := n.getStatusInfo() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	addr = net.ParseIP(info.externalIpAddress) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { | ||||
| 	// A single concatenation would break ARM compilation.
 | ||||
| 	message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||
| 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) | ||||
| 	message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" | ||||
| 	message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + | ||||
| 		"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + | ||||
| 		"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" | ||||
| 	message += description + | ||||
| 		"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + | ||||
| 		"</NewLeaseDuration></u:AddPortMapping>" | ||||
| 
 | ||||
| 	var response *http.Response | ||||
| 	response, err = soapRequest(n.serviceURL, "AddPortMapping", message) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: check response to see if the port was forwarded
 | ||||
| 	// log.Println(message, response)
 | ||||
| 	mappedExternalPort = externalPort | ||||
| 	_ = response | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { | ||||
| 
 | ||||
| 	message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||
| 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + | ||||
| 		"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + | ||||
| 		"</u:DeletePortMapping>" | ||||
| 
 | ||||
| 	var response *http.Response | ||||
| 	response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: check response to see if the port was deleted
 | ||||
| 	// log.Println(message, response)
 | ||||
| 	_ = response | ||||
| 	return | ||||
| } | ||||
|  | ||||
							
								
								
									
										196
									
								
								p2p/network.go
									
									
									
									
									
								
							
							
						
						
									
										196
									
								
								p2p/network.go
									
									
									
									
									
								
							| @ -1,196 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	DialerTimeout             = 180 //seconds
 | ||||
| 	KeepAlivePeriod           = 60  //minutes
 | ||||
| 	portMappingUpdateInterval = 900 // seconds = 15 mins
 | ||||
| 	upnpDiscoverAttempts      = 3 | ||||
| ) | ||||
| 
 | ||||
| // Dialer is not an interface in net, so we define one
 | ||||
| // *net.Dialer conforms to this
 | ||||
| type Dialer interface { | ||||
| 	Dial(network, address string) (net.Conn, error) | ||||
| } | ||||
| 
 | ||||
| type Network interface { | ||||
| 	Start() error | ||||
| 	Listener(net.Addr) (net.Listener, error) | ||||
| 	Dialer(net.Addr) (Dialer, error) | ||||
| 	NewAddr(string, int) (addr net.Addr, err error) | ||||
| 	ParseAddr(string) (addr net.Addr, err error) | ||||
| } | ||||
| 
 | ||||
| type NAT interface { | ||||
| 	GetExternalAddress() (addr net.IP, err error) | ||||
| 	AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) | ||||
| 	DeletePortMapping(protocol string, externalPort, internalPort int) (err error) | ||||
| } | ||||
| 
 | ||||
| type TCPNetwork struct { | ||||
| 	nat     NAT | ||||
| 	natType NATType | ||||
| 	quit    chan chan bool | ||||
| 	ports   chan string | ||||
| } | ||||
| 
 | ||||
| type NATType int | ||||
| 
 | ||||
| const ( | ||||
| 	NONE = iota | ||||
| 	UPNP | ||||
| 	PMP | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	portMappingTimeout = 1200 // 20 mins
 | ||||
| ) | ||||
| 
 | ||||
| func NewTCPNetwork(natType NATType) (net *TCPNetwork) { | ||||
| 	return &TCPNetwork{ | ||||
| 		natType: natType, | ||||
| 		ports:   make(chan string), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { | ||||
| 	return &net.Dialer{ | ||||
| 		Timeout: DialerTimeout * time.Second, | ||||
| 		// KeepAlive: KeepAlivePeriod * time.Minute,
 | ||||
| 		LocalAddr: addr, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { | ||||
| 	if self.natType == UPNP { | ||||
| 		_, port, _ := net.SplitHostPort(addr.String()) | ||||
| 		if self.quit == nil { | ||||
| 			self.quit = make(chan chan bool) | ||||
| 			go self.updatePortMappings() | ||||
| 		} | ||||
| 		self.ports <- port | ||||
| 	} | ||||
| 	return net.Listen(addr.Network(), addr.String()) | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) Start() (err error) { | ||||
| 	switch self.natType { | ||||
| 	case NONE: | ||||
| 	case UPNP: | ||||
| 		nat, uerr := upnpDiscover(upnpDiscoverAttempts) | ||||
| 		if uerr != nil { | ||||
| 			err = fmt.Errorf("UPNP failed: ", uerr) | ||||
| 		} else { | ||||
| 			self.nat = nat | ||||
| 		} | ||||
| 	case PMP: | ||||
| 		err = fmt.Errorf("PMP not implemented") | ||||
| 	default: | ||||
| 		err = fmt.Errorf("Invalid NAT type: %v", self.natType) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) Stop() { | ||||
| 	q := make(chan bool) | ||||
| 	self.quit <- q | ||||
| 	<-q | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) addPortMapping(lport int) (err error) { | ||||
| 	_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("unable to add port mapping on %v: %v", lport, err) | ||||
| 	} else { | ||||
| 		logger.Debugf("succesfully added port mapping on %v", lport) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) updatePortMappings() { | ||||
| 	timer := time.NewTimer(portMappingUpdateInterval * time.Second) | ||||
| 	lports := []int{} | ||||
| out: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case port := <-self.ports: | ||||
| 			int64lport, _ := strconv.ParseInt(port, 10, 16) | ||||
| 			lport := int(int64lport) | ||||
| 			if err := self.addPortMapping(lport); err != nil { | ||||
| 				lports = append(lports, lport) | ||||
| 			} | ||||
| 		case <-timer.C: | ||||
| 			for lport := range lports { | ||||
| 				if err := self.addPortMapping(lport); err != nil { | ||||
| 				} | ||||
| 			} | ||||
| 		case errc := <-self.quit: | ||||
| 			errc <- true | ||||
| 			break out | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	timer.Stop() | ||||
| 	for lport := range lports { | ||||
| 		if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { | ||||
| 			logger.Debugf("unable to remove port mapping on %v: %v", lport, err) | ||||
| 		} else { | ||||
| 			logger.Debugf("succesfully removed port mapping on %v", lport) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { | ||||
| 	ip, err := self.lookupIP(host) | ||||
| 	if err == nil { | ||||
| 		return &net.TCPAddr{ | ||||
| 			IP:   ip, | ||||
| 			Port: port, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	return nil, err | ||||
| } | ||||
| 
 | ||||
| func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { | ||||
| 	host, port, err := net.SplitHostPort(address) | ||||
| 	if err == nil { | ||||
| 		iport, _ := strconv.Atoi(port) | ||||
| 		addr, e := self.NewAddr(host, iport) | ||||
| 		return addr, e | ||||
| 	} | ||||
| 	return nil, err | ||||
| } | ||||
| 
 | ||||
| func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { | ||||
| 	if ip = net.ParseIP(host); ip != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var ips []net.IP | ||||
| 	ips, err = net.LookupIP(host) | ||||
| 	if err != nil { | ||||
| 		logger.Warnln(err) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(ips) == 0 { | ||||
| 		err = fmt.Errorf("No IP addresses available for %v", host) | ||||
| 		logger.Warnln(err) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(ips) > 1 { | ||||
| 		// Pick a random IP address, simulating round-robin DNS.
 | ||||
| 		rand.Seed(time.Now().UTC().UnixNano()) | ||||
| 		ip = ips[rand.Intn(len(ips))] | ||||
| 	} else { | ||||
| 		ip = ips[0] | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										490
									
								
								p2p/peer.go
									
									
									
									
									
								
							
							
						
						
									
										490
									
								
								p2p/peer.go
									
									
									
									
									
								
							| @ -1,83 +1,455 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"sort" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/event" | ||||
| 	"github.com/ethereum/go-ethereum/logger" | ||||
| ) | ||||
| 
 | ||||
| type Peer struct { | ||||
| 	// quit      chan chan bool
 | ||||
| 	Inbound          bool // inbound (via listener) or outbound (via dialout)
 | ||||
| 	Address          net.Addr | ||||
| 	Host             []byte | ||||
| 	Port             uint16 | ||||
| 	Pubkey           []byte | ||||
| 	Id               string | ||||
| 	Caps             []string | ||||
| 	peerErrorChan    chan *PeerError | ||||
| 	messenger        *Messenger | ||||
| 	peerErrorHandler *PeerErrorHandler | ||||
| 	server           *Server | ||||
| // peerAddr is the structure of a peer list element.
 | ||||
| // It is also a valid net.Addr.
 | ||||
| type peerAddr struct { | ||||
| 	IP     net.IP | ||||
| 	Port   uint64 | ||||
| 	Pubkey []byte // optional
 | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) Messenger() *Messenger { | ||||
| 	return self.messenger | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) PeerErrorChan() chan *PeerError { | ||||
| 	return self.peerErrorChan | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) Server() *Server { | ||||
| 	return self.server | ||||
| } | ||||
| 
 | ||||
| func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { | ||||
| 	peerErrorChan := NewPeerErrorChannel() | ||||
| 	host, port, _ := net.SplitHostPort(address.String()) | ||||
| 	intport, _ := strconv.Atoi(port) | ||||
| 	peer := &Peer{ | ||||
| 		Inbound:       inbound, | ||||
| 		Address:       address, | ||||
| 		Port:          uint16(intport), | ||||
| 		Host:          net.ParseIP(host), | ||||
| 		peerErrorChan: peerErrorChan, | ||||
| 		server:        server, | ||||
| func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { | ||||
| 	n := addr.Network() | ||||
| 	if n != "tcp" && n != "tcp4" && n != "tcp6" { | ||||
| 		// for testing with non-TCP
 | ||||
| 		return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} | ||||
| 	} | ||||
| 	connection := NewConnection(conn, peerErrorChan) | ||||
| 	peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) | ||||
| 	peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) | ||||
| 	ta := addr.(*net.TCPAddr) | ||||
| 	return &peerAddr{ta.IP, uint64(ta.Port), pubkey} | ||||
| } | ||||
| 
 | ||||
| func (d peerAddr) Network() string { | ||||
| 	if d.IP.To4() != nil { | ||||
| 		return "tcp4" | ||||
| 	} else { | ||||
| 		return "tcp6" | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (d peerAddr) String() string { | ||||
| 	return fmt.Sprintf("%v:%d", d.IP, d.Port) | ||||
| } | ||||
| 
 | ||||
| func (d peerAddr) RlpData() interface{} { | ||||
| 	return []interface{}{d.IP, d.Port, d.Pubkey} | ||||
| } | ||||
| 
 | ||||
| // Peer represents a remote peer.
 | ||||
| type Peer struct { | ||||
| 	// Peers have all the log methods.
 | ||||
| 	// Use them to display messages related to the peer.
 | ||||
| 	*logger.Logger | ||||
| 
 | ||||
| 	infolock   sync.Mutex | ||||
| 	identity   ClientIdentity | ||||
| 	caps       []Cap | ||||
| 	listenAddr *peerAddr // what remote peer is listening on
 | ||||
| 	dialAddr   *peerAddr // non-nil if dialing
 | ||||
| 
 | ||||
| 	// The mutex protects the connection
 | ||||
| 	// so only one protocol can write at a time.
 | ||||
| 	writeMu sync.Mutex | ||||
| 	conn    net.Conn | ||||
| 	bufconn *bufio.ReadWriter | ||||
| 
 | ||||
| 	// These fields maintain the running protocols.
 | ||||
| 	protocols       []Protocol | ||||
| 	runBaseProtocol bool // for testing
 | ||||
| 
 | ||||
| 	runlock sync.RWMutex // protects running
 | ||||
| 	running map[string]*proto | ||||
| 
 | ||||
| 	protoWG  sync.WaitGroup | ||||
| 	protoErr chan error | ||||
| 	closed   chan struct{} | ||||
| 	disc     chan DiscReason | ||||
| 
 | ||||
| 	activity event.TypeMux // for activity events
 | ||||
| 
 | ||||
| 	slot int // index into Server peer list
 | ||||
| 
 | ||||
| 	// These fields are kept so base protocol can access them.
 | ||||
| 	// TODO: this should be one or more interfaces
 | ||||
| 	ourID         ClientIdentity        // client id of the Server
 | ||||
| 	ourListenAddr *peerAddr             // listen addr of Server, nil if not listening
 | ||||
| 	newPeerAddr   chan<- *peerAddr      // tell server about received peers
 | ||||
| 	otherPeers    func() []*Peer        // should return the list of all peers
 | ||||
| 	pubkeyHook    func(*peerAddr) error // called at end of handshake to validate pubkey
 | ||||
| } | ||||
| 
 | ||||
| // NewPeer returns a peer for testing purposes.
 | ||||
| func NewPeer(id ClientIdentity, caps []Cap) *Peer { | ||||
| 	conn, _ := net.Pipe() | ||||
| 	peer := newPeer(conn, nil, nil) | ||||
| 	peer.setHandshakeInfo(id, nil, caps) | ||||
| 	close(peer.closed) | ||||
| 	return peer | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) String() string { | ||||
| 	var kind string | ||||
| 	if self.Inbound { | ||||
| 		kind = "inbound" | ||||
| 	} else { | ||||
| func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { | ||||
| 	p := newPeer(conn, server.Protocols, dialAddr) | ||||
| 	p.ourID = server.Identity | ||||
| 	p.newPeerAddr = server.peerConnect | ||||
| 	p.otherPeers = server.Peers | ||||
| 	p.pubkeyHook = server.verifyPeer | ||||
| 	p.runBaseProtocol = true | ||||
| 
 | ||||
| 	// laddr can be updated concurrently by NAT traversal.
 | ||||
| 	// newServerPeer must be called with the server lock held.
 | ||||
| 	if server.laddr != nil { | ||||
| 		p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { | ||||
| 	p := &Peer{ | ||||
| 		Logger:    logger.NewLogger("P2P " + conn.RemoteAddr().String()), | ||||
| 		conn:      conn, | ||||
| 		dialAddr:  dialAddr, | ||||
| 		bufconn:   bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), | ||||
| 		protocols: protocols, | ||||
| 		running:   make(map[string]*proto), | ||||
| 		disc:      make(chan DiscReason), | ||||
| 		protoErr:  make(chan error), | ||||
| 		closed:    make(chan struct{}), | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
| 
 | ||||
| // Identity returns the client identity of the remote peer. The
 | ||||
| // identity can be nil if the peer has not yet completed the
 | ||||
| // handshake.
 | ||||
| func (p *Peer) Identity() ClientIdentity { | ||||
| 	p.infolock.Lock() | ||||
| 	defer p.infolock.Unlock() | ||||
| 	return p.identity | ||||
| } | ||||
| 
 | ||||
| // Caps returns the capabilities (supported subprotocols) of the remote peer.
 | ||||
| func (p *Peer) Caps() []Cap { | ||||
| 	p.infolock.Lock() | ||||
| 	defer p.infolock.Unlock() | ||||
| 	return p.caps | ||||
| } | ||||
| 
 | ||||
| func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { | ||||
| 	p.infolock.Lock() | ||||
| 	p.identity = id | ||||
| 	p.listenAddr = laddr | ||||
| 	p.caps = caps | ||||
| 	p.infolock.Unlock() | ||||
| } | ||||
| 
 | ||||
| // RemoteAddr returns the remote address of the network connection.
 | ||||
| func (p *Peer) RemoteAddr() net.Addr { | ||||
| 	return p.conn.RemoteAddr() | ||||
| } | ||||
| 
 | ||||
| // LocalAddr returns the local address of the network connection.
 | ||||
| func (p *Peer) LocalAddr() net.Addr { | ||||
| 	return p.conn.LocalAddr() | ||||
| } | ||||
| 
 | ||||
| // Disconnect terminates the peer connection with the given reason.
 | ||||
| // It returns immediately and does not wait until the connection is closed.
 | ||||
| func (p *Peer) Disconnect(reason DiscReason) { | ||||
| 	select { | ||||
| 	case p.disc <- reason: | ||||
| 	case <-p.closed: | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // String implements fmt.Stringer.
 | ||||
| func (p *Peer) String() string { | ||||
| 	kind := "inbound" | ||||
| 	p.infolock.Lock() | ||||
| 	if p.dialAddr != nil { | ||||
| 		kind = "outbound" | ||||
| 	} | ||||
| 	return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) | ||||
| 	p.infolock.Unlock() | ||||
| 	return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) Write(protocol string, msg *Msg) error { | ||||
| 	return self.messenger.Write(protocol, msg) | ||||
| const ( | ||||
| 	// maximum amount of time allowed for reading a message
 | ||||
| 	msgReadTimeout = 5 * time.Second | ||||
| 	// maximum amount of time allowed for writing a message
 | ||||
| 	msgWriteTimeout = 5 * time.Second | ||||
| 	// messages smaller than this many bytes will be read at
 | ||||
| 	// once before passing them to a protocol.
 | ||||
| 	wholePayloadSize = 64 * 1024 | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	inactivityTimeout     = 2 * time.Second | ||||
| 	disconnectGracePeriod = 2 * time.Second | ||||
| ) | ||||
| 
 | ||||
| func (p *Peer) loop() (reason DiscReason, err error) { | ||||
| 	defer p.activity.Stop() | ||||
| 	defer p.closeProtocols() | ||||
| 	defer close(p.closed) | ||||
| 	defer p.conn.Close() | ||||
| 
 | ||||
| 	// read loop
 | ||||
| 	readMsg := make(chan Msg) | ||||
| 	readErr := make(chan error) | ||||
| 	readNext := make(chan bool, 1) | ||||
| 	protoDone := make(chan struct{}, 1) | ||||
| 	go p.readLoop(readMsg, readErr, readNext) | ||||
| 	readNext <- true | ||||
| 
 | ||||
| 	if p.runBaseProtocol { | ||||
| 		p.startBaseProtocol() | ||||
| 	} | ||||
| 
 | ||||
| loop: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case msg := <-readMsg: | ||||
| 			// a new message has arrived.
 | ||||
| 			var wait bool | ||||
| 			if wait, err = p.dispatch(msg, protoDone); err != nil { | ||||
| 				p.Errorf("msg dispatch error: %v\n", err) | ||||
| 				reason = discReasonForError(err) | ||||
| 				break loop | ||||
| 			} | ||||
| 			if !wait { | ||||
| 				// Msg has already been read completely, continue with next message.
 | ||||
| 				readNext <- true | ||||
| 			} | ||||
| 			p.activity.Post(time.Now()) | ||||
| 		case <-protoDone: | ||||
| 			// protocol has consumed the message payload,
 | ||||
| 			// we can continue reading from the socket.
 | ||||
| 			readNext <- true | ||||
| 
 | ||||
| 		case err := <-readErr: | ||||
| 			// read failed. there is no need to run the
 | ||||
| 			// polite disconnect sequence because the connection
 | ||||
| 			// is probably dead anyway.
 | ||||
| 			// TODO: handle write errors as well
 | ||||
| 			return DiscNetworkError, err | ||||
| 		case err = <-p.protoErr: | ||||
| 			reason = discReasonForError(err) | ||||
| 			break loop | ||||
| 		case reason = <-p.disc: | ||||
| 			break loop | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// wait for read loop to return.
 | ||||
| 	close(readNext) | ||||
| 	<-readErr | ||||
| 	// tell the remote end to disconnect
 | ||||
| 	done := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) | ||||
| 		p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) | ||||
| 		io.Copy(ioutil.Discard, p.conn) | ||||
| 		close(done) | ||||
| 	}() | ||||
| 	select { | ||||
| 	case <-done: | ||||
| 	case <-time.After(disconnectGracePeriod): | ||||
| 	} | ||||
| 	return reason, err | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) Start() { | ||||
| 	self.peerErrorHandler.Start() | ||||
| 	self.messenger.Start() | ||||
| func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { | ||||
| 	for _ = range unblock { | ||||
| 		p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) | ||||
| 		if msg, err := readMsg(p.bufconn); err != nil { | ||||
| 			errc <- err | ||||
| 		} else { | ||||
| 			msgc <- msg | ||||
| 		} | ||||
| 	} | ||||
| 	close(errc) | ||||
| } | ||||
| 
 | ||||
| func (self *Peer) Stop() { | ||||
| 	self.peerErrorHandler.Stop() | ||||
| 	self.messenger.Stop() | ||||
| 	// q := make(chan bool)
 | ||||
| 	// self.quit <- q
 | ||||
| 	// <-q
 | ||||
| func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { | ||||
| 	proto, err := p.getProto(msg.Code) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	if msg.Size <= wholePayloadSize { | ||||
| 		// optimization: msg is small enough, read all
 | ||||
| 		// of it and move on to the next message
 | ||||
| 		buf, err := ioutil.ReadAll(msg.Payload) | ||||
| 		if err != nil { | ||||
| 			return false, err | ||||
| 		} | ||||
| 		msg.Payload = bytes.NewReader(buf) | ||||
| 		proto.in <- msg | ||||
| 	} else { | ||||
| 		wait = true | ||||
| 		pr := &eofSignal{msg.Payload, protoDone} | ||||
| 		msg.Payload = pr | ||||
| 		proto.in <- msg | ||||
| 	} | ||||
| 	return wait, nil | ||||
| } | ||||
| 
 | ||||
| func (p *Peer) Encode() []interface{} { | ||||
| 	return []interface{}{p.Host, p.Port, p.Pubkey} | ||||
| func (p *Peer) startBaseProtocol() { | ||||
| 	p.runlock.Lock() | ||||
| 	defer p.runlock.Unlock() | ||||
| 	p.running[""] = p.startProto(0, Protocol{ | ||||
| 		Length: baseProtocolLength, | ||||
| 		Run:    runBaseProtocol, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // startProtocols starts matching named subprotocols.
 | ||||
| func (p *Peer) startSubprotocols(caps []Cap) { | ||||
| 	sort.Sort(capsByName(caps)) | ||||
| 
 | ||||
| 	p.runlock.Lock() | ||||
| 	defer p.runlock.Unlock() | ||||
| 	offset := baseProtocolLength | ||||
| outer: | ||||
| 	for _, cap := range caps { | ||||
| 		for _, proto := range p.protocols { | ||||
| 			if proto.Name == cap.Name && | ||||
| 				proto.Version == cap.Version && | ||||
| 				p.running[cap.Name] == nil { | ||||
| 				p.running[cap.Name] = p.startProto(offset, proto) | ||||
| 				offset += proto.Length | ||||
| 				continue outer | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *Peer) startProto(offset uint64, impl Protocol) *proto { | ||||
| 	rw := &proto{ | ||||
| 		in:      make(chan Msg), | ||||
| 		offset:  offset, | ||||
| 		maxcode: impl.Length, | ||||
| 		peer:    p, | ||||
| 	} | ||||
| 	p.protoWG.Add(1) | ||||
| 	go func() { | ||||
| 		err := impl.Run(p, rw) | ||||
| 		if err == nil { | ||||
| 			p.Infof("protocol %q returned", impl.Name) | ||||
| 			err = newPeerError(errMisc, "protocol returned") | ||||
| 		} else { | ||||
| 			p.Errorf("protocol %q error: %v\n", impl.Name, err) | ||||
| 		} | ||||
| 		select { | ||||
| 		case p.protoErr <- err: | ||||
| 		case <-p.closed: | ||||
| 		} | ||||
| 		p.protoWG.Done() | ||||
| 	}() | ||||
| 	return rw | ||||
| } | ||||
| 
 | ||||
| // getProto finds the protocol responsible for handling
 | ||||
| // the given message code.
 | ||||
| func (p *Peer) getProto(code uint64) (*proto, error) { | ||||
| 	p.runlock.RLock() | ||||
| 	defer p.runlock.RUnlock() | ||||
| 	for _, proto := range p.running { | ||||
| 		if code >= proto.offset && code < proto.offset+proto.maxcode { | ||||
| 			return proto, nil | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, newPeerError(errInvalidMsgCode, "%d", code) | ||||
| } | ||||
| 
 | ||||
| func (p *Peer) closeProtocols() { | ||||
| 	p.runlock.RLock() | ||||
| 	for _, p := range p.running { | ||||
| 		close(p.in) | ||||
| 	} | ||||
| 	p.runlock.RUnlock() | ||||
| 	p.protoWG.Wait() | ||||
| } | ||||
| 
 | ||||
| // writeProtoMsg sends the given message on behalf of the given named protocol.
 | ||||
| func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { | ||||
| 	p.runlock.RLock() | ||||
| 	proto, ok := p.running[protoName] | ||||
| 	p.runlock.RUnlock() | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("protocol %s not handled by peer", protoName) | ||||
| 	} | ||||
| 	if msg.Code >= proto.maxcode { | ||||
| 		return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) | ||||
| 	} | ||||
| 	msg.Code += proto.offset | ||||
| 	return p.writeMsg(msg, msgWriteTimeout) | ||||
| } | ||||
| 
 | ||||
| // writeMsg writes a message to the connection.
 | ||||
| func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { | ||||
| 	p.writeMu.Lock() | ||||
| 	defer p.writeMu.Unlock() | ||||
| 	p.conn.SetWriteDeadline(time.Now().Add(timeout)) | ||||
| 	if err := writeMsg(p.bufconn, msg); err != nil { | ||||
| 		return newPeerError(errWrite, "%v", err) | ||||
| 	} | ||||
| 	return p.bufconn.Flush() | ||||
| } | ||||
| 
 | ||||
| type proto struct { | ||||
| 	name            string | ||||
| 	in              chan Msg | ||||
| 	maxcode, offset uint64 | ||||
| 	peer            *Peer | ||||
| } | ||||
| 
 | ||||
| func (rw *proto) WriteMsg(msg Msg) error { | ||||
| 	if msg.Code >= rw.maxcode { | ||||
| 		return newPeerError(errInvalidMsgCode, "not handled") | ||||
| 	} | ||||
| 	msg.Code += rw.offset | ||||
| 	return rw.peer.writeMsg(msg, msgWriteTimeout) | ||||
| } | ||||
| 
 | ||||
| func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { | ||||
| 	return rw.WriteMsg(NewMsg(code, data)) | ||||
| } | ||||
| 
 | ||||
| func (rw *proto) ReadMsg() (Msg, error) { | ||||
| 	msg, ok := <-rw.in | ||||
| 	if !ok { | ||||
| 		return msg, io.EOF | ||||
| 	} | ||||
| 	msg.Code -= rw.offset | ||||
| 	return msg, nil | ||||
| } | ||||
| 
 | ||||
| // eofSignal wraps a reader with eof signaling.
 | ||||
| // the eof channel is closed when the wrapped reader
 | ||||
| // reaches EOF.
 | ||||
| type eofSignal struct { | ||||
| 	wrapped io.Reader | ||||
| 	eof     chan<- struct{} | ||||
| } | ||||
| 
 | ||||
| func (r *eofSignal) Read(buf []byte) (int, error) { | ||||
| 	n, err := r.wrapped.Read(buf) | ||||
| 	if err != nil { | ||||
| 		r.eof <- struct{}{} // tell Peer that msg has been consumed
 | ||||
| 	} | ||||
| 	return n, err | ||||
| } | ||||
|  | ||||
| @ -4,73 +4,121 @@ import ( | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| type ErrorCode int | ||||
| 
 | ||||
| const errorChanCapacity = 10 | ||||
| 
 | ||||
| const ( | ||||
| 	PacketTooShort = iota | ||||
| 	PayloadTooShort | ||||
| 	MagicTokenMismatch | ||||
| 	EmptyPayload | ||||
| 	ReadError | ||||
| 	WriteError | ||||
| 	MiscError | ||||
| 	InvalidMsgCode | ||||
| 	InvalidMsg | ||||
| 	P2PVersionMismatch | ||||
| 	PubkeyMissing | ||||
| 	PubkeyInvalid | ||||
| 	PubkeyForbidden | ||||
| 	ProtocolBreach | ||||
| 	PortMismatch | ||||
| 	PingTimeout | ||||
| 	InvalidGenesis | ||||
| 	InvalidNetworkId | ||||
| 	InvalidProtocolVersion | ||||
| 	errMagicTokenMismatch = iota | ||||
| 	errRead | ||||
| 	errWrite | ||||
| 	errMisc | ||||
| 	errInvalidMsgCode | ||||
| 	errInvalidMsg | ||||
| 	errP2PVersionMismatch | ||||
| 	errPubkeyMissing | ||||
| 	errPubkeyInvalid | ||||
| 	errPubkeyForbidden | ||||
| 	errProtocolBreach | ||||
| 	errPingTimeout | ||||
| 	errInvalidNetworkId | ||||
| 	errInvalidProtocolVersion | ||||
| ) | ||||
| 
 | ||||
| var errorToString = map[ErrorCode]string{ | ||||
| 	PacketTooShort:         "Packet too short", | ||||
| 	PayloadTooShort:        "Payload too short", | ||||
| 	MagicTokenMismatch:     "Magic token mismatch", | ||||
| 	EmptyPayload:           "Empty payload", | ||||
| 	ReadError:              "Read error", | ||||
| 	WriteError:             "Write error", | ||||
| 	MiscError:              "Misc error", | ||||
| 	InvalidMsgCode:         "Invalid message code", | ||||
| 	InvalidMsg:             "Invalid message", | ||||
| 	P2PVersionMismatch:     "P2P Version Mismatch", | ||||
| 	PubkeyMissing:          "Public key missing", | ||||
| 	PubkeyInvalid:          "Public key invalid", | ||||
| 	PubkeyForbidden:        "Public key forbidden", | ||||
| 	ProtocolBreach:         "Protocol Breach", | ||||
| 	PortMismatch:           "Port mismatch", | ||||
| 	PingTimeout:            "Ping timeout", | ||||
| 	InvalidGenesis:         "Invalid genesis block", | ||||
| 	InvalidNetworkId:       "Invalid network id", | ||||
| 	InvalidProtocolVersion: "Invalid protocol version", | ||||
| var errorToString = map[int]string{ | ||||
| 	errMagicTokenMismatch:     "Magic token mismatch", | ||||
| 	errRead:                   "Read error", | ||||
| 	errWrite:                  "Write error", | ||||
| 	errMisc:                   "Misc error", | ||||
| 	errInvalidMsgCode:         "Invalid message code", | ||||
| 	errInvalidMsg:             "Invalid message", | ||||
| 	errP2PVersionMismatch:     "P2P Version Mismatch", | ||||
| 	errPubkeyMissing:          "Public key missing", | ||||
| 	errPubkeyInvalid:          "Public key invalid", | ||||
| 	errPubkeyForbidden:        "Public key forbidden", | ||||
| 	errProtocolBreach:         "Protocol Breach", | ||||
| 	errPingTimeout:            "Ping timeout", | ||||
| 	errInvalidNetworkId:       "Invalid network id", | ||||
| 	errInvalidProtocolVersion: "Invalid protocol version", | ||||
| } | ||||
| 
 | ||||
| type PeerError struct { | ||||
| 	Code    ErrorCode | ||||
| type peerError struct { | ||||
| 	Code    int | ||||
| 	message string | ||||
| } | ||||
| 
 | ||||
| func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { | ||||
| func newPeerError(code int, format string, v ...interface{}) *peerError { | ||||
| 	desc, ok := errorToString[code] | ||||
| 	if !ok { | ||||
| 		panic("invalid error code") | ||||
| 	} | ||||
| 	format = desc + ": " + format | ||||
| 	message := fmt.Sprintf(format, v...) | ||||
| 	return &PeerError{code, message} | ||||
| 	err := &peerError{code, desc} | ||||
| 	if format != "" { | ||||
| 		err.message += ": " + fmt.Sprintf(format, v...) | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (self *PeerError) Error() string { | ||||
| func (self *peerError) Error() string { | ||||
| 	return self.message | ||||
| } | ||||
| 
 | ||||
| func NewPeerErrorChannel() chan *PeerError { | ||||
| 	return make(chan *PeerError, errorChanCapacity) | ||||
| type DiscReason byte | ||||
| 
 | ||||
| const ( | ||||
| 	DiscRequested           DiscReason = 0x00 | ||||
| 	DiscNetworkError                   = 0x01 | ||||
| 	DiscProtocolError                  = 0x02 | ||||
| 	DiscUselessPeer                    = 0x03 | ||||
| 	DiscTooManyPeers                   = 0x04 | ||||
| 	DiscAlreadyConnected               = 0x05 | ||||
| 	DiscIncompatibleVersion            = 0x06 | ||||
| 	DiscInvalidIdentity                = 0x07 | ||||
| 	DiscQuitting                       = 0x08 | ||||
| 	DiscUnexpectedIdentity             = 0x09 | ||||
| 	DiscSelf                           = 0x0a | ||||
| 	DiscReadTimeout                    = 0x0b | ||||
| 	DiscSubprotocolError               = 0x10 | ||||
| ) | ||||
| 
 | ||||
| var discReasonToString = [DiscSubprotocolError + 1]string{ | ||||
| 	DiscRequested:           "Disconnect requested", | ||||
| 	DiscNetworkError:        "Network error", | ||||
| 	DiscProtocolError:       "Breach of protocol", | ||||
| 	DiscUselessPeer:         "Useless peer", | ||||
| 	DiscTooManyPeers:        "Too many peers", | ||||
| 	DiscAlreadyConnected:    "Already connected", | ||||
| 	DiscIncompatibleVersion: "Incompatible P2P protocol version", | ||||
| 	DiscInvalidIdentity:     "Invalid node identity", | ||||
| 	DiscQuitting:            "Client quitting", | ||||
| 	DiscUnexpectedIdentity:  "Unexpected identity", | ||||
| 	DiscSelf:                "Connected to self", | ||||
| 	DiscReadTimeout:         "Read timeout", | ||||
| 	DiscSubprotocolError:    "Subprotocol error", | ||||
| } | ||||
| 
 | ||||
| func (d DiscReason) String() string { | ||||
| 	if len(discReasonToString) < int(d) { | ||||
| 		return fmt.Sprintf("Unknown Reason(%d)", d) | ||||
| 	} | ||||
| 	return discReasonToString[d] | ||||
| } | ||||
| 
 | ||||
| func discReasonForError(err error) DiscReason { | ||||
| 	peerError, ok := err.(*peerError) | ||||
| 	if !ok { | ||||
| 		return DiscSubprotocolError | ||||
| 	} | ||||
| 	switch peerError.Code { | ||||
| 	case errP2PVersionMismatch: | ||||
| 		return DiscIncompatibleVersion | ||||
| 	case errPubkeyMissing, errPubkeyInvalid: | ||||
| 		return DiscInvalidIdentity | ||||
| 	case errPubkeyForbidden: | ||||
| 		return DiscUselessPeer | ||||
| 	case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: | ||||
| 		return DiscProtocolError | ||||
| 	case errPingTimeout: | ||||
| 		return DiscReadTimeout | ||||
| 	case errRead, errWrite, errMisc: | ||||
| 		return DiscNetworkError | ||||
| 	default: | ||||
| 		return DiscSubprotocolError | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -1,101 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"net" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	severityThreshold = 10 | ||||
| ) | ||||
| 
 | ||||
| type DisconnectRequest struct { | ||||
| 	addr   net.Addr | ||||
| 	reason DiscReason | ||||
| } | ||||
| 
 | ||||
| type PeerErrorHandler struct { | ||||
| 	quit           chan chan bool | ||||
| 	address        net.Addr | ||||
| 	peerDisconnect chan DisconnectRequest | ||||
| 	severity       int | ||||
| 	peerErrorChan  chan *PeerError | ||||
| 	blacklist      Blacklist | ||||
| } | ||||
| 
 | ||||
| func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { | ||||
| 	return &PeerErrorHandler{ | ||||
| 		quit:           make(chan chan bool), | ||||
| 		address:        address, | ||||
| 		peerDisconnect: peerDisconnect, | ||||
| 		peerErrorChan:  peerErrorChan, | ||||
| 		blacklist:      blacklist, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *PeerErrorHandler) Start() { | ||||
| 	go self.listen() | ||||
| } | ||||
| 
 | ||||
| func (self *PeerErrorHandler) Stop() { | ||||
| 	q := make(chan bool) | ||||
| 	self.quit <- q | ||||
| 	<-q | ||||
| } | ||||
| 
 | ||||
| func (self *PeerErrorHandler) listen() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peerError, ok := <-self.peerErrorChan: | ||||
| 			if ok { | ||||
| 				logger.Debugf("error %v\n", peerError) | ||||
| 				go self.handle(peerError) | ||||
| 			} else { | ||||
| 				return | ||||
| 			} | ||||
| 		case q := <-self.quit: | ||||
| 			q <- true | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *PeerErrorHandler) handle(peerError *PeerError) { | ||||
| 	reason := DiscReason(' ') | ||||
| 	switch peerError.Code { | ||||
| 	case P2PVersionMismatch: | ||||
| 		reason = DiscIncompatibleVersion | ||||
| 	case PubkeyMissing, PubkeyInvalid: | ||||
| 		reason = DiscInvalidIdentity | ||||
| 	case PubkeyForbidden: | ||||
| 		reason = DiscUselessPeer | ||||
| 	case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: | ||||
| 		reason = DiscProtocolError | ||||
| 	case PingTimeout: | ||||
| 		reason = DiscReadTimeout | ||||
| 	case WriteError, MiscError: | ||||
| 		reason = DiscNetworkError | ||||
| 	case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: | ||||
| 		reason = DiscSubprotocolError | ||||
| 	default: | ||||
| 		self.severity += self.getSeverity(peerError) | ||||
| 	} | ||||
| 
 | ||||
| 	if self.severity >= severityThreshold { | ||||
| 		reason = DiscSubprotocolError | ||||
| 	} | ||||
| 	if reason != DiscReason(' ') { | ||||
| 		self.peerDisconnect <- DisconnectRequest{ | ||||
| 			addr:   self.address, | ||||
| 			reason: reason, | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { | ||||
| 	switch peerError.Code { | ||||
| 	case ReadError: | ||||
| 		return 4 //tolerate 3 :)
 | ||||
| 	default: | ||||
| 		return 1 | ||||
| 	} | ||||
| } | ||||
| @ -1,34 +0,0 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	// "fmt"
 | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestPeerErrorHandler(t *testing.T) { | ||||
| 	address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} | ||||
| 	peerDisconnect := make(chan DisconnectRequest) | ||||
| 	peerErrorChan := NewPeerErrorChannel() | ||||
| 	peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) | ||||
| 	peh.Start() | ||||
| 	defer peh.Stop() | ||||
| 	for i := 0; i < 11; i++ { | ||||
| 		select { | ||||
| 		case <-peerDisconnect: | ||||
| 			t.Errorf("expected no disconnect request") | ||||
| 		default: | ||||
| 		} | ||||
| 		peerErrorChan <- NewPeerError(MiscError, "") | ||||
| 	} | ||||
| 	time.Sleep(1 * time.Millisecond) | ||||
| 	select { | ||||
| 	case request := <-peerDisconnect: | ||||
| 		if request.addr.String() != address.String() { | ||||
| 			t.Errorf("incorrect address %v != %v", request.addr, address) | ||||
| 		} | ||||
| 	default: | ||||
| 		t.Errorf("expected disconnect request") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										303
									
								
								p2p/peer_test.go
									
									
									
									
									
								
							
							
						
						
									
										303
									
								
								p2p/peer_test.go
									
									
									
									
									
								
							| @ -1,96 +1,239 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	// "net"
 | ||||
| 	"encoding/hex" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestPeer(t *testing.T) { | ||||
| 	handlers := make(Handlers) | ||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} | ||||
| 	handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } | ||||
| 	handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } | ||||
| 	addr := &TestAddr{"test:30"} | ||||
| 	conn := NewTestNetworkConnection(addr) | ||||
| 	_, server := SetupTestServer(handlers) | ||||
| 	server.Handshake() | ||||
| 	peer := NewPeer(conn, addr, true, server) | ||||
| 	// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
 | ||||
| 	peer.Start() | ||||
| 	defer peer.Stop() | ||||
| 	time.Sleep(2 * time.Millisecond) | ||||
| 	if len(conn.Out) != 1 { | ||||
| 		t.Errorf("handshake not sent") | ||||
| 	} else { | ||||
| 		out := conn.Out[0] | ||||
| 		packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) | ||||
| 		if bytes.Compare(out, packet) != 0 { | ||||
| 			t.Errorf("incorrect handshake packet %v != %v", out, packet) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) | ||||
| 	conn.In(0, packet) | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 
 | ||||
| 	pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) | ||||
| 	if pro.state != handshakeReceived { | ||||
| 		t.Errorf("handshake not received") | ||||
| 	} | ||||
| 	if peer.Port != 30 { | ||||
| 		t.Errorf("port incorrectly set") | ||||
| 	} | ||||
| 	if peer.Id != "peer" { | ||||
| 		t.Errorf("id incorrectly set") | ||||
| 	} | ||||
| 	if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { | ||||
| 		t.Errorf("pubkey incorrectly set") | ||||
| 	} | ||||
| 	fmt.Println(peer.Caps) | ||||
| 	if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { | ||||
| 		t.Errorf("protocols incorrectly set") | ||||
| 	} | ||||
| 
 | ||||
| 	msg, _ := NewMsg(3) | ||||
| 	err := peer.Write("aaa", msg) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expect no error for known protocol: %v", err) | ||||
| 	} else { | ||||
| 		time.Sleep(1 * time.Millisecond) | ||||
| 		if len(conn.Out) != 2 { | ||||
| 			t.Errorf("msg not written") | ||||
| 		} else { | ||||
| 			out := conn.Out[1] | ||||
| 			packet := Packet(16, 3) | ||||
| 			if bytes.Compare(out, packet) != 0 { | ||||
| 				t.Errorf("incorrect packet %v != %v", out, packet) | ||||
| var discard = Protocol{ | ||||
| 	Name:   "discard", | ||||
| 	Length: 1, | ||||
| 	Run: func(p *Peer, rw MsgReadWriter) error { | ||||
| 		for { | ||||
| 			msg, err := rw.ReadMsg() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if err = msg.Discard(); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	}, | ||||
| } | ||||
| 
 | ||||
| 	msg, _ = NewMsg(2) | ||||
| 	err = peer.Write("ccc", msg) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expect no error for known protocol: %v", err) | ||||
| 	} else { | ||||
| 		time.Sleep(1 * time.Millisecond) | ||||
| 		if len(conn.Out) != 3 { | ||||
| 			t.Errorf("msg not written") | ||||
| 		} else { | ||||
| 			out := conn.Out[2] | ||||
| 			packet := Packet(21, 2) | ||||
| 			if bytes.Compare(out, packet) != 0 { | ||||
| 				t.Errorf("incorrect packet %v != %v", out, packet) | ||||
| func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { | ||||
| 	conn1, conn2 := net.Pipe() | ||||
| 	id := NewSimpleClientIdentity("test", "0", "0", "public key") | ||||
| 	peer := newPeer(conn1, protos, nil) | ||||
| 	peer.ourID = id | ||||
| 	peer.pubkeyHook = func(*peerAddr) error { return nil } | ||||
| 	errc := make(chan error, 1) | ||||
| 	go func() { | ||||
| 		_, err := peer.loop() | ||||
| 		errc <- err | ||||
| 	}() | ||||
| 	return conn2, peer, errc | ||||
| } | ||||
| 
 | ||||
| func TestPeerProtoReadMsg(t *testing.T) { | ||||
| 	defer testlog(t).detach() | ||||
| 
 | ||||
| 	done := make(chan struct{}) | ||||
| 	proto := Protocol{ | ||||
| 		Name:   "a", | ||||
| 		Length: 5, | ||||
| 		Run: func(peer *Peer, rw MsgReadWriter) error { | ||||
| 			msg, err := rw.ReadMsg() | ||||
| 			if err != nil { | ||||
| 				t.Errorf("read error: %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 			if msg.Code != 2 { | ||||
| 				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) | ||||
| 			} | ||||
| 			data, err := ioutil.ReadAll(msg.Payload) | ||||
| 			if err != nil { | ||||
| 				t.Errorf("payload read error: %v", err) | ||||
| 			} | ||||
| 			expdata, _ := hex.DecodeString("0183303030") | ||||
| 			if !bytes.Equal(expdata, data) { | ||||
| 				t.Errorf("incorrect msg data %x", data) | ||||
| 			} | ||||
| 			close(done) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	err = peer.Write("bbb", msg) | ||||
| 	time.Sleep(1 * time.Millisecond) | ||||
| 	if err == nil { | ||||
| 		t.Errorf("expect error for unknown protocol") | ||||
| 	net, peer, errc := testPeer([]Protocol{proto}) | ||||
| 	defer net.Close() | ||||
| 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||
| 
 | ||||
| 	writeMsg(net, NewMsg(18, 1, "000")) | ||||
| 	select { | ||||
| 	case <-done: | ||||
| 	case err := <-errc: | ||||
| 		t.Errorf("peer returned: %v", err) | ||||
| 	case <-time.After(2 * time.Second): | ||||
| 		t.Errorf("receive timeout") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPeerProtoReadLargeMsg(t *testing.T) { | ||||
| 	defer testlog(t).detach() | ||||
| 
 | ||||
| 	msgsize := uint32(10 * 1024 * 1024) | ||||
| 	done := make(chan struct{}) | ||||
| 	proto := Protocol{ | ||||
| 		Name:   "a", | ||||
| 		Length: 5, | ||||
| 		Run: func(peer *Peer, rw MsgReadWriter) error { | ||||
| 			msg, err := rw.ReadMsg() | ||||
| 			if err != nil { | ||||
| 				t.Errorf("read error: %v", err) | ||||
| 			} | ||||
| 			if msg.Size != msgsize+4 { | ||||
| 				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) | ||||
| 			} | ||||
| 			msg.Discard() | ||||
| 			close(done) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	net, peer, errc := testPeer([]Protocol{proto}) | ||||
| 	defer net.Close() | ||||
| 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||
| 
 | ||||
| 	writeMsg(net, NewMsg(18, make([]byte, msgsize))) | ||||
| 	select { | ||||
| 	case <-done: | ||||
| 	case err := <-errc: | ||||
| 		t.Errorf("peer returned: %v", err) | ||||
| 	case <-time.After(2 * time.Second): | ||||
| 		t.Errorf("receive timeout") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPeerProtoEncodeMsg(t *testing.T) { | ||||
| 	defer testlog(t).detach() | ||||
| 
 | ||||
| 	proto := Protocol{ | ||||
| 		Name:   "a", | ||||
| 		Length: 2, | ||||
| 		Run: func(peer *Peer, rw MsgReadWriter) error { | ||||
| 			if err := rw.EncodeMsg(2); err == nil { | ||||
| 				t.Error("expected error for out-of-range msg code, got nil") | ||||
| 			} | ||||
| 			if err := rw.EncodeMsg(1); err != nil { | ||||
| 				t.Errorf("write error: %v", err) | ||||
| 			} | ||||
| 			return nil | ||||
| 		}, | ||||
| 	} | ||||
| 	net, peer, _ := testPeer([]Protocol{proto}) | ||||
| 	defer net.Close() | ||||
| 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||
| 
 | ||||
| 	bufr := bufio.NewReader(net) | ||||
| 	msg, err := readMsg(bufr) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("read error: %v", err) | ||||
| 	} | ||||
| 	if msg.Code != 17 { | ||||
| 		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPeerWrite(t *testing.T) { | ||||
| 	defer testlog(t).detach() | ||||
| 
 | ||||
| 	net, peer, peerErr := testPeer([]Protocol{discard}) | ||||
| 	defer net.Close() | ||||
| 	peer.startSubprotocols([]Cap{discard.cap()}) | ||||
| 
 | ||||
| 	// test write errors
 | ||||
| 	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { | ||||
| 		t.Errorf("expected error for unknown protocol, got nil") | ||||
| 	} | ||||
| 	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { | ||||
| 		t.Errorf("expected error for out-of-range msg code, got nil") | ||||
| 	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { | ||||
| 		t.Errorf("wrong error for out-of-range msg code, got %#v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// setup for reading the message on the other end
 | ||||
| 	read := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		bufr := bufio.NewReader(net) | ||||
| 		msg, err := readMsg(bufr) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("read error: %v", err) | ||||
| 		} else if msg.Code != 16 { | ||||
| 			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) | ||||
| 		} | ||||
| 		msg.Discard() | ||||
| 		close(read) | ||||
| 	}() | ||||
| 
 | ||||
| 	// test succcessful write
 | ||||
| 	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { | ||||
| 		t.Errorf("expect no error for known protocol: %v", err) | ||||
| 	} | ||||
| 	select { | ||||
| 	case <-read: | ||||
| 	case err := <-peerErr: | ||||
| 		t.Fatalf("peer stopped: %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPeerActivity(t *testing.T) { | ||||
| 	// shorten inactivityTimeout while this test is running
 | ||||
| 	oldT := inactivityTimeout | ||||
| 	defer func() { inactivityTimeout = oldT }() | ||||
| 	inactivityTimeout = 20 * time.Millisecond | ||||
| 
 | ||||
| 	net, peer, peerErr := testPeer([]Protocol{discard}) | ||||
| 	defer net.Close() | ||||
| 	peer.startSubprotocols([]Cap{discard.cap()}) | ||||
| 
 | ||||
| 	sub := peer.activity.Subscribe(time.Time{}) | ||||
| 	defer sub.Unsubscribe() | ||||
| 
 | ||||
| 	for i := 0; i < 6; i++ { | ||||
| 		writeMsg(net, NewMsg(16)) | ||||
| 		select { | ||||
| 		case <-sub.Chan(): | ||||
| 		case <-time.After(inactivityTimeout / 2): | ||||
| 			t.Fatal("no event within ", inactivityTimeout/2) | ||||
| 		case err := <-peerErr: | ||||
| 			t.Fatal("peer error", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	select { | ||||
| 	case <-time.After(inactivityTimeout * 2): | ||||
| 	case <-sub.Chan(): | ||||
| 		t.Fatal("got activity event while connection was inactive") | ||||
| 	case err := <-peerErr: | ||||
| 		t.Fatal("peer error", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNewPeer(t *testing.T) { | ||||
| 	id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") | ||||
| 	caps := []Cap{{"foo", 2}, {"bar", 3}} | ||||
| 	p := NewPeer(id, caps) | ||||
| 	if !reflect.DeepEqual(p.Caps(), caps) { | ||||
| 		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) | ||||
| 	} | ||||
| 	if p.Identity() != id { | ||||
| 		t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) | ||||
| 	} | ||||
| 	// Should not hang.
 | ||||
| 	p.Disconnect(DiscAlreadyConnected) | ||||
| } | ||||
|  | ||||
							
								
								
									
										501
									
								
								p2p/protocol.go
									
									
									
									
									
								
							
							
						
						
									
										501
									
								
								p2p/protocol.go
									
									
									
									
									
								
							| @ -2,277 +2,294 @@ package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"sort" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/ethutil" | ||||
| ) | ||||
| 
 | ||||
| type Protocol interface { | ||||
| 	Start() | ||||
| 	Stop() | ||||
| 	HandleIn(*Msg, chan *Msg) | ||||
| 	HandleOut(*Msg) bool | ||||
| 	Offset() MsgCode | ||||
| 	Name() string | ||||
| // Protocol represents a P2P subprotocol implementation.
 | ||||
| type Protocol struct { | ||||
| 	// Name should contain the official protocol name,
 | ||||
| 	// often a three-letter word.
 | ||||
| 	Name string | ||||
| 
 | ||||
| 	// Version should contain the version number of the protocol.
 | ||||
| 	Version uint | ||||
| 
 | ||||
| 	// Length should contain the number of message codes used
 | ||||
| 	// by the protocol.
 | ||||
| 	Length uint64 | ||||
| 
 | ||||
| 	// Run is called in a new groutine when the protocol has been
 | ||||
| 	// negotiated with a peer. It should read and write messages from
 | ||||
| 	// rw. The Payload for each message must be fully consumed.
 | ||||
| 	//
 | ||||
| 	// The peer connection is closed when Start returns. It should return
 | ||||
| 	// any protocol-level error (such as an I/O error) that is
 | ||||
| 	// encountered.
 | ||||
| 	Run func(peer *Peer, rw MsgReadWriter) error | ||||
| } | ||||
| 
 | ||||
| func (p Protocol) cap() Cap { | ||||
| 	return Cap{p.Name, p.Version} | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
| 	P2PVersion      = 0 | ||||
| 	pingTimeout     = 2 | ||||
| 	pingGracePeriod = 2 | ||||
| 	baseProtocolVersion    = 2 | ||||
| 	baseProtocolLength     = uint64(16) | ||||
| 	baseProtocolMaxMsgSize = 10 * 1024 * 1024 | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	HandshakeMsg = iota | ||||
| 	DiscMsg | ||||
| 	PingMsg | ||||
| 	PongMsg | ||||
| 	GetPeersMsg | ||||
| 	PeersMsg | ||||
| 	offset = 16 | ||||
| 	// devp2p message codes
 | ||||
| 	handshakeMsg = 0x00 | ||||
| 	discMsg      = 0x01 | ||||
| 	pingMsg      = 0x02 | ||||
| 	pongMsg      = 0x03 | ||||
| 	getPeersMsg  = 0x04 | ||||
| 	peersMsg     = 0x05 | ||||
| ) | ||||
| 
 | ||||
| type ProtocolState uint8 | ||||
| 
 | ||||
| const ( | ||||
| 	nullState = iota | ||||
| 	handshakeReceived | ||||
| ) | ||||
| 
 | ||||
| type DiscReason byte | ||||
| 
 | ||||
| const ( | ||||
| 	// Values are given explicitly instead of by iota because these values are
 | ||||
| 	// defined by the wire protocol spec; it is easier for humans to ensure
 | ||||
| 	// correctness when values are explicit.
 | ||||
| 	DiscRequested           = 0x00 | ||||
| 	DiscNetworkError        = 0x01 | ||||
| 	DiscProtocolError       = 0x02 | ||||
| 	DiscUselessPeer         = 0x03 | ||||
| 	DiscTooManyPeers        = 0x04 | ||||
| 	DiscAlreadyConnected    = 0x05 | ||||
| 	DiscIncompatibleVersion = 0x06 | ||||
| 	DiscInvalidIdentity     = 0x07 | ||||
| 	DiscQuitting            = 0x08 | ||||
| 	DiscUnexpectedIdentity  = 0x09 | ||||
| 	DiscSelf                = 0x0a | ||||
| 	DiscReadTimeout         = 0x0b | ||||
| 	DiscSubprotocolError    = 0x10 | ||||
| ) | ||||
| 
 | ||||
| var discReasonToString = map[DiscReason]string{ | ||||
| 	DiscRequested:           "Disconnect requested", | ||||
| 	DiscNetworkError:        "Network error", | ||||
| 	DiscProtocolError:       "Breach of protocol", | ||||
| 	DiscUselessPeer:         "Useless peer", | ||||
| 	DiscTooManyPeers:        "Too many peers", | ||||
| 	DiscAlreadyConnected:    "Already connected", | ||||
| 	DiscIncompatibleVersion: "Incompatible P2P protocol version", | ||||
| 	DiscInvalidIdentity:     "Invalid node identity", | ||||
| 	DiscQuitting:            "Client quitting", | ||||
| 	DiscUnexpectedIdentity:  "Unexpected identity", | ||||
| 	DiscSelf:                "Connected to self", | ||||
| 	DiscReadTimeout:         "Read timeout", | ||||
| 	DiscSubprotocolError:    "Subprotocol error", | ||||
| // handshake is the structure of a handshake list.
 | ||||
| type handshake struct { | ||||
| 	Version    uint64 | ||||
| 	ID         string | ||||
| 	Caps       []Cap | ||||
| 	ListenPort uint64 | ||||
| 	NodeID     []byte | ||||
| } | ||||
| 
 | ||||
| func (d DiscReason) String() string { | ||||
| 	if len(discReasonToString) < int(d) { | ||||
| 		return "Unknown" | ||||
| func (h *handshake) String() string { | ||||
| 	return h.ID | ||||
| } | ||||
| func (h *handshake) Pubkey() []byte { | ||||
| 	return h.NodeID | ||||
| } | ||||
| 
 | ||||
| // Cap is the structure of a peer capability.
 | ||||
| type Cap struct { | ||||
| 	Name    string | ||||
| 	Version uint | ||||
| } | ||||
| 
 | ||||
| func (cap Cap) RlpData() interface{} { | ||||
| 	return []interface{}{cap.Name, cap.Version} | ||||
| } | ||||
| 
 | ||||
| type capsByName []Cap | ||||
| 
 | ||||
| func (cs capsByName) Len() int           { return len(cs) } | ||||
| func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } | ||||
| func (cs capsByName) Swap(i, j int)      { cs[i], cs[j] = cs[j], cs[i] } | ||||
| 
 | ||||
| type baseProtocol struct { | ||||
| 	rw   MsgReadWriter | ||||
| 	peer *Peer | ||||
| } | ||||
| 
 | ||||
| func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { | ||||
| 	bp := &baseProtocol{rw, peer} | ||||
| 	if err := bp.doHandshake(rw); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return discReasonToString[d] | ||||
| } | ||||
| 
 | ||||
| type BaseProtocol struct { | ||||
| 	peer      *Peer | ||||
| 	state     ProtocolState | ||||
| 	stateLock sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| func NewBaseProtocol(peer *Peer) *BaseProtocol { | ||||
| 	self := &BaseProtocol{ | ||||
| 		peer: peer, | ||||
| 	} | ||||
| 
 | ||||
| 	return self | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) Start() { | ||||
| 	if self.peer != nil { | ||||
| 		self.peer.Write("", self.peer.Server().Handshake()) | ||||
| 		go self.peer.Messenger().PingPong( | ||||
| 			pingTimeout*time.Second, | ||||
| 			pingGracePeriod*time.Second, | ||||
| 			self.Ping, | ||||
| 			self.Timeout, | ||||
| 		) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) Stop() { | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) Ping() { | ||||
| 	msg, _ := NewMsg(PingMsg) | ||||
| 	self.peer.Write("", msg) | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) Timeout() { | ||||
| 	self.peerError(PingTimeout, "") | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) Name() string { | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) Offset() MsgCode { | ||||
| 	return offset | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) CheckState(state ProtocolState) bool { | ||||
| 	self.stateLock.RLock() | ||||
| 	self.stateLock.RUnlock() | ||||
| 	if self.state != state { | ||||
| 		return false | ||||
| 	} else { | ||||
| 		return true | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { | ||||
| 	if msg.Code() == HandshakeMsg { | ||||
| 		self.handleHandshake(msg) | ||||
| 	} else { | ||||
| 		if !self.CheckState(handshakeReceived) { | ||||
| 			self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) | ||||
| 			close(response) | ||||
| 			return | ||||
| 	// run main loop
 | ||||
| 	quit := make(chan error, 1) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			if err := bp.handle(rw); err != nil { | ||||
| 				quit <- err | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		switch msg.Code() { | ||||
| 		case DiscMsg: | ||||
| 			logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) | ||||
| 			self.peer.Server().PeerDisconnect() <- DisconnectRequest{ | ||||
| 				addr:   self.peer.Address, | ||||
| 				reason: DiscRequested, | ||||
| 	}() | ||||
| 	return bp.loop(quit) | ||||
| } | ||||
| 
 | ||||
| var pingTimeout = 2 * time.Second | ||||
| 
 | ||||
| func (bp *baseProtocol) loop(quit <-chan error) error { | ||||
| 	ping := time.NewTimer(pingTimeout) | ||||
| 	activity := bp.peer.activity.Subscribe(time.Time{}) | ||||
| 	lastActive := time.Time{} | ||||
| 	defer ping.Stop() | ||||
| 	defer activity.Unsubscribe() | ||||
| 
 | ||||
| 	getPeersTick := time.NewTicker(10 * time.Second) | ||||
| 	defer getPeersTick.Stop() | ||||
| 	err := bp.rw.EncodeMsg(getPeersMsg) | ||||
| 
 | ||||
| 	for err == nil { | ||||
| 		select { | ||||
| 		case err = <-quit: | ||||
| 			return err | ||||
| 		case <-getPeersTick.C: | ||||
| 			err = bp.rw.EncodeMsg(getPeersMsg) | ||||
| 		case event := <-activity.Chan(): | ||||
| 			ping.Reset(pingTimeout) | ||||
| 			lastActive = event.(time.Time) | ||||
| 		case t := <-ping.C: | ||||
| 			if lastActive.Add(pingTimeout * 2).Before(t) { | ||||
| 				err = newPeerError(errPingTimeout, "") | ||||
| 			} else if lastActive.Add(pingTimeout).Before(t) { | ||||
| 				err = bp.rw.EncodeMsg(pingMsg) | ||||
| 			} | ||||
| 		case PingMsg: | ||||
| 			out, _ := NewMsg(PongMsg) | ||||
| 			response <- out | ||||
| 		case PongMsg: | ||||
| 		case GetPeersMsg: | ||||
| 			// Peer asked for list of connected peers
 | ||||
| 			if out, err := self.peer.Server().PeersMessage(); err != nil { | ||||
| 				response <- out | ||||
| 			} | ||||
| 		case PeersMsg: | ||||
| 			self.handlePeers(msg) | ||||
| 		default: | ||||
| 			self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) | ||||
| 		} | ||||
| 	} | ||||
| 	close(response) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { | ||||
| 	// somewhat overly paranoid
 | ||||
| 	allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { | ||||
| 	err := NewPeerError(errorCode, format, v...) | ||||
| 	logger.Warnln(err) | ||||
| 	fmt.Println(self.peer, err) | ||||
| 	if self.peer != nil { | ||||
| 		self.peer.PeerErrorChan() <- err | ||||
| func (bp *baseProtocol) handle(rw MsgReadWriter) error { | ||||
| 	msg, err := rw.ReadMsg() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) handlePeers(msg *Msg) { | ||||
| 	it := msg.Data().NewIterator() | ||||
| 	for it.Next() { | ||||
| 		ip := net.IP(it.Value().Get(0).Bytes()) | ||||
| 		port := it.Value().Get(1).Uint() | ||||
| 		address := &net.TCPAddr{IP: ip, Port: int(port)} | ||||
| 		go self.peer.Server().PeerConnect(address) | ||||
| 	if msg.Size > baseProtocolMaxMsgSize { | ||||
| 		return newPeerError(errMisc, "message too big") | ||||
| 	} | ||||
| 	// make sure that the payload has been fully consumed
 | ||||
| 	defer msg.Discard() | ||||
| 
 | ||||
| 	switch msg.Code { | ||||
| 	case handshakeMsg: | ||||
| 		return newPeerError(errProtocolBreach, "extra handshake received") | ||||
| 
 | ||||
| 	case discMsg: | ||||
| 		var reason DiscReason | ||||
| 		if err := msg.Decode(&reason); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		bp.peer.Disconnect(reason) | ||||
| 		return nil | ||||
| 
 | ||||
| 	case pingMsg: | ||||
| 		return bp.rw.EncodeMsg(pongMsg) | ||||
| 
 | ||||
| 	case pongMsg: | ||||
| 
 | ||||
| 	case getPeersMsg: | ||||
| 		peers := bp.peerList() | ||||
| 		// this is dangerous. the spec says that we should _delay_
 | ||||
| 		// sending the response if no new information is available.
 | ||||
| 		// this means that would need to send a response later when
 | ||||
| 		// new peers become available.
 | ||||
| 		//
 | ||||
| 		// TODO: add event mechanism to notify baseProtocol for new peers
 | ||||
| 		if len(peers) > 0 { | ||||
| 			return bp.rw.EncodeMsg(peersMsg, peers) | ||||
| 		} | ||||
| 
 | ||||
| 	case peersMsg: | ||||
| 		var peers []*peerAddr | ||||
| 		if err := msg.Decode(&peers); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		for _, addr := range peers { | ||||
| 			bp.peer.Debugf("received peer suggestion: %v", addr) | ||||
| 			bp.peer.newPeerAddr <- addr | ||||
| 		} | ||||
| 
 | ||||
| 	default: | ||||
| 		return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (self *BaseProtocol) handleHandshake(msg *Msg) { | ||||
| 	self.stateLock.Lock() | ||||
| 	defer self.stateLock.Unlock() | ||||
| 	if self.state != nullState { | ||||
| 		self.peerError(ProtocolBreach, "extra handshake") | ||||
| 		return | ||||
| func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { | ||||
| 	// send our handshake
 | ||||
| 	if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	c := msg.Data() | ||||
| 	// read and handle remote handshake
 | ||||
| 	msg, err := rw.ReadMsg() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if msg.Code != handshakeMsg { | ||||
| 		return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) | ||||
| 	} | ||||
| 	if msg.Size > baseProtocolMaxMsgSize { | ||||
| 		return newPeerError(errMisc, "message too big") | ||||
| 	} | ||||
| 
 | ||||
| 	var hs handshake | ||||
| 	if err := msg.Decode(&hs); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// validate handshake info
 | ||||
| 	if hs.Version != baseProtocolVersion { | ||||
| 		return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", | ||||
| 			baseProtocolVersion, hs.Version) | ||||
| 	} | ||||
| 	if len(hs.NodeID) == 0 { | ||||
| 		return newPeerError(errPubkeyMissing, "") | ||||
| 	} | ||||
| 	if len(hs.NodeID) != 64 { | ||||
| 		return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) | ||||
| 	} | ||||
| 	if da := bp.peer.dialAddr; da != nil { | ||||
| 		// verify that the peer we wanted to connect to
 | ||||
| 		// actually holds the target public key.
 | ||||
| 		if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { | ||||
| 			return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") | ||||
| 		} | ||||
| 	} | ||||
| 	pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) | ||||
| 	if err := bp.peer.pubkeyHook(pa); err != nil { | ||||
| 		return newPeerError(errPubkeyForbidden, "%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: remove Caps with empty name
 | ||||
| 
 | ||||
| 	var addr *peerAddr | ||||
| 	if hs.ListenPort != 0 { | ||||
| 		addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) | ||||
| 		addr.Port = hs.ListenPort | ||||
| 	} | ||||
| 	bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) | ||||
| 	bp.peer.startSubprotocols(hs.Caps) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (bp *baseProtocol) handshakeMsg() Msg { | ||||
| 	var ( | ||||
| 		p2pVersion = c.Get(0).Uint() | ||||
| 		id         = c.Get(1).Str() | ||||
| 		caps       = c.Get(2) | ||||
| 		port       = c.Get(3).Uint() | ||||
| 		pubkey     = c.Get(4).Bytes() | ||||
| 		port uint64 | ||||
| 		caps []interface{} | ||||
| 	) | ||||
| 	fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) | ||||
| 
 | ||||
| 	// Check correctness of p2p protocol version
 | ||||
| 	if p2pVersion != P2PVersion { | ||||
| 		self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) | ||||
| 		return | ||||
| 	if bp.peer.ourListenAddr != nil { | ||||
| 		port = bp.peer.ourListenAddr.Port | ||||
| 	} | ||||
| 
 | ||||
| 	// Handle the pub key (validation, uniqueness)
 | ||||
| 	if len(pubkey) == 0 { | ||||
| 		self.peerError(PubkeyMissing, "not supplied in handshake.") | ||||
| 		return | ||||
| 	for _, proto := range bp.peer.protocols { | ||||
| 		caps = append(caps, proto.cap()) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(pubkey) != 64 { | ||||
| 		self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Self connect detection
 | ||||
| 	if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { | ||||
| 		self.peerError(PubkeyForbidden, "not allowed to connect to self") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// register pubkey on server. this also sets the pubkey on the peer (need lock)
 | ||||
| 	if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { | ||||
| 		self.peerError(PubkeyForbidden, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// check port
 | ||||
| 	if self.peer.Inbound { | ||||
| 		uint16port := uint16(port) | ||||
| 		if self.peer.Port > 0 && self.peer.Port != uint16port { | ||||
| 			self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) | ||||
| 			return | ||||
| 		} else { | ||||
| 			self.peer.Port = uint16port | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	capsIt := caps.NewIterator() | ||||
| 	for capsIt.Next() { | ||||
| 		cap := capsIt.Value().Str() | ||||
| 		self.peer.Caps = append(self.peer.Caps, cap) | ||||
| 	} | ||||
| 	sort.Strings(self.peer.Caps) | ||||
| 	self.peer.Messenger().AddProtocols(self.peer.Caps) | ||||
| 
 | ||||
| 	self.peer.Id = id | ||||
| 
 | ||||
| 	self.state = handshakeReceived | ||||
| 
 | ||||
| 	//p.ethereum.PushPeer(p)
 | ||||
| 	// p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
 | ||||
| 	return | ||||
| 	return NewMsg(handshakeMsg, | ||||
| 		baseProtocolVersion, | ||||
| 		bp.peer.ourID.String(), | ||||
| 		caps, | ||||
| 		port, | ||||
| 		bp.peer.ourID.Pubkey()[1:], | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { | ||||
| 	peers := bp.peer.otherPeers() | ||||
| 	ds := make([]ethutil.RlpEncodable, 0, len(peers)) | ||||
| 	for _, p := range peers { | ||||
| 		p.infolock.Lock() | ||||
| 		addr := p.listenAddr | ||||
| 		p.infolock.Unlock() | ||||
| 		// filter out this peer and peers that are not listening or
 | ||||
| 		// have not completed the handshake.
 | ||||
| 		// TODO: track previously sent peers and exclude them as well.
 | ||||
| 		if p == bp.peer || addr == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		ds = append(ds, addr) | ||||
| 	} | ||||
| 	ourAddr := bp.peer.ourListenAddr | ||||
| 	if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { | ||||
| 		ds = append(ds, ourAddr) | ||||
| 	} | ||||
| 	return ds | ||||
| } | ||||
|  | ||||
							
								
								
									
										825
									
								
								p2p/server.go
									
									
									
									
									
								
							
							
						
						
									
										825
									
								
								p2p/server.go
									
									
									
									
									
								
							| @ -2,21 +2,420 @@ package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	logpkg "github.com/ethereum/go-ethereum/logger" | ||||
| 	"github.com/ethereum/go-ethereum/logger" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	outboundAddressPoolSize = 10 | ||||
| 	disconnectGracePeriod   = 2 | ||||
| 	outboundAddressPoolSize   = 500 | ||||
| 	defaultDialTimeout        = 10 * time.Second | ||||
| 	portMappingUpdateInterval = 15 * time.Minute | ||||
| 	portMappingTimeout        = 20 * time.Minute | ||||
| ) | ||||
| 
 | ||||
| var srvlog = logger.NewLogger("P2P Server") | ||||
| 
 | ||||
| // Server manages all peer connections.
 | ||||
| //
 | ||||
| // The fields of Server are used as configuration parameters.
 | ||||
| // You should set them before starting the Server. Fields may not be
 | ||||
| // modified while the server is running.
 | ||||
| type Server struct { | ||||
| 	// This field must be set to a valid client identity.
 | ||||
| 	Identity ClientIdentity | ||||
| 
 | ||||
| 	// MaxPeers is the maximum number of peers that can be
 | ||||
| 	// connected. It must be greater than zero.
 | ||||
| 	MaxPeers int | ||||
| 
 | ||||
| 	// Protocols should contain the protocols supported
 | ||||
| 	// by the server. Matching protocols are launched for
 | ||||
| 	// each peer.
 | ||||
| 	Protocols []Protocol | ||||
| 
 | ||||
| 	// If Blacklist is set to a non-nil value, the given Blacklist
 | ||||
| 	// is used to verify peer connections.
 | ||||
| 	Blacklist Blacklist | ||||
| 
 | ||||
| 	// If ListenAddr is set to a non-nil address, the server
 | ||||
| 	// will listen for incoming connections.
 | ||||
| 	//
 | ||||
| 	// If the port is zero, the operating system will pick a port. The
 | ||||
| 	// ListenAddr field will be updated with the actual address when
 | ||||
| 	// the server is started.
 | ||||
| 	ListenAddr string | ||||
| 
 | ||||
| 	// If set to a non-nil value, the given NAT port mapper
 | ||||
| 	// is used to make the listening port available to the
 | ||||
| 	// Internet.
 | ||||
| 	NAT NAT | ||||
| 
 | ||||
| 	// If Dialer is set to a non-nil value, the given Dialer
 | ||||
| 	// is used to dial outbound peer connections.
 | ||||
| 	Dialer *net.Dialer | ||||
| 
 | ||||
| 	// If NoDial is true, the server will not dial any peers.
 | ||||
| 	NoDial bool | ||||
| 
 | ||||
| 	// Hook for testing. This is useful because we can inhibit
 | ||||
| 	// the whole protocol stack.
 | ||||
| 	newPeerFunc peerFunc | ||||
| 
 | ||||
| 	lock      sync.RWMutex | ||||
| 	running   bool | ||||
| 	listener  net.Listener | ||||
| 	laddr     *net.TCPAddr // real listen addr
 | ||||
| 	peers     []*Peer | ||||
| 	peerSlots chan int | ||||
| 	peerCount int | ||||
| 
 | ||||
| 	quit           chan struct{} | ||||
| 	wg             sync.WaitGroup | ||||
| 	peerConnect    chan *peerAddr | ||||
| 	peerDisconnect chan *Peer | ||||
| } | ||||
| 
 | ||||
| // NAT is implemented by NAT traversal methods.
 | ||||
| type NAT interface { | ||||
| 	GetExternalAddress() (net.IP, error) | ||||
| 	AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error | ||||
| 	DeletePortMapping(protocol string, extport, intport int) error | ||||
| 
 | ||||
| 	// Should return name of the method.
 | ||||
| 	String() string | ||||
| } | ||||
| 
 | ||||
| type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer | ||||
| 
 | ||||
| // Peers returns all connected peers.
 | ||||
| func (srv *Server) Peers() (peers []*Peer) { | ||||
| 	srv.lock.RLock() | ||||
| 	defer srv.lock.RUnlock() | ||||
| 	for _, peer := range srv.peers { | ||||
| 		if peer != nil { | ||||
| 			peers = append(peers, peer) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // PeerCount returns the number of connected peers.
 | ||||
| func (srv *Server) PeerCount() int { | ||||
| 	srv.lock.RLock() | ||||
| 	defer srv.lock.RUnlock() | ||||
| 	return srv.peerCount | ||||
| } | ||||
| 
 | ||||
| // SuggestPeer injects an address into the outbound address pool.
 | ||||
| func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { | ||||
| 	select { | ||||
| 	case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: | ||||
| 	default: // don't block
 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Broadcast sends an RLP-encoded message to all connected peers.
 | ||||
| // This method is deprecated and will be removed later.
 | ||||
| func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { | ||||
| 	var payload []byte | ||||
| 	if data != nil { | ||||
| 		payload = encodePayload(data...) | ||||
| 	} | ||||
| 	srv.lock.RLock() | ||||
| 	defer srv.lock.RUnlock() | ||||
| 	for _, peer := range srv.peers { | ||||
| 		if peer != nil { | ||||
| 			var msg = Msg{Code: code} | ||||
| 			if data != nil { | ||||
| 				msg.Payload = bytes.NewReader(payload) | ||||
| 				msg.Size = uint32(len(payload)) | ||||
| 			} | ||||
| 			peer.writeProtoMsg(protocol, msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Start starts running the server.
 | ||||
| // Servers can be re-used and started again after stopping.
 | ||||
| func (srv *Server) Start() (err error) { | ||||
| 	srv.lock.Lock() | ||||
| 	defer srv.lock.Unlock() | ||||
| 	if srv.running { | ||||
| 		return errors.New("server already running") | ||||
| 	} | ||||
| 	srvlog.Infoln("Starting Server") | ||||
| 
 | ||||
| 	// initialize fields
 | ||||
| 	if srv.Identity == nil { | ||||
| 		return fmt.Errorf("Server.Identity must be set to a non-nil identity") | ||||
| 	} | ||||
| 	if srv.MaxPeers <= 0 { | ||||
| 		return fmt.Errorf("Server.MaxPeers must be > 0") | ||||
| 	} | ||||
| 	srv.quit = make(chan struct{}) | ||||
| 	srv.peers = make([]*Peer, srv.MaxPeers) | ||||
| 	srv.peerSlots = make(chan int, srv.MaxPeers) | ||||
| 	srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) | ||||
| 	srv.peerDisconnect = make(chan *Peer) | ||||
| 	if srv.newPeerFunc == nil { | ||||
| 		srv.newPeerFunc = newServerPeer | ||||
| 	} | ||||
| 	if srv.Blacklist == nil { | ||||
| 		srv.Blacklist = NewBlacklist() | ||||
| 	} | ||||
| 	if srv.Dialer == nil { | ||||
| 		srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} | ||||
| 	} | ||||
| 
 | ||||
| 	if srv.ListenAddr != "" { | ||||
| 		if err := srv.startListening(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	if !srv.NoDial { | ||||
| 		srv.wg.Add(1) | ||||
| 		go srv.dialLoop() | ||||
| 	} | ||||
| 	if srv.NoDial && srv.ListenAddr == "" { | ||||
| 		srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") | ||||
| 	} | ||||
| 
 | ||||
| 	// make all slots available
 | ||||
| 	for i := range srv.peers { | ||||
| 		srv.peerSlots <- i | ||||
| 	} | ||||
| 	// note: discLoop is not part of WaitGroup
 | ||||
| 	go srv.discLoop() | ||||
| 	srv.running = true | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) startListening() error { | ||||
| 	listener, err := net.Listen("tcp", srv.ListenAddr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	srv.ListenAddr = listener.Addr().String() | ||||
| 	srv.laddr = listener.Addr().(*net.TCPAddr) | ||||
| 	srv.listener = listener | ||||
| 	srv.wg.Add(1) | ||||
| 	go srv.listenLoop() | ||||
| 	if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { | ||||
| 		srv.wg.Add(1) | ||||
| 		go srv.natLoop(srv.laddr.Port) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Stop terminates the server and all active peer connections.
 | ||||
| // It blocks until all active connections have been closed.
 | ||||
| func (srv *Server) Stop() { | ||||
| 	srv.lock.Lock() | ||||
| 	if !srv.running { | ||||
| 		srv.lock.Unlock() | ||||
| 		return | ||||
| 	} | ||||
| 	srv.running = false | ||||
| 	srv.lock.Unlock() | ||||
| 
 | ||||
| 	srvlog.Infoln("Stopping server") | ||||
| 	if srv.listener != nil { | ||||
| 		// this unblocks listener Accept
 | ||||
| 		srv.listener.Close() | ||||
| 	} | ||||
| 	close(srv.quit) | ||||
| 	for _, peer := range srv.Peers() { | ||||
| 		peer.Disconnect(DiscQuitting) | ||||
| 	} | ||||
| 	srv.wg.Wait() | ||||
| 
 | ||||
| 	// wait till they actually disconnect
 | ||||
| 	// this is checked by claiming all peerSlots.
 | ||||
| 	// slots become available as the peers disconnect.
 | ||||
| 	for i := 0; i < cap(srv.peerSlots); i++ { | ||||
| 		<-srv.peerSlots | ||||
| 	} | ||||
| 	// terminate discLoop
 | ||||
| 	close(srv.peerDisconnect) | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) discLoop() { | ||||
| 	for peer := range srv.peerDisconnect { | ||||
| 		// peer has just disconnected. free up its slot.
 | ||||
| 		srvlog.Infof("%v is gone", peer) | ||||
| 		srv.peerSlots <- peer.slot | ||||
| 		srv.lock.Lock() | ||||
| 		srv.peers[peer.slot] = nil | ||||
| 		srv.lock.Unlock() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // main loop for adding connections via listening
 | ||||
| func (srv *Server) listenLoop() { | ||||
| 	defer srv.wg.Done() | ||||
| 
 | ||||
| 	srvlog.Infoln("Listening on", srv.listener.Addr()) | ||||
| 	for { | ||||
| 		select { | ||||
| 		case slot := <-srv.peerSlots: | ||||
| 			conn, err := srv.listener.Accept() | ||||
| 			if err != nil { | ||||
| 				srv.peerSlots <- slot | ||||
| 				return | ||||
| 			} | ||||
| 			srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) | ||||
| 			srv.addPeer(conn, nil, slot) | ||||
| 		case <-srv.quit: | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) natLoop(port int) { | ||||
| 	defer srv.wg.Done() | ||||
| 	for { | ||||
| 		srv.updatePortMapping(port) | ||||
| 		select { | ||||
| 		case <-time.After(portMappingUpdateInterval): | ||||
| 			// one more round
 | ||||
| 		case <-srv.quit: | ||||
| 			srv.removePortMapping(port) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) updatePortMapping(port int) { | ||||
| 	srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) | ||||
| 	err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) | ||||
| 	if err != nil { | ||||
| 		srvlog.Errorln("Port mapping error:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	extip, err := srv.NAT.GetExternalAddress() | ||||
| 	if err != nil { | ||||
| 		srvlog.Errorln("Error getting external IP:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	srv.lock.Lock() | ||||
| 	extaddr := *(srv.listener.Addr().(*net.TCPAddr)) | ||||
| 	extaddr.IP = extip | ||||
| 	srvlog.Infoln("Mapped port, external addr is", &extaddr) | ||||
| 	srv.laddr = &extaddr | ||||
| 	srv.lock.Unlock() | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) removePortMapping(port int) { | ||||
| 	srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) | ||||
| 	srv.NAT.DeletePortMapping("tcp", port, port) | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) dialLoop() { | ||||
| 	defer srv.wg.Done() | ||||
| 	var ( | ||||
| 		suggest chan *peerAddr | ||||
| 		slot    *int | ||||
| 		slots   = srv.peerSlots | ||||
| 	) | ||||
| 	for { | ||||
| 		select { | ||||
| 		case i := <-slots: | ||||
| 			// we need a peer in slot i, slot reserved
 | ||||
| 			slot = &i | ||||
| 			// now we can watch for candidate peers in the next loop
 | ||||
| 			suggest = srv.peerConnect | ||||
| 			// do not consume more until candidate peer is found
 | ||||
| 			slots = nil | ||||
| 
 | ||||
| 		case desc := <-suggest: | ||||
| 			// candidate peer found, will dial out asyncronously
 | ||||
| 			// if connection fails slot will be released
 | ||||
| 			go srv.dialPeer(desc, *slot) | ||||
| 			// we can watch if more peers needed in the next loop
 | ||||
| 			slots = srv.peerSlots | ||||
| 			// until then we dont care about candidate peers
 | ||||
| 			suggest = nil | ||||
| 
 | ||||
| 		case <-srv.quit: | ||||
| 			// give back the currently reserved slot
 | ||||
| 			if slot != nil { | ||||
| 				srv.peerSlots <- *slot | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // connect to peer via dial out
 | ||||
| func (srv *Server) dialPeer(desc *peerAddr, slot int) { | ||||
| 	srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) | ||||
| 	conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) | ||||
| 	if err != nil { | ||||
| 		srvlog.Errorf("Dial error: %v", err) | ||||
| 		srv.peerSlots <- slot | ||||
| 		return | ||||
| 	} | ||||
| 	go srv.addPeer(conn, desc, slot) | ||||
| } | ||||
| 
 | ||||
| // creates the new peer object and inserts it into its slot
 | ||||
| func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { | ||||
| 	srv.lock.Lock() | ||||
| 	defer srv.lock.Unlock() | ||||
| 	if !srv.running { | ||||
| 		conn.Close() | ||||
| 		srv.peerSlots <- slot // release slot
 | ||||
| 		return nil | ||||
| 	} | ||||
| 	peer := srv.newPeerFunc(srv, conn, desc) | ||||
| 	peer.slot = slot | ||||
| 	srv.peers[slot] = peer | ||||
| 	srv.peerCount++ | ||||
| 	go func() { peer.loop(); srv.peerDisconnect <- peer }() | ||||
| 	return peer | ||||
| } | ||||
| 
 | ||||
| // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
 | ||||
| func (srv *Server) removePeer(peer *Peer) { | ||||
| 	srv.lock.Lock() | ||||
| 	defer srv.lock.Unlock() | ||||
| 	srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) | ||||
| 	if srv.peers[peer.slot] != peer { | ||||
| 		srvlog.Warnln("Invalid peer to remove:", peer) | ||||
| 		return | ||||
| 	} | ||||
| 	// remove from list and index
 | ||||
| 	srv.peerCount-- | ||||
| 	srv.peers[peer.slot] = nil | ||||
| 	// release slot to signal need for a new peer, last!
 | ||||
| 	srv.peerSlots <- peer.slot | ||||
| } | ||||
| 
 | ||||
| func (srv *Server) verifyPeer(addr *peerAddr) error { | ||||
| 	if srv.Blacklist.Exists(addr.Pubkey) { | ||||
| 		return errors.New("blacklisted") | ||||
| 	} | ||||
| 	if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { | ||||
| 		return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") | ||||
| 	} | ||||
| 	srv.lock.RLock() | ||||
| 	defer srv.lock.RUnlock() | ||||
| 	for _, peer := range srv.peers { | ||||
| 		if peer != nil { | ||||
| 			id := peer.Identity() | ||||
| 			if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { | ||||
| 				return errors.New("already connected") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type Blacklist interface { | ||||
| 	Get([]byte) (bool, error) | ||||
| 	Put([]byte) error | ||||
| @ -66,419 +465,3 @@ func (self *BlacklistMap) Delete(pubkey []byte) error { | ||||
| 	delete(self.blacklist, string(pubkey)) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type Server struct { | ||||
| 	network   Network | ||||
| 	listening bool //needed?
 | ||||
| 	dialing   bool //needed?
 | ||||
| 	closed    bool | ||||
| 	identity  ClientIdentity | ||||
| 	addr      net.Addr | ||||
| 	port      uint16 | ||||
| 	protocols []string | ||||
| 
 | ||||
| 	quit      chan chan bool | ||||
| 	peersLock sync.RWMutex | ||||
| 
 | ||||
| 	maxPeers   int | ||||
| 	peers      []*Peer | ||||
| 	peerSlots  chan int | ||||
| 	peersTable map[string]int | ||||
| 	peersMsg   *Msg | ||||
| 	peerCount  int | ||||
| 
 | ||||
| 	peerConnect    chan net.Addr | ||||
| 	peerDisconnect chan DisconnectRequest | ||||
| 	blacklist      Blacklist | ||||
| 	handlers       Handlers | ||||
| } | ||||
| 
 | ||||
| var logger = logpkg.NewLogger("P2P") | ||||
| 
 | ||||
| func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { | ||||
| 	// get alphabetical list of protocol names from handlers map
 | ||||
| 	protocols := []string{} | ||||
| 	for protocol := range handlers { | ||||
| 		protocols = append(protocols, protocol) | ||||
| 	} | ||||
| 	sort.Strings(protocols) | ||||
| 
 | ||||
| 	_, port, _ := net.SplitHostPort(addr.String()) | ||||
| 	intport, _ := strconv.Atoi(port) | ||||
| 
 | ||||
| 	self := &Server{ | ||||
| 		// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
 | ||||
| 		network:   network, | ||||
| 		identity:  identity, | ||||
| 		addr:      addr, | ||||
| 		port:      uint16(intport), | ||||
| 		protocols: protocols, | ||||
| 
 | ||||
| 		quit: make(chan chan bool), | ||||
| 
 | ||||
| 		maxPeers:   maxPeers, | ||||
| 		peers:      make([]*Peer, maxPeers), | ||||
| 		peerSlots:  make(chan int, maxPeers), | ||||
| 		peersTable: make(map[string]int), | ||||
| 
 | ||||
| 		peerConnect:    make(chan net.Addr, outboundAddressPoolSize), | ||||
| 		peerDisconnect: make(chan DisconnectRequest), | ||||
| 		blacklist:      blacklist, | ||||
| 
 | ||||
| 		handlers: handlers, | ||||
| 	} | ||||
| 	for i := 0; i < maxPeers; i++ { | ||||
| 		self.peerSlots <- i // fill up with indexes
 | ||||
| 	} | ||||
| 	return self | ||||
| } | ||||
| 
 | ||||
| func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { | ||||
| 	addr, err = self.network.NewAddr(host, port) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { | ||||
| 	addr, err = self.network.ParseAddr(address) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *Server) ClientIdentity() ClientIdentity { | ||||
| 	return self.identity | ||||
| } | ||||
| 
 | ||||
| func (self *Server) PeersMessage() (msg *Msg, err error) { | ||||
| 	// TODO: memoize and reset when peers change
 | ||||
| 	self.peersLock.RLock() | ||||
| 	defer self.peersLock.RUnlock() | ||||
| 	msg = self.peersMsg | ||||
| 	if msg == nil { | ||||
| 		var peerData []interface{} | ||||
| 		for _, i := range self.peersTable { | ||||
| 			peer := self.peers[i] | ||||
| 			peerData = append(peerData, peer.Encode()) | ||||
| 		} | ||||
| 		if len(peerData) == 0 { | ||||
| 			err = fmt.Errorf("no peers") | ||||
| 		} else { | ||||
| 			msg, err = NewMsg(PeersMsg, peerData...) | ||||
| 			self.peersMsg = msg //memoize
 | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *Server) Peers() (peers []*Peer) { | ||||
| 	self.peersLock.RLock() | ||||
| 	defer self.peersLock.RUnlock() | ||||
| 	for _, peer := range self.peers { | ||||
| 		if peer != nil { | ||||
| 			peers = append(peers, peer) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *Server) PeerCount() int { | ||||
| 	self.peersLock.RLock() | ||||
| 	defer self.peersLock.RUnlock() | ||||
| 	return self.peerCount | ||||
| } | ||||
| 
 | ||||
| var getPeersMsg, _ = NewMsg(GetPeersMsg) | ||||
| 
 | ||||
| func (self *Server) PeerConnect(addr net.Addr) { | ||||
| 	// TODO: should buffer, filter and uniq
 | ||||
| 	// send GetPeersMsg if not blocking
 | ||||
| 	select { | ||||
| 	case self.peerConnect <- addr: // not enough peers
 | ||||
| 		self.Broadcast("", getPeersMsg) | ||||
| 	default: // we dont care
 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *Server) PeerDisconnect() chan DisconnectRequest { | ||||
| 	return self.peerDisconnect | ||||
| } | ||||
| 
 | ||||
| func (self *Server) Blacklist() Blacklist { | ||||
| 	return self.blacklist | ||||
| } | ||||
| 
 | ||||
| func (self *Server) Handlers() Handlers { | ||||
| 	return self.handlers | ||||
| } | ||||
| 
 | ||||
| func (self *Server) Broadcast(protocol string, msg *Msg) { | ||||
| 	self.peersLock.RLock() | ||||
| 	defer self.peersLock.RUnlock() | ||||
| 	for _, peer := range self.peers { | ||||
| 		if peer != nil { | ||||
| 			peer.Write(protocol, msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Start the server
 | ||||
| func (self *Server) Start(listen bool, dial bool) { | ||||
| 	self.network.Start() | ||||
| 	if listen { | ||||
| 		listener, err := self.network.Listener(self.addr) | ||||
| 		if err != nil { | ||||
| 			logger.Warnf("Error initializing listener: %v", err) | ||||
| 			logger.Warnf("Connection listening disabled") | ||||
| 			self.listening = false | ||||
| 		} else { | ||||
| 			self.listening = true | ||||
| 			logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) | ||||
| 			go self.inboundPeerHandler(listener) | ||||
| 		} | ||||
| 	} | ||||
| 	if dial { | ||||
| 		dialer, err := self.network.Dialer(self.addr) | ||||
| 		if err != nil { | ||||
| 			logger.Warnf("Error initializing dialer: %v", err) | ||||
| 			logger.Warnf("Connection dialout disabled") | ||||
| 			self.dialing = false | ||||
| 		} else { | ||||
| 			self.dialing = true | ||||
| 			logger.Infoln("Dial peers watching outbound address pool") | ||||
| 			go self.outboundPeerHandler(dialer) | ||||
| 		} | ||||
| 	} | ||||
| 	logger.Infoln("server started") | ||||
| } | ||||
| 
 | ||||
| func (self *Server) Stop() { | ||||
| 	logger.Infoln("server stopping...") | ||||
| 	// // quit one loop if dialing
 | ||||
| 	if self.dialing { | ||||
| 		logger.Infoln("stop dialout...") | ||||
| 		dialq := make(chan bool) | ||||
| 		self.quit <- dialq | ||||
| 		<-dialq | ||||
| 		fmt.Println("quit another") | ||||
| 	} | ||||
| 	// quit the other loop if listening
 | ||||
| 	if self.listening { | ||||
| 		logger.Infoln("stop listening...") | ||||
| 		listenq := make(chan bool) | ||||
| 		self.quit <- listenq | ||||
| 		<-listenq | ||||
| 		fmt.Println("quit one") | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Println("quit waited") | ||||
| 
 | ||||
| 	logger.Infoln("stopping peers...") | ||||
| 	peers := []net.Addr{} | ||||
| 	self.peersLock.RLock() | ||||
| 	self.closed = true | ||||
| 	for _, peer := range self.peers { | ||||
| 		if peer != nil { | ||||
| 			peers = append(peers, peer.Address) | ||||
| 		} | ||||
| 	} | ||||
| 	self.peersLock.RUnlock() | ||||
| 	for _, address := range peers { | ||||
| 		go self.removePeer(DisconnectRequest{ | ||||
| 			addr:   address, | ||||
| 			reason: DiscQuitting, | ||||
| 		}) | ||||
| 	} | ||||
| 	// wait till they actually disconnect
 | ||||
| 	// this is checked by draining the peerSlots (slots are released back if a peer is removed)
 | ||||
| 	i := 0 | ||||
| 	fmt.Println("draining peers") | ||||
| 
 | ||||
| FOR: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case slot := <-self.peerSlots: | ||||
| 			i++ | ||||
| 			fmt.Printf("%v: found slot %v", i, slot) | ||||
| 			if i == self.maxPeers { | ||||
| 				break FOR | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	logger.Infoln("server stopped") | ||||
| } | ||||
| 
 | ||||
| // main loop for adding connections via listening
 | ||||
| func (self *Server) inboundPeerHandler(listener net.Listener) { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case slot := <-self.peerSlots: | ||||
| 			go self.connectInboundPeer(listener, slot) | ||||
| 		case errc := <-self.quit: | ||||
| 			listener.Close() | ||||
| 			fmt.Println("quit listenloop") | ||||
| 			errc <- true | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // main loop for adding outbound peers based on peerConnect address pool
 | ||||
| // this same loop handles peer disconnect requests as well
 | ||||
| func (self *Server) outboundPeerHandler(dialer Dialer) { | ||||
| 	// addressChan initially set to nil (only watches peerConnect if we need more peers)
 | ||||
| 	var addressChan chan net.Addr | ||||
| 	slots := self.peerSlots | ||||
| 	var slot *int | ||||
| 	for { | ||||
| 		select { | ||||
| 		case i := <-slots: | ||||
| 			// we need a peer in slot i, slot reserved
 | ||||
| 			slot = &i | ||||
| 			// now we can watch for candidate peers in the next loop
 | ||||
| 			addressChan = self.peerConnect | ||||
| 			// do not consume more until candidate peer is found
 | ||||
| 			slots = nil | ||||
| 		case address := <-addressChan: | ||||
| 			// candidate peer found, will dial out asyncronously
 | ||||
| 			// if connection fails slot will be released
 | ||||
| 			go self.connectOutboundPeer(dialer, address, *slot) | ||||
| 			// we can watch if more peers needed in the next loop
 | ||||
| 			slots = self.peerSlots | ||||
| 			// until then we dont care about candidate peers
 | ||||
| 			addressChan = nil | ||||
| 		case request := <-self.peerDisconnect: | ||||
| 			go self.removePeer(request) | ||||
| 		case errc := <-self.quit: | ||||
| 			if addressChan != nil && slot != nil { | ||||
| 				self.peerSlots <- *slot | ||||
| 			} | ||||
| 			fmt.Println("quit dialloop") | ||||
| 			errc <- true | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // check if peer address already connected
 | ||||
| func (self *Server) connected(address net.Addr) (err error) { | ||||
| 	self.peersLock.RLock() | ||||
| 	defer self.peersLock.RUnlock() | ||||
| 	// fmt.Printf("address: %v\n", address)
 | ||||
| 	slot, found := self.peersTable[address.String()] | ||||
| 	if found { | ||||
| 		err = fmt.Errorf("already connected as peer %v (%v)", slot, address) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // connect to peer via listener.Accept()
 | ||||
| func (self *Server) connectInboundPeer(listener net.Listener, slot int) { | ||||
| 	var address net.Addr | ||||
| 	conn, err := listener.Accept() | ||||
| 	if err == nil { | ||||
| 		address = conn.RemoteAddr() | ||||
| 		err = self.connected(address) | ||||
| 		if err != nil { | ||||
| 			conn.Close() | ||||
| 		} | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		logger.Debugln(err) | ||||
| 		self.peerSlots <- slot | ||||
| 	} else { | ||||
| 		fmt.Printf("adding %v\n", address) | ||||
| 		go self.addPeer(conn, address, true, slot) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // connect to peer via dial out
 | ||||
| func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { | ||||
| 	var conn net.Conn | ||||
| 	err := self.connected(address) | ||||
| 	if err == nil { | ||||
| 		conn, err = dialer.Dial(address.Network(), address.String()) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		logger.Debugln(err) | ||||
| 		self.peerSlots <- slot | ||||
| 	} else { | ||||
| 		go self.addPeer(conn, address, false, slot) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // creates the new peer object and inserts it into its slot
 | ||||
| func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { | ||||
| 	self.peersLock.Lock() | ||||
| 	defer self.peersLock.Unlock() | ||||
| 	if self.closed { | ||||
| 		fmt.Println("oopsy, not no longer need peer") | ||||
| 		conn.Close()           //oopsy our bad
 | ||||
| 		self.peerSlots <- slot // release slot
 | ||||
| 	} else { | ||||
| 		peer := NewPeer(conn, address, inbound, self) | ||||
| 		self.peers[slot] = peer | ||||
| 		self.peersTable[address.String()] = slot | ||||
| 		self.peerCount++ | ||||
| 		// reset peersmsg
 | ||||
| 		self.peersMsg = nil | ||||
| 		fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) | ||||
| 		peer.Start() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
 | ||||
| func (self *Server) removePeer(request DisconnectRequest) { | ||||
| 	self.peersLock.Lock() | ||||
| 
 | ||||
| 	address := request.addr | ||||
| 	slot := self.peersTable[address.String()] | ||||
| 	peer := self.peers[slot] | ||||
| 	fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) | ||||
| 	if peer == nil { | ||||
| 		logger.Debugf("already removed peer on %v", address) | ||||
| 		self.peersLock.Unlock() | ||||
| 		return | ||||
| 	} | ||||
| 	// remove from list and index
 | ||||
| 	self.peerCount-- | ||||
| 	self.peers[slot] = nil | ||||
| 	delete(self.peersTable, address.String()) | ||||
| 	// reset peersmsg
 | ||||
| 	self.peersMsg = nil | ||||
| 	fmt.Printf("removed peer %v (slot %v)\n", peer, slot) | ||||
| 	self.peersLock.Unlock() | ||||
| 
 | ||||
| 	// sending disconnect message
 | ||||
| 	disconnectMsg, _ := NewMsg(DiscMsg, request.reason) | ||||
| 	peer.Write("", disconnectMsg) | ||||
| 	// be nice and wait
 | ||||
| 	time.Sleep(disconnectGracePeriod * time.Second) | ||||
| 	// switch off peer and close connections etc.
 | ||||
| 	fmt.Println("stopping peer") | ||||
| 	peer.Stop() | ||||
| 	fmt.Println("stopped peer") | ||||
| 	// release slot to signal need for a new peer, last!
 | ||||
| 	self.peerSlots <- slot | ||||
| } | ||||
| 
 | ||||
| // fix handshake message to push to peers
 | ||||
| func (self *Server) Handshake() *Msg { | ||||
| 	fmt.Println(self.identity.Pubkey()[1:]) | ||||
| 	msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) | ||||
| 	return msg | ||||
| } | ||||
| 
 | ||||
| func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { | ||||
| 	// Check for blacklisting
 | ||||
| 	if self.blacklist.Exists(pubkey) { | ||||
| 		return fmt.Errorf("blacklisted") | ||||
| 	} | ||||
| 
 | ||||
| 	self.peersLock.RLock() | ||||
| 	defer self.peersLock.RUnlock() | ||||
| 	for _, peer := range self.peers { | ||||
| 		if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { | ||||
| 			return fmt.Errorf("already connected") | ||||
| 		} | ||||
| 	} | ||||
| 	candidate.Pubkey = pubkey | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @ -2,207 +2,160 @@ package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type TestNetwork struct { | ||||
| 	connections map[string]*TestNetworkConnection | ||||
| 	dialer      Dialer | ||||
| 	maxinbound  int | ||||
| } | ||||
| 
 | ||||
| func NewTestNetwork(maxinbound int) *TestNetwork { | ||||
| 	connections := make(map[string]*TestNetworkConnection) | ||||
| 	return &TestNetwork{ | ||||
| 		connections: connections, | ||||
| 		dialer:      &TestDialer{connections}, | ||||
| 		maxinbound:  maxinbound, | ||||
| func startTestServer(t *testing.T, pf peerFunc) *Server { | ||||
| 	server := &Server{ | ||||
| 		Identity:    NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), | ||||
| 		MaxPeers:    10, | ||||
| 		ListenAddr:  "127.0.0.1:0", | ||||
| 		newPeerFunc: pf, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { | ||||
| 	return self.dialer, nil | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { | ||||
| 	return &TestListener{ | ||||
| 		connections: self.connections, | ||||
| 		addr:        addr, | ||||
| 		max:         self.maxinbound, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetwork) Start() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type TestAddr struct { | ||||
| 	name string | ||||
| } | ||||
| 
 | ||||
| func (self *TestAddr) String() string { | ||||
| 	return self.name | ||||
| } | ||||
| 
 | ||||
| func (*TestAddr) Network() string { | ||||
| 	return "test" | ||||
| } | ||||
| 
 | ||||
| type TestDialer struct { | ||||
| 	connections map[string]*TestNetworkConnection | ||||
| } | ||||
| 
 | ||||
| func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { | ||||
| 	address := &TestAddr{addr} | ||||
| 	tconn := NewTestNetworkConnection(address) | ||||
| 	self.connections[addr] = tconn | ||||
| 	conn = net.Conn(tconn) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type TestListener struct { | ||||
| 	connections map[string]*TestNetworkConnection | ||||
| 	addr        net.Addr | ||||
| 	max         int | ||||
| 	i           int | ||||
| } | ||||
| 
 | ||||
| func (self *TestListener) Accept() (conn net.Conn, err error) { | ||||
| 	self.i++ | ||||
| 	if self.i > self.max { | ||||
| 		err = fmt.Errorf("no more") | ||||
| 	} else { | ||||
| 		addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} | ||||
| 		tconn := NewTestNetworkConnection(addr) | ||||
| 		key := tconn.RemoteAddr().String() | ||||
| 		self.connections[key] = tconn | ||||
| 		conn = net.Conn(tconn) | ||||
| 		fmt.Printf("accepted connection from: %v \n", addr) | ||||
| 	if err := server.Start(); err != nil { | ||||
| 		t.Fatalf("Could not start server: %v", err) | ||||
| 	} | ||||
| 	return | ||||
| 	return server | ||||
| } | ||||
| 
 | ||||
| func (self *TestListener) Close() error { | ||||
| 	return nil | ||||
| } | ||||
| func TestServerListen(t *testing.T) { | ||||
| 	defer testlog(t).detach() | ||||
| 
 | ||||
| func (self *TestListener) Addr() net.Addr { | ||||
| 	return self.addr | ||||
| } | ||||
| 
 | ||||
| func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { | ||||
| 	network = NewTestNetwork(1) | ||||
| 	addr := &TestAddr{"test:30303"} | ||||
| 	identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") | ||||
| 	maxPeers := 2 | ||||
| 	if handlers == nil { | ||||
| 		handlers = make(Handlers) | ||||
| 	} | ||||
| 	blackist := NewBlacklist() | ||||
| 	server = New(network, addr, identity, handlers, maxPeers, blackist) | ||||
| 	fmt.Println(server.identity.Pubkey()) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func TestServerListener(t *testing.T) { | ||||
| 	network, server := SetupTestServer(nil) | ||||
| 	server.Start(true, false) | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	server.Stop() | ||||
| 	peer1, ok := network.connections["inboundpeer-1"] | ||||
| 	if !ok { | ||||
| 		t.Error("not found inbound peer 1") | ||||
| 	} else { | ||||
| 		fmt.Printf("out: %v\n", peer1.Out) | ||||
| 		if len(peer1.Out) != 2 { | ||||
| 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) | ||||
| 	// start the test server
 | ||||
| 	connected := make(chan *Peer) | ||||
| 	srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { | ||||
| 		if conn == nil { | ||||
| 			t.Error("peer func called with nil conn") | ||||
| 		} | ||||
| 	} | ||||
| 		if dialAddr != nil { | ||||
| 			t.Error("peer func called with non-nil dialAddr") | ||||
| 		} | ||||
| 		peer := newPeer(conn, nil, dialAddr) | ||||
| 		connected <- peer | ||||
| 		return peer | ||||
| 	}) | ||||
| 	defer close(connected) | ||||
| 	defer srv.Stop() | ||||
| 
 | ||||
| 	// dial the test server
 | ||||
| 	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("could not dial: %v", err) | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 
 | ||||
| 	select { | ||||
| 	case peer := <-connected: | ||||
| 		if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { | ||||
| 			t.Errorf("peer started with wrong conn: got %v, want %v", | ||||
| 				peer.conn.LocalAddr(), conn.RemoteAddr()) | ||||
| 		} | ||||
| 	case <-time.After(1 * time.Second): | ||||
| 		t.Error("server did not accept within one second") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestServerDialer(t *testing.T) { | ||||
| 	network, server := SetupTestServer(nil) | ||||
| 	server.Start(false, true) | ||||
| 	server.peerConnect <- &TestAddr{"outboundpeer-1"} | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	server.Stop() | ||||
| 	peer1, ok := network.connections["outboundpeer-1"] | ||||
| 	if !ok { | ||||
| 		t.Error("not found outbound peer 1") | ||||
| 	} else { | ||||
| 		fmt.Printf("out: %v\n", peer1.Out) | ||||
| 		if len(peer1.Out) != 2 { | ||||
| 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) | ||||
| func TestServerDial(t *testing.T) { | ||||
| 	defer testlog(t).detach() | ||||
| 
 | ||||
| 	// run a fake TCP server to handle the connection.
 | ||||
| 	listener, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("could not setup listener: %v") | ||||
| 	} | ||||
| 	defer listener.Close() | ||||
| 	accepted := make(chan net.Conn) | ||||
| 	go func() { | ||||
| 		conn, err := listener.Accept() | ||||
| 		if err != nil { | ||||
| 			t.Error("acccept error:", err) | ||||
| 		} | ||||
| 		conn.Close() | ||||
| 		accepted <- conn | ||||
| 	}() | ||||
| 
 | ||||
| 	// start the test server
 | ||||
| 	connected := make(chan *Peer) | ||||
| 	srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { | ||||
| 		if conn == nil { | ||||
| 			t.Error("peer func called with nil conn") | ||||
| 		} | ||||
| 		peer := newPeer(conn, nil, dialAddr) | ||||
| 		connected <- peer | ||||
| 		return peer | ||||
| 	}) | ||||
| 	defer close(connected) | ||||
| 	defer srv.Stop() | ||||
| 
 | ||||
| 	// tell the server to connect.
 | ||||
| 	connAddr := newPeerAddr(listener.Addr(), nil) | ||||
| 	srv.peerConnect <- connAddr | ||||
| 
 | ||||
| 	select { | ||||
| 	case conn := <-accepted: | ||||
| 		select { | ||||
| 		case peer := <-connected: | ||||
| 			if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { | ||||
| 				t.Errorf("peer started with wrong conn: got %v, want %v", | ||||
| 					peer.conn.RemoteAddr(), conn.LocalAddr()) | ||||
| 			} | ||||
| 			if peer.dialAddr != connAddr { | ||||
| 				t.Errorf("peer started with wrong dialAddr: got %v, want %v", | ||||
| 					peer.dialAddr, connAddr) | ||||
| 			} | ||||
| 		case <-time.After(1 * time.Second): | ||||
| 			t.Error("server did not launch peer within one second") | ||||
| 		} | ||||
| 
 | ||||
| 	case <-time.After(1 * time.Second): | ||||
| 		t.Error("server did not connect within one second") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestServerBroadcast(t *testing.T) { | ||||
| 	handlers := make(Handlers) | ||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} | ||||
| 	handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } | ||||
| 	network, server := SetupTestServer(handlers) | ||||
| 	server.Start(true, true) | ||||
| 	server.peerConnect <- &TestAddr{"outboundpeer-1"} | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	msg, _ := NewMsg(0) | ||||
| 	server.Broadcast("", msg) | ||||
| 	packet := Packet(0, 0) | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	server.Stop() | ||||
| 	peer1, ok := network.connections["outboundpeer-1"] | ||||
| 	if !ok { | ||||
| 		t.Error("not found outbound peer 1") | ||||
| 	} else { | ||||
| 		fmt.Printf("out: %v\n", peer1.Out) | ||||
| 		if len(peer1.Out) != 3 { | ||||
| 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) | ||||
| 		} else { | ||||
| 			if bytes.Compare(peer1.Out[1], packet) != 0 { | ||||
| 				t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	peer2, ok := network.connections["inboundpeer-1"] | ||||
| 	if !ok { | ||||
| 		t.Error("not found inbound peer 2") | ||||
| 	} else { | ||||
| 		fmt.Printf("out: %v\n", peer2.Out) | ||||
| 		if len(peer1.Out) != 3 { | ||||
| 			t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) | ||||
| 		} else { | ||||
| 			if bytes.Compare(peer2.Out[1], packet) != 0 { | ||||
| 				t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 	defer testlog(t).detach() | ||||
| 	var connected sync.WaitGroup | ||||
| 	srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { | ||||
| 		peer := newPeer(c, []Protocol{discard}, dialAddr) | ||||
| 		peer.startSubprotocols([]Cap{discard.cap()}) | ||||
| 		connected.Done() | ||||
| 		return peer | ||||
| 	}) | ||||
| 	defer srv.Stop() | ||||
| 
 | ||||
| func TestServerPeersMessage(t *testing.T) { | ||||
| 	handlers := make(Handlers) | ||||
| 	_, server := SetupTestServer(handlers) | ||||
| 	server.Start(true, true) | ||||
| 	defer server.Stop() | ||||
| 	server.peerConnect <- &TestAddr{"outboundpeer-1"} | ||||
| 	time.Sleep(10 * time.Millisecond) | ||||
| 	peersMsg, err := server.PeersMessage() | ||||
| 	fmt.Println(peersMsg) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expect no error, got %v", err) | ||||
| 	// dial a bunch of conns
 | ||||
| 	var conns = make([]net.Conn, 8) | ||||
| 	connected.Add(len(conns)) | ||||
| 	deadline := time.Now().Add(3 * time.Second) | ||||
| 	dialer := &net.Dialer{Deadline: deadline} | ||||
| 	for i := range conns { | ||||
| 		conn, err := dialer.Dial("tcp", srv.ListenAddr) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("conn %d: dial error: %v", i, err) | ||||
| 		} | ||||
| 		defer conn.Close() | ||||
| 		conn.SetDeadline(deadline) | ||||
| 		conns[i] = conn | ||||
| 	} | ||||
| 	if c := server.PeerCount(); c != 2 { | ||||
| 		t.Errorf("expect 2 peers, got %v", c) | ||||
| 	connected.Wait() | ||||
| 
 | ||||
| 	// broadcast one message
 | ||||
| 	srv.Broadcast("discard", 0, "foo") | ||||
| 	goldbuf := new(bytes.Buffer) | ||||
| 	writeMsg(goldbuf, NewMsg(16, "foo")) | ||||
| 	golden := goldbuf.Bytes() | ||||
| 
 | ||||
| 	// check that the message has been written everywhere
 | ||||
| 	for i, conn := range conns { | ||||
| 		buf := make([]byte, len(golden)) | ||||
| 		if _, err := io.ReadFull(conn, buf); err != nil { | ||||
| 			t.Errorf("conn %d: read error: %v", i, err) | ||||
| 		} else if !bytes.Equal(buf, golden) { | ||||
| 			t.Errorf("conn %d: msg mismatch\ngot:  %x\nwant: %x", i, buf, golden) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										28
									
								
								p2p/testlog_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								p2p/testlog_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | ||||
| package p2p | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/logger" | ||||
| ) | ||||
| 
 | ||||
| type testLogger struct{ t *testing.T } | ||||
| 
 | ||||
| func testlog(t *testing.T) testLogger { | ||||
| 	logger.Reset() | ||||
| 	l := testLogger{t} | ||||
| 	logger.AddLogSystem(l) | ||||
| 	return l | ||||
| } | ||||
| 
 | ||||
| func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } | ||||
| func (testLogger) SetLogLevel(logger.LogLevel)  {} | ||||
| 
 | ||||
| func (l testLogger) LogPrint(level logger.LogLevel, msg string) { | ||||
| 	l.t.Logf("%s", msg) | ||||
| } | ||||
| 
 | ||||
| func (testLogger) detach() { | ||||
| 	logger.Flush() | ||||
| 	logger.Reset() | ||||
| } | ||||
							
								
								
									
										40
									
								
								p2p/testpoc7.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								p2p/testpoc7.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | ||||
| // +build none
 | ||||
| 
 | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"github.com/ethereum/go-ethereum/logger" | ||||
| 	"github.com/ethereum/go-ethereum/p2p" | ||||
| 	"github.com/obscuren/secp256k1-go" | ||||
| ) | ||||
| 
 | ||||
| func main() { | ||||
| 	logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) | ||||
| 
 | ||||
| 	pub, _ := secp256k1.GenerateKeyPair() | ||||
| 	srv := p2p.Server{ | ||||
| 		MaxPeers:   10, | ||||
| 		Identity:   p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), | ||||
| 		ListenAddr: ":30303", | ||||
| 		NAT:        p2p.PMP(net.ParseIP("10.0.0.1")), | ||||
| 	} | ||||
| 	if err := srv.Start(); err != nil { | ||||
| 		fmt.Println("could not start server:", err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| 
 | ||||
| 	// add seed peers
 | ||||
| 	seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") | ||||
| 	if err != nil { | ||||
| 		fmt.Println("couldn't resolve:", err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| 	srv.SuggestPeer(seed.IP, seed.Port, nil) | ||||
| 
 | ||||
| 	select {} | ||||
| } | ||||
| @ -1,6 +1,7 @@ | ||||
| package rlp | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/binary" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| @ -24,8 +25,9 @@ type Decoder interface { | ||||
| 	DecodeRLP(*Stream) error | ||||
| } | ||||
| 
 | ||||
| // Decode parses RLP-encoded data from r and stores the result
 | ||||
| // in the value pointed to by val. Val must be a non-nil pointer.
 | ||||
| // Decode parses RLP-encoded data from r and stores the result in the
 | ||||
| // value pointed to by val. Val must be a non-nil pointer. If r does
 | ||||
| // not implement ByteReader, Decode will do its own buffering.
 | ||||
| //
 | ||||
| // Decode uses the following type-dependent decoding rules:
 | ||||
| //
 | ||||
| @ -66,10 +68,19 @@ type Decoder interface { | ||||
| //
 | ||||
| // Non-empty interface types are not supported, nor are bool, float32,
 | ||||
| // float64, maps, channel types and functions.
 | ||||
| func Decode(r ByteReader, val interface{}) error { | ||||
| func Decode(r io.Reader, val interface{}) error { | ||||
| 	return NewStream(r).Decode(val) | ||||
| } | ||||
| 
 | ||||
| type decodeError struct { | ||||
| 	msg string | ||||
| 	typ reflect.Type | ||||
| } | ||||
| 
 | ||||
| func (err decodeError) Error() string { | ||||
| 	return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) | ||||
| } | ||||
| 
 | ||||
| func makeNumDecoder(typ reflect.Type) decoder { | ||||
| 	kind := typ.Kind() | ||||
| 	switch { | ||||
| @ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder { | ||||
| } | ||||
| 
 | ||||
| func decodeInt(s *Stream, val reflect.Value) error { | ||||
| 	num, err := s.uint(val.Type().Bits()) | ||||
| 	if err != nil { | ||||
| 	typ := val.Type() | ||||
| 	num, err := s.uint(typ.Bits()) | ||||
| 	if err == errUintOverflow { | ||||
| 		return decodeError{"input string too long", typ} | ||||
| 	} else if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	val.SetInt(int64(num)) | ||||
| @ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error { | ||||
| } | ||||
| 
 | ||||
| func decodeUint(s *Stream, val reflect.Value) error { | ||||
| 	num, err := s.uint(val.Type().Bits()) | ||||
| 	if err != nil { | ||||
| 	typ := val.Type() | ||||
| 	num, err := s.uint(typ.Bits()) | ||||
| 	if err == errUintOverflow { | ||||
| 		return decodeError{"input string too big", typ} | ||||
| 	} else if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	val.SetUint(num) | ||||
| @ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro | ||||
| 	i := 0 | ||||
| 	for { | ||||
| 		if i > maxelem { | ||||
| 			return fmt.Errorf("rlp: input List has more than %d elements", maxelem) | ||||
| 			return decodeError{"input list has too many elements", val.Type()} | ||||
| 		} | ||||
| 		if val.Kind() == reflect.Slice { | ||||
| 			// grow slice if necessary
 | ||||
| @ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error { | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array") | ||||
| 
 | ||||
| func decodeByteArray(s *Stream, val reflect.Value) error { | ||||
| 	kind, size, err := s.Kind() | ||||
| 	if err != nil { | ||||
| @ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { | ||||
| 	switch kind { | ||||
| 	case Byte: | ||||
| 		if val.Len() == 0 { | ||||
| 			return errStringDoesntFitArray | ||||
| 			return decodeError{"input string too big", val.Type()} | ||||
| 		} | ||||
| 		bv, _ := s.Uint() | ||||
| 		val.Index(0).SetUint(bv) | ||||
| 		zero(val, 1) | ||||
| 	case String: | ||||
| 		if uint64(val.Len()) < size { | ||||
| 			return errStringDoesntFitArray | ||||
| 			return decodeError{"input string too big", val.Type()} | ||||
| 		} | ||||
| 		slice := val.Slice(0, int(size)).Interface().([]byte) | ||||
| 		if err := s.readFull(slice); err != nil { | ||||
| @ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { | ||||
| 			} | ||||
| 		} | ||||
| 		if err = s.ListEnd(); err == errNotAtEOL { | ||||
| 			err = errors.New("rlp: input List has too many elements") | ||||
| 			err = decodeError{"input list has too many elements", typ} | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| @ -432,8 +447,23 @@ type Stream struct { | ||||
| 
 | ||||
| type listpos struct{ pos, size uint64 } | ||||
| 
 | ||||
| func NewStream(r ByteReader) *Stream { | ||||
| 	return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} | ||||
| // NewStream creates a new stream reading from r.
 | ||||
| // If r does not implement ByteReader, the Stream will
 | ||||
| // introduce its own buffering.
 | ||||
| func NewStream(r io.Reader) *Stream { | ||||
| 	s := new(Stream) | ||||
| 	s.Reset(r) | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // NewListStream creates a new stream that pretends to be positioned
 | ||||
| // at an encoded list of the given length.
 | ||||
| func NewListStream(r io.Reader, len uint64) *Stream { | ||||
| 	s := new(Stream) | ||||
| 	s.Reset(r) | ||||
| 	s.kind = List | ||||
| 	s.size = len | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // Bytes reads an RLP string and returns its contents as a byte slice.
 | ||||
| @ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| var errUintOverflow = errors.New("rlp: uint overflow") | ||||
| 
 | ||||
| // Uint reads an RLP string of up to 8 bytes and returns its contents
 | ||||
| // as an unsigned integer. If the input does not contain an RLP string, the
 | ||||
| // returned error will be ErrExpectedString.
 | ||||
| @ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) { | ||||
| 		return uint64(s.byteval), nil | ||||
| 	case String: | ||||
| 		if size > uint64(maxbits/8) { | ||||
| 			return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) | ||||
| 			return 0, errUintOverflow | ||||
| 		} | ||||
| 		return s.readUint(byte(size)) | ||||
| 	default: | ||||
| @ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error { | ||||
| 	return info.decoder(s, rval.Elem()) | ||||
| } | ||||
| 
 | ||||
| // Reset discards any information about the current decoding context
 | ||||
| // and starts reading from r. If r does not also implement ByteReader,
 | ||||
| // Stream will do its own buffering.
 | ||||
| func (s *Stream) Reset(r io.Reader) { | ||||
| 	bufr, ok := r.(ByteReader) | ||||
| 	if !ok { | ||||
| 		bufr = bufio.NewReader(r) | ||||
| 	} | ||||
| 	s.r = bufr | ||||
| 	s.stack = s.stack[:0] | ||||
| 	s.size = 0 | ||||
| 	s.kind = -1 | ||||
| 	if s.uintbuf == nil { | ||||
| 		s.uintbuf = make([]byte, 8) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Kind returns the kind and size of the next value in the
 | ||||
| // input stream.
 | ||||
| //
 | ||||
|  | ||||
| @ -3,7 +3,6 @@ package rlp | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math/big" | ||||
| @ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNewListStream(t *testing.T) { | ||||
| 	ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3) | ||||
| 	if k, size, err := ls.Kind(); k != List || size != 3 || err != nil { | ||||
| 		t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err) | ||||
| 	} | ||||
| 	if size, err := ls.List(); size != 3 || err != nil { | ||||
| 		t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err) | ||||
| 	} | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		if val, err := ls.Uint(); val != 1 || err != nil { | ||||
| 			t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err) | ||||
| 		} | ||||
| 	} | ||||
| 	if err := ls.ListEnd(); err != nil { | ||||
| 		t.Errorf("ListEnd() returned %v, expected (3, nil)", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestStreamErrors(t *testing.T) { | ||||
| 	type calls []string | ||||
| 	tests := []struct { | ||||
| @ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) { | ||||
| 		{"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, | ||||
| 		{"81", calls{"Uint"}, io.ErrUnexpectedEOF}, | ||||
| 		{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, | ||||
| 		{"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, | ||||
| 		{"89000000000000000001", calls{"Uint"}, errUintOverflow}, | ||||
| 		{"00", calls{"List"}, ErrExpectedList}, | ||||
| 		{"80", calls{"List"}, ErrExpectedList}, | ||||
| 		{"C0", calls{"List", "Uint"}, EOL}, | ||||
| @ -163,7 +180,7 @@ type decodeTest struct { | ||||
| 	input string | ||||
| 	ptr   interface{} | ||||
| 	value interface{} | ||||
| 	error error | ||||
| 	error string | ||||
| } | ||||
| 
 | ||||
| type simplestruct struct { | ||||
| @ -196,8 +213,8 @@ var decodeTests = []decodeTest{ | ||||
| 	{input: "820505", ptr: new(uint32), value: uint32(0x0505)}, | ||||
| 	{input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, | ||||
| 	{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, | ||||
| 	{input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, | ||||
| 	{input: "C0", ptr: new(uint32), error: ErrExpectedString}, | ||||
| 	{input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"}, | ||||
| 	{input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()}, | ||||
| 
 | ||||
| 	// slices
 | ||||
| 	{input: "C0", ptr: new([]int), value: []int{}}, | ||||
| @ -206,7 +223,7 @@ var decodeTests = []decodeTest{ | ||||
| 	// arrays
 | ||||
| 	{input: "C0", ptr: new([5]int), value: [5]int{}}, | ||||
| 	{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, | ||||
| 	{input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, | ||||
| 	{input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"}, | ||||
| 
 | ||||
| 	// byte slices
 | ||||
| 	{input: "01", ptr: new([]byte), value: []byte{1}}, | ||||
| @ -214,7 +231,7 @@ var decodeTests = []decodeTest{ | ||||
| 	{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, | ||||
| 	{input: "C0", ptr: new([]byte), value: []byte{}}, | ||||
| 	{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, | ||||
| 	{input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, | ||||
| 	{input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"}, | ||||
| 
 | ||||
| 	// byte arrays
 | ||||
| 	{input: "01", ptr: new([5]byte), value: [5]byte{1}}, | ||||
| @ -222,9 +239,9 @@ var decodeTests = []decodeTest{ | ||||
| 	{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, | ||||
| 	{input: "C0", ptr: new([5]byte), value: [5]byte{}}, | ||||
| 	{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, | ||||
| 	{input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, | ||||
| 	{input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, | ||||
| 	{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, | ||||
| 	{input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"}, | ||||
| 	{input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"}, | ||||
| 	{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, | ||||
| 
 | ||||
| 	// byte array reuse (should be zeroed)
 | ||||
| 	{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, | ||||
| @ -237,25 +254,25 @@ var decodeTests = []decodeTest{ | ||||
| 	// zero sized byte arrays
 | ||||
| 	{input: "80", ptr: new([0]byte), value: [0]byte{}}, | ||||
| 	{input: "C0", ptr: new([0]byte), value: [0]byte{}}, | ||||
| 	{input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, | ||||
| 	{input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, | ||||
| 	{input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, | ||||
| 	{input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, | ||||
| 
 | ||||
| 	// strings
 | ||||
| 	{input: "00", ptr: new(string), value: "\000"}, | ||||
| 	{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, | ||||
| 	{input: "C0", ptr: new(string), error: ErrExpectedString}, | ||||
| 	{input: "C0", ptr: new(string), error: ErrExpectedString.Error()}, | ||||
| 
 | ||||
| 	// big ints
 | ||||
| 	{input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, | ||||
| 	{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, | ||||
| 	{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
 | ||||
| 	{input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, | ||||
| 	{input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()}, | ||||
| 
 | ||||
| 	// structs
 | ||||
| 	{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, | ||||
| 	{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, | ||||
| 	{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, | ||||
| 	{input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, | ||||
| 	{input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"}, | ||||
| 	{ | ||||
| 		input: "C501C302C103", | ||||
| 		ptr:   new(recstruct), | ||||
| @ -286,20 +303,20 @@ var decodeTests = []decodeTest{ | ||||
| 
 | ||||
| func intp(i int) *int { return &i } | ||||
| 
 | ||||
| func TestDecode(t *testing.T) { | ||||
| func runTests(t *testing.T, decode func([]byte, interface{}) error) { | ||||
| 	for i, test := range decodeTests { | ||||
| 		input, err := hex.DecodeString(test.input) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("test %d: invalid hex input %q", i, test.input) | ||||
| 			continue | ||||
| 		} | ||||
| 		err = Decode(bytes.NewReader(input), test.ptr) | ||||
| 		if err != nil && test.error == nil { | ||||
| 		err = decode(input, test.ptr) | ||||
| 		if err != nil && test.error == "" { | ||||
| 			t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", | ||||
| 				i, err, test.ptr, test.input) | ||||
| 			continue | ||||
| 		} | ||||
| 		if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { | ||||
| 		if test.error != "" && fmt.Sprint(err) != test.error { | ||||
| 			t.Errorf("test %d: Decode error mismatch\ngot  %v\nwant %v\ndecoding into %T\ninput %q", | ||||
| 				i, err, test.error, test.ptr, test.input) | ||||
| 			continue | ||||
| @ -312,6 +329,40 @@ func TestDecode(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDecodeWithByteReader(t *testing.T) { | ||||
| 	runTests(t, func(input []byte, into interface{}) error { | ||||
| 		return Decode(bytes.NewReader(input), into) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // dumbReader reads from a byte slice but does not
 | ||||
| // implement ReadByte.
 | ||||
| type dumbReader []byte | ||||
| 
 | ||||
| func (r *dumbReader) Read(buf []byte) (n int, err error) { | ||||
| 	if len(*r) == 0 { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
| 	n = copy(buf, *r) | ||||
| 	*r = (*r)[n:] | ||||
| 	return n, nil | ||||
| } | ||||
| 
 | ||||
| func TestDecodeWithNonByteReader(t *testing.T) { | ||||
| 	runTests(t, func(input []byte, into interface{}) error { | ||||
| 		r := dumbReader(input) | ||||
| 		return Decode(&r, into) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func TestDecodeStreamReset(t *testing.T) { | ||||
| 	s := NewStream(nil) | ||||
| 	runTests(t, func(input []byte, into interface{}) error { | ||||
| 		s.Reset(bytes.NewReader(input)) | ||||
| 		return s.Decode(into) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| type testDecoder struct{ called bool } | ||||
| 
 | ||||
| func (t *testDecoder) DecodeRLP(s *Stream) error { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user