This commit is contained in:
obscuren 2015-01-05 17:10:42 +01:00
parent b0854fbff5
commit 6abf8ef78f
41 changed files with 2329 additions and 1203 deletions

View File

@ -59,6 +59,8 @@ var (
DumpNumber int DumpNumber int
VmType int VmType int
ImportChain string ImportChain string
SHH bool
Dial bool
) )
// flags specific to cli client // flags specific to cli client
@ -94,6 +96,8 @@ func Init() {
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server") flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)") flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.BoolVar(&UseSeed, "seed", true, "seed peers") flag.BoolVar(&UseSeed, "seed", true, "seed peers")
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key") flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)") flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given") flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
@ -105,7 +109,7 @@ func Init() {
flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0") flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0")
flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false") flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false")
flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block") flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block")
flag.StringVar(&ImportChain, "chain", "", "Imports fiven chain") flag.StringVar(&ImportChain, "chain", "", "Imports given chain")
flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]") flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]")
flag.StringVar(&DumpHash, "hash", "", "specify arg in hex") flag.StringVar(&DumpHash, "hash", "", "specify arg in hex")

View File

@ -64,10 +64,14 @@ func main() {
NATType: PMPGateway, NATType: PMPGateway,
PMPGateway: PMPGateway, PMPGateway: PMPGateway,
KeyRing: KeyRing, KeyRing: KeyRing,
Shh: SHH,
Dial: Dial,
}) })
if err != nil { if err != nil {
clilogger.Fatalln(err) clilogger.Fatalln(err)
} }
utils.KeyTasks(ethereum.KeyManager(), KeyRing, GenAddr, SecretFile, ExportDir, NonInteractive) utils.KeyTasks(ethereum.KeyManager(), KeyRing, GenAddr, SecretFile, ExportDir, NonInteractive)
if Dump { if Dump {
@ -112,13 +116,6 @@ func main() {
return return
} }
// better reworked as cases
if StartJsConsole {
InitJsConsole(ethereum)
} else if len(InputFile) > 0 {
ExecJsFile(ethereum, InputFile)
}
if StartRpc { if StartRpc {
utils.StartRpc(ethereum, RpcPort) utils.StartRpc(ethereum, RpcPort)
} }
@ -129,6 +126,11 @@ func main() {
utils.StartEthereum(ethereum, UseSeed) utils.StartEthereum(ethereum, UseSeed)
if StartJsConsole {
InitJsConsole(ethereum)
} else if len(InputFile) > 0 {
ExecJsFile(ethereum, InputFile)
}
// this blocks the thread // this blocks the thread
ethereum.WaitForShutdown() ethereum.WaitForShutdown()
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"gopkg.in/fatih/set.v0"
) )
var txplogger = logger.NewLogger("TXP") var txplogger = logger.NewLogger("TXP")
@ -38,7 +37,7 @@ type TxPool struct {
quit chan bool quit chan bool
// The actual pool // The actual pool
//pool *list.List //pool *list.List
pool *set.Set txs map[string]*types.Transaction
SecondaryProcessor TxProcessor SecondaryProcessor TxProcessor
@ -49,21 +48,19 @@ type TxPool struct {
func NewTxPool(eventMux *event.TypeMux) *TxPool { func NewTxPool(eventMux *event.TypeMux) *TxPool {
return &TxPool{ return &TxPool{
pool: set.New(), txs: make(map[string]*types.Transaction),
queueChan: make(chan *types.Transaction, txPoolQueueSize), queueChan: make(chan *types.Transaction, txPoolQueueSize),
quit: make(chan bool), quit: make(chan bool),
eventMux: eventMux, eventMux: eventMux,
} }
} }
func (pool *TxPool) addTransaction(tx *types.Transaction) {
pool.pool.Add(tx)
// Broadcast the transaction to the rest of the peers
pool.eventMux.Post(TxPreEvent{tx})
}
func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error {
hash := tx.Hash()
if pool.txs[string(hash)] != nil {
return fmt.Errorf("Known transaction (%x)", hash[0:4])
}
if len(tx.To()) != 0 && len(tx.To()) != 20 { if len(tx.To()) != 0 && len(tx.To()) != 20 {
return fmt.Errorf("Invalid recipient. len = %d", len(tx.To())) return fmt.Errorf("Invalid recipient. len = %d", len(tx.To()))
} }
@ -95,18 +92,17 @@ func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error {
return nil return nil
} }
func (self *TxPool) Add(tx *types.Transaction) error { func (self *TxPool) addTx(tx *types.Transaction) {
hash := tx.Hash() self.txs[string(tx.Hash())] = tx
if self.pool.Has(tx) { }
return fmt.Errorf("Known transaction (%x)", hash[0:4])
}
func (self *TxPool) Add(tx *types.Transaction) error {
err := self.ValidateTransaction(tx) err := self.ValidateTransaction(tx)
if err != nil { if err != nil {
return err return err
} }
self.addTransaction(tx) self.addTx(tx)
var to string var to string
if len(tx.To()) > 0 { if len(tx.To()) > 0 {
@ -124,7 +120,7 @@ func (self *TxPool) Add(tx *types.Transaction) error {
} }
func (self *TxPool) Size() int { func (self *TxPool) Size() int {
return self.pool.Size() return len(self.txs)
} }
func (self *TxPool) AddTransactions(txs []*types.Transaction) { func (self *TxPool) AddTransactions(txs []*types.Transaction) {
@ -137,43 +133,39 @@ func (self *TxPool) AddTransactions(txs []*types.Transaction) {
} }
} }
func (pool *TxPool) GetTransactions() []*types.Transaction { func (self *TxPool) GetTransactions() (txs types.Transactions) {
txList := make([]*types.Transaction, pool.Size()) txs = make(types.Transactions, self.Size())
i := 0 i := 0
pool.pool.Each(func(v interface{}) bool { for _, tx := range self.txs {
txList[i] = v.(*types.Transaction) txs[i] = tx
i++ i++
}
return true return
})
return txList
} }
func (pool *TxPool) RemoveInvalid(query StateQuery) { func (pool *TxPool) RemoveInvalid(query StateQuery) {
var removedTxs types.Transactions var removedTxs types.Transactions
pool.pool.Each(func(v interface{}) bool { for _, tx := range pool.txs {
tx := v.(*types.Transaction)
sender := query.GetAccount(tx.From()) sender := query.GetAccount(tx.From())
err := pool.ValidateTransaction(tx) err := pool.ValidateTransaction(tx)
if err != nil || sender.Nonce >= tx.Nonce() { if err != nil || sender.Nonce >= tx.Nonce() {
removedTxs = append(removedTxs, tx) removedTxs = append(removedTxs, tx)
} }
}
return true
})
pool.RemoveSet(removedTxs) pool.RemoveSet(removedTxs)
} }
func (self *TxPool) RemoveSet(txs types.Transactions) { func (self *TxPool) RemoveSet(txs types.Transactions) {
for _, tx := range txs { for _, tx := range txs {
self.pool.Remove(tx) delete(self.txs, string(tx.Hash()))
} }
} }
func (pool *TxPool) Flush() []*types.Transaction { func (pool *TxPool) Flush() []*types.Transaction {
txList := pool.GetTransactions() txList := pool.GetTransactions()
pool.pool.Clear() pool.txs = make(map[string]*types.Transaction)
return txList return txList
} }

View File

@ -67,10 +67,13 @@ func (self *Header) HashNoNonce() []byte {
} }
type Block struct { type Block struct {
header *Header // Preset Hash for mock
uncles []*Header HeaderHash []byte
transactions Transactions ParentHeaderHash []byte
Td *big.Int header *Header
uncles []*Header
transactions Transactions
Td *big.Int
receipts Receipts receipts Receipts
Reward *big.Int Reward *big.Int
@ -99,41 +102,19 @@ func NewBlockWithHeader(header *Header) *Block {
} }
func (self *Block) DecodeRLP(s *rlp.Stream) error { func (self *Block) DecodeRLP(s *rlp.Stream) error {
if _, err := s.List(); err != nil { var extblock struct {
Header *Header
Txs []*Transaction
Uncles []*Header
TD *big.Int // optional
}
if err := s.Decode(&extblock); err != nil {
return err return err
} }
self.header = extblock.Header
var header Header self.uncles = extblock.Uncles
if err := s.Decode(&header); err != nil { self.transactions = extblock.Txs
return err self.Td = extblock.TD
}
var transactions []*Transaction
if err := s.Decode(&transactions); err != nil {
return err
}
var uncleHeaders []*Header
if err := s.Decode(&uncleHeaders); err != nil {
return err
}
var tdBytes []byte
if err := s.Decode(&tdBytes); err != nil {
// If this block comes from the network that's fine. If loaded from disk it should be there
// Blocks don't store their Td when propagated over the network
} else {
self.Td = ethutil.BigD(tdBytes)
}
if err := s.ListEnd(); err != nil {
return err
}
self.header = &header
self.uncles = uncleHeaders
self.transactions = transactions
return nil return nil
} }
@ -189,23 +170,35 @@ func (self *Block) RlpDataForStorage() interface{} {
// Header accessors (add as you need them) // Header accessors (add as you need them)
func (self *Block) Number() *big.Int { return self.header.Number } func (self *Block) Number() *big.Int { return self.header.Number }
func (self *Block) NumberU64() uint64 { return self.header.Number.Uint64() } func (self *Block) NumberU64() uint64 { return self.header.Number.Uint64() }
func (self *Block) ParentHash() []byte { return self.header.ParentHash }
func (self *Block) Bloom() []byte { return self.header.Bloom } func (self *Block) Bloom() []byte { return self.header.Bloom }
func (self *Block) Coinbase() []byte { return self.header.Coinbase } func (self *Block) Coinbase() []byte { return self.header.Coinbase }
func (self *Block) Time() int64 { return int64(self.header.Time) } func (self *Block) Time() int64 { return int64(self.header.Time) }
func (self *Block) GasLimit() *big.Int { return self.header.GasLimit } func (self *Block) GasLimit() *big.Int { return self.header.GasLimit }
func (self *Block) GasUsed() *big.Int { return self.header.GasUsed } func (self *Block) GasUsed() *big.Int { return self.header.GasUsed }
func (self *Block) Hash() []byte { return self.header.Hash() }
func (self *Block) Trie() *ptrie.Trie { return ptrie.New(self.header.Root, ethutil.Config.Db) } func (self *Block) Trie() *ptrie.Trie { return ptrie.New(self.header.Root, ethutil.Config.Db) }
func (self *Block) SetRoot(root []byte) { self.header.Root = root }
func (self *Block) State() *state.StateDB { return state.New(self.Trie()) } func (self *Block) State() *state.StateDB { return state.New(self.Trie()) }
func (self *Block) Size() ethutil.StorageSize { return ethutil.StorageSize(len(ethutil.Encode(self))) } func (self *Block) Size() ethutil.StorageSize { return ethutil.StorageSize(len(ethutil.Encode(self))) }
func (self *Block) SetRoot(root []byte) { self.header.Root = root }
// Implement block.Pow // Implement pow.Block
func (self *Block) Difficulty() *big.Int { return self.header.Difficulty } func (self *Block) Difficulty() *big.Int { return self.header.Difficulty }
func (self *Block) N() []byte { return self.header.Nonce } func (self *Block) N() []byte { return self.header.Nonce }
func (self *Block) HashNoNonce() []byte { func (self *Block) HashNoNonce() []byte { return self.header.HashNoNonce() }
return crypto.Sha3(ethutil.Encode(self.header.rlpData(false)))
func (self *Block) Hash() []byte {
if self.HeaderHash != nil {
return self.HeaderHash
} else {
return self.header.Hash()
}
}
func (self *Block) ParentHash() []byte {
if self.ParentHeaderHash != nil {
return self.ParentHeaderHash
} else {
return self.header.ParentHash
}
} }
func (self *Block) String() string { func (self *Block) String() string {

View File

@ -36,6 +36,9 @@ type Config struct {
NATType string NATType string
PMPGateway string PMPGateway string
Shh bool
Dial bool
KeyManager *crypto.KeyManager KeyManager *crypto.KeyManager
} }
@ -130,11 +133,13 @@ func New(config *Config) (*Ethereum, error) {
insertChain := eth.chainManager.InsertChain insertChain := eth.chainManager.InsertChain
eth.blockPool = NewBlockPool(hasBlock, insertChain, ezp.Verify) eth.blockPool = NewBlockPool(hasBlock, insertChain, ezp.Verify)
// Start services
eth.txPool.Start()
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool) ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()} protocols := []p2p.Protocol{ethProto}
if config.Shh {
eth.whisper = whisper.New()
protocols = append(protocols, eth.whisper.Protocol())
}
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway) nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway)
if err != nil { if err != nil {
@ -142,12 +147,16 @@ func New(config *Config) (*Ethereum, error) {
} }
eth.net = &p2p.Server{ eth.net = &p2p.Server{
Identity: clientId, Identity: clientId,
MaxPeers: config.MaxPeers, MaxPeers: config.MaxPeers,
Protocols: protocols, Protocols: protocols,
ListenAddr: ":" + config.Port, Blacklist: eth.blacklist,
Blacklist: eth.blacklist, NAT: nat,
NAT: nat, NoDial: !config.Dial,
}
if len(config.Port) > 0 {
eth.net.ListenAddr = ":" + config.Port
} }
return eth, nil return eth, nil
@ -219,8 +228,14 @@ func (s *Ethereum) Start(seed bool) error {
if err != nil { if err != nil {
return err return err
} }
// Start services
s.txPool.Start()
s.blockPool.Start() s.blockPool.Start()
s.whisper.Start()
if s.whisper != nil {
s.whisper.Start()
}
// broadcast transactions // broadcast transactions
s.txSub = s.eventMux.Subscribe(core.TxPreEvent{}) s.txSub = s.eventMux.Subscribe(core.TxPreEvent{})
@ -268,7 +283,9 @@ func (s *Ethereum) Stop() {
s.txPool.Stop() s.txPool.Stop()
s.eventMux.Stop() s.eventMux.Stop()
s.blockPool.Stop() s.blockPool.Stop()
s.whisper.Stop() if s.whisper != nil {
s.whisper.Stop()
}
logger.Infoln("Server stopped") logger.Infoln("Server stopped")
close(s.shutdownChan) close(s.shutdownChan)
@ -285,16 +302,16 @@ func (self *Ethereum) txBroadcastLoop() {
// automatically stops if unsubscribe // automatically stops if unsubscribe
for obj := range self.txSub.Chan() { for obj := range self.txSub.Chan() {
event := obj.(core.TxPreEvent) event := obj.(core.TxPreEvent)
self.net.Broadcast("eth", TxMsg, []interface{}{event.Tx.RlpData()}) self.net.Broadcast("eth", TxMsg, event.Tx.RlpData())
} }
} }
func (self *Ethereum) blockBroadcastLoop() { func (self *Ethereum) blockBroadcastLoop() {
// automatically stops if unsubscribe // automatically stops if unsubscribe
for obj := range self.txSub.Chan() { for obj := range self.blockSub.Chan() {
switch ev := obj.(type) { switch ev := obj.(type) {
case core.NewMinedBlockEvent: case core.NewMinedBlockEvent:
self.net.Broadcast("eth", NewBlockMsg, ev.Block.RlpData()) self.net.Broadcast("eth", NewBlockMsg, ev.Block.RlpData(), ev.Block.Td)
} }
} }
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -52,18 +52,17 @@ func ProtocolError(code int, format string, params ...interface{}) (err *protoco
} }
func (self protocolError) Error() (message string) { func (self protocolError) Error() (message string) {
message = self.message if len(message) == 0 {
if message == "" { var ok bool
message, ok := errorToString[self.Code] self.message, ok = errorToString[self.Code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
if self.format != "" { if self.format != "" {
message += ": " + fmt.Sprintf(self.format, self.params...) self.message += ": " + fmt.Sprintf(self.format, self.params...)
} }
self.message = message
} }
return return self.message
} }
func (self *protocolError) Fatal() bool { func (self *protocolError) Fatal() bool {

View File

@ -3,7 +3,7 @@ package eth
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math" "io"
"math/big" "math/big"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
@ -95,14 +95,13 @@ func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPoo
blockPool: blockPool, blockPool: blockPool,
rw: rw, rw: rw,
peer: peer, peer: peer,
id: (string)(peer.Identity().Pubkey()), id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]),
} }
err = self.handleStatus() err = self.handleStatus()
if err == nil { if err == nil {
for { for {
err = self.handle() err = self.handle()
if err != nil { if err != nil {
fmt.Println(err)
self.blockPool.RemovePeer(self.id) self.blockPool.RemovePeer(self.id)
break break
} }
@ -117,7 +116,7 @@ func (self *ethProtocol) handle() error {
return err return err
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
} }
// make sure that the payload has been fully consumed // make sure that the payload has been fully consumed
defer msg.Discard() defer msg.Discard()
@ -125,76 +124,87 @@ func (self *ethProtocol) handle() error {
switch msg.Code { switch msg.Code {
case StatusMsg: case StatusMsg:
return ProtocolError(ErrExtraStatusMsg, "") return self.protoError(ErrExtraStatusMsg, "")
case TxMsg: case TxMsg:
// TODO: rework using lazy RLP stream // TODO: rework using lazy RLP stream
var txs []*types.Transaction var txs []*types.Transaction
if err := msg.Decode(&txs); err != nil { if err := msg.Decode(&txs); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
self.txPool.AddTransactions(txs) self.txPool.AddTransactions(txs)
case GetBlockHashesMsg: case GetBlockHashesMsg:
var request getBlockHashesMsgData var request getBlockHashesMsgData
if err := msg.Decode(&request); err != nil { if err := msg.Decode(&request); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "->msg %v: %v", msg, err)
} }
hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount) hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount)
return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...) return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...)
case BlockHashesMsg: case BlockHashesMsg:
// TODO: redo using lazy decode , this way very inefficient on known chains // TODO: redo using lazy decode , this way very inefficient on known chains
msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) msgStream := rlp.NewStream(msg.Payload)
var err error var err error
var i int
iter := func() (hash []byte, ok bool) { iter := func() (hash []byte, ok bool) {
hash, err = msgStream.Bytes() hash, err = msgStream.Bytes()
if err == nil { if err == nil {
i++
ok = true ok = true
} else {
if err != io.EOF {
self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err)
}
} }
return return
} }
self.blockPool.AddBlockHashes(iter, self.id) self.blockPool.AddBlockHashes(iter, self.id)
if err != nil && err != rlp.EOL {
return ProtocolError(ErrDecode, "%v", err)
}
case GetBlocksMsg: case GetBlocksMsg:
var blockHashes [][]byte msgStream := rlp.NewStream(msg.Payload)
if err := msg.Decode(&blockHashes); err != nil {
return ProtocolError(ErrDecode, "%v", err)
}
max := int(math.Min(float64(len(blockHashes)), blockHashesBatchSize))
var blocks []interface{} var blocks []interface{}
for i, hash := range blockHashes { var i int
if i >= max { for {
break i++
var hash []byte
if err := msgStream.Decode(&hash); err != nil {
if err == io.EOF {
break
} else {
return self.protoError(ErrDecode, "msg %v: %v", msg, err)
}
} }
block := self.chainManager.GetBlock(hash) block := self.chainManager.GetBlock(hash)
if block != nil { if block != nil {
blocks = append(blocks, block.RlpData()) blocks = append(blocks, block)
}
if i == blockHashesBatchSize {
break
} }
} }
return self.rw.EncodeMsg(BlocksMsg, blocks...) return self.rw.EncodeMsg(BlocksMsg, blocks...)
case BlocksMsg: case BlocksMsg:
msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) msgStream := rlp.NewStream(msg.Payload)
for { for {
var block *types.Block var block types.Block
if err := msgStream.Decode(&block); err != nil { if err := msgStream.Decode(&block); err != nil {
if err == rlp.EOL { if err == io.EOF {
break break
} else { } else {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
} }
self.blockPool.AddBlock(block, self.id) self.blockPool.AddBlock(&block, self.id)
} }
case NewBlockMsg: case NewBlockMsg:
var request newBlockMsgData var request newBlockMsgData
if err := msg.Decode(&request); err != nil { if err := msg.Decode(&request); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
hash := request.Block.Hash() hash := request.Block.Hash()
// to simplify backend interface adding a new block // to simplify backend interface adding a new block
@ -202,12 +212,12 @@ func (self *ethProtocol) handle() error {
// (or selected as new best peer) // (or selected as new best peer)
if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) { if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) {
called := true called := true
iter := func() (hash []byte, ok bool) { iter := func() ([]byte, bool) {
if called { if called {
called = false called = false
return hash, true return hash, true
} else { } else {
return return nil, false
} }
} }
self.blockPool.AddBlockHashes(iter, self.id) self.blockPool.AddBlockHashes(iter, self.id)
@ -215,14 +225,14 @@ func (self *ethProtocol) handle() error {
} }
default: default:
return ProtocolError(ErrInvalidMsgCode, "%v", msg.Code) return self.protoError(ErrInvalidMsgCode, "%v", msg.Code)
} }
return nil return nil
} }
type statusMsgData struct { type statusMsgData struct {
ProtocolVersion uint ProtocolVersion uint32
NetworkId uint NetworkId uint32
TD *big.Int TD *big.Int
CurrentBlock []byte CurrentBlock []byte
GenesisBlock []byte GenesisBlock []byte
@ -253,56 +263,56 @@ func (self *ethProtocol) handleStatus() error {
} }
if msg.Code != StatusMsg { if msg.Code != StatusMsg {
return ProtocolError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) return self.protoError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
} }
var status statusMsgData var status statusMsgData
if err := msg.Decode(&status); err != nil { if err := msg.Decode(&status); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
_, _, genesisBlock := self.chainManager.Status() _, _, genesisBlock := self.chainManager.Status()
if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 { if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 {
return ProtocolError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock) return self.protoError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock)
} }
if status.NetworkId != NetworkId { if status.NetworkId != NetworkId {
return ProtocolError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId) return self.protoError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId)
} }
if ProtocolVersion != status.ProtocolVersion { if ProtocolVersion != status.ProtocolVersion {
return ProtocolError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion) return self.protoError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion)
} }
self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4]) self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4])
//self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect)
self.peer.Infoln("AddPeer(IGNORED)")
return nil return nil
} }
func (self *ethProtocol) requestBlockHashes(from []byte) error { func (self *ethProtocol) requestBlockHashes(from []byte) error {
self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4]) self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4])
return self.rw.EncodeMsg(GetBlockHashesMsg, from, blockHashesBatchSize) return self.rw.EncodeMsg(GetBlockHashesMsg, interface{}(from), uint64(blockHashesBatchSize))
} }
func (self *ethProtocol) requestBlocks(hashes [][]byte) error { func (self *ethProtocol) requestBlocks(hashes [][]byte) error {
self.peer.Debugf("fetching %v blocks", len(hashes)) self.peer.Debugf("fetching %v blocks", len(hashes))
return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)) return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)...)
} }
func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) { func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) {
err = ProtocolError(code, format, params...) err = ProtocolError(code, format, params...)
if err.Fatal() { if err.Fatal() {
self.peer.Errorln(err) self.peer.Errorln("err %v", err)
// disconnect
} else { } else {
self.peer.Debugln(err) self.peer.Debugf("fyi %v", err)
} }
return return
} }
@ -310,10 +320,10 @@ func (self *ethProtocol) protoError(code int, format string, params ...interface
func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) { func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) {
err := ProtocolError(code, format, params...) err := ProtocolError(code, format, params...)
if err.Fatal() { if err.Fatal() {
self.peer.Errorln(err) self.peer.Errorln("err %v", err)
// disconnect // disconnect
} else { } else {
self.peer.Debugln(err) self.peer.Debugf("fyi %v", err)
} }
} }

View File

@ -1,35 +1,48 @@
package eth package eth
import ( import (
"bytes"
"io" "io"
"log"
"math/big" "math/big"
"os"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethutil"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
) )
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
type testMsgReadWriter struct { type testMsgReadWriter struct {
in chan p2p.Msg in chan p2p.Msg
out chan p2p.Msg out []p2p.Msg
} }
func (self *testMsgReadWriter) In(msg p2p.Msg) { func (self *testMsgReadWriter) In(msg p2p.Msg) {
self.in <- msg self.in <- msg
} }
func (self *testMsgReadWriter) Out(msg p2p.Msg) { func (self *testMsgReadWriter) Out() (msg p2p.Msg, ok bool) {
self.in <- msg if len(self.out) > 0 {
msg = self.out[0]
self.out = self.out[1:]
ok = true
}
return
} }
func (self *testMsgReadWriter) WriteMsg(msg p2p.Msg) error { func (self *testMsgReadWriter) WriteMsg(msg p2p.Msg) error {
self.out <- msg self.out = append(self.out, msg)
return nil return nil
} }
func (self *testMsgReadWriter) EncodeMsg(code uint64, data ...interface{}) error { func (self *testMsgReadWriter) EncodeMsg(code uint64, data ...interface{}) error {
return self.WriteMsg(p2p.NewMsg(code, data)) return self.WriteMsg(p2p.NewMsg(code, data...))
} }
func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) { func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) {
@ -40,145 +53,83 @@ func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) {
return msg, nil return msg, nil
} }
func errorCheck(t *testing.T, expCode int, err error) { type testTxPool struct {
perr, ok := err.(*protocolError)
if ok && perr != nil {
if code := perr.Code; code != expCode {
ok = false
}
}
if !ok {
t.Errorf("expected error code %v, got %v", ErrNoStatusMsg, err)
}
}
type TestBackend struct {
getTransactions func() []*types.Transaction getTransactions func() []*types.Transaction
addTransactions func(txs []*types.Transaction) addTransactions func(txs []*types.Transaction)
getBlockHashes func(hash []byte, amount uint32) (hashes [][]byte)
addBlockHashes func(next func() ([]byte, bool), peerId string)
getBlock func(hash []byte) *types.Block
addBlock func(block *types.Block, peerId string) (err error)
addPeer func(td *big.Int, currentBlock []byte, peerId string, requestHashes func([]byte) error, requestBlocks func([][]byte) error, invalidBlock func(error)) (best bool)
removePeer func(peerId string)
status func() (td *big.Int, currentBlock []byte, genesisBlock []byte)
} }
func (self *TestBackend) GetTransactions() (txs []*types.Transaction) { type testChainManager struct {
if self.getTransactions != nil { getBlockHashes func(hash []byte, amount uint64) (hashes [][]byte)
txs = self.getTransactions() getBlock func(hash []byte) *types.Block
} status func() (td *big.Int, currentBlock []byte, genesisBlock []byte)
return
} }
func (self *TestBackend) AddTransactions(txs []*types.Transaction) { type testBlockPool struct {
addBlockHashes func(next func() ([]byte, bool), peerId string)
addBlock func(block *types.Block, peerId string) (err error)
addPeer func(td *big.Int, currentBlock []byte, peerId string, requestHashes func([]byte) error, requestBlocks func([][]byte) error, peerError func(int, string, ...interface{})) (best bool)
removePeer func(peerId string)
}
// func (self *testTxPool) GetTransactions() (txs []*types.Transaction) {
// if self.getTransactions != nil {
// txs = self.getTransactions()
// }
// return
// }
func (self *testTxPool) AddTransactions(txs []*types.Transaction) {
if self.addTransactions != nil { if self.addTransactions != nil {
self.addTransactions(txs) self.addTransactions(txs)
} }
} }
func (self *TestBackend) GetBlockHashes(hash []byte, amount uint32) (hashes [][]byte) { func (self *testChainManager) GetBlockHashesFromHash(hash []byte, amount uint64) (hashes [][]byte) {
if self.getBlockHashes != nil { if self.getBlockHashes != nil {
hashes = self.getBlockHashes(hash, amount) hashes = self.getBlockHashes(hash, amount)
} }
return return
} }
<<<<<<< HEAD func (self *testChainManager) Status() (td *big.Int, currentBlock []byte, genesisBlock []byte) {
<<<<<<< HEAD
func (self *TestBackend) AddBlockHashes(next func() ([]byte, bool), peerId string) {
if self.addBlockHashes != nil {
self.addBlockHashes(next, peerId)
}
}
=======
func (self *TestBackend) AddHash(hash []byte, peer *p2p.Peer) (more bool) {
if self.addHash != nil {
more = self.addHash(hash, peer)
=======
func (self *TestBackend) AddBlockHashes(next func() ([]byte, bool), peerId string) {
if self.addBlockHashes != nil {
self.addBlockHashes(next, peerId)
>>>>>>> eth protocol changes
}
}
<<<<<<< HEAD
>>>>>>> initial commit for eth-p2p integration
=======
>>>>>>> eth protocol changes
func (self *TestBackend) GetBlock(hash []byte) (block *types.Block) {
if self.getBlock != nil {
block = self.getBlock(hash)
}
return
}
<<<<<<< HEAD
<<<<<<< HEAD
func (self *TestBackend) AddBlock(block *types.Block, peerId string) (err error) {
if self.addBlock != nil {
err = self.addBlock(block, peerId)
=======
func (self *TestBackend) AddBlock(td *big.Int, block *types.Block, peer *p2p.Peer) (fetchHashes bool, err error) {
if self.addBlock != nil {
fetchHashes, err = self.addBlock(td, block, peer)
>>>>>>> initial commit for eth-p2p integration
=======
func (self *TestBackend) AddBlock(block *types.Block, peerId string) (err error) {
if self.addBlock != nil {
err = self.addBlock(block, peerId)
>>>>>>> eth protocol changes
}
return
}
<<<<<<< HEAD
<<<<<<< HEAD
func (self *TestBackend) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, invalidBlock func(error)) (best bool) {
if self.addPeer != nil {
best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, invalidBlock)
=======
func (self *TestBackend) AddPeer(td *big.Int, currentBlock []byte, peer *p2p.Peer) (fetchHashes bool) {
if self.addPeer != nil {
fetchHashes = self.addPeer(td, currentBlock, peer)
>>>>>>> initial commit for eth-p2p integration
=======
func (self *TestBackend) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, invalidBlock func(error)) (best bool) {
if self.addPeer != nil {
best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, invalidBlock)
>>>>>>> eth protocol changes
}
return
}
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> eth protocol changes
func (self *TestBackend) RemovePeer(peerId string) {
if self.removePeer != nil {
self.removePeer(peerId)
}
}
<<<<<<< HEAD
=======
>>>>>>> initial commit for eth-p2p integration
=======
>>>>>>> eth protocol changes
func (self *TestBackend) Status() (td *big.Int, currentBlock []byte, genesisBlock []byte) {
if self.status != nil { if self.status != nil {
td, currentBlock, genesisBlock = self.status() td, currentBlock, genesisBlock = self.status()
} }
return return
} }
<<<<<<< HEAD func (self *testChainManager) GetBlock(hash []byte) (block *types.Block) {
<<<<<<< HEAD if self.getBlock != nil {
======= block = self.getBlock(hash)
>>>>>>> eth protocol changes }
return
}
func (self *testBlockPool) AddBlockHashes(next func() ([]byte, bool), peerId string) {
if self.addBlockHashes != nil {
self.addBlockHashes(next, peerId)
}
}
func (self *testBlockPool) AddBlock(block *types.Block, peerId string) {
if self.addBlock != nil {
self.addBlock(block, peerId)
}
}
func (self *testBlockPool) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, peerError func(int, string, ...interface{})) (best bool) {
if self.addPeer != nil {
best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, peerError)
}
return
}
func (self *testBlockPool) RemovePeer(peerId string) {
if self.removePeer != nil {
self.removePeer(peerId)
}
}
// TODO: refactor this into p2p/client_identity // TODO: refactor this into p2p/client_identity
type peerId struct { type peerId struct {
pubkey []byte pubkey []byte
@ -201,32 +152,119 @@ func testPeer() *p2p.Peer {
return p2p.NewPeer(&peerId{}, []p2p.Cap{}) return p2p.NewPeer(&peerId{}, []p2p.Cap{})
} }
func TestErrNoStatusMsg(t *testing.T) { type ethProtocolTester struct {
<<<<<<< HEAD quit chan error
======= rw *testMsgReadWriter // p2p.MsgReadWriter
func TestEth(t *testing.T) { txPool *testTxPool // txPool
>>>>>>> initial commit for eth-p2p integration chainManager *testChainManager // chainManager
======= blockPool *testBlockPool // blockPool
>>>>>>> eth protocol changes t *testing.T
quit := make(chan bool) }
rw := &testMsgReadWriter{make(chan p2p.Msg, 10), make(chan p2p.Msg, 10)}
testBackend := &TestBackend{} func newEth(t *testing.T) *ethProtocolTester {
var err error return &ethProtocolTester{
go func() { quit: make(chan error),
<<<<<<< HEAD rw: &testMsgReadWriter{in: make(chan p2p.Msg, 10)},
<<<<<<< HEAD txPool: &testTxPool{},
err = runEthProtocol(testBackend, testPeer(), rw) chainManager: &testChainManager{},
======= blockPool: &testBlockPool{},
err = runEthProtocol(testBackend, nil, rw) t: t,
>>>>>>> initial commit for eth-p2p integration }
======= }
err = runEthProtocol(testBackend, testPeer(), rw)
>>>>>>> eth protocol changes func (self *ethProtocolTester) reset() {
close(quit) self.rw = &testMsgReadWriter{in: make(chan p2p.Msg, 10)}
}() self.quit = make(chan error)
statusMsg := p2p.NewMsg(4) }
rw.In(statusMsg)
<-quit func (self *ethProtocolTester) checkError(expCode int, delay time.Duration) (err error) {
errorCheck(t, ErrNoStatusMsg, err) var timer = time.After(delay)
// read(t, remote, []byte("hello, world"), nil) select {
case err = <-self.quit:
case <-timer:
self.t.Errorf("no error after %v, expected %v", delay, expCode)
return
}
perr, ok := err.(*protocolError)
if ok && perr != nil {
if code := perr.Code; code != expCode {
self.t.Errorf("expected protocol error (code %v), got %v (%v)", expCode, code, err)
}
} else {
self.t.Errorf("expected protocol error (code %v), got %v", expCode, err)
}
return
}
func (self *ethProtocolTester) In(msg p2p.Msg) {
self.rw.In(msg)
}
func (self *ethProtocolTester) Out() (p2p.Msg, bool) {
return self.rw.Out()
}
func (self *ethProtocolTester) checkMsg(i int, code uint64, val interface{}) (msg p2p.Msg) {
if i >= len(self.rw.out) {
self.t.Errorf("expected at least %v msgs, got %v", i, len(self.rw.out))
return
}
msg = self.rw.out[i]
if msg.Code != code {
self.t.Errorf("expected msg code %v, got %v", code, msg.Code)
}
if val != nil {
if err := msg.Decode(val); err != nil {
self.t.Errorf("rlp encoding error: %v", err)
}
}
return
}
func (self *ethProtocolTester) run() {
err := runEthProtocol(self.txPool, self.chainManager, self.blockPool, testPeer(), self.rw)
self.quit <- err
}
func TestStatusMsgErrors(t *testing.T) {
logInit()
eth := newEth(t)
td := ethutil.Big1
currentBlock := []byte{1}
genesis := []byte{2}
eth.chainManager.status = func() (*big.Int, []byte, []byte) { return td, currentBlock, genesis }
go eth.run()
statusMsg := p2p.NewMsg(4)
eth.In(statusMsg)
delay := 1 * time.Second
eth.checkError(ErrNoStatusMsg, delay)
var status statusMsgData
eth.checkMsg(0, StatusMsg, &status) // first outgoing msg should be StatusMsg
if status.TD.Cmp(td) != 0 ||
status.ProtocolVersion != ProtocolVersion ||
status.NetworkId != NetworkId ||
status.TD.Cmp(td) != 0 ||
bytes.Compare(status.CurrentBlock, currentBlock) != 0 ||
bytes.Compare(status.GenesisBlock, genesis) != 0 {
t.Errorf("incorrect outgoing status")
}
eth.reset()
go eth.run()
statusMsg = p2p.NewMsg(0, uint32(48), uint32(0), td, currentBlock, genesis)
eth.In(statusMsg)
eth.checkError(ErrProtocolVersionMismatch, delay)
eth.reset()
go eth.run()
statusMsg = p2p.NewMsg(0, uint32(49), uint32(1), td, currentBlock, genesis)
eth.In(statusMsg)
eth.checkError(ErrNetworkIdMismatch, delay)
eth.reset()
go eth.run()
statusMsg = p2p.NewMsg(0, uint32(49), uint32(0), td, currentBlock, []byte{3})
eth.In(statusMsg)
eth.checkError(ErrGenesisBlockMismatch, delay)
} }

27
eth/test/README.md Normal file
View File

@ -0,0 +1,27 @@
= Integration tests for eth protocol and blockpool
This is a simple suite of tests to fire up a local test node with peers to test blockchain synchronisation and download.
The scripts call ethereum (assumed to be compiled in go-ethereum root).
To run a test:
. run.sh 00 02
Without arguments, all tests are run.
Peers are launched with preloaded imported chains. In order to prevent them from synchronizing with each other they are set with `-dial=false` and `-maxpeer 1` options. They log into `/tmp/eth.test/nodes/XX` where XX is the last two digits of their port.
Chains to import can be bootstrapped by letting nodes mine for some time. This is done with
. bootstrap.sh
Only the relative timing and forks matter so they should work if the bootstrap script is rerun.
The reference blockchain of tests are soft links to these import chains and check at the end of a test run.
Connecting to peers and exporting blockchain is scripted with JS files executed by the JSRE, see `tests/XX.sh`.
Each test is set with a timeout. This may vary on different computers so adjust sensibly.
If you kill a test before it completes, do not forget to kill all the background processes, since they will impact the result. Use:
killall ethereum

9
eth/test/bootstrap.sh Normal file
View File

@ -0,0 +1,9 @@
#!/bin/bash
# bootstrap chains - used to regenerate tests/chains/*.chain
mkdir -p chains
bash ./mine.sh 00 10
bash ./mine.sh 01 5 00
bash ./mine.sh 02 10 00
bash ./mine.sh 03 5 02
bash ./mine.sh 04 10 02

BIN
eth/test/chains/00.chain Executable file

Binary file not shown.

BIN
eth/test/chains/01.chain Executable file

Binary file not shown.

BIN
eth/test/chains/02.chain Executable file

Binary file not shown.

BIN
eth/test/chains/03.chain Executable file

Binary file not shown.

BIN
eth/test/chains/04.chain Executable file

Binary file not shown.

20
eth/test/mine.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/bash
# bash ./mine.sh node_id timeout(sec) [basechain]
ETH=../../ethereum
MINE="$ETH -datadir tmp/nodes/$1 -seed=false -port '' -shh=false -id test$1"
rm -rf tmp/nodes/$1
echo "Creating chain $1..."
if [[ "" != "$3" ]]; then
CHAIN="chains/$3.chain"
CHAINARG="-chain $CHAIN"
$MINE -mine $CHAINARG -loglevel 3 | grep 'importing'
fi
$MINE -mine -loglevel 0 &
PID=$!
sleep $2
kill $PID
$MINE -loglevel 3 <(echo "eth.export(\"chains/$1.chain\")") > /tmp/eth.test/mine.tmp &
PID=$!
sleep 1
kill $PID
cat /tmp/eth.test/mine.tmp | grep 'exporting'

53
eth/test/run.sh Normal file
View File

@ -0,0 +1,53 @@
#!/bin/bash
# bash run.sh (testid0 testid1 ...)
# runs tests tests/testid0.sh tests/testid1.sh ...
# without arguments, it runs all tests
. tests/common.sh
TESTS=
if [ "$#" -eq 0 ]; then
for NAME in tests/??.sh; do
i=`basename $NAME .sh`
TESTS="$TESTS $i"
done
else
TESTS=$@
fi
ETH=../../ethereum
DIR="/tmp/eth.test/nodes"
TIMEOUT=10
mkdir -p $DIR/js
echo "running tests $TESTS"
for NAME in $TESTS; do
PIDS=
CHAIN="tests/$NAME.chain"
JSFILE="$DIR/js/$NAME.js"
CHAIN_TEST="$DIR/$NAME/chain"
echo "RUN: test $NAME"
cat tests/common.js > $JSFILE
. tests/$NAME.sh
sleep $TIMEOUT
echo "timeout after $TIMEOUT seconds: killing $PIDS"
kill $PIDS
if [ -r "$CHAIN" ]; then
if diff $CHAIN $CHAIN_TEST >/dev/null ; then
echo "chain ok: $CHAIN=$CHAIN_TEST"
else
echo "FAIL: chains differ: expected $CHAIN ; got $CHAIN_TEST"
continue
fi
fi
ERRORS=$DIR/errors
if [ -r "$ERRORS" ]; then
echo "FAIL: "
cat $ERRORS
else
echo PASS
fi
done

1
eth/test/tests/00.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/01.chain

13
eth/test/tests/00.sh Normal file
View File

@ -0,0 +1,13 @@
#!/bin/bash
TIMEOUT=4
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(1000)
eth.export("$CHAIN_TEST");
EOF
peer 11 01
test_node $NAME "" -loglevel 5 $JSFILE

1
eth/test/tests/01.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/02.chain

18
eth/test/tests/01.sh Normal file
View File

@ -0,0 +1,18 @@
#!/bin/bash
TIMEOUT=5
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
log("added peer localhost:30311");
sleep(1000);
log("added peer localhost:30312");
eth.addPeer("localhost:30312");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01
peer 12 02
test_node $NAME "" -loglevel 5 $JSFILE

1
eth/test/tests/02.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/01.chain

19
eth/test/tests/02.sh Normal file
View File

@ -0,0 +1,19 @@
#!/bin/bash
TIMEOUT=6
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01
peer 12 02
P13ID=$PID
test_node $NAME "" -loglevel 5 $JSFILE
sleep 0.5
kill $P13ID

1
eth/test/tests/03.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/12k.chain

14
eth/test/tests/03.sh Normal file
View File

@ -0,0 +1,14 @@
#!/bin/bash
TIMEOUT=35
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(30000);
eth.export("$CHAIN_TEST");
EOF
peer 11 12k
sleep 2
test_node $NAME "" -loglevel 5 $JSFILE

17
eth/test/tests/04.sh Normal file
View File

@ -0,0 +1,17 @@
#!/bin/bash
TIMEOUT=15
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
sleep(13000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01 -mine
peer 12 02
test_node $NAME "" -loglevel 5 $JSFILE
sleep 6
cat $DIR/$NAME/debug.log | grep 'best peer'

20
eth/test/tests/05.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/bash
TIMEOUT=60
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
eth.addPeer("localhost:30313");
eth.addPeer("localhost:30314");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01 -mine
peer 12 02 -mine
peer 13 03
peer 14 04
test_node $NAME "" -loglevel 5 $JSFILE

9
eth/test/tests/common.js Normal file
View File

@ -0,0 +1,9 @@
function log(text) {
console.log("[JS TEST SCRIPT] " + text);
}
function sleep(seconds) {
var now = new Date().getTime();
while(new Date().getTime() < now + seconds){}
}

20
eth/test/tests/common.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/bash
# launched by run.sh
function test_node {
rm -rf $DIR/$1
ARGS="-datadir $DIR/$1 -debug debug -seed=false -shh=false -id test$1"
if [ "" != "$2" ]; then
chain="chains/$2.chain"
echo "import chain $chain"
$ETH $ARGS -loglevel 3 -chain $chain | grep CLI |grep import
fi
echo "starting test node $1 with extra args ${@:3}"
$ETH $ARGS -port 303$1 ${@:3} &
PID=$!
PIDS="$PIDS $PID"
}
function peer {
test_node $@ -loglevel 5 -logfile debug.log -maxpeer 1 -dial=false
}

View File

@ -7,7 +7,7 @@ import (
) )
func TestClientIdentity(t *testing.T) { func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
clientString := clientIdentity.String() clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected { if clientString != expected {

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
@ -49,7 +50,14 @@ func encodePayload(params ...interface{}) []byte {
// For the decoding rules, please see package rlp. // For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error { func (msg Msg) Decode(val interface{}) error {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
return s.Decode(val) if err := s.Decode(val); err != nil {
return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
}
return nil
}
func (msg Msg) String() string {
return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size)
} }
// Discard reads any remaining payload data into a black hole. // Discard reads any remaining payload data into a black hole.

View File

@ -45,8 +45,8 @@ func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port) return fmt.Sprintf("%v:%d", d.IP, d.Port)
} }
func (d peerAddr) RlpData() interface{} { func (d *peerAddr) RlpData() interface{} {
return []interface{}{d.IP, d.Port, d.Pubkey} return []interface{}{string(d.IP), d.Port, d.Pubkey}
} }
// Peer represents a remote peer. // Peer represents a remote peer.
@ -426,7 +426,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
} }
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data)) return rw.WriteMsg(NewMsg(code, data...))
} }
func (rw *proto) ReadMsg() (Msg, error) { func (rw *proto) ReadMsg() (Msg, error) {
@ -460,3 +460,25 @@ func (r *eofSignal) Read(buf []byte) (int, error) {
} }
return n, err return n, err
} }
func (peer *Peer) PeerList() []interface{} {
peers := peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -30,9 +30,8 @@ var discard = Protocol{
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
peer := newPeer(conn1, protos, nil) peer := newPeer(conn1, protos, nil)
peer.ourID = id peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil } peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
@ -130,7 +129,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
if err := rw.EncodeMsg(2); err == nil { if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil") t.Error("expected error for out-of-range msg code, got nil")
} }
if err := rw.EncodeMsg(1); err != nil { if err := rw.EncodeMsg(1, "foo", "bar"); err != nil {
t.Errorf("write error: %v", err) t.Errorf("write error: %v", err)
} }
return nil return nil
@ -148,6 +147,13 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
if msg.Code != 17 { if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
} }
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
}
} }
func TestPeerWrite(t *testing.T) { func TestPeerWrite(t *testing.T) {
@ -226,8 +232,8 @@ func TestPeerActivity(t *testing.T) {
} }
func TestNewPeer(t *testing.T) { func TestNewPeer(t *testing.T) {
id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
caps := []Cap{{"foo", 2}, {"bar", 3}} caps := []Cap{{"foo", 2}, {"bar", 3}}
id := &peerId{}
p := NewPeer(id, caps) p := NewPeer(id, caps)
if !reflect.DeepEqual(p.Caps(), caps) { if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)

View File

@ -3,8 +3,6 @@ package p2p
import ( import (
"bytes" "bytes"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
@ -89,20 +87,25 @@ type baseProtocol struct {
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer} bp := &baseProtocol{rw, peer}
if err := bp.doHandshake(rw); err != nil { errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err return err
} }
// run main loop // run main loop
quit := make(chan error, 1)
go func() { go func() {
for { for {
if err := bp.handle(rw); err != nil { if err := bp.handle(rw); err != nil {
quit <- err errc <- err
break break
} }
} }
}() }()
return bp.loop(quit) return bp.loop(errc)
} }
var pingTimeout = 2 * time.Second var pingTimeout = 2 * time.Second
@ -166,7 +169,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
case pongMsg: case pongMsg:
case getPeersMsg: case getPeersMsg:
peers := bp.peerList() peers := bp.peer.PeerList()
// this is dangerous. the spec says that we should _delay_ // this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available. // sending the response if no new information is available.
// this means that would need to send a response later when // this means that would need to send a response later when
@ -174,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
// //
// TODO: add event mechanism to notify baseProtocol for new peers // TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 { if len(peers) > 0 {
return bp.rw.EncodeMsg(peersMsg, peers) return bp.rw.EncodeMsg(peersMsg, peers...)
} }
case peersMsg: case peersMsg:
@ -193,14 +196,9 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
return nil return nil
} }
func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { func (bp *baseProtocol) readHandshake() error {
// send our handshake
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
return err
}
// read and handle remote handshake // read and handle remote handshake
msg, err := rw.ReadMsg() msg, err := bp.rw.ReadMsg()
if err != nil { if err != nil {
return err return err
} }
@ -210,12 +208,10 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
if msg.Size > baseProtocolMaxMsgSize { if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big") return newPeerError(errMisc, "message too big")
} }
var hs handshake var hs handshake
if err := msg.Decode(&hs); err != nil { if err := msg.Decode(&hs); err != nil {
return err return err
} }
// validate handshake info // validate handshake info
if hs.Version != baseProtocolVersion { if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
@ -238,9 +234,7 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
if err := bp.peer.pubkeyHook(pa); err != nil { if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err) return newPeerError(errPubkeyForbidden, "%v", err)
} }
// TODO: remove Caps with empty name // TODO: remove Caps with empty name
var addr *peerAddr var addr *peerAddr
if hs.ListenPort != 0 { if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
@ -270,25 +264,3 @@ func (bp *baseProtocol) handshakeMsg() Msg {
bp.peer.ourID.Pubkey()[1:], bp.peer.ourID.Pubkey()[1:],
) )
} }
func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
peers := bp.peer.otherPeers()
ds := make([]ethutil.RlpEncodable, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -2,12 +2,89 @@ package p2p
import ( import (
"fmt" "fmt"
"net"
"reflect"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
) )
type peerId struct {
pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
cannedPeerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
var ownAddr *peerAddr = &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
// run matcher, close pipe when addresses have arrived
addrChan := make(chan *peerAddr, len(cannedPeerList))
go func() {
for _, want := range cannedPeerList {
got := <-addrChan
t.Logf("got peer: %+v", got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %#v, want %#v", got, want)
}
}
close(addrChan)
var own []*peerAddr
var got *peerAddr
for got = range addrChan {
own = append(own, got)
}
if len(own) != 1 || !reflect.DeepEqual(ownAddr, own[0]) {
t.Errorf("mismatch: peers own address is incorrectly or not given, got %v, want %#v", ownAddr)
}
rw2.Close()
}()
// run first peer
peer1 := newTestPeer()
peer1.ourListenAddr = ownAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(cannedPeerList))
for i, addr := range cannedPeerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
go runBaseProtocol(peer1, rw1)
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
}
func TestBaseProtocolDisconnect(t *testing.T) { func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) peer := NewPeer(&peerId{}, nil)
peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil } peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
@ -32,6 +109,7 @@ func TestBaseProtocolDisconnect(t *testing.T) {
if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil {
t.Error(err) t.Error(err)
} }
close(done) close(done)
}() }()

View File

@ -113,9 +113,11 @@ func (srv *Server) PeerCount() int {
// SuggestPeer injects an address into the outbound address pool. // SuggestPeer injects an address into the outbound address pool.
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
addr := &peerAddr{ip, uint64(port), nodeID}
select { select {
case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: case srv.peerConnect <- addr:
default: // don't block default: // don't block
srvlog.Warnf("peer suggestion %v ignored", addr)
} }
} }
@ -258,6 +260,7 @@ func (srv *Server) listenLoop() {
for { for {
select { select {
case slot := <-srv.peerSlots: case slot := <-srv.peerSlots:
srvlog.Debugf("grabbed slot %v for listening", slot)
conn, err := srv.listener.Accept() conn, err := srv.listener.Accept()
if err != nil { if err != nil {
srv.peerSlots <- slot srv.peerSlots <- slot
@ -330,6 +333,7 @@ func (srv *Server) dialLoop() {
case desc := <-suggest: case desc := <-suggest:
// candidate peer found, will dial out asyncronously // candidate peer found, will dial out asyncronously
// if connection fails slot will be released // if connection fails slot will be released
srvlog.Infof("dial %v (%v)", desc, *slot)
go srv.dialPeer(desc, *slot) go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop // we can watch if more peers needed in the next loop
slots = srv.peerSlots slots = srv.peerSlots

View File

@ -11,7 +11,7 @@ import (
func startTestServer(t *testing.T, pf peerFunc) *Server { func startTestServer(t *testing.T, pf peerFunc) *Server {
server := &Server{ server := &Server{
Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), Identity: &peerId{},
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
newPeerFunc: pf, newPeerFunc: pf,

View File

@ -76,22 +76,37 @@ func Decode(r io.Reader, val interface{}) error {
type decodeError struct { type decodeError struct {
msg string msg string
typ reflect.Type typ reflect.Type
ctx []string
} }
func (err decodeError) Error() string { func (err *decodeError) Error() string {
return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) ctx := ""
if len(err.ctx) > 0 {
ctx = ", decoding into "
for i := len(err.ctx) - 1; i >= 0; i-- {
ctx += err.ctx[i]
}
}
return fmt.Sprintf("rlp: %s for %v%s", err.msg, err.typ, ctx)
} }
func wrapStreamError(err error, typ reflect.Type) error { func wrapStreamError(err error, typ reflect.Type) error {
switch err { switch err {
case ErrExpectedList: case ErrExpectedList:
return decodeError{"expected input list", typ} return &decodeError{msg: "expected input list", typ: typ}
case ErrExpectedString: case ErrExpectedString:
return decodeError{"expected input string or byte", typ} return &decodeError{msg: "expected input string or byte", typ: typ}
case errUintOverflow: case errUintOverflow:
return decodeError{"input string too long", typ} return &decodeError{msg: "input string too long", typ: typ}
case errNotAtEOL: case errNotAtEOL:
return decodeError{"input list has too many elements", typ} return &decodeError{msg: "input list has too many elements", typ: typ}
}
return err
}
func addErrorContext(err error, ctx string) error {
if decErr, ok := err.(*decodeError); ok {
decErr.ctx = append(decErr.ctx, ctx)
} }
return err return err
} }
@ -180,13 +195,13 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
return nil, err return nil, err
} }
if typ.Kind() == reflect.Array { isArray := typ.Kind() == reflect.Array
return func(s *Stream, val reflect.Value) error {
return decodeListArray(s, val, etypeinfo.decoder)
}, nil
}
return func(s *Stream, val reflect.Value) error { return func(s *Stream, val reflect.Value) error {
return decodeListSlice(s, val, etypeinfo.decoder) if isArray {
return decodeListArray(s, val, etypeinfo.decoder)
} else {
return decodeListSlice(s, val, etypeinfo.decoder)
}
}, nil }, nil
} }
@ -219,7 +234,7 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
if err := elemdec(s, val.Index(i)); err == EOL { if err := elemdec(s, val.Index(i)); err == EOL {
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, fmt.Sprint("[", i, "]"))
} }
} }
if i < val.Len() { if i < val.Len() {
@ -248,7 +263,7 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
if err := elemdec(s, val.Index(i)); err == EOL { if err := elemdec(s, val.Index(i)); err == EOL {
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, fmt.Sprint("[", i, "]"))
} }
} }
if i < vlen { if i < vlen {
@ -280,14 +295,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
switch kind { switch kind {
case Byte: case Byte:
if val.Len() == 0 { if val.Len() == 0 {
return decodeError{"input string too long", val.Type()} return &decodeError{msg: "input string too long", typ: val.Type()}
} }
bv, _ := s.Uint() bv, _ := s.Uint()
val.Index(0).SetUint(bv) val.Index(0).SetUint(bv)
zero(val, 1) zero(val, 1)
case String: case String:
if uint64(val.Len()) < size { if uint64(val.Len()) < size {
return decodeError{"input string too long", val.Type()} return &decodeError{msg: "input string too long", typ: val.Type()}
} }
slice := val.Slice(0, int(size)).Interface().([]byte) slice := val.Slice(0, int(size)).Interface().([]byte)
if err := s.readFull(slice); err != nil { if err := s.readFull(slice); err != nil {
@ -334,7 +349,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
// too few elements. leave the rest at their zero value. // too few elements. leave the rest at their zero value.
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, "."+typ.Field(f.index).Name)
} }
} }
return wrapStreamError(s.ListEnd(), typ) return wrapStreamError(s.ListEnd(), typ)
@ -599,7 +614,13 @@ func (s *Stream) Decode(val interface{}) error {
if err != nil { if err != nil {
return err return err
} }
return info.decoder(s, rval.Elem())
err = info.decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
// add decode target type to error so context has more meaning
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
}
return err
} }
// Reset discards any information about the current decoding context // Reset discards any information about the current decoding context

View File

@ -231,7 +231,12 @@ var decodeTests = []decodeTest{
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
{input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C0", ptr: new([]byte), value: []byte{}},
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
{input: "C3820102", ptr: new([]byte), error: "rlp: input string too long for uint8"},
{
input: "C3820102",
ptr: new([]byte),
error: "rlp: input string too long for uint8, decoding into ([]uint8)[0]",
},
// byte arrays // byte arrays
{input: "01", ptr: new([5]byte), value: [5]byte{1}}, {input: "01", ptr: new([5]byte), value: [5]byte{1}},
@ -239,9 +244,22 @@ var decodeTests = []decodeTest{
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
{input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}},
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
{input: "C3820102", ptr: new([5]byte), error: "rlp: input string too long for uint8"},
{input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"}, {
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, input: "C3820102",
ptr: new([5]byte),
error: "rlp: input string too long for uint8, decoding into ([5]uint8)[0]",
},
{
input: "86010203040506",
ptr: new([5]byte),
error: "rlp: input string too long for [5]uint8",
},
{
input: "850101",
ptr: new([5]byte),
error: io.ErrUnexpectedEOF.Error(),
},
// byte array reuse (should be zeroed) // byte array reuse (should be zeroed)
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
@ -272,13 +290,23 @@ var decodeTests = []decodeTest{
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"},
{ {
input: "C501C302C103", input: "C501C302C103",
ptr: new(recstruct), ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
}, },
{
input: "C3010101",
ptr: new(simplestruct),
error: "rlp: input list has too many elements for rlp.simplestruct",
},
{
input: "C501C3C00000",
ptr: new(recstruct),
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
},
// pointers // pointers
{input: "00", ptr: new(*uint), value: (*uint)(nil)}, {input: "00", ptr: new(*uint), value: (*uint)(nil)},
{input: "80", ptr: new(*uint), value: (*uint)(nil)}, {input: "80", ptr: new(*uint), value: (*uint)(nil)},