diff --git a/les/client_handler.go b/les/client_handler.go index d7ca1c54f..6de576696 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/light" @@ -37,6 +38,7 @@ import ( // responses. type clientHandler struct { ulc *ulc + forkFilter forkid.Filter checkpoint *params.TrustedCheckpoint fetcher *lightFetcher downloader *downloader.Downloader @@ -49,6 +51,7 @@ type clientHandler struct { func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { handler := &clientHandler{ + forkFilter: forkid.NewFilter(backend.blockchain), checkpoint: checkpoint, backend: backend, closeCh: make(chan struct{}), @@ -103,7 +106,8 @@ func (h *clientHandler) handle(p *serverPeer) error { p.Log().Debug("Light Ethereum peer connected", "name", p.Name()) // Execute the LES handshake - if err := p.Handshake(h.backend.blockchain.Genesis().Hash()); err != nil { + forkid := forkid.NewID(h.backend.blockchain.Config(), h.backend.genesis, h.backend.blockchain.CurrentHeader().Number.Uint64()) + if err := p.Handshake(h.backend.blockchain.Genesis().Hash(), forkid, h.forkFilter); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } @@ -154,8 +158,8 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { var deliverMsg *Msg // Handle the message depending on its contents - switch msg.Code { - case AnnounceMsg: + switch { + case msg.Code == AnnounceMsg: p.Log().Trace("Received announce message") var req announceData if err := msg.Decode(&req); err != nil { @@ -188,7 +192,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { p.updateHead(req.Hash, req.Number, req.Td) h.fetcher.announce(p, &req) } - case BlockHeadersMsg: + case msg.Code == BlockHeadersMsg: p.Log().Trace("Received block header response message") var resp struct { ReqID, BV uint64 @@ -220,7 +224,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { } } } - case BlockBodiesMsg: + case msg.Code == BlockBodiesMsg: p.Log().Trace("Received block bodies response") var resp struct { ReqID, BV uint64 @@ -236,7 +240,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case CodeMsg: + case msg.Code == CodeMsg: p.Log().Trace("Received code response") var resp struct { ReqID, BV uint64 @@ -252,7 +256,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case ReceiptsMsg: + case msg.Code == ReceiptsMsg: p.Log().Trace("Received receipts response") var resp struct { ReqID, BV uint64 @@ -268,7 +272,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Receipts, } - case ProofsV2Msg: + case msg.Code == ProofsV2Msg: p.Log().Trace("Received les/2 proofs response") var resp struct { ReqID, BV uint64 @@ -284,7 +288,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case HelperTrieProofsMsg: + case msg.Code == HelperTrieProofsMsg: p.Log().Trace("Received helper trie proof response") var resp struct { ReqID, BV uint64 @@ -300,7 +304,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case TxStatusMsg: + case msg.Code == TxStatusMsg: p.Log().Trace("Received tx status response") var resp struct { ReqID, BV uint64 @@ -316,11 +320,11 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Status, } - case StopMsg: + case msg.Code == StopMsg && p.version >= lpv3: p.freeze() h.backend.retriever.frozen(p) p.Log().Debug("Service stopped") - case ResumeMsg: + case msg.Code == ResumeMsg && p.version >= lpv3: var bv uint64 if err := msg.Decode(&bv); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) diff --git a/les/peer.go b/les/peer.go index 2b0117bed..31ee8f7f0 100644 --- a/les/peer.go +++ b/les/peer.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" @@ -246,7 +247,7 @@ func (p *peerCommons) sendReceiveHandshake(sendList keyValueList) (keyValueList, // network IDs, difficulties, head and genesis blocks. Besides the basic handshake // fields, server and client can exchange and resolve some specified fields through // two callback functions. -func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, sendCallback func(*keyValueList), recvCallback func(keyValueMap) error) error { +func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter, sendCallback func(*keyValueList), recvCallback func(keyValueMap) error) error { p.lock.Lock() defer p.lock.Unlock() @@ -262,6 +263,12 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g send = send.add("headNum", headNum) send = send.add("genesisHash", genesis) + // If the protocol version is beyond les4, then pass the forkID + // as well. Check http://eips.ethereum.org/EIPS/eip-2124 for more + // spec detail. + if p.version >= lpv4 { + send = send.add("forkID", forkID) + } // Add client-specified or server-specified fields if sendCallback != nil { sendCallback(&send) @@ -295,6 +302,16 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g if int(rVersion) != p.version { return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version) } + // Check forkID if the protocol version is beyond the les4 + if p.version >= lpv4 { + var forkID forkid.ID + if err := recv.get("forkID", &forkID); err != nil { + return err + } + if err := forkFilter(forkID); err != nil { + return errResp(ErrForkIDRejected, "%v", err) + } + } if recvCallback != nil { return recvCallback(recv) } @@ -561,10 +578,10 @@ func (p *serverPeer) updateHead(hash common.Hash, number uint64, td *big.Int) { // Handshake executes the les protocol handshake, negotiating version number, // network IDs and genesis blocks. -func (p *serverPeer) Handshake(genesis common.Hash) error { +func (p *serverPeer) Handshake(genesis common.Hash, forkid forkid.ID, forkFilter forkid.Filter) error { // Note: there is no need to share local head with a server but older servers still // require these fields so we announce zero values. - return p.handshake(common.Big0, common.Hash{}, 0, genesis, func(lists *keyValueList) { + return p.handshake(common.Big0, common.Hash{}, 0, genesis, forkid, forkFilter, func(lists *keyValueList) { // Add some client-specific handshake fields // // Enable signed announcement randomly even the server is not trusted. @@ -944,11 +961,11 @@ func (p *clientPeer) freezeClient() { // Handshake executes the les protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. -func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { +func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter, server *LesServer) error { // Note: clientPeer.headInfo should contain the last head announced to the client by us. // The values announced in the handshake are dummy values for compatibility reasons and should be ignored. p.headInfo = blockInfo{Hash: head, Number: headNum, Td: td} - return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) { + return p.handshake(td, head, headNum, genesis, forkID, forkFilter, func(lists *keyValueList) { // Add some information which services server can offer. if !server.config.UltraLightOnlyAnnounce { *lists = (*lists).add("serveHeaders", nil) diff --git a/les/peer_test.go b/les/peer_test.go index 6d3c7f975..d6551ce6b 100644 --- a/les/peer_test.go +++ b/les/peer_test.go @@ -26,8 +26,13 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/params" ) type testServerPeerSub struct { @@ -91,6 +96,14 @@ func TestPeerSubscription(t *testing.T) { checkPeers(sub.unregCh) } +type fakeChain struct{} + +func (f *fakeChain) Config() *params.ChainConfig { return params.MainnetChainConfig } +func (f *fakeChain) Genesis() *types.Block { + return core.DefaultGenesisBlock().ToBlock(rawdb.NewMemoryDatabase()) +} +func (f *fakeChain) CurrentHeader() *types.Header { return &types.Header{Number: big.NewInt(10000000)} } + func TestHandshake(t *testing.T) { // Create a message pipe to communicate through app, net := p2p.MsgPipe() @@ -110,15 +123,21 @@ func TestHandshake(t *testing.T) { head = common.HexToHash("deadbeef") headNum = uint64(10) genesis = common.HexToHash("cafebabe") + + chain1, chain2 = &fakeChain{}, &fakeChain{} + forkID1 = forkid.NewID(chain1.Config(), chain1.Genesis().Hash(), chain1.CurrentHeader().Number.Uint64()) + forkID2 = forkid.NewID(chain2.Config(), chain2.Genesis().Hash(), chain2.CurrentHeader().Number.Uint64()) + filter1, filter2 = forkid.NewFilter(chain1), forkid.NewFilter(chain2) ) + go func() { - errCh1 <- peer1.handshake(td, head, headNum, genesis, func(list *keyValueList) { + errCh1 <- peer1.handshake(td, head, headNum, genesis, forkID1, filter1, func(list *keyValueList) { var announceType uint64 = announceTypeSigned *list = (*list).add("announceType", announceType) }, nil) }() go func() { - errCh2 <- peer2.handshake(td, head, headNum, genesis, nil, func(recv keyValueMap) error { + errCh2 <- peer2.handshake(td, head, headNum, genesis, forkID2, filter2, nil, func(recv keyValueMap) error { var reqType uint64 err := recv.get("announceType", &reqType) if err != nil { diff --git a/les/protocol.go b/les/protocol.go index 19a9561ce..aebe0f2c0 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -34,6 +34,7 @@ import ( const ( lpv2 = 2 lpv3 = 3 + lpv4 = 4 ) // Supported versions of the les protocol (first is primary) @@ -44,7 +45,7 @@ var ( ) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24} +var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24, lpv4: 24} const ( NetworkId = 1 @@ -150,6 +151,7 @@ const ( ErrInvalidResponse ErrTooManyTimeouts ErrMissingKey + ErrForkIDRejected ) func (e errCode) String() string { @@ -172,6 +174,7 @@ var errorToString = map[int]string{ ErrInvalidResponse: "Invalid response", ErrTooManyTimeouts: "Too many request timeouts", ErrMissingKey: "Key missing from list", + ErrForkIDRejected: "ForkID rejected", } // announceData is the network packet for the block announcements. diff --git a/les/server_handler.go b/les/server_handler.go index c0600b368..f965e3fc6 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -66,6 +67,7 @@ var ( // serverHandler is responsible for serving light client and process // all incoming light requests. type serverHandler struct { + forkFilter forkid.Filter blockchain *core.BlockChain chainDb ethdb.Database txpool *core.TxPool @@ -81,6 +83,7 @@ type serverHandler struct { func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler { handler := &serverHandler{ + forkFilter: forkid.NewFilter(blockchain), server: server, blockchain: blockchain, chainDb: chainDb, @@ -121,8 +124,9 @@ func (h *serverHandler) handle(p *clientPeer) error { hash = head.Hash() number = head.Number.Uint64() td = h.blockchain.GetTd(hash, number) + forkID = forkid.NewID(h.blockchain.Config(), h.blockchain.Genesis().Hash(), h.blockchain.CurrentBlock().NumberU64()) ) - if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil { + if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), forkID, h.forkFilter, h.server); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err }