p2p, p2p/discover, p2p/discv5: add IP network restriction feature
The p2p packages can now be configured to restrict all communication to a certain subset of IP networks. This feature is meant to be used for private networks.
This commit is contained in:
		
							parent
							
								
									e46bda5093
								
							
						
					
					
						commit
						a47341cf96
					
				
							
								
								
									
										45
									
								
								p2p/dial.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								p2p/dial.go
									
									
									
									
									
								
							| @ -19,6 +19,7 @@ package p2p | |||||||
| import ( | import ( | ||||||
| 	"container/heap" | 	"container/heap" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"time" | 	"time" | ||||||
| @ -26,6 +27,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -48,6 +50,7 @@ const ( | |||||||
| type dialstate struct { | type dialstate struct { | ||||||
| 	maxDynDials int | 	maxDynDials int | ||||||
| 	ntab        discoverTable | 	ntab        discoverTable | ||||||
|  | 	netrestrict *netutil.Netlist | ||||||
| 
 | 
 | ||||||
| 	lookupRunning bool | 	lookupRunning bool | ||||||
| 	dialing       map[discover.NodeID]connFlag | 	dialing       map[discover.NodeID]connFlag | ||||||
| @ -100,10 +103,11 @@ type waitExpireTask struct { | |||||||
| 	time.Duration | 	time.Duration | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { | func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { | ||||||
| 	s := &dialstate{ | 	s := &dialstate{ | ||||||
| 		maxDynDials: maxdyn, | 		maxDynDials: maxdyn, | ||||||
| 		ntab:        ntab, | 		ntab:        ntab, | ||||||
|  | 		netrestrict: netrestrict, | ||||||
| 		static:      make(map[discover.NodeID]*dialTask), | 		static:      make(map[discover.NodeID]*dialTask), | ||||||
| 		dialing:     make(map[discover.NodeID]connFlag), | 		dialing:     make(map[discover.NodeID]connFlag), | ||||||
| 		randomNodes: make([]*discover.Node, maxdyn/2), | 		randomNodes: make([]*discover.Node, maxdyn/2), | ||||||
| @ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) { | |||||||
| 
 | 
 | ||||||
| func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { | func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { | ||||||
| 	var newtasks []task | 	var newtasks []task | ||||||
| 	isDialing := func(id discover.NodeID) bool { |  | ||||||
| 		_, found := s.dialing[id] |  | ||||||
| 		return found || peers[id] != nil || s.hist.contains(id) |  | ||||||
| 	} |  | ||||||
| 	addDial := func(flag connFlag, n *discover.Node) bool { | 	addDial := func(flag connFlag, n *discover.Node) bool { | ||||||
| 		if isDialing(n.ID) { | 		if err := s.checkDial(n, peers); err != nil { | ||||||
|  | 			glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err) | ||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 		s.dialing[n.ID] = flag | 		s.dialing[n.ID] = flag | ||||||
| @ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now | |||||||
| 
 | 
 | ||||||
| 	// Create dials for static nodes if they are not connected.
 | 	// Create dials for static nodes if they are not connected.
 | ||||||
| 	for id, t := range s.static { | 	for id, t := range s.static { | ||||||
| 		if !isDialing(id) { | 		err := s.checkDial(t.dest, peers) | ||||||
|  | 		switch err { | ||||||
|  | 		case errNotWhitelisted, errSelf: | ||||||
|  | 			glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err) | ||||||
|  | 			delete(s.static, t.dest.ID) | ||||||
|  | 		case nil: | ||||||
| 			s.dialing[id] = t.flags | 			s.dialing[id] = t.flags | ||||||
| 			newtasks = append(newtasks, t) | 			newtasks = append(newtasks, t) | ||||||
| 		} | 		} | ||||||
| @ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now | |||||||
| 	return newtasks | 	return newtasks | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var ( | ||||||
|  | 	errSelf             = errors.New("is self") | ||||||
|  | 	errAlreadyDialing   = errors.New("already dialing") | ||||||
|  | 	errAlreadyConnected = errors.New("already connected") | ||||||
|  | 	errRecentlyDialed   = errors.New("recently dialed") | ||||||
|  | 	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist") | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error { | ||||||
|  | 	_, dialing := s.dialing[n.ID] | ||||||
|  | 	switch { | ||||||
|  | 	case dialing: | ||||||
|  | 		return errAlreadyDialing | ||||||
|  | 	case peers[n.ID] != nil: | ||||||
|  | 		return errAlreadyConnected | ||||||
|  | 	case s.ntab != nil && n.ID == s.ntab.Self().ID: | ||||||
|  | 		return errSelf | ||||||
|  | 	case s.netrestrict != nil && !s.netrestrict.Contains(n.IP): | ||||||
|  | 		return errNotWhitelisted | ||||||
|  | 	case s.hist.contains(n.ID): | ||||||
|  | 		return errRecentlyDialed | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (s *dialstate) taskDone(t task, now time.Time) { | func (s *dialstate) taskDone(t task, now time.Time) { | ||||||
| 	switch t := t.(type) { | 	switch t := t.(type) { | ||||||
| 	case *dialTask: | 	case *dialTask: | ||||||
|  | |||||||
| @ -25,6 +25,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/davecgh/go-spew/spew" | 	"github.com/davecgh/go-spew/spew" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func init() { | func init() { | ||||||
| @ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, | |||||||
| // This test checks that dynamic dials are launched from discovery results.
 | // This test checks that dynamic dials are launched from discovery results.
 | ||||||
| func TestDialStateDynDial(t *testing.T) { | func TestDialStateDynDial(t *testing.T) { | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(nil, fakeTable{}, 5), | 		init: newDialState(nil, fakeTable{}, 5, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// A discovery query is launched.
 | 			// A discovery query is launched.
 | ||||||
| 			{ | 			{ | ||||||
| @ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(nil, table, 10), | 		init: newDialState(nil, table, 10, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
 | 			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
 | ||||||
| 			{ | 			{ | ||||||
| @ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // This test checks that candidates that do not match the netrestrict list are not dialed.
 | ||||||
|  | func TestDialStateNetRestrict(t *testing.T) { | ||||||
|  | 	// This table always returns the same random nodes
 | ||||||
|  | 	// in the order given below.
 | ||||||
|  | 	table := fakeTable{ | ||||||
|  | 		{ID: uintID(1), IP: net.ParseIP("127.0.0.1")}, | ||||||
|  | 		{ID: uintID(2), IP: net.ParseIP("127.0.0.2")}, | ||||||
|  | 		{ID: uintID(3), IP: net.ParseIP("127.0.0.3")}, | ||||||
|  | 		{ID: uintID(4), IP: net.ParseIP("127.0.0.4")}, | ||||||
|  | 		{ID: uintID(5), IP: net.ParseIP("127.0.2.5")}, | ||||||
|  | 		{ID: uintID(6), IP: net.ParseIP("127.0.2.6")}, | ||||||
|  | 		{ID: uintID(7), IP: net.ParseIP("127.0.2.7")}, | ||||||
|  | 		{ID: uintID(8), IP: net.ParseIP("127.0.2.8")}, | ||||||
|  | 	} | ||||||
|  | 	restrict := new(netutil.Netlist) | ||||||
|  | 	restrict.Add("127.0.2.0/24") | ||||||
|  | 
 | ||||||
|  | 	runDialTest(t, dialtest{ | ||||||
|  | 		init: newDialState(nil, table, 10, restrict), | ||||||
|  | 		rounds: []round{ | ||||||
|  | 			{ | ||||||
|  | 				new: []task{ | ||||||
|  | 					&dialTask{flags: dynDialedConn, dest: table[4]}, | ||||||
|  | 					&discoverTask{}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // This test checks that static dials are launched.
 | // This test checks that static dials are launched.
 | ||||||
| func TestDialStateStaticDial(t *testing.T) { | func TestDialStateStaticDial(t *testing.T) { | ||||||
| 	wantStatic := []*discover.Node{ | 	wantStatic := []*discover.Node{ | ||||||
| @ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(wantStatic, fakeTable{}, 0), | 		init: newDialState(wantStatic, fakeTable{}, 0, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// Static dials are launched for the nodes that
 | 			// Static dials are launched for the nodes that
 | ||||||
| 			// aren't yet connected.
 | 			// aren't yet connected.
 | ||||||
| @ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(wantStatic, fakeTable{}, 0), | 		init: newDialState(wantStatic, fakeTable{}, 0, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// Static dials are launched for the nodes that
 | 			// Static dials are launched for the nodes that
 | ||||||
| 			// aren't yet connected.
 | 			// aren't yet connected.
 | ||||||
| @ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) { | |||||||
| func TestDialResolve(t *testing.T) { | func TestDialResolve(t *testing.T) { | ||||||
| 	resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) | 	resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) | ||||||
| 	table := &resolveMock{answer: resolved} | 	table := &resolveMock{answer: resolved} | ||||||
| 	state := newDialState(nil, table, 0) | 	state := newDialState(nil, table, 0, nil) | ||||||
| 
 | 
 | ||||||
| 	// Check that the task is generated with an incomplete ID.
 | 	// Check that the task is generated with an incomplete ID.
 | ||||||
| 	dest := discover.NewNode(uintID(1), nil, 0, 0) | 	dest := discover.NewNode(uintID(1), nil, 0, 0) | ||||||
|  | |||||||
| @ -127,13 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { | |||||||
| 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} | 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { | func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { | ||||||
| 	if rn.UDP <= 1024 { | 	if rn.UDP <= 1024 { | ||||||
| 		return nil, errors.New("low port") | 		return nil, errors.New("low port") | ||||||
| 	} | 	} | ||||||
| 	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { | 	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { | ||||||
|  | 		return nil, errors.New("not contained in netrestrict whitelist") | ||||||
|  | 	} | ||||||
| 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) | 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) | ||||||
| 	err := n.validateComplete() | 	err := n.validateComplete() | ||||||
| 	return n, err | 	return n, err | ||||||
| @ -157,6 +160,7 @@ type conn interface { | |||||||
| // udp implements the RPC protocol.
 | // udp implements the RPC protocol.
 | ||||||
| type udp struct { | type udp struct { | ||||||
| 	conn        conn | 	conn        conn | ||||||
|  | 	netrestrict *netutil.Netlist | ||||||
| 	priv        *ecdsa.PrivateKey | 	priv        *ecdsa.PrivateKey | ||||||
| 	ourEndpoint rpcEndpoint | 	ourEndpoint rpcEndpoint | ||||||
| 
 | 
 | ||||||
| @ -207,7 +211,7 @@ type reply struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListenUDP returns a new table that listens for UDP packets on laddr.
 | // ListenUDP returns a new table that listens for UDP packets on laddr.
 | ||||||
| func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { | func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { | ||||||
| 	addr, err := net.ResolveUDPAddr("udp", laddr) | 	addr, err := net.ResolveUDPAddr("udp", laddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -216,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	tab, _, err := newUDP(priv, conn, natm, nodeDBPath) | 	tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @ -224,10 +228,11 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP | |||||||
| 	return tab, nil | 	return tab, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) { | func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { | ||||||
| 	udp := &udp{ | 	udp := &udp{ | ||||||
| 		conn:        c, | 		conn:        c, | ||||||
| 		priv:        priv, | 		priv:        priv, | ||||||
|  | 		netrestrict: netrestrict, | ||||||
| 		closing:     make(chan struct{}), | 		closing:     make(chan struct{}), | ||||||
| 		gotreply:    make(chan reply), | 		gotreply:    make(chan reply), | ||||||
| 		addpending:  make(chan *pending), | 		addpending:  make(chan *pending), | ||||||
| @ -287,7 +292,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node | |||||||
| 		reply := r.(*neighbors) | 		reply := r.(*neighbors) | ||||||
| 		for _, rn := range reply.Nodes { | 		for _, rn := range reply.Nodes { | ||||||
| 			nreceived++ | 			nreceived++ | ||||||
| 			n, err := nodeFromRPC(toaddr, rn) | 			n, err := t.nodeFromRPC(toaddr, rn) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) | 				glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) | ||||||
| 				continue | 				continue | ||||||
|  | |||||||
| @ -70,7 +70,7 @@ func newUDPTest(t *testing.T) *udpTest { | |||||||
| 		remotekey:  newkey(), | 		remotekey:  newkey(), | ||||||
| 		remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, | 		remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, | ||||||
| 	} | 	} | ||||||
| 	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") | 	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil) | ||||||
| 	return test | 	return test | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -31,6 +31,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| 	"github.com/ethereum/go-ethereum/rlp" | 	"github.com/ethereum/go-ethereum/rlp" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -65,6 +66,7 @@ func debugLog(s string) { | |||||||
| type Network struct { | type Network struct { | ||||||
| 	db          *nodeDB // database of known nodes
 | 	db          *nodeDB // database of known nodes
 | ||||||
| 	conn        transport | 	conn        transport | ||||||
|  | 	netrestrict *netutil.Netlist | ||||||
| 
 | 
 | ||||||
| 	closed           chan struct{}          // closed when loop is done
 | 	closed           chan struct{}          // closed when loop is done
 | ||||||
| 	closeReq         chan struct{}          // 'request to close'
 | 	closeReq         chan struct{}          // 'request to close'
 | ||||||
| @ -133,7 +135,7 @@ type timeoutEvent struct { | |||||||
| 	node *Node | 	node *Node | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) { | func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { | ||||||
| 	ourID := PubkeyID(&ourPubkey) | 	ourID := PubkeyID(&ourPubkey) | ||||||
| 
 | 
 | ||||||
| 	var db *nodeDB | 	var db *nodeDB | ||||||
| @ -148,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d | |||||||
| 	net := &Network{ | 	net := &Network{ | ||||||
| 		db:               db, | 		db:               db, | ||||||
| 		conn:             conn, | 		conn:             conn, | ||||||
|  | 		netrestrict:      netrestrict, | ||||||
| 		tab:              tab, | 		tab:              tab, | ||||||
| 		topictab:         newTopicTable(db, tab.self), | 		topictab:         newTopicTable(db, tab.self), | ||||||
| 		ticketStore:      newTicketStore(), | 		ticketStore:      newTicketStore(), | ||||||
| @ -696,6 +699,9 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n | |||||||
| 	if n == nil { | 	if n == nil { | ||||||
| 		// We haven't seen this node before.
 | 		// We haven't seen this node before.
 | ||||||
| 		n, err = nodeFromRPC(sender, rn) | 		n, err = nodeFromRPC(sender, rn) | ||||||
|  | 		if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) { | ||||||
|  | 			return n, errors.New("not contained in netrestrict whitelist") | ||||||
|  | 		} | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			n.state = unknown | 			n.state = unknown | ||||||
| 			net.nodes[n.ID] = n | 			net.nodes[n.ID] = n | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| func TestNetwork_Lookup(t *testing.T) { | func TestNetwork_Lookup(t *testing.T) { | ||||||
| 	key, _ := crypto.GenerateKey() | 	key, _ := crypto.GenerateKey() | ||||||
| 	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "") | 	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network { | |||||||
| 	addr := &net.UDPAddr{IP: ip, Port: 30303} | 	addr := &net.UDPAddr{IP: ip, Port: 30303} | ||||||
| 
 | 
 | ||||||
| 	transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} | 	transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} | ||||||
| 	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>") | 	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic("cannot launch new node: " + err.Error()) | 		panic("cannot launch new node: " + err.Error()) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -238,12 +238,12 @@ type udp struct { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListenUDP returns a new table that listens for UDP packets on laddr.
 | // ListenUDP returns a new table that listens for UDP packets on laddr.
 | ||||||
| func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) { | func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { | ||||||
| 	transport, err := listenUDP(priv, laddr) | 	transport, err := listenUDP(priv, laddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath) | 	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -30,6 +30,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discv5" | 	"github.com/ethereum/go-ethereum/p2p/discv5" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -101,6 +102,11 @@ type Config struct { | |||||||
| 	// allowed to connect, even above the peer limit.
 | 	// allowed to connect, even above the peer limit.
 | ||||||
| 	TrustedNodes []*discover.Node | 	TrustedNodes []*discover.Node | ||||||
| 
 | 
 | ||||||
|  | 	// Connectivity can be restricted to certain IP networks.
 | ||||||
|  | 	// If this option is set to a non-nil value, only hosts which match one of the
 | ||||||
|  | 	// IP networks contained in the list are considered.
 | ||||||
|  | 	NetRestrict *netutil.Netlist | ||||||
|  | 
 | ||||||
| 	// NodeDatabase is the path to the database containing the previously seen
 | 	// NodeDatabase is the path to the database containing the previously seen
 | ||||||
| 	// live nodes in the network.
 | 	// live nodes in the network.
 | ||||||
| 	NodeDatabase string | 	NodeDatabase string | ||||||
| @ -356,7 +362,7 @@ func (srv *Server) Start() (err error) { | |||||||
| 
 | 
 | ||||||
| 	// node table
 | 	// node table
 | ||||||
| 	if srv.Discovery { | 	if srv.Discovery { | ||||||
| 		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) | 		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @ -367,7 +373,7 @@ func (srv *Server) Start() (err error) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if srv.DiscoveryV5 { | 	if srv.DiscoveryV5 { | ||||||
| 		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase)
 | 		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
 | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @ -381,7 +387,7 @@ func (srv *Server) Start() (err error) { | |||||||
| 	if !srv.Discovery { | 	if !srv.Discovery { | ||||||
| 		dynPeers = 0 | 		dynPeers = 0 | ||||||
| 	} | 	} | ||||||
| 	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) | 	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict) | ||||||
| 
 | 
 | ||||||
| 	// handshake
 | 	// handshake
 | ||||||
| 	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} | 	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} | ||||||
| @ -634,8 +640,19 @@ func (srv *Server) listenLoop() { | |||||||
| 			} | 			} | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		// Reject connections that do not match NetRestrict.
 | ||||||
|  | 		if srv.NetRestrict != nil { | ||||||
|  | 			if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { | ||||||
|  | 				glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr()) | ||||||
|  | 				fd.Close() | ||||||
|  | 				slots <- struct{}{} | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		fd = newMeteredConn(fd, true) | 		fd = newMeteredConn(fd, true) | ||||||
| 		glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) | 		glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr()) | ||||||
| 
 | 
 | ||||||
| 		// Spawn the handler. It will give the slot back when the connection
 | 		// Spawn the handler. It will give the slot back when the connection
 | ||||||
| 		// has been established.
 | 		// has been established.
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user