From a47341cf96498332e2f0f67c1a6456c67831a5d0 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 22 Nov 2016 20:51:59 +0100 Subject: [PATCH] p2p, p2p/discover, p2p/discv5: add IP network restriction feature The p2p packages can now be configured to restrict all communication to a certain subset of IP networks. This feature is meant to be used for private networks. --- p2p/dial.go | 45 +++++++++++++++++++++++++++++++++------- p2p/dial_test.go | 41 +++++++++++++++++++++++++++++++----- p2p/discover/udp.go | 25 +++++++++++++--------- p2p/discover/udp_test.go | 2 +- p2p/discv5/net.go | 12 ++++++++--- p2p/discv5/net_test.go | 2 +- p2p/discv5/sim_test.go | 2 +- p2p/discv5/udp.go | 4 ++-- p2p/server.go | 25 ++++++++++++++++++---- 9 files changed, 124 insertions(+), 34 deletions(-) diff --git a/p2p/dial.go b/p2p/dial.go index 691b8539e..57fba136a 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -19,6 +19,7 @@ package p2p import ( "container/heap" "crypto/rand" + "errors" "fmt" "net" "time" @@ -26,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/netutil" ) const ( @@ -48,6 +50,7 @@ const ( type dialstate struct { maxDynDials int ntab discoverTable + netrestrict *netutil.Netlist lookupRunning bool dialing map[discover.NodeID]connFlag @@ -100,10 +103,11 @@ type waitExpireTask struct { time.Duration } -func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { +func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { s := &dialstate{ maxDynDials: maxdyn, ntab: ntab, + netrestrict: netrestrict, static: make(map[discover.NodeID]*dialTask), dialing: make(map[discover.NodeID]connFlag), randomNodes: make([]*discover.Node, maxdyn/2), @@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) { func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { var newtasks []task - isDialing := func(id discover.NodeID) bool { - _, found := s.dialing[id] - return found || peers[id] != nil || s.hist.contains(id) - } addDial := func(flag connFlag, n *discover.Node) bool { - if isDialing(n.ID) { + if err := s.checkDial(n, peers); err != nil { + glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err) return false } s.dialing[n.ID] = flag @@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now // Create dials for static nodes if they are not connected. for id, t := range s.static { - if !isDialing(id) { + err := s.checkDial(t.dest, peers) + switch err { + case errNotWhitelisted, errSelf: + glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err) + delete(s.static, t.dest.ID) + case nil: s.dialing[id] = t.flags newtasks = append(newtasks, t) } @@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now return newtasks } +var ( + errSelf = errors.New("is self") + errAlreadyDialing = errors.New("already dialing") + errAlreadyConnected = errors.New("already connected") + errRecentlyDialed = errors.New("recently dialed") + errNotWhitelisted = errors.New("not contained in netrestrict whitelist") +) + +func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error { + _, dialing := s.dialing[n.ID] + switch { + case dialing: + return errAlreadyDialing + case peers[n.ID] != nil: + return errAlreadyConnected + case s.ntab != nil && n.ID == s.ntab.Self().ID: + return errSelf + case s.netrestrict != nil && !s.netrestrict.Contains(n.IP): + return errNotWhitelisted + case s.hist.contains(n.ID): + return errRecentlyDialed + } + return nil +} + func (s *dialstate) taskDone(t task, now time.Time) { switch t := t.(type) { case *dialTask: diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 05d9b7562..c850233db 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -25,6 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/netutil" ) func init() { @@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { runDialTest(t, dialtest{ - init: newDialState(nil, fakeTable{}, 5), + init: newDialState(nil, fakeTable{}, 5, nil), rounds: []round{ // A discovery query is launched. { @@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(nil, table, 10), + init: newDialState(nil, table, 10, nil), rounds: []round{ // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. { @@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) { }) } +// This test checks that candidates that do not match the netrestrict list are not dialed. +func TestDialStateNetRestrict(t *testing.T) { + // This table always returns the same random nodes + // in the order given below. + table := fakeTable{ + {ID: uintID(1), IP: net.ParseIP("127.0.0.1")}, + {ID: uintID(2), IP: net.ParseIP("127.0.0.2")}, + {ID: uintID(3), IP: net.ParseIP("127.0.0.3")}, + {ID: uintID(4), IP: net.ParseIP("127.0.0.4")}, + {ID: uintID(5), IP: net.ParseIP("127.0.2.5")}, + {ID: uintID(6), IP: net.ParseIP("127.0.2.6")}, + {ID: uintID(7), IP: net.ParseIP("127.0.2.7")}, + {ID: uintID(8), IP: net.ParseIP("127.0.2.8")}, + } + restrict := new(netutil.Netlist) + restrict.Add("127.0.2.0/24") + + runDialTest(t, dialtest{ + init: newDialState(nil, table, 10, restrict), + rounds: []round{ + { + new: []task{ + &dialTask{flags: dynDialedConn, dest: table[4]}, + &discoverTask{}, + }, + }, + }, + }) +} + // This test checks that static dials are launched. func TestDialStateStaticDial(t *testing.T) { wantStatic := []*discover.Node{ @@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(wantStatic, fakeTable{}, 0), + init: newDialState(wantStatic, fakeTable{}, 0, nil), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(wantStatic, fakeTable{}, 0), + init: newDialState(wantStatic, fakeTable{}, 0, nil), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) { func TestDialResolve(t *testing.T) { resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) table := &resolveMock{answer: resolved} - state := newDialState(nil, table, 0) + state := newDialState(nil, table, 0, nil) // Check that the task is generated with an incomplete ID. dest := discover.NewNode(uintID(1), nil, 0, 0) diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 6a2c91317..e09c63ffb 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -127,13 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} } -func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { +func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { if rn.UDP <= 1024 { return nil, errors.New("low port") } if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { return nil, err } + if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { + return nil, errors.New("not contained in netrestrict whitelist") + } n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) err := n.validateComplete() return n, err @@ -157,6 +160,7 @@ type conn interface { // udp implements the RPC protocol. type udp struct { conn conn + netrestrict *netutil.Netlist priv *ecdsa.PrivateKey ourEndpoint rpcEndpoint @@ -207,7 +211,7 @@ type reply struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { addr, err := net.ResolveUDPAddr("udp", laddr) if err != nil { return nil, err @@ -216,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP if err != nil { return nil, err } - tab, _, err := newUDP(priv, conn, natm, nodeDBPath) + tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) if err != nil { return nil, err } @@ -224,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) { +func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { udp := &udp{ - conn: c, - priv: priv, - closing: make(chan struct{}), - gotreply: make(chan reply), - addpending: make(chan *pending), + conn: c, + priv: priv, + netrestrict: netrestrict, + closing: make(chan struct{}), + gotreply: make(chan reply), + addpending: make(chan *pending), } realaddr := c.LocalAddr().(*net.UDPAddr) if natm != nil { @@ -287,7 +292,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node reply := r.(*neighbors) for _, rn := range reply.Nodes { nreceived++ - n, err := nodeFromRPC(toaddr, rn) + n, err := t.nodeFromRPC(toaddr, rn) if err != nil { glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) continue diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 495d41ef2..53cfac6f9 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -70,7 +70,7 @@ func newUDPTest(t *testing.T) *udpTest { remotekey: newkey(), remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") + test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil) return test } diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index b7e4a0bee..d1c48904e 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -31,6 +31,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -63,8 +64,9 @@ func debugLog(s string) { // Network manages the table and all protocol interaction. type Network struct { - db *nodeDB // database of known nodes - conn transport + db *nodeDB // database of known nodes + conn transport + netrestrict *netutil.Netlist closed chan struct{} // closed when loop is done closeReq chan struct{} // 'request to close' @@ -133,7 +135,7 @@ type timeoutEvent struct { node *Node } -func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) { +func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { ourID := PubkeyID(&ourPubkey) var db *nodeDB @@ -148,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d net := &Network{ db: db, conn: conn, + netrestrict: netrestrict, tab: tab, topictab: newTopicTable(db, tab.self), ticketStore: newTicketStore(), @@ -696,6 +699,9 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n if n == nil { // We haven't seen this node before. n, err = nodeFromRPC(sender, rn) + if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) { + return n, errors.New("not contained in netrestrict whitelist") + } if err == nil { n.state = unknown net.nodes[n.ID] = n diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go index c8d2558f1..327457c7c 100644 --- a/p2p/discv5/net_test.go +++ b/p2p/discv5/net_test.go @@ -28,7 +28,7 @@ import ( func TestNetwork_Lookup(t *testing.T) { key, _ := crypto.GenerateKey() - network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "") + network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil) if err != nil { t.Fatal(err) } diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go index 2e232fbaa..cb64d7fa0 100644 --- a/p2p/discv5/sim_test.go +++ b/p2p/discv5/sim_test.go @@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network { addr := &net.UDPAddr{IP: ip, Port: 30303} transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} - net, err := newNetwork(transport, key.PublicKey, nil, "") + net, err := newNetwork(transport, key.PublicKey, nil, "", nil) if err != nil { panic("cannot launch new node: " + err.Error()) } diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 396f438a2..a6114e032 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -238,12 +238,12 @@ type udp struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) { +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { transport, err := listenUDP(priv, laddr) if err != nil { return nil, err } - net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath) + net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) if err != nil { return nil, err } diff --git a/p2p/server.go b/p2p/server.go index 7381127dc..cf9672e2d 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" ) const ( @@ -101,6 +102,11 @@ type Config struct { // allowed to connect, even above the peer limit. TrustedNodes []*discover.Node + // Connectivity can be restricted to certain IP networks. + // If this option is set to a non-nil value, only hosts which match one of the + // IP networks contained in the list are considered. + NetRestrict *netutil.Netlist + // NodeDatabase is the path to the database containing the previously seen // live nodes in the network. NodeDatabase string @@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) { // node table if srv.Discovery { - ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) + ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) if err != nil { return err } @@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) { } if srv.DiscoveryV5 { - ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase) + ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase) if err != nil { return err } @@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) { if !srv.Discovery { dynPeers = 0 } - dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) + dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} @@ -634,8 +640,19 @@ func (srv *Server) listenLoop() { } break } + + // Reject connections that do not match NetRestrict. + if srv.NetRestrict != nil { + if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { + glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr()) + fd.Close() + slots <- struct{}{} + continue + } + } + fd = newMeteredConn(fd, true) - glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) + glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr()) // Spawn the handler. It will give the slot back when the connection // has been established.