diff --git a/eth/filters/api.go b/eth/filters/api.go index 834513262..584f55afd 100644 --- a/eth/filters/api.go +++ b/eth/filters/api.go @@ -17,7 +17,6 @@ package filters import ( - "encoding/hex" "encoding/json" "errors" "fmt" @@ -28,6 +27,7 @@ import ( "golang.org/x/net/context" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -459,52 +459,28 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { if raw.Addresses != nil { // raw.Address can contain a single address or an array of addresses - var addresses []common.Address - if strAddrs, ok := raw.Addresses.([]interface{}); ok { - for i, addr := range strAddrs { + switch rawAddr := raw.Addresses.(type) { + case []interface{}: + for i, addr := range rawAddr { if strAddr, ok := addr.(string); ok { - if len(strAddr) >= 2 && strAddr[0] == '0' && (strAddr[1] == 'x' || strAddr[1] == 'X') { - strAddr = strAddr[2:] - } - if decAddr, err := hex.DecodeString(strAddr); err == nil { - addresses = append(addresses, common.BytesToAddress(decAddr)) - } else { - return fmt.Errorf("invalid address given") + addr, err := decodeAddress(strAddr) + if err != nil { + return fmt.Errorf("invalid address at index %d: %v", i, err) } + args.Addresses = append(args.Addresses, addr) } else { - return fmt.Errorf("invalid address on index %d", i) + return fmt.Errorf("non-string address at index %d", i) } } - } else if singleAddr, ok := raw.Addresses.(string); ok { - if len(singleAddr) >= 2 && singleAddr[0] == '0' && (singleAddr[1] == 'x' || singleAddr[1] == 'X') { - singleAddr = singleAddr[2:] + case string: + addr, err := decodeAddress(rawAddr) + if err != nil { + return fmt.Errorf("invalid address: %v", err) } - if decAddr, err := hex.DecodeString(singleAddr); err == nil { - addresses = append(addresses, common.BytesToAddress(decAddr)) - } else { - return fmt.Errorf("invalid address given") - } - } else { - return errors.New("invalid address(es) given") + args.Addresses = []common.Address{addr} + default: + return errors.New("invalid addresses in query") } - args.Addresses = addresses - } - - // helper function which parses a string to a topic hash - topicConverter := func(raw string) (common.Hash, error) { - if len(raw) == 0 { - return common.Hash{}, nil - } - if len(raw) >= 2 && raw[0] == '0' && (raw[1] == 'x' || raw[1] == 'X') { - raw = raw[2:] - } - if len(raw) != 2*common.HashLength { - return common.Hash{}, errors.New("invalid topic(s)") - } - if decAddr, err := hex.DecodeString(raw); err == nil { - return common.BytesToHash(decAddr), nil - } - return common.Hash{}, errors.New("invalid topic(s)") } // topics is an array consisting of strings and/or arrays of strings. @@ -512,20 +488,25 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { if len(raw.Topics) > 0 { args.Topics = make([][]common.Hash, len(raw.Topics)) for i, t := range raw.Topics { - if t == nil { // ignore topic when matching logs + switch topic := t.(type) { + case nil: + // ignore topic when matching logs args.Topics[i] = []common.Hash{common.Hash{}} - } else if topic, ok := t.(string); ok { // match specific topic - top, err := topicConverter(topic) + + case string: + // match specific topic + top, err := decodeTopic(topic) if err != nil { return err } args.Topics[i] = []common.Hash{top} - } else if topics, ok := t.([]interface{}); ok { // or case e.g. [null, "topic0", "topic1"] - for _, rawTopic := range topics { + case []interface{}: + // or case e.g. [null, "topic0", "topic1"] + for _, rawTopic := range topic { if rawTopic == nil { args.Topics[i] = append(args.Topics[i], common.Hash{}) } else if topic, ok := rawTopic.(string); ok { - parsed, err := topicConverter(topic) + parsed, err := decodeTopic(topic) if err != nil { return err } @@ -534,7 +515,7 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { return fmt.Errorf("invalid topic(s)") } } - } else { + default: return fmt.Errorf("invalid topic(s)") } } @@ -542,3 +523,19 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { return nil } + +func decodeAddress(s string) (common.Address, error) { + b, err := hexutil.Decode(s) + if err == nil && len(b) != common.AddressLength { + err = fmt.Errorf("hex has invalid length %d after decoding", len(b)) + } + return common.BytesToAddress(b), err +} + +func decodeTopic(s string) (common.Hash, error) { + b, err := hexutil.Decode(s) + if err == nil && len(b) != common.HashLength { + err = fmt.Errorf("hex has invalid length %d after decoding", len(b)) + } + return common.BytesToHash(b), err +}