Merge pull request #3325 from fjl/p2p-netrestrict
Prevent relay of invalid IPs, add --netrestrict
This commit is contained in:
		
						commit
						d1a95c643e
					
				@ -29,6 +29,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discv5"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
@ -39,6 +40,7 @@ func main() {
 | 
			
		||||
		nodeKeyFile = flag.String("nodekey", "", "private key filename")
 | 
			
		||||
		nodeKeyHex  = flag.String("nodekeyhex", "", "private key as hex (for testing)")
 | 
			
		||||
		natdesc     = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
 | 
			
		||||
		netrestrict = flag.String("netrestrict", "", "restrict network communication to the given IP networks (CIDR masks)")
 | 
			
		||||
		runv5       = flag.Bool("v5", false, "run a v5 topic discovery bootnode")
 | 
			
		||||
 | 
			
		||||
		nodeKey *ecdsa.PrivateKey
 | 
			
		||||
@ -81,12 +83,20 @@ func main() {
 | 
			
		||||
		os.Exit(0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var restrictList *netutil.Netlist
 | 
			
		||||
	if *netrestrict != "" {
 | 
			
		||||
		restrictList, err = netutil.ParseNetlist(*netrestrict)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			utils.Fatalf("-netrestrict: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if *runv5 {
 | 
			
		||||
		if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil {
 | 
			
		||||
		if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil {
 | 
			
		||||
			utils.Fatalf("%v", err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil {
 | 
			
		||||
		if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil {
 | 
			
		||||
			utils.Fatalf("%v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -96,6 +96,7 @@ func init() {
 | 
			
		||||
		utils.BootnodesFlag,
 | 
			
		||||
		utils.KeyStoreDirFlag,
 | 
			
		||||
		utils.ListenPortFlag,
 | 
			
		||||
		utils.NetrestrictFlag,
 | 
			
		||||
		utils.MaxPeersFlag,
 | 
			
		||||
		utils.NATFlag,
 | 
			
		||||
		utils.NodeKeyFileFlag,
 | 
			
		||||
 | 
			
		||||
@ -148,6 +148,7 @@ participating.
 | 
			
		||||
		utils.NatspecEnabledFlag,
 | 
			
		||||
		utils.NoDiscoverFlag,
 | 
			
		||||
		utils.DiscoveryV5Flag,
 | 
			
		||||
		utils.NetrestrictFlag,
 | 
			
		||||
		utils.NodeKeyFileFlag,
 | 
			
		||||
		utils.NodeKeyHexFlag,
 | 
			
		||||
		utils.RPCEnabledFlag,
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discv5"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/params"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/pow"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/rpc"
 | 
			
		||||
@ -366,10 +367,16 @@ var (
 | 
			
		||||
		Name:  "v5disc",
 | 
			
		||||
		Usage: "Enables the experimental RLPx V5 (Topic Discovery) mechanism",
 | 
			
		||||
	}
 | 
			
		||||
	NetrestrictFlag = cli.StringFlag{
 | 
			
		||||
		Name:  "netrestrict",
 | 
			
		||||
		Usage: "Restricts network communication to the given IP networks (CIDR masks)",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	WhisperEnabledFlag = cli.BoolFlag{
 | 
			
		||||
		Name:  "shh",
 | 
			
		||||
		Usage: "Enable Whisper",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ATM the url is left to the user and deployment to
 | 
			
		||||
	JSpathFlag = cli.StringFlag{
 | 
			
		||||
		Name:  "jspath",
 | 
			
		||||
@ -693,6 +700,14 @@ func MakeNode(ctx *cli.Context, name, gitCommit string) *node.Node {
 | 
			
		||||
		config.MaxPeers = 0
 | 
			
		||||
		config.ListenAddr = ":0"
 | 
			
		||||
	}
 | 
			
		||||
	if netrestrict := ctx.GlobalString(NetrestrictFlag.Name); netrestrict != "" {
 | 
			
		||||
		list, err := netutil.ParseNetlist(netrestrict)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			Fatalf("Option %q: %v", NetrestrictFlag.Name, err)
 | 
			
		||||
		}
 | 
			
		||||
		config.NetRestrict = list
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	stack, err := node.New(config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Fatalf("Failed to create the protocol stack: %v", err)
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discv5"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
@ -103,6 +104,10 @@ type Config struct {
 | 
			
		||||
	// Listener address for the V5 discovery protocol UDP traffic.
 | 
			
		||||
	DiscoveryV5Addr string
 | 
			
		||||
 | 
			
		||||
	// Restrict communication to white listed IP networks.
 | 
			
		||||
	// The whitelist only applies when non-nil.
 | 
			
		||||
	NetRestrict *netutil.Netlist
 | 
			
		||||
 | 
			
		||||
	// BootstrapNodes used to establish connectivity with the rest of the network.
 | 
			
		||||
	BootstrapNodes []*discover.Node
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -165,6 +165,7 @@ func (n *Node) Start() error {
 | 
			
		||||
		TrustedNodes:     n.config.TrusterNodes(),
 | 
			
		||||
		NodeDatabase:     n.config.NodeDB(),
 | 
			
		||||
		ListenAddr:       n.config.ListenAddr,
 | 
			
		||||
		NetRestrict:      n.config.NetRestrict,
 | 
			
		||||
		NAT:              n.config.NAT,
 | 
			
		||||
		Dialer:           n.config.Dialer,
 | 
			
		||||
		NoDial:           n.config.NoDial,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										45
									
								
								p2p/dial.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								p2p/dial.go
									
									
									
									
									
								
							@ -19,6 +19,7 @@ package p2p
 | 
			
		||||
import (
 | 
			
		||||
	"container/heap"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
@ -26,6 +27,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger/glog"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@ -48,6 +50,7 @@ const (
 | 
			
		||||
type dialstate struct {
 | 
			
		||||
	maxDynDials int
 | 
			
		||||
	ntab        discoverTable
 | 
			
		||||
	netrestrict *netutil.Netlist
 | 
			
		||||
 | 
			
		||||
	lookupRunning bool
 | 
			
		||||
	dialing       map[discover.NodeID]connFlag
 | 
			
		||||
@ -100,10 +103,11 @@ type waitExpireTask struct {
 | 
			
		||||
	time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
 | 
			
		||||
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
 | 
			
		||||
	s := &dialstate{
 | 
			
		||||
		maxDynDials: maxdyn,
 | 
			
		||||
		ntab:        ntab,
 | 
			
		||||
		netrestrict: netrestrict,
 | 
			
		||||
		static:      make(map[discover.NodeID]*dialTask),
 | 
			
		||||
		dialing:     make(map[discover.NodeID]connFlag),
 | 
			
		||||
		randomNodes: make([]*discover.Node, maxdyn/2),
 | 
			
		||||
@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) {
 | 
			
		||||
 | 
			
		||||
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
 | 
			
		||||
	var newtasks []task
 | 
			
		||||
	isDialing := func(id discover.NodeID) bool {
 | 
			
		||||
		_, found := s.dialing[id]
 | 
			
		||||
		return found || peers[id] != nil || s.hist.contains(id)
 | 
			
		||||
	}
 | 
			
		||||
	addDial := func(flag connFlag, n *discover.Node) bool {
 | 
			
		||||
		if isDialing(n.ID) {
 | 
			
		||||
		if err := s.checkDial(n, peers); err != nil {
 | 
			
		||||
			glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err)
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		s.dialing[n.ID] = flag
 | 
			
		||||
@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
 | 
			
		||||
 | 
			
		||||
	// Create dials for static nodes if they are not connected.
 | 
			
		||||
	for id, t := range s.static {
 | 
			
		||||
		if !isDialing(id) {
 | 
			
		||||
		err := s.checkDial(t.dest, peers)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case errNotWhitelisted, errSelf:
 | 
			
		||||
			glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err)
 | 
			
		||||
			delete(s.static, t.dest.ID)
 | 
			
		||||
		case nil:
 | 
			
		||||
			s.dialing[id] = t.flags
 | 
			
		||||
			newtasks = append(newtasks, t)
 | 
			
		||||
		}
 | 
			
		||||
@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
 | 
			
		||||
	return newtasks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	errSelf             = errors.New("is self")
 | 
			
		||||
	errAlreadyDialing   = errors.New("already dialing")
 | 
			
		||||
	errAlreadyConnected = errors.New("already connected")
 | 
			
		||||
	errRecentlyDialed   = errors.New("recently dialed")
 | 
			
		||||
	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
 | 
			
		||||
	_, dialing := s.dialing[n.ID]
 | 
			
		||||
	switch {
 | 
			
		||||
	case dialing:
 | 
			
		||||
		return errAlreadyDialing
 | 
			
		||||
	case peers[n.ID] != nil:
 | 
			
		||||
		return errAlreadyConnected
 | 
			
		||||
	case s.ntab != nil && n.ID == s.ntab.Self().ID:
 | 
			
		||||
		return errSelf
 | 
			
		||||
	case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
 | 
			
		||||
		return errNotWhitelisted
 | 
			
		||||
	case s.hist.contains(n.ID):
 | 
			
		||||
		return errRecentlyDialed
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *dialstate) taskDone(t task, now time.Time) {
 | 
			
		||||
	switch t := t.(type) {
 | 
			
		||||
	case *dialTask:
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ import (
 | 
			
		||||
 | 
			
		||||
	"github.com/davecgh/go-spew/spew"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf,
 | 
			
		||||
// This test checks that dynamic dials are launched from discovery results.
 | 
			
		||||
func TestDialStateDynDial(t *testing.T) {
 | 
			
		||||
	runDialTest(t, dialtest{
 | 
			
		||||
		init: newDialState(nil, fakeTable{}, 5),
 | 
			
		||||
		init: newDialState(nil, fakeTable{}, 5, nil),
 | 
			
		||||
		rounds: []round{
 | 
			
		||||
			// A discovery query is launched.
 | 
			
		||||
			{
 | 
			
		||||
@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	runDialTest(t, dialtest{
 | 
			
		||||
		init: newDialState(nil, table, 10),
 | 
			
		||||
		init: newDialState(nil, table, 10, nil),
 | 
			
		||||
		rounds: []round{
 | 
			
		||||
			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
 | 
			
		||||
			{
 | 
			
		||||
@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This test checks that candidates that do not match the netrestrict list are not dialed.
 | 
			
		||||
func TestDialStateNetRestrict(t *testing.T) {
 | 
			
		||||
	// This table always returns the same random nodes
 | 
			
		||||
	// in the order given below.
 | 
			
		||||
	table := fakeTable{
 | 
			
		||||
		{ID: uintID(1), IP: net.ParseIP("127.0.0.1")},
 | 
			
		||||
		{ID: uintID(2), IP: net.ParseIP("127.0.0.2")},
 | 
			
		||||
		{ID: uintID(3), IP: net.ParseIP("127.0.0.3")},
 | 
			
		||||
		{ID: uintID(4), IP: net.ParseIP("127.0.0.4")},
 | 
			
		||||
		{ID: uintID(5), IP: net.ParseIP("127.0.2.5")},
 | 
			
		||||
		{ID: uintID(6), IP: net.ParseIP("127.0.2.6")},
 | 
			
		||||
		{ID: uintID(7), IP: net.ParseIP("127.0.2.7")},
 | 
			
		||||
		{ID: uintID(8), IP: net.ParseIP("127.0.2.8")},
 | 
			
		||||
	}
 | 
			
		||||
	restrict := new(netutil.Netlist)
 | 
			
		||||
	restrict.Add("127.0.2.0/24")
 | 
			
		||||
 | 
			
		||||
	runDialTest(t, dialtest{
 | 
			
		||||
		init: newDialState(nil, table, 10, restrict),
 | 
			
		||||
		rounds: []round{
 | 
			
		||||
			{
 | 
			
		||||
				new: []task{
 | 
			
		||||
					&dialTask{flags: dynDialedConn, dest: table[4]},
 | 
			
		||||
					&discoverTask{},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This test checks that static dials are launched.
 | 
			
		||||
func TestDialStateStaticDial(t *testing.T) {
 | 
			
		||||
	wantStatic := []*discover.Node{
 | 
			
		||||
@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	runDialTest(t, dialtest{
 | 
			
		||||
		init: newDialState(wantStatic, fakeTable{}, 0),
 | 
			
		||||
		init: newDialState(wantStatic, fakeTable{}, 0, nil),
 | 
			
		||||
		rounds: []round{
 | 
			
		||||
			// Static dials are launched for the nodes that
 | 
			
		||||
			// aren't yet connected.
 | 
			
		||||
@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	runDialTest(t, dialtest{
 | 
			
		||||
		init: newDialState(wantStatic, fakeTable{}, 0),
 | 
			
		||||
		init: newDialState(wantStatic, fakeTable{}, 0, nil),
 | 
			
		||||
		rounds: []round{
 | 
			
		||||
			// Static dials are launched for the nodes that
 | 
			
		||||
			// aren't yet connected.
 | 
			
		||||
@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) {
 | 
			
		||||
func TestDialResolve(t *testing.T) {
 | 
			
		||||
	resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444)
 | 
			
		||||
	table := &resolveMock{answer: resolved}
 | 
			
		||||
	state := newDialState(nil, table, 0)
 | 
			
		||||
	state := newDialState(nil, table, 0, nil)
 | 
			
		||||
 | 
			
		||||
	// Check that the task is generated with an incomplete ID.
 | 
			
		||||
	dest := discover.NewNode(uintID(1), nil, 0, 0)
 | 
			
		||||
 | 
			
		||||
@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) {
 | 
			
		||||
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
 | 
			
		||||
	n = new(Node)
 | 
			
		||||
	n.sha = hashAtDistance(base, ld)
 | 
			
		||||
	n.IP = net.IP{10, 0, 2, byte(ld)}
 | 
			
		||||
	copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
 | 
			
		||||
	return n
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger/glog"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/rlp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -126,8 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
 | 
			
		||||
	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func nodeFromRPC(rn rpcNode) (*Node, error) {
 | 
			
		||||
	// TODO: don't accept localhost, LAN addresses from internet hosts
 | 
			
		||||
func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
 | 
			
		||||
	if rn.UDP <= 1024 {
 | 
			
		||||
		return nil, errors.New("low port")
 | 
			
		||||
	}
 | 
			
		||||
	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
 | 
			
		||||
		return nil, errors.New("not contained in netrestrict whitelist")
 | 
			
		||||
	}
 | 
			
		||||
	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 | 
			
		||||
	err := n.validateComplete()
 | 
			
		||||
	return n, err
 | 
			
		||||
@ -151,6 +160,7 @@ type conn interface {
 | 
			
		||||
// udp implements the RPC protocol.
 | 
			
		||||
type udp struct {
 | 
			
		||||
	conn        conn
 | 
			
		||||
	netrestrict *netutil.Netlist
 | 
			
		||||
	priv        *ecdsa.PrivateKey
 | 
			
		||||
	ourEndpoint rpcEndpoint
 | 
			
		||||
 | 
			
		||||
@ -201,7 +211,7 @@ type reply struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
 | 
			
		||||
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) {
 | 
			
		||||
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
 | 
			
		||||
	addr, err := net.ResolveUDPAddr("udp", laddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@ -210,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	tab, _, err := newUDP(priv, conn, natm, nodeDBPath)
 | 
			
		||||
	tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@ -218,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
 | 
			
		||||
	return tab, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) {
 | 
			
		||||
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
 | 
			
		||||
	udp := &udp{
 | 
			
		||||
		conn:       c,
 | 
			
		||||
		priv:       priv,
 | 
			
		||||
		closing:    make(chan struct{}),
 | 
			
		||||
		gotreply:   make(chan reply),
 | 
			
		||||
		addpending: make(chan *pending),
 | 
			
		||||
		conn:        c,
 | 
			
		||||
		priv:        priv,
 | 
			
		||||
		netrestrict: netrestrict,
 | 
			
		||||
		closing:     make(chan struct{}),
 | 
			
		||||
		gotreply:    make(chan reply),
 | 
			
		||||
		addpending:  make(chan *pending),
 | 
			
		||||
	}
 | 
			
		||||
	realaddr := c.LocalAddr().(*net.UDPAddr)
 | 
			
		||||
	if natm != nil {
 | 
			
		||||
@ -281,9 +292,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
 | 
			
		||||
		reply := r.(*neighbors)
 | 
			
		||||
		for _, rn := range reply.Nodes {
 | 
			
		||||
			nreceived++
 | 
			
		||||
			if n, err := nodeFromRPC(rn); err == nil {
 | 
			
		||||
				nodes = append(nodes, n)
 | 
			
		||||
			n, err := t.nodeFromRPC(toaddr, rn)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			nodes = append(nodes, n)
 | 
			
		||||
		}
 | 
			
		||||
		return nreceived >= bucketSize
 | 
			
		||||
	})
 | 
			
		||||
@ -479,13 +493,6 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte,
 | 
			
		||||
	return packet, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isTemporaryError(err error) bool {
 | 
			
		||||
	tempErr, ok := err.(interface {
 | 
			
		||||
		Temporary() bool
 | 
			
		||||
	})
 | 
			
		||||
	return ok && tempErr.Temporary() || isPacketTooBig(err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
 | 
			
		||||
func (t *udp) readLoop() {
 | 
			
		||||
	defer t.conn.Close()
 | 
			
		||||
@ -495,7 +502,7 @@ func (t *udp) readLoop() {
 | 
			
		||||
	buf := make([]byte, 1280)
 | 
			
		||||
	for {
 | 
			
		||||
		nbytes, from, err := t.conn.ReadFromUDP(buf)
 | 
			
		||||
		if isTemporaryError(err) {
 | 
			
		||||
		if netutil.IsTemporaryError(err) {
 | 
			
		||||
			// Ignore temporary read errors.
 | 
			
		||||
			glog.V(logger.Debug).Infof("Temporary read error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
@ -602,6 +609,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
 | 
			
		||||
	// Send neighbors in chunks with at most maxNeighbors per packet
 | 
			
		||||
	// to stay below the 1280 byte limit.
 | 
			
		||||
	for i, n := range closest {
 | 
			
		||||
		if netutil.CheckRelayIP(from.IP, n.IP) != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		p.Nodes = append(p.Nodes, nodeToRPC(n))
 | 
			
		||||
		if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
 | 
			
		||||
			t.send(from, neighborsPacket, p)
 | 
			
		||||
 | 
			
		||||
@ -43,56 +43,6 @@ func init() {
 | 
			
		||||
	spew.Config.DisableMethods = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This test checks that isPacketTooBig correctly identifies
 | 
			
		||||
// errors that result from receiving a UDP packet larger
 | 
			
		||||
// than the supplied receive buffer.
 | 
			
		||||
func TestIsPacketTooBig(t *testing.T) {
 | 
			
		||||
	listener, err := net.ListenPacket("udp", "127.0.0.1:0")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer listener.Close()
 | 
			
		||||
	sender, err := net.Dial("udp", listener.LocalAddr().String())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer sender.Close()
 | 
			
		||||
 | 
			
		||||
	sendN := 1800
 | 
			
		||||
	recvN := 300
 | 
			
		||||
	for i := 0; i < 20; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			buf := make([]byte, sendN)
 | 
			
		||||
			for i := range buf {
 | 
			
		||||
				buf[i] = byte(i)
 | 
			
		||||
			}
 | 
			
		||||
			sender.Write(buf)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		buf := make([]byte, recvN)
 | 
			
		||||
		listener.SetDeadline(time.Now().Add(1 * time.Second))
 | 
			
		||||
		n, _, err := listener.ReadFrom(buf)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !isPacketTooBig(err) {
 | 
			
		||||
				t.Fatal("unexpected read error:", spew.Sdump(err))
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if n != recvN {
 | 
			
		||||
			t.Fatalf("short read: %d, want %d", n, recvN)
 | 
			
		||||
		}
 | 
			
		||||
		for i := range buf {
 | 
			
		||||
			if buf[i] != byte(i) {
 | 
			
		||||
				t.Fatalf("error in pattern")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// shared test variables
 | 
			
		||||
var (
 | 
			
		||||
	futureExp          = uint64(time.Now().Add(10 * time.Hour).Unix())
 | 
			
		||||
@ -118,9 +68,9 @@ func newUDPTest(t *testing.T) *udpTest {
 | 
			
		||||
		pipe:       newpipe(),
 | 
			
		||||
		localkey:   newkey(),
 | 
			
		||||
		remotekey:  newkey(),
 | 
			
		||||
		remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
 | 
			
		||||
		remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
 | 
			
		||||
	}
 | 
			
		||||
	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "")
 | 
			
		||||
	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
 | 
			
		||||
	return test
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -362,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
 | 
			
		||||
	// check that the sent neighbors are all returned by findnode
 | 
			
		||||
	select {
 | 
			
		||||
	case result := <-resultc:
 | 
			
		||||
		if !reflect.DeepEqual(result, list) {
 | 
			
		||||
			t.Errorf("neighbors mismatch:\n  got:  %v\n  want: %v", result, list)
 | 
			
		||||
		want := append(list[:2], list[3:]...)
 | 
			
		||||
		if !reflect.DeepEqual(result, want) {
 | 
			
		||||
			t.Errorf("neighbors mismatch:\n  got:  %v\n  want: %v", result, want)
 | 
			
		||||
		}
 | 
			
		||||
	case err := <-errc:
 | 
			
		||||
		t.Errorf("findnode error: %v", err)
 | 
			
		||||
 | 
			
		||||
@ -31,6 +31,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger/glog"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/rlp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -45,6 +46,7 @@ const (
 | 
			
		||||
	bucketRefreshInterval = 1 * time.Minute
 | 
			
		||||
	seedCount             = 30
 | 
			
		||||
	seedMaxAge            = 5 * 24 * time.Hour
 | 
			
		||||
	lowPort               = 1024
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const testTopic = "foo"
 | 
			
		||||
@ -62,8 +64,9 @@ func debugLog(s string) {
 | 
			
		||||
 | 
			
		||||
// Network manages the table and all protocol interaction.
 | 
			
		||||
type Network struct {
 | 
			
		||||
	db   *nodeDB // database of known nodes
 | 
			
		||||
	conn transport
 | 
			
		||||
	db          *nodeDB // database of known nodes
 | 
			
		||||
	conn        transport
 | 
			
		||||
	netrestrict *netutil.Netlist
 | 
			
		||||
 | 
			
		||||
	closed           chan struct{}          // closed when loop is done
 | 
			
		||||
	closeReq         chan struct{}          // 'request to close'
 | 
			
		||||
@ -132,7 +135,7 @@ type timeoutEvent struct {
 | 
			
		||||
	node *Node
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) {
 | 
			
		||||
func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
 | 
			
		||||
	ourID := PubkeyID(&ourPubkey)
 | 
			
		||||
 | 
			
		||||
	var db *nodeDB
 | 
			
		||||
@ -147,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d
 | 
			
		||||
	net := &Network{
 | 
			
		||||
		db:               db,
 | 
			
		||||
		conn:             conn,
 | 
			
		||||
		netrestrict:      netrestrict,
 | 
			
		||||
		tab:              tab,
 | 
			
		||||
		topictab:         newTopicTable(db, tab.self),
 | 
			
		||||
		ticketStore:      newTicketStore(),
 | 
			
		||||
@ -684,16 +688,22 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node {
 | 
			
		||||
	return n
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) {
 | 
			
		||||
func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
 | 
			
		||||
	if rn.ID == net.tab.self.ID {
 | 
			
		||||
		return nil, errors.New("is self")
 | 
			
		||||
	}
 | 
			
		||||
	if rn.UDP <= lowPort {
 | 
			
		||||
		return nil, errors.New("low port")
 | 
			
		||||
	}
 | 
			
		||||
	n = net.nodes[rn.ID]
 | 
			
		||||
	if n == nil {
 | 
			
		||||
		// We haven't seen this node before.
 | 
			
		||||
		n, err = nodeFromRPC(rn)
 | 
			
		||||
		n.state = unknown
 | 
			
		||||
		n, err = nodeFromRPC(sender, rn)
 | 
			
		||||
		if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
 | 
			
		||||
			return n, errors.New("not contained in netrestrict whitelist")
 | 
			
		||||
		}
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			n.state = unknown
 | 
			
		||||
			net.nodes[n.ID] = n
 | 
			
		||||
		}
 | 
			
		||||
		return n, err
 | 
			
		||||
@ -1095,7 +1105,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
 | 
			
		||||
		net.conn.sendNeighbours(n, results)
 | 
			
		||||
		return n.state, nil
 | 
			
		||||
	case neighborsPacket:
 | 
			
		||||
		err := net.handleNeighboursPacket(n, pkt.data.(*neighbors))
 | 
			
		||||
		err := net.handleNeighboursPacket(n, pkt)
 | 
			
		||||
		return n.state, err
 | 
			
		||||
	case neighboursTimeout:
 | 
			
		||||
		if n.pendingNeighbours != nil {
 | 
			
		||||
@ -1182,17 +1192,18 @@ func rlpHash(x interface{}) (h common.Hash) {
 | 
			
		||||
	return h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error {
 | 
			
		||||
func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
 | 
			
		||||
	if n.pendingNeighbours == nil {
 | 
			
		||||
		return errNoQuery
 | 
			
		||||
	}
 | 
			
		||||
	net.abortTimedEvent(n, neighboursTimeout)
 | 
			
		||||
 | 
			
		||||
	req := pkt.data.(*neighbors)
 | 
			
		||||
	nodes := make([]*Node, len(req.Nodes))
 | 
			
		||||
	for i, rn := range req.Nodes {
 | 
			
		||||
		nn, err := net.internNodeFromNeighbours(rn)
 | 
			
		||||
		nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err)
 | 
			
		||||
			glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		nodes[i] = nn
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func TestNetwork_Lookup(t *testing.T) {
 | 
			
		||||
	key, _ := crypto.GenerateKey()
 | 
			
		||||
	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "")
 | 
			
		||||
	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) {
 | 
			
		||||
	// 	t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
 | 
			
		||||
	// }
 | 
			
		||||
	// seed table with initial node (otherwise lookup will terminate immediately)
 | 
			
		||||
	seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)}
 | 
			
		||||
	seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)}
 | 
			
		||||
	if err := network.SetFallbackNodes(seeds); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) {
 | 
			
		||||
func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) {
 | 
			
		||||
	// current log distance is encoded in port number
 | 
			
		||||
	// fmt.Println("findnode query at dist", toaddr.Port)
 | 
			
		||||
	if to.UDP == 0 {
 | 
			
		||||
		panic("query to node at distance 0")
 | 
			
		||||
	if to.UDP <= lowPort {
 | 
			
		||||
		panic("query to node at or below distance 0")
 | 
			
		||||
	}
 | 
			
		||||
	next := to.UDP - 1
 | 
			
		||||
	var result []rpcNode
 | 
			
		||||
	for i, id := range tn.dists[to.UDP] {
 | 
			
		||||
		result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1)))
 | 
			
		||||
	for i, id := range tn.dists[to.UDP-lowPort] {
 | 
			
		||||
		result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
 | 
			
		||||
	}
 | 
			
		||||
	injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
 | 
			
		||||
}
 | 
			
		||||
@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha
 | 
			
		||||
		// ignored
 | 
			
		||||
	case findnodeHashPacket:
 | 
			
		||||
		// current log distance is encoded in port number
 | 
			
		||||
		// fmt.Println("findnode query at dist", toaddr.Port)
 | 
			
		||||
		if to.UDP == 0 {
 | 
			
		||||
			panic("query to node at distance 0")
 | 
			
		||||
		// fmt.Println("findnode query at dist", toaddr.Port-lowPort)
 | 
			
		||||
		if to.UDP <= lowPort {
 | 
			
		||||
			panic("query to node at or below  distance 0")
 | 
			
		||||
		}
 | 
			
		||||
		next := to.UDP - 1
 | 
			
		||||
		var result []rpcNode
 | 
			
		||||
		for i, id := range tn.dists[to.UDP] {
 | 
			
		||||
			result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1)))
 | 
			
		||||
		for i, id := range tn.dists[to.UDP-lowPort] {
 | 
			
		||||
			result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
 | 
			
		||||
		}
 | 
			
		||||
		injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
 | 
			
		||||
	default:
 | 
			
		||||
@ -328,8 +328,11 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int,
 | 
			
		||||
	panic("sendTopicRegister called")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*preminedTestnet) Close()                  {}
 | 
			
		||||
func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) }
 | 
			
		||||
func (*preminedTestnet) Close() {}
 | 
			
		||||
 | 
			
		||||
func (*preminedTestnet) localAddr() *net.UDPAddr {
 | 
			
		||||
	return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mine generates a testnet struct literal with nodes at
 | 
			
		||||
// various distances to the given target.
 | 
			
		||||
 | 
			
		||||
@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network {
 | 
			
		||||
	addr := &net.UDPAddr{IP: ip, Port: 30303}
 | 
			
		||||
 | 
			
		||||
	transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
 | 
			
		||||
	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>")
 | 
			
		||||
	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic("cannot launch new node: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger/glog"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/rlp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
 | 
			
		||||
	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func nodeFromRPC(rn rpcNode) (*Node, error) {
 | 
			
		||||
	// TODO: don't accept localhost, LAN addresses from internet hosts
 | 
			
		||||
func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
 | 
			
		||||
	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 | 
			
		||||
	err := n.validateComplete()
 | 
			
		||||
	return n, err
 | 
			
		||||
@ -235,12 +238,12 @@ type udp struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
 | 
			
		||||
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) {
 | 
			
		||||
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
 | 
			
		||||
	transport, err := listenUDP(priv, laddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath)
 | 
			
		||||
	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	for i, result := range nodes {
 | 
			
		||||
		if netutil.CheckRelayIP(remote.IP, result.IP) != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		p.Nodes = append(p.Nodes, nodeToRPC(result))
 | 
			
		||||
		if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
 | 
			
		||||
			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
 | 
			
		||||
@ -385,7 +391,7 @@ func (t *udp) readLoop() {
 | 
			
		||||
	buf := make([]byte, 1280)
 | 
			
		||||
	for {
 | 
			
		||||
		nbytes, from, err := t.conn.ReadFromUDP(buf)
 | 
			
		||||
		if isTemporaryError(err) {
 | 
			
		||||
		if netutil.IsTemporaryError(err) {
 | 
			
		||||
			// Ignore temporary read errors.
 | 
			
		||||
			glog.V(logger.Debug).Infof("Temporary read error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
@ -398,13 +404,6 @@ func (t *udp) readLoop() {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isTemporaryError(err error) bool {
 | 
			
		||||
	tempErr, ok := err.(interface {
 | 
			
		||||
		Temporary() bool
 | 
			
		||||
	})
 | 
			
		||||
	return ok && tempErr.Temporary() || isPacketTooBig(err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
 | 
			
		||||
	pkt := ingressPacket{remoteAddr: from}
 | 
			
		||||
	if err := decodePacket(buf, &pkt); err != nil {
 | 
			
		||||
 | 
			
		||||
@ -36,56 +36,6 @@ func init() {
 | 
			
		||||
	spew.Config.DisableMethods = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This test checks that isPacketTooBig correctly identifies
 | 
			
		||||
// errors that result from receiving a UDP packet larger
 | 
			
		||||
// than the supplied receive buffer.
 | 
			
		||||
func TestIsPacketTooBig(t *testing.T) {
 | 
			
		||||
	listener, err := net.ListenPacket("udp", "127.0.0.1:0")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer listener.Close()
 | 
			
		||||
	sender, err := net.Dial("udp", listener.LocalAddr().String())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer sender.Close()
 | 
			
		||||
 | 
			
		||||
	sendN := 1800
 | 
			
		||||
	recvN := 300
 | 
			
		||||
	for i := 0; i < 20; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			buf := make([]byte, sendN)
 | 
			
		||||
			for i := range buf {
 | 
			
		||||
				buf[i] = byte(i)
 | 
			
		||||
			}
 | 
			
		||||
			sender.Write(buf)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		buf := make([]byte, recvN)
 | 
			
		||||
		listener.SetDeadline(time.Now().Add(1 * time.Second))
 | 
			
		||||
		n, _, err := listener.ReadFrom(buf)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !isPacketTooBig(err) {
 | 
			
		||||
				t.Fatal("unexpected read error:", spew.Sdump(err))
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if n != recvN {
 | 
			
		||||
			t.Fatalf("short read: %d, want %d", n, recvN)
 | 
			
		||||
		}
 | 
			
		||||
		for i := range buf {
 | 
			
		||||
			if buf[i] != byte(i) {
 | 
			
		||||
				t.Fatalf("error in pattern")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// shared test variables
 | 
			
		||||
var (
 | 
			
		||||
	futureExp          = uint64(time.Now().Add(10 * time.Hour).Unix())
 | 
			
		||||
 | 
			
		||||
@ -1,40 +0,0 @@
 | 
			
		||||
// Copyright 2016 The go-ethereum Authors
 | 
			
		||||
// This file is part of the go-ethereum library.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is free software: you can redistribute it and/or modify
 | 
			
		||||
// it under the terms of the GNU Lesser General Public License as published by
 | 
			
		||||
// the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
// (at your option) any later version.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is distributed in the hope that it will be useful,
 | 
			
		||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 | 
			
		||||
// GNU Lesser General Public License for more details.
 | 
			
		||||
//
 | 
			
		||||
// You should have received a copy of the GNU Lesser General Public License
 | 
			
		||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
 | 
			
		||||
//+build windows
 | 
			
		||||
 | 
			
		||||
package discv5
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"syscall"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const _WSAEMSGSIZE = syscall.Errno(10040)
 | 
			
		||||
 | 
			
		||||
// reports whether err indicates that a UDP packet didn't
 | 
			
		||||
// fit the receive buffer. On Windows, WSARecvFrom returns
 | 
			
		||||
// code WSAEMSGSIZE and no data if this happens.
 | 
			
		||||
func isPacketTooBig(err error) bool {
 | 
			
		||||
	if opErr, ok := err.(*net.OpError); ok {
 | 
			
		||||
		if scErr, ok := opErr.Err.(*os.SyscallError); ok {
 | 
			
		||||
			return scErr.Err == _WSAEMSGSIZE
 | 
			
		||||
		}
 | 
			
		||||
		return opErr.Err == _WSAEMSGSIZE
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
@ -14,13 +14,12 @@
 | 
			
		||||
// You should have received a copy of the GNU Lesser General Public License
 | 
			
		||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
 | 
			
		||||
//+build !windows
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
package discv5
 | 
			
		||||
 | 
			
		||||
// reports whether err indicates that a UDP packet didn't
 | 
			
		||||
// fit the receive buffer. There is no such error on
 | 
			
		||||
// non-Windows platforms.
 | 
			
		||||
func isPacketTooBig(err error) bool {
 | 
			
		||||
	return false
 | 
			
		||||
// IsTemporaryError checks whether the given error should be considered temporary.
 | 
			
		||||
func IsTemporaryError(err error) bool {
 | 
			
		||||
	tempErr, ok := err.(interface {
 | 
			
		||||
		Temporary() bool
 | 
			
		||||
	})
 | 
			
		||||
	return ok && tempErr.Temporary() || isPacketTooBig(err)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										73
									
								
								p2p/netutil/error_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								p2p/netutil/error_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,73 @@
 | 
			
		||||
// Copyright 2016 The go-ethereum Authors
 | 
			
		||||
// This file is part of the go-ethereum library.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is free software: you can redistribute it and/or modify
 | 
			
		||||
// it under the terms of the GNU Lesser General Public License as published by
 | 
			
		||||
// the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
// (at your option) any later version.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is distributed in the hope that it will be useful,
 | 
			
		||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 | 
			
		||||
// GNU Lesser General Public License for more details.
 | 
			
		||||
//
 | 
			
		||||
// You should have received a copy of the GNU Lesser General Public License
 | 
			
		||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// This test checks that isPacketTooBig correctly identifies
 | 
			
		||||
// errors that result from receiving a UDP packet larger
 | 
			
		||||
// than the supplied receive buffer.
 | 
			
		||||
func TestIsPacketTooBig(t *testing.T) {
 | 
			
		||||
	listener, err := net.ListenPacket("udp", "127.0.0.1:0")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer listener.Close()
 | 
			
		||||
	sender, err := net.Dial("udp", listener.LocalAddr().String())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer sender.Close()
 | 
			
		||||
 | 
			
		||||
	sendN := 1800
 | 
			
		||||
	recvN := 300
 | 
			
		||||
	for i := 0; i < 20; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			buf := make([]byte, sendN)
 | 
			
		||||
			for i := range buf {
 | 
			
		||||
				buf[i] = byte(i)
 | 
			
		||||
			}
 | 
			
		||||
			sender.Write(buf)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		buf := make([]byte, recvN)
 | 
			
		||||
		listener.SetDeadline(time.Now().Add(1 * time.Second))
 | 
			
		||||
		n, _, err := listener.ReadFrom(buf)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !isPacketTooBig(err) {
 | 
			
		||||
				t.Fatalf("unexpected read error: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if n != recvN {
 | 
			
		||||
			t.Fatalf("short read: %d, want %d", n, recvN)
 | 
			
		||||
		}
 | 
			
		||||
		for i := range buf {
 | 
			
		||||
			if buf[i] != byte(i) {
 | 
			
		||||
				t.Fatalf("error in pattern")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										166
									
								
								p2p/netutil/net.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								p2p/netutil/net.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,166 @@
 | 
			
		||||
// Copyright 2016 The go-ethereum Authors
 | 
			
		||||
// This file is part of the go-ethereum library.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is free software: you can redistribute it and/or modify
 | 
			
		||||
// it under the terms of the GNU Lesser General Public License as published by
 | 
			
		||||
// the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
// (at your option) any later version.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is distributed in the hope that it will be useful,
 | 
			
		||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 | 
			
		||||
// GNU Lesser General Public License for more details.
 | 
			
		||||
//
 | 
			
		||||
// You should have received a copy of the GNU Lesser General Public License
 | 
			
		||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
 | 
			
		||||
// Package netutil contains extensions to the net package.
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var lan4, lan6, special4, special6 Netlist
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	// Lists from RFC 5735, RFC 5156,
 | 
			
		||||
	// https://www.iana.org/assignments/iana-ipv4-special-registry/
 | 
			
		||||
	lan4.Add("0.0.0.0/8")              // "This" network
 | 
			
		||||
	lan4.Add("10.0.0.0/8")             // Private Use
 | 
			
		||||
	lan4.Add("172.16.0.0/12")          // Private Use
 | 
			
		||||
	lan4.Add("192.168.0.0/16")         // Private Use
 | 
			
		||||
	lan6.Add("fe80::/10")              // Link-Local
 | 
			
		||||
	lan6.Add("fc00::/7")               // Unique-Local
 | 
			
		||||
	special4.Add("192.0.0.0/29")       // IPv4 Service Continuity
 | 
			
		||||
	special4.Add("192.0.0.9/32")       // PCP Anycast
 | 
			
		||||
	special4.Add("192.0.0.170/32")     // NAT64/DNS64 Discovery
 | 
			
		||||
	special4.Add("192.0.0.171/32")     // NAT64/DNS64 Discovery
 | 
			
		||||
	special4.Add("192.0.2.0/24")       // TEST-NET-1
 | 
			
		||||
	special4.Add("192.31.196.0/24")    // AS112
 | 
			
		||||
	special4.Add("192.52.193.0/24")    // AMT
 | 
			
		||||
	special4.Add("192.88.99.0/24")     // 6to4 Relay Anycast
 | 
			
		||||
	special4.Add("192.175.48.0/24")    // AS112
 | 
			
		||||
	special4.Add("198.18.0.0/15")      // Device Benchmark Testing
 | 
			
		||||
	special4.Add("198.51.100.0/24")    // TEST-NET-2
 | 
			
		||||
	special4.Add("203.0.113.0/24")     // TEST-NET-3
 | 
			
		||||
	special4.Add("255.255.255.255/32") // Limited Broadcast
 | 
			
		||||
 | 
			
		||||
	// http://www.iana.org/assignments/iana-ipv6-special-registry/
 | 
			
		||||
	special6.Add("100::/64")
 | 
			
		||||
	special6.Add("2001::/32")
 | 
			
		||||
	special6.Add("2001:1::1/128")
 | 
			
		||||
	special6.Add("2001:2::/48")
 | 
			
		||||
	special6.Add("2001:3::/32")
 | 
			
		||||
	special6.Add("2001:4:112::/48")
 | 
			
		||||
	special6.Add("2001:5::/32")
 | 
			
		||||
	special6.Add("2001:10::/28")
 | 
			
		||||
	special6.Add("2001:20::/28")
 | 
			
		||||
	special6.Add("2001:db8::/32")
 | 
			
		||||
	special6.Add("2002::/16")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Netlist is a list of IP networks.
 | 
			
		||||
type Netlist []net.IPNet
 | 
			
		||||
 | 
			
		||||
// ParseNetlist parses a comma-separated list of CIDR masks.
 | 
			
		||||
// Whitespace and extra commas are ignored.
 | 
			
		||||
func ParseNetlist(s string) (*Netlist, error) {
 | 
			
		||||
	ws := strings.NewReplacer(" ", "", "\n", "", "\t", "")
 | 
			
		||||
	masks := strings.Split(ws.Replace(s), ",")
 | 
			
		||||
	l := make(Netlist, 0)
 | 
			
		||||
	for _, mask := range masks {
 | 
			
		||||
		if mask == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		_, n, err := net.ParseCIDR(mask)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		l = append(l, *n)
 | 
			
		||||
	}
 | 
			
		||||
	return &l, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
 | 
			
		||||
// intended to be used for setting up static lists.
 | 
			
		||||
func (l *Netlist) Add(cidr string) {
 | 
			
		||||
	_, n, err := net.ParseCIDR(cidr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	*l = append(*l, *n)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains reports whether the given IP is contained in the list.
 | 
			
		||||
func (l *Netlist) Contains(ip net.IP) bool {
 | 
			
		||||
	if l == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	for _, net := range *l {
 | 
			
		||||
		if net.Contains(ip) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsLAN reports whether an IP is a local network address.
 | 
			
		||||
func IsLAN(ip net.IP) bool {
 | 
			
		||||
	if ip.IsLoopback() {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if v4 := ip.To4(); v4 != nil {
 | 
			
		||||
		return lan4.Contains(v4)
 | 
			
		||||
	}
 | 
			
		||||
	return lan6.Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsSpecialNetwork reports whether an IP is located in a special-use network range
 | 
			
		||||
// This includes broadcast, multicast and documentation addresses.
 | 
			
		||||
func IsSpecialNetwork(ip net.IP) bool {
 | 
			
		||||
	if ip.IsMulticast() {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if v4 := ip.To4(); v4 != nil {
 | 
			
		||||
		return special4.Contains(v4)
 | 
			
		||||
	}
 | 
			
		||||
	return special6.Contains(ip)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	errInvalid     = errors.New("invalid IP")
 | 
			
		||||
	errUnspecified = errors.New("zero address")
 | 
			
		||||
	errSpecial     = errors.New("special network")
 | 
			
		||||
	errLoopback    = errors.New("loopback address from non-loopback host")
 | 
			
		||||
	errLAN         = errors.New("LAN address from WAN host")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CheckRelayIP reports whether an IP relayed from the given sender IP
 | 
			
		||||
// is a valid connection target.
 | 
			
		||||
//
 | 
			
		||||
// There are four rules:
 | 
			
		||||
//   - Special network addresses are never valid.
 | 
			
		||||
//   - Loopback addresses are OK if relayed by a loopback host.
 | 
			
		||||
//   - LAN addresses are OK if relayed by a LAN host.
 | 
			
		||||
//   - All other addresses are always acceptable.
 | 
			
		||||
func CheckRelayIP(sender, addr net.IP) error {
 | 
			
		||||
	if len(addr) != net.IPv4len && len(addr) != net.IPv6len {
 | 
			
		||||
		return errInvalid
 | 
			
		||||
	}
 | 
			
		||||
	if addr.IsUnspecified() {
 | 
			
		||||
		return errUnspecified
 | 
			
		||||
	}
 | 
			
		||||
	if IsSpecialNetwork(addr) {
 | 
			
		||||
		return errSpecial
 | 
			
		||||
	}
 | 
			
		||||
	if addr.IsLoopback() && !sender.IsLoopback() {
 | 
			
		||||
		return errLoopback
 | 
			
		||||
	}
 | 
			
		||||
	if IsLAN(addr) && !IsLAN(sender) {
 | 
			
		||||
		return errLAN
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										173
									
								
								p2p/netutil/net_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								p2p/netutil/net_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,173 @@
 | 
			
		||||
// Copyright 2016 The go-ethereum Authors
 | 
			
		||||
// This file is part of the go-ethereum library.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is free software: you can redistribute it and/or modify
 | 
			
		||||
// it under the terms of the GNU Lesser General Public License as published by
 | 
			
		||||
// the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
// (at your option) any later version.
 | 
			
		||||
//
 | 
			
		||||
// The go-ethereum library is distributed in the hope that it will be useful,
 | 
			
		||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 | 
			
		||||
// GNU Lesser General Public License for more details.
 | 
			
		||||
//
 | 
			
		||||
// You should have received a copy of the GNU Lesser General Public License
 | 
			
		||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/davecgh/go-spew/spew"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestParseNetlist(t *testing.T) {
 | 
			
		||||
	var tests = []struct {
 | 
			
		||||
		input    string
 | 
			
		||||
		wantErr  error
 | 
			
		||||
		wantList *Netlist
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			input:    "",
 | 
			
		||||
			wantList: &Netlist{},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			input:    "127.0.0.0/8",
 | 
			
		||||
			wantErr:  nil,
 | 
			
		||||
			wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			input:   "127.0.0.0/44",
 | 
			
		||||
			wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			input: "127.0.0.0/16, 23.23.23.23/24,",
 | 
			
		||||
			wantList: &Netlist{
 | 
			
		||||
				{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)},
 | 
			
		||||
				{IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		l, err := ParseNetlist(test.input)
 | 
			
		||||
		if !reflect.DeepEqual(err, test.wantErr) {
 | 
			
		||||
			t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if !reflect.DeepEqual(l, test.wantList) {
 | 
			
		||||
			spew.Dump(l)
 | 
			
		||||
			spew.Dump(test.wantList)
 | 
			
		||||
			t.Errorf("%q: got %v, want %v", test.input, l, test.wantList)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNilNetListContains(t *testing.T) {
 | 
			
		||||
	var list *Netlist
 | 
			
		||||
	checkContains(t, list.Contains, nil, []string{"1.2.3.4"})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIsLAN(t *testing.T) {
 | 
			
		||||
	checkContains(t, IsLAN,
 | 
			
		||||
		[]string{ // included
 | 
			
		||||
			"0.0.0.0",
 | 
			
		||||
			"0.2.0.8",
 | 
			
		||||
			"127.0.0.1",
 | 
			
		||||
			"10.0.1.1",
 | 
			
		||||
			"10.22.0.3",
 | 
			
		||||
			"172.31.252.251",
 | 
			
		||||
			"192.168.1.4",
 | 
			
		||||
			"fe80::f4a1:8eff:fec5:9d9d",
 | 
			
		||||
			"febf::ab32:2233",
 | 
			
		||||
			"fc00::4",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{ // excluded
 | 
			
		||||
			"192.0.2.1",
 | 
			
		||||
			"1.0.0.0",
 | 
			
		||||
			"172.32.0.1",
 | 
			
		||||
			"fec0::2233",
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIsSpecialNetwork(t *testing.T) {
 | 
			
		||||
	checkContains(t, IsSpecialNetwork,
 | 
			
		||||
		[]string{ // included
 | 
			
		||||
			"192.0.2.1",
 | 
			
		||||
			"192.0.2.44",
 | 
			
		||||
			"2001:db8:85a3:8d3:1319:8a2e:370:7348",
 | 
			
		||||
			"255.255.255.255",
 | 
			
		||||
			"224.0.0.22", // IPv4 multicast
 | 
			
		||||
			"ff05::1:3",  // IPv6 multicast
 | 
			
		||||
		},
 | 
			
		||||
		[]string{ // excluded
 | 
			
		||||
			"192.0.3.1",
 | 
			
		||||
			"1.0.0.0",
 | 
			
		||||
			"172.32.0.1",
 | 
			
		||||
			"fec0::2233",
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) {
 | 
			
		||||
	for _, s := range inc {
 | 
			
		||||
		if !fn(parseIP(s)) {
 | 
			
		||||
			t.Error("returned false for included address", s)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for _, s := range exc {
 | 
			
		||||
		if fn(parseIP(s)) {
 | 
			
		||||
			t.Error("returned true for excluded address", s)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseIP(s string) net.IP {
 | 
			
		||||
	ip := net.ParseIP(s)
 | 
			
		||||
	if ip == nil {
 | 
			
		||||
		panic("invalid " + s)
 | 
			
		||||
	}
 | 
			
		||||
	return ip
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCheckRelayIP(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		sender, addr string
 | 
			
		||||
		want         error
 | 
			
		||||
	}{
 | 
			
		||||
		{"127.0.0.1", "0.0.0.0", errUnspecified},
 | 
			
		||||
		{"192.168.0.1", "0.0.0.0", errUnspecified},
 | 
			
		||||
		{"23.55.1.242", "0.0.0.0", errUnspecified},
 | 
			
		||||
		{"127.0.0.1", "255.255.255.255", errSpecial},
 | 
			
		||||
		{"192.168.0.1", "255.255.255.255", errSpecial},
 | 
			
		||||
		{"23.55.1.242", "255.255.255.255", errSpecial},
 | 
			
		||||
		{"192.168.0.1", "127.0.2.19", errLoopback},
 | 
			
		||||
		{"23.55.1.242", "192.168.0.1", errLAN},
 | 
			
		||||
 | 
			
		||||
		{"127.0.0.1", "127.0.2.19", nil},
 | 
			
		||||
		{"127.0.0.1", "192.168.0.1", nil},
 | 
			
		||||
		{"127.0.0.1", "23.55.1.242", nil},
 | 
			
		||||
		{"192.168.0.1", "192.168.0.1", nil},
 | 
			
		||||
		{"192.168.0.1", "23.55.1.242", nil},
 | 
			
		||||
		{"23.55.1.242", "23.55.1.242", nil},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		err := CheckRelayIP(parseIP(test.sender), parseIP(test.addr))
 | 
			
		||||
		if err != test.want {
 | 
			
		||||
			t.Errorf("%s from %s: got %q, want %q", test.addr, test.sender, err, test.want)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkCheckRelayIP(b *testing.B) {
 | 
			
		||||
	sender := parseIP("23.55.1.242")
 | 
			
		||||
	addr := parseIP("23.55.1.2")
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		CheckRelayIP(sender, addr)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -16,9 +16,9 @@
 | 
			
		||||
 | 
			
		||||
//+build !windows
 | 
			
		||||
 | 
			
		||||
package discover
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
// reports whether err indicates that a UDP packet didn't
 | 
			
		||||
// isPacketTooBig reports whether err indicates that a UDP packet didn't
 | 
			
		||||
// fit the receive buffer. There is no such error on
 | 
			
		||||
// non-Windows platforms.
 | 
			
		||||
func isPacketTooBig(err error) bool {
 | 
			
		||||
@ -16,7 +16,7 @@
 | 
			
		||||
 | 
			
		||||
//+build windows
 | 
			
		||||
 | 
			
		||||
package discover
 | 
			
		||||
package netutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
@ -26,7 +26,7 @@ import (
 | 
			
		||||
 | 
			
		||||
const _WSAEMSGSIZE = syscall.Errno(10040)
 | 
			
		||||
 | 
			
		||||
// reports whether err indicates that a UDP packet didn't
 | 
			
		||||
// isPacketTooBig reports whether err indicates that a UDP packet didn't
 | 
			
		||||
// fit the receive buffer. On Windows, WSARecvFrom returns
 | 
			
		||||
// code WSAEMSGSIZE and no data if this happens.
 | 
			
		||||
func isPacketTooBig(err error) bool {
 | 
			
		||||
@ -30,6 +30,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discv5"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/nat"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@ -101,6 +102,11 @@ type Config struct {
 | 
			
		||||
	// allowed to connect, even above the peer limit.
 | 
			
		||||
	TrustedNodes []*discover.Node
 | 
			
		||||
 | 
			
		||||
	// Connectivity can be restricted to certain IP networks.
 | 
			
		||||
	// If this option is set to a non-nil value, only hosts which match one of the
 | 
			
		||||
	// IP networks contained in the list are considered.
 | 
			
		||||
	NetRestrict *netutil.Netlist
 | 
			
		||||
 | 
			
		||||
	// NodeDatabase is the path to the database containing the previously seen
 | 
			
		||||
	// live nodes in the network.
 | 
			
		||||
	NodeDatabase string
 | 
			
		||||
@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) {
 | 
			
		||||
 | 
			
		||||
	// node table
 | 
			
		||||
	if srv.Discovery {
 | 
			
		||||
		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
 | 
			
		||||
		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if srv.DiscoveryV5 {
 | 
			
		||||
		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase)
 | 
			
		||||
		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) {
 | 
			
		||||
	if !srv.Discovery {
 | 
			
		||||
		dynPeers = 0
 | 
			
		||||
	}
 | 
			
		||||
	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers)
 | 
			
		||||
	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict)
 | 
			
		||||
 | 
			
		||||
	// handshake
 | 
			
		||||
	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
 | 
			
		||||
@ -634,8 +640,19 @@ func (srv *Server) listenLoop() {
 | 
			
		||||
			}
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Reject connections that do not match NetRestrict.
 | 
			
		||||
		if srv.NetRestrict != nil {
 | 
			
		||||
			if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
 | 
			
		||||
				glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr())
 | 
			
		||||
				fd.Close()
 | 
			
		||||
				slots <- struct{}{}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fd = newMeteredConn(fd, true)
 | 
			
		||||
		glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
 | 
			
		||||
		glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr())
 | 
			
		||||
 | 
			
		||||
		// Spawn the handler. It will give the slot back when the connection
 | 
			
		||||
		// has been established.
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,7 @@ import (
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/logger/glog"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/p2p/netutil"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/swarm/network/kademlia"
 | 
			
		||||
	"github.com/ethereum/go-ethereum/swarm/storage"
 | 
			
		||||
)
 | 
			
		||||
@ -288,6 +289,10 @@ func newNodeRecord(addr *peerAddr) *kademlia.NodeRecord {
 | 
			
		||||
func (self *Hive) HandlePeersMsg(req *peersMsgData, from *peer) {
 | 
			
		||||
	var nrs []*kademlia.NodeRecord
 | 
			
		||||
	for _, p := range req.Peers {
 | 
			
		||||
		if err := netutil.CheckRelayIP(from.remoteAddr.IP, p.IP); err != nil {
 | 
			
		||||
			glog.V(logger.Detail).Infof("invalid peer IP %v from %v: %v", from.remoteAddr.IP, p.IP, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		nrs = append(nrs, newNodeRecord(p))
 | 
			
		||||
	}
 | 
			
		||||
	self.kad.Add(nrs)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user