diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 1a2405740..102c7c2d1 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) { n = new(Node) n.sha = hashAtDistance(base, ld) + n.IP = net.IP{10, 0, 2, byte(ld)} copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID return n } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index b48f16229..6a2c91317 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -126,8 +127,13 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} } -func nodeFromRPC(rn rpcNode) (*Node, error) { - // TODO: don't accept localhost, LAN addresses from internet hosts +func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { + if rn.UDP <= 1024 { + return nil, errors.New("low port") + } + if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + return nil, err + } n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) err := n.validateComplete() return n, err @@ -281,9 +287,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node reply := r.(*neighbors) for _, rn := range reply.Nodes { nreceived++ - if n, err := nodeFromRPC(rn); err == nil { - nodes = append(nodes, n) + n, err := nodeFromRPC(toaddr, rn) + if err != nil { + glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) + continue } + nodes = append(nodes, n) } return nreceived >= bucketSize }) @@ -595,6 +604,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. for i, n := range closest { + if netutil.CheckRelayIP(from.IP, n.IP) != nil { + continue + } p.Nodes = append(p.Nodes, nodeToRPC(n)) if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { t.send(from, neighborsPacket, p) diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 861bf48ca..495d41ef2 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -68,7 +68,7 @@ func newUDPTest(t *testing.T) *udpTest { pipe: newpipe(), localkey: newkey(), remotekey: newkey(), - remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303}, + remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") return test @@ -312,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) { // check that the sent neighbors are all returned by findnode select { case result := <-resultc: - if !reflect.DeepEqual(result, list) { - t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list) + want := append(list[:2], list[3:]...) + if !reflect.DeepEqual(result, want) { + t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, want) } case err := <-errc: t.Errorf("findnode error: %v", err) diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index 7ad6f1e5b..b7e4a0bee 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -45,6 +45,7 @@ const ( bucketRefreshInterval = 1 * time.Minute seedCount = 30 seedMaxAge = 5 * 24 * time.Hour + lowPort = 1024 ) const testTopic = "foo" @@ -684,16 +685,19 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node { return n } -func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) { +func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) { if rn.ID == net.tab.self.ID { return nil, errors.New("is self") } + if rn.UDP <= lowPort { + return nil, errors.New("low port") + } n = net.nodes[rn.ID] if n == nil { // We haven't seen this node before. - n, err = nodeFromRPC(rn) - n.state = unknown + n, err = nodeFromRPC(sender, rn) if err == nil { + n.state = unknown net.nodes[n.ID] = n } return n, err @@ -1095,7 +1099,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) net.conn.sendNeighbours(n, results) return n.state, nil case neighborsPacket: - err := net.handleNeighboursPacket(n, pkt.data.(*neighbors)) + err := net.handleNeighboursPacket(n, pkt) return n.state, err case neighboursTimeout: if n.pendingNeighbours != nil { @@ -1182,17 +1186,18 @@ func rlpHash(x interface{}) (h common.Hash) { return h } -func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error { +func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { if n.pendingNeighbours == nil { return errNoQuery } net.abortTimedEvent(n, neighboursTimeout) + req := pkt.data.(*neighbors) nodes := make([]*Node, len(req.Nodes)) for i, rn := range req.Nodes { - nn, err := net.internNodeFromNeighbours(rn) + nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn) if err != nil { - glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err) + glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err) continue } nodes[i] = nn diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go index 422daa33b..c8d2558f1 100644 --- a/p2p/discv5/net_test.go +++ b/p2p/discv5/net_test.go @@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) { // t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) // } // seed table with initial node (otherwise lookup will terminate immediately) - seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)} + seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)} if err := network.SetFallbackNodes(seeds); err != nil { t.Fatal(err) } @@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) { func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) { // current log distance is encoded in port number // fmt.Println("findnode query at dist", toaddr.Port) - if to.UDP == 0 { - panic("query to node at distance 0") + if to.UDP <= lowPort { + panic("query to node at or below distance 0") } next := to.UDP - 1 var result []rpcNode - for i, id := range tn.dists[to.UDP] { - result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) + for i, id := range tn.dists[to.UDP-lowPort] { + result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort))) } injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) } @@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha // ignored case findnodeHashPacket: // current log distance is encoded in port number - // fmt.Println("findnode query at dist", toaddr.Port) - if to.UDP == 0 { - panic("query to node at distance 0") + // fmt.Println("findnode query at dist", toaddr.Port-lowPort) + if to.UDP <= lowPort { + panic("query to node at or below distance 0") } next := to.UDP - 1 var result []rpcNode - for i, id := range tn.dists[to.UDP] { - result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) + for i, id := range tn.dists[to.UDP-lowPort] { + result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort))) } injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) default: @@ -328,8 +328,11 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int, panic("sendTopicRegister called") } -func (*preminedTestnet) Close() {} -func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) } +func (*preminedTestnet) Close() {} + +func (*preminedTestnet) localAddr() *net.UDPAddr { + return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000} +} // mine generates a testnet struct literal with nodes at // various distances to the given target. diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index f86fb4d81..396f438a2 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool { return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP) } -func nodeFromRPC(rn rpcNode) (*Node, error) { - // TODO: don't accept localhost, LAN addresses from internet hosts +func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { + if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + return nil, err + } n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) err := n.validateComplete() return n, err @@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) return } for i, result := range nodes { + if netutil.CheckRelayIP(remote.IP, result.IP) != nil { + continue + } p.Nodes = append(p.Nodes, nodeToRPC(result)) if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)