From ac9c03f910a91db95db1ed7b20c50b49faefbaae Mon Sep 17 00:00:00 2001 From: Marius van der Wijden Date: Mon, 20 Apr 2020 09:01:04 +0200 Subject: [PATCH] accounts/abi: Prevent recalculation of internal fields (#20895) * accounts/abi: prevent recalculation of ID, Sig and String * accounts/abi: fixed unpacking of no values * accounts/abi: multiple fixes to arguments * accounts/abi: refactored methodName and eventName This commit moves the complicated logic of how we assign method names and event names if they already exist into their own functions for better readability. * accounts/abi: prevent recalculation of internal In this commit, I changed the way we calculate the string representations, sig representations and the id's of methods. Before that these fields would be recalculated everytime someone called .Sig() .String() or .ID() on a method or an event. Additionally this commit fixes issue #20856 as we assign names to inputs with no name (input with name "" becomes "arg0") * accounts/abi: added unnamed event params test * accounts/abi: fixed rebasing errors in method sig * accounts/abi: fixed rebasing errors in method sig * accounts/abi: addressed comments * accounts/abi: added FunctionType enumeration * accounts/abi/bind: added test for unnamed arguments * accounts/abi: improved readability in NewMethod, nitpicks * accounts/abi: method/eventName -> overloadedMethodName --- accounts/abi/abi.go | 124 +++++++++++-------------------- accounts/abi/abi_test.go | 77 +++++++++++-------- accounts/abi/argument.go | 55 ++++++-------- accounts/abi/bind/base.go | 4 +- accounts/abi/bind/bind.go | 4 +- accounts/abi/bind/bind_test.go | 9 ++- accounts/abi/event.go | 81 ++++++++++++-------- accounts/abi/event_test.go | 4 +- accounts/abi/method.go | 119 +++++++++++++++++++---------- accounts/abi/method_test.go | 2 +- accounts/abi/pack_test.go | 14 ++-- accounts/abi/reflect.go | 17 ++--- signer/fourbyte/abi.go | 4 +- signer/fourbyte/fourbyte_test.go | 4 +- 14 files changed, 279 insertions(+), 239 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 4b88a52ce..b9a34a77a 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -76,7 +76,7 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return nil, err } // Pack up the method ID too if not a constructor and return - return append(method.ID(), arguments...), nil + return append(method.ID, arguments...), nil } // Unpack output in v according to the abi specification @@ -139,59 +139,17 @@ func (abi *ABI) UnmarshalJSON(data []byte) error { for _, field := range fields { switch field.Type { case "constructor": - abi.Constructor = Method{ - Inputs: field.Inputs, - - // Note for constructor the `StateMutability` can only - // be payable or nonpayable according to the output of - // compiler. So constant is always false. - StateMutability: field.StateMutability, - - // Legacy fields, keep them for backward compatibility - Constant: field.Constant, - Payable: field.Payable, - } + abi.Constructor = NewMethod("", "", Constructor, field.StateMutability, field.Constant, field.Payable, field.Inputs, nil) case "function": - name := field.Name - _, ok := abi.Methods[name] - for idx := 0; ok; idx++ { - name = fmt.Sprintf("%s%d", field.Name, idx) - _, ok = abi.Methods[name] - } - abi.Methods[name] = Method{ - Name: name, - RawName: field.Name, - StateMutability: field.StateMutability, - Inputs: field.Inputs, - Outputs: field.Outputs, - - // Legacy fields, keep them for backward compatibility - Constant: field.Constant, - Payable: field.Payable, - } + name := abi.overloadedMethodName(field.Name) + abi.Methods[name] = NewMethod(name, field.Name, Function, field.StateMutability, field.Constant, field.Payable, field.Inputs, field.Outputs) case "fallback": // New introduced function type in v0.6.0, check more detail // here https://solidity.readthedocs.io/en/v0.6.0/contracts.html#fallback-function if abi.HasFallback() { return errors.New("only single fallback is allowed") } - abi.Fallback = Method{ - Name: "", - RawName: "", - - // The `StateMutability` can only be payable or nonpayable, - // so the constant is always false. - StateMutability: field.StateMutability, - IsFallback: true, - - // Fallback doesn't have any input or output - Inputs: nil, - Outputs: nil, - - // Legacy fields, keep them for backward compatibility - Constant: field.Constant, - Payable: field.Payable, - } + abi.Fallback = NewMethod("", "", Fallback, field.StateMutability, field.Constant, field.Payable, nil, nil) case "receive": // New introduced function type in v0.6.0, check more detail // here https://solidity.readthedocs.io/en/v0.6.0/contracts.html#fallback-function @@ -201,41 +159,47 @@ func (abi *ABI) UnmarshalJSON(data []byte) error { if field.StateMutability != "payable" { return errors.New("the statemutability of receive can only be payable") } - abi.Receive = Method{ - Name: "", - RawName: "", - - // The `StateMutability` can only be payable, so constant - // is always true while payable is always false. - StateMutability: field.StateMutability, - IsReceive: true, - - // Receive doesn't have any input or output - Inputs: nil, - Outputs: nil, - - // Legacy fields, keep them for backward compatibility - Constant: field.Constant, - Payable: field.Payable, - } + abi.Receive = NewMethod("", "", Receive, field.StateMutability, field.Constant, field.Payable, nil, nil) case "event": - name := field.Name - _, ok := abi.Events[name] - for idx := 0; ok; idx++ { - name = fmt.Sprintf("%s%d", field.Name, idx) - _, ok = abi.Events[name] - } - abi.Events[name] = Event{ - Name: name, - RawName: field.Name, - Anonymous: field.Anonymous, - Inputs: field.Inputs, - } + name := abi.overloadedEventName(field.Name) + abi.Events[name] = NewEvent(name, field.Name, field.Anonymous, field.Inputs) + default: + return fmt.Errorf("abi: could not recognize type %v of field %v", field.Type, field.Name) } } return nil } +// overloadedMethodName returns the next available name for a given function. +// Needed since solidity allows for function overload. +// +// e.g. if the abi contains Methods send, send1 +// overloadedMethodName would return send2 for input send. +func (abi *ABI) overloadedMethodName(rawName string) string { + name := rawName + _, ok := abi.Methods[name] + for idx := 0; ok; idx++ { + name = fmt.Sprintf("%s%d", rawName, idx) + _, ok = abi.Methods[name] + } + return name +} + +// overloadedEventName returns the next available name for a given event. +// Needed since solidity allows for event overload. +// +// e.g. if the abi contains events received, received1 +// overloadedEventName would return received2 for input received. +func (abi *ABI) overloadedEventName(rawName string) string { + name := rawName + _, ok := abi.Events[name] + for idx := 0; ok; idx++ { + name = fmt.Sprintf("%s%d", rawName, idx) + _, ok = abi.Events[name] + } + return name +} + // MethodById looks up a method by the 4-byte id // returns nil if none found func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { @@ -243,7 +207,7 @@ func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { return nil, fmt.Errorf("data too short (%d bytes) for abi method lookup", len(sigdata)) } for _, method := range abi.Methods { - if bytes.Equal(method.ID(), sigdata[:4]) { + if bytes.Equal(method.ID, sigdata[:4]) { return &method, nil } } @@ -254,7 +218,7 @@ func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { // ABI and returns nil if none found. func (abi *ABI) EventByID(topic common.Hash) (*Event, error) { for _, event := range abi.Events { - if bytes.Equal(event.ID().Bytes(), topic.Bytes()) { + if bytes.Equal(event.ID.Bytes(), topic.Bytes()) { return &event, nil } } @@ -263,10 +227,10 @@ func (abi *ABI) EventByID(topic common.Hash) (*Event, error) { // HasFallback returns an indicator whether a fallback function is included. func (abi *ABI) HasFallback() bool { - return abi.Fallback.IsFallback + return abi.Fallback.Type == Fallback } // HasReceive returns an indicator whether a receive function is included. func (abi *ABI) HasReceive() bool { - return abi.Receive.IsReceive + return abi.Receive.Type == Receive } diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 352006cf5..509040e5d 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -58,20 +58,14 @@ const jsondata2 = ` func TestReader(t *testing.T) { Uint256, _ := NewType("uint256", "", nil) - exp := ABI{ + abi := ABI{ Methods: map[string]Method{ - "balance": { - "balance", "balance", "view", false, false, false, false, nil, nil, - }, - "send": { - "send", "send", "", false, false, false, false, []Argument{ - {"amount", Uint256, false}, - }, nil, - }, + "balance": NewMethod("balance", "balance", Function, "view", false, false, nil, nil), + "send": NewMethod("send", "send", Function, "", false, false, []Argument{{"amount", Uint256, false}}, nil), }, } - abi, err := JSON(strings.NewReader(jsondata)) + exp, err := JSON(strings.NewReader(jsondata)) if err != nil { t.Error(err) } @@ -173,22 +167,22 @@ func TestTestSlice(t *testing.T) { func TestMethodSignature(t *testing.T) { String, _ := NewType("string", "", nil) - m := Method{"foo", "foo", "", false, false, false, false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil} + m := NewMethod("foo", "foo", Function, "", false, false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil) exp := "foo(string,string)" - if m.Sig() != exp { - t.Error("signature mismatch", exp, "!=", m.Sig()) + if m.Sig != exp { + t.Error("signature mismatch", exp, "!=", m.Sig) } idexp := crypto.Keccak256([]byte(exp))[:4] - if !bytes.Equal(m.ID(), idexp) { - t.Errorf("expected ids to match %x != %x", m.ID(), idexp) + if !bytes.Equal(m.ID, idexp) { + t.Errorf("expected ids to match %x != %x", m.ID, idexp) } uintt, _ := NewType("uint256", "", nil) - m = Method{"foo", "foo", "", false, false, false, false, []Argument{{"bar", uintt, false}}, nil} + m = NewMethod("foo", "foo", Function, "", false, false, []Argument{{"bar", uintt, false}}, nil) exp = "foo(uint256)" - if m.Sig() != exp { - t.Error("signature mismatch", exp, "!=", m.Sig()) + if m.Sig != exp { + t.Error("signature mismatch", exp, "!=", m.Sig) } // Method with tuple arguments @@ -204,10 +198,10 @@ func TestMethodSignature(t *testing.T) { {Name: "y", Type: "int256"}, }}, }) - m = Method{"foo", "foo", "", false, false, false, false, []Argument{{"s", s, false}, {"bar", String, false}}, nil} + m = NewMethod("foo", "foo", Function, "", false, false, []Argument{{"s", s, false}, {"bar", String, false}}, nil) exp = "foo((int256,int256[],(int256,int256)[],(int256,int256)[2]),string)" - if m.Sig() != exp { - t.Error("signature mismatch", exp, "!=", m.Sig()) + if m.Sig != exp { + t.Error("signature mismatch", exp, "!=", m.Sig) } } @@ -219,12 +213,12 @@ func TestOverloadedMethodSignature(t *testing.T) { } check := func(name string, expect string, method bool) { if method { - if abi.Methods[name].Sig() != expect { - t.Fatalf("The signature of overloaded method mismatch, want %s, have %s", expect, abi.Methods[name].Sig()) + if abi.Methods[name].Sig != expect { + t.Fatalf("The signature of overloaded method mismatch, want %s, have %s", expect, abi.Methods[name].Sig) } } else { - if abi.Events[name].Sig() != expect { - t.Fatalf("The signature of overloaded event mismatch, want %s, have %s", expect, abi.Events[name].Sig()) + if abi.Events[name].Sig != expect { + t.Fatalf("The signature of overloaded event mismatch, want %s, have %s", expect, abi.Events[name].Sig) } } } @@ -921,13 +915,13 @@ func TestABI_MethodById(t *testing.T) { } for name, m := range abi.Methods { a := fmt.Sprintf("%v", m) - m2, err := abi.MethodById(m.ID()) + m2, err := abi.MethodById(m.ID) if err != nil { t.Fatalf("Failed to look up ABI method: %v", err) } b := fmt.Sprintf("%v", m2) if a != b { - t.Errorf("Method %v (id %x) not 'findable' by id in ABI", name, m.ID()) + t.Errorf("Method %v (id %x) not 'findable' by id in ABI", name, m.ID) } } // Also test empty @@ -995,8 +989,8 @@ func TestABI_EventById(t *testing.T) { t.Errorf("We should find a event for topic %s, test #%d", topicID.Hex(), testnum) } - if event.ID() != topicID { - t.Errorf("Event id %s does not match topic %s, test #%d", event.ID().Hex(), topicID.Hex(), testnum) + if event.ID != topicID { + t.Errorf("Event id %s does not match topic %s, test #%d", event.ID.Hex(), topicID.Hex(), testnum) } unknowntopicID := crypto.Keccak256Hash([]byte("unknownEvent")) @@ -1051,3 +1045,28 @@ func TestDoubleDuplicateMethodNames(t *testing.T) { t.Fatalf("Should not have found extra method") } } + +// TestUnnamedEventParam checks that an event with unnamed parameters is +// correctly handled +// The test runs the abi of the following contract. +// contract TestEvent { +// event send(uint256, uint256); +// } +func TestUnnamedEventParam(t *testing.T) { + abiJSON := `[{ "anonymous": false, "inputs": [{ "indexed": false,"internalType": "uint256", "name": "","type": "uint256"},{"indexed": false,"internalType": "uint256","name": "","type": "uint256"}],"name": "send","type": "event"}]` + contractAbi, err := JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + + event, ok := contractAbi.Events["send"] + if !ok { + t.Fatalf("Could not find event") + } + if event.Inputs[0].Name != "arg0" { + t.Fatalf("Could not find input") + } + if event.Inputs[1].Name != "arg1" { + t.Fatalf("Could not find input") + } +} diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index 7f7f50586..27af0d8a6 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -92,9 +92,8 @@ func (arguments Arguments) Unpack(v interface{}, data []byte) error { if len(data) == 0 { if len(arguments) != 0 { return fmt.Errorf("abi: attempting to unmarshall an empty string while arguments are expected") - } else { - return nil // Nothing to unmarshal, return } + return nil // Nothing to unmarshal, return } // make sure the passed value is arguments pointer if reflect.Ptr != reflect.ValueOf(v).Kind() { @@ -104,6 +103,9 @@ func (arguments Arguments) Unpack(v interface{}, data []byte) error { if err != nil { return err } + if len(marshalledValues) == 0 { + return fmt.Errorf("abi: Unpack(no-values unmarshalled %T)", v) + } if arguments.isTuple() { return arguments.unpackTuple(v, marshalledValues) } @@ -112,18 +114,24 @@ func (arguments Arguments) Unpack(v interface{}, data []byte) error { // UnpackIntoMap performs the operation hexdata -> mapping of argument name to argument value func (arguments Arguments) UnpackIntoMap(v map[string]interface{}, data []byte) error { + // Make sure map is not nil + if v == nil { + return fmt.Errorf("abi: cannot unpack into a nil map") + } if len(data) == 0 { if len(arguments) != 0 { return fmt.Errorf("abi: attempting to unmarshall an empty string while arguments are expected") - } else { - return nil // Nothing to unmarshal, return } + return nil // Nothing to unmarshal, return } marshalledValues, err := arguments.UnpackValues(data) if err != nil { return err } - return arguments.unpackIntoMap(v, marshalledValues) + for i, arg := range arguments.NonIndexed() { + v[arg.Name] = marshalledValues[i] + } + return nil } // unpack sets the unmarshalled value to go format. @@ -195,19 +203,6 @@ func unpack(t *Type, dst interface{}, src interface{}) error { return nil } -// unpackIntoMap unpacks marshalledValues into the provided map[string]interface{} -func (arguments Arguments) unpackIntoMap(v map[string]interface{}, marshalledValues []interface{}) error { - // Make sure map is not nil - if v == nil { - return fmt.Errorf("abi: cannot unpack into a nil map") - } - - for i, arg := range arguments.NonIndexed() { - v[arg.Name] = marshalledValues[i] - } - return nil -} - // unpackAtomic unpacks ( hexdata -> go ) a single value func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues interface{}) error { if arguments.LengthNonIndexed() == 0 { @@ -233,30 +228,28 @@ func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues interfac // unpackTuple unpacks ( hexdata -> go ) a batch of values. func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error { var ( - value = reflect.ValueOf(v).Elem() - typ = value.Type() - kind = value.Kind() + value = reflect.ValueOf(v).Elem() + typ = value.Type() + kind = value.Kind() + nonIndexedArgs = arguments.NonIndexed() ) - if err := requireUnpackKind(value, typ, kind, arguments); err != nil { + if err := requireUnpackKind(value, len(nonIndexedArgs), arguments); err != nil { return err } // If the interface is a struct, get of abi->struct_field mapping var abi2struct map[string]string if kind == reflect.Struct { - var ( - argNames []string - err error - ) - for _, arg := range arguments.NonIndexed() { - argNames = append(argNames, arg.Name) + argNames := make([]string, len(nonIndexedArgs)) + for i, arg := range nonIndexedArgs { + argNames[i] = arg.Name } - abi2struct, err = mapArgNamesToStructFields(argNames, value) - if err != nil { + var err error + if abi2struct, err = mapArgNamesToStructFields(argNames, value); err != nil { return err } } - for i, arg := range arguments.NonIndexed() { + for i, arg := range nonIndexedArgs { switch kind { case reflect.Struct: field := value.FieldByName(abi2struct[arg.Name]) diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index e69e3afa5..1d6811d74 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -264,7 +264,7 @@ func (c *BoundContract) FilterLogs(opts *FilterOpts, name string, query ...[]int opts = new(FilterOpts) } // Append the event selector to the query parameters and construct the topic set - query = append([][]interface{}{{c.abi.Events[name].ID()}}, query...) + query = append([][]interface{}{{c.abi.Events[name].ID}}, query...) topics, err := makeTopics(query...) if err != nil { @@ -313,7 +313,7 @@ func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]inter opts = new(WatchOpts) } // Append the event selector to the query parameters and construct the topic set - query = append([][]interface{}{{c.abi.Events[name].ID()}}, query...) + query = append([][]interface{}{{c.abi.Events[name].ID}}, query...) topics, err := makeTopics(query...) if err != nil { diff --git a/accounts/abi/bind/bind.go b/accounts/abi/bind/bind.go index c98f8b4d4..4c6a9e9ce 100644 --- a/accounts/abi/bind/bind.go +++ b/accounts/abi/bind/bind.go @@ -639,9 +639,9 @@ func formatMethod(method abi.Method, structs map[string]*tmplStruct) string { state = state + " " } identity := fmt.Sprintf("function %v", method.RawName) - if method.IsFallback { + if method.Type == abi.Fallback { identity = "fallback" - } else if method.IsReceive { + } else if method.Type == abi.Receive { identity = "receive" } return fmt.Sprintf("%s(%v) %sreturns(%v)", identity, strings.Join(inputs, ", "), state, strings.Join(outputs, ", ")) diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index 7add7110a..a5f08499d 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -199,7 +199,8 @@ var bindTests = []struct { {"type":"event","name":"indexed","inputs":[{"name":"addr","type":"address","indexed":true},{"name":"num","type":"int256","indexed":true}]}, {"type":"event","name":"mixed","inputs":[{"name":"addr","type":"address","indexed":true},{"name":"num","type":"int256"}]}, {"type":"event","name":"anonymous","anonymous":true,"inputs":[]}, - {"type":"event","name":"dynamic","inputs":[{"name":"idxStr","type":"string","indexed":true},{"name":"idxDat","type":"bytes","indexed":true},{"name":"str","type":"string"},{"name":"dat","type":"bytes"}]} + {"type":"event","name":"dynamic","inputs":[{"name":"idxStr","type":"string","indexed":true},{"name":"idxDat","type":"bytes","indexed":true},{"name":"str","type":"string"},{"name":"dat","type":"bytes"}]}, + {"type":"event","name":"unnamed","inputs":[{"name":"","type":"uint256","indexed": true},{"name":"","type":"uint256","indexed":true}]} ] `}, ` @@ -249,6 +250,12 @@ var bindTests = []struct { fmt.Println(event.Addr) // Make sure the reconstructed indexed fields are present fmt.Println(res, str, dat, hash, err) + + oit, err := e.FilterUnnamed(nil, []*big.Int{}, []*big.Int{}) + + arg0 := oit.Event.Arg0 // Make sure unnamed arguments are handled correctly + arg1 := oit.Event.Arg1 // Make sure unnamed arguments are handled correctly + fmt.Println(arg0, arg1) } // Run a tiny reflection test to ensure disallowed methods don't appear if _, ok := reflect.TypeOf(&EventChecker{}).MethodByName("FilterAnonymous"); ok { diff --git a/accounts/abi/event.go b/accounts/abi/event.go index f1474813a..f1e5398f7 100644 --- a/accounts/abi/event.go +++ b/accounts/abi/event.go @@ -42,36 +42,59 @@ type Event struct { RawName string Anonymous bool Inputs Arguments + str string + // Sig contains the string signature according to the ABI spec. + // e.g. event foo(uint32 a, int b) = "foo(uint32,int256)" + // Please note that "int" is substitute for its canonical representation "int256" + Sig string + // ID returns the canonical representation of the event's signature used by the + // abi definition to identify event names and types. + ID common.Hash +} + +// NewEvent creates a new Event. +// It sanitizes the input arguments to remove unnamed arguments. +// It also precomputes the id, signature and string representation +// of the event. +func NewEvent(name, rawName string, anonymous bool, inputs Arguments) Event { + // sanitize inputs to remove inputs without names + // and precompute string and sig representation. + names := make([]string, len(inputs)) + types := make([]string, len(inputs)) + for i, input := range inputs { + if input.Name == "" { + inputs[i] = Argument{ + Name: fmt.Sprintf("arg%d", i), + Indexed: input.Indexed, + Type: input.Type, + } + } else { + inputs[i] = input + } + // string representation + names[i] = fmt.Sprintf("%v %v", input.Type, inputs[i].Name) + if input.Indexed { + names[i] = fmt.Sprintf("%v indexed %v", input.Type, inputs[i].Name) + } + // sig representation + types[i] = input.Type.String() + } + + str := fmt.Sprintf("event %v(%v)", rawName, strings.Join(names, ", ")) + sig := fmt.Sprintf("%v(%v)", rawName, strings.Join(types, ",")) + id := common.BytesToHash(crypto.Keccak256([]byte(sig))) + + return Event{ + Name: name, + RawName: rawName, + Anonymous: anonymous, + Inputs: inputs, + str: str, + Sig: sig, + ID: id, + } } func (e Event) String() string { - inputs := make([]string, len(e.Inputs)) - for i, input := range e.Inputs { - inputs[i] = fmt.Sprintf("%v %v", input.Type, input.Name) - if input.Indexed { - inputs[i] = fmt.Sprintf("%v indexed %v", input.Type, input.Name) - } - } - return fmt.Sprintf("event %v(%v)", e.RawName, strings.Join(inputs, ", ")) -} - -// Sig returns the event string signature according to the ABI spec. -// -// Example -// -// event foo(uint32 a, int b) = "foo(uint32,int256)" -// -// Please note that "int" is substitute for its canonical representation "int256" -func (e Event) Sig() string { - types := make([]string, len(e.Inputs)) - for i, input := range e.Inputs { - types[i] = input.Type.String() - } - return fmt.Sprintf("%v(%v)", e.RawName, strings.Join(types, ",")) -} - -// ID returns the canonical representation of the event's signature used by the -// abi definition to identify event names and types. -func (e Event) ID() common.Hash { - return common.BytesToHash(crypto.Keccak256([]byte(e.Sig()))) + return e.str } diff --git a/accounts/abi/event_test.go b/accounts/abi/event_test.go index 090b9217d..28da4c502 100644 --- a/accounts/abi/event_test.go +++ b/accounts/abi/event_test.go @@ -104,8 +104,8 @@ func TestEventId(t *testing.T) { } for name, event := range abi.Events { - if event.ID() != test.expectations[name] { - t.Errorf("expected id to be %x, got %x", test.expectations[name], event.ID()) + if event.ID != test.expectations[name] { + t.Errorf("expected id to be %x, got %x", test.expectations[name], event.ID) } } } diff --git a/accounts/abi/method.go b/accounts/abi/method.go index 217c3d2e6..37c7af65e 100644 --- a/accounts/abi/method.go +++ b/accounts/abi/method.go @@ -23,6 +23,24 @@ import ( "github.com/ethereum/go-ethereum/crypto" ) +// FunctionType represents different types of functions a contract might have. +type FunctionType int + +const ( + // Constructor represents the constructor of the contract. + // The constructor function is called while deploying a contract. + Constructor FunctionType = iota + // Fallback represents the fallback function. + // This function is executed if no other function matches the given function + // signature and no receive function is specified. + Fallback + // Receive represents the receive function. + // This function is executed on plain Ether transfers. + Receive + // Function represents a normal function. + Function +) + // Method represents a callable given a `Name` and whether the method is a constant. // If the method is `Const` no transaction needs to be created for this // particular Method call. It can easily be simulated using a local VM. @@ -44,6 +62,10 @@ type Method struct { Name string RawName string // RawName is the raw method name parsed from ABI + // Type indicates whether the method is a + // special fallback introduced in solidity v0.6.0 + Type FunctionType + // StateMutability indicates the mutability state of method, // the default value is nonpayable. It can be empty if the abi // is generated by legacy compiler. @@ -53,69 +75,84 @@ type Method struct { Constant bool Payable bool - // The following two flags indicates whether the method is a - // special fallback introduced in solidity v0.6.0 - IsFallback bool - IsReceive bool - Inputs Arguments Outputs Arguments + str string + // Sig returns the methods string signature according to the ABI spec. + // e.g. function foo(uint32 a, int b) = "foo(uint32,int256)" + // Please note that "int" is substitute for its canonical representation "int256" + Sig string + // ID returns the canonical representation of the method's signature used by the + // abi definition to identify method names and types. + ID []byte } -// Sig returns the methods string signature according to the ABI spec. -// -// Example -// -// function foo(uint32 a, int b) = "foo(uint32,int256)" -// -// Please note that "int" is substitute for its canonical representation "int256" -func (method Method) Sig() string { - // Short circuit if the method is special. Fallback - // and Receive don't have signature at all. - if method.IsFallback || method.IsReceive { - return "" - } - types := make([]string, len(method.Inputs)) - for i, input := range method.Inputs { +// NewMethod creates a new Method. +// A method should always be created using NewMethod. +// It also precomputes the sig representation and the string representation +// of the method. +func NewMethod(name string, rawName string, funType FunctionType, mutability string, isConst, isPayable bool, inputs Arguments, outputs Arguments) Method { + var ( + types = make([]string, len(inputs)) + inputNames = make([]string, len(inputs)) + outputNames = make([]string, len(outputs)) + ) + for i, input := range inputs { + inputNames[i] = fmt.Sprintf("%v %v", input.Type, input.Name) types[i] = input.Type.String() } - return fmt.Sprintf("%v(%v)", method.RawName, strings.Join(types, ",")) -} - -func (method Method) String() string { - inputs := make([]string, len(method.Inputs)) - for i, input := range method.Inputs { - inputs[i] = fmt.Sprintf("%v %v", input.Type, input.Name) - } - outputs := make([]string, len(method.Outputs)) - for i, output := range method.Outputs { - outputs[i] = output.Type.String() + for i, output := range outputs { + outputNames[i] = output.Type.String() if len(output.Name) > 0 { - outputs[i] += fmt.Sprintf(" %v", output.Name) + outputNames[i] += fmt.Sprintf(" %v", output.Name) } } + // calculate the signature and method id. Note only function + // has meaningful signature and id. + var ( + sig string + id []byte + ) + if funType == Function { + sig = fmt.Sprintf("%v(%v)", rawName, strings.Join(types, ",")) + id = crypto.Keccak256([]byte(sig))[:4] + } // Extract meaningful state mutability of solidity method. // If it's default value, never print it. - state := method.StateMutability + state := mutability if state == "nonpayable" { state = "" } if state != "" { state = state + " " } - identity := fmt.Sprintf("function %v", method.RawName) - if method.IsFallback { + identity := fmt.Sprintf("function %v", rawName) + if funType == Fallback { identity = "fallback" - } else if method.IsReceive { + } else if funType == Receive { identity = "receive" + } else if funType == Constructor { + identity = "constructor" + } + str := fmt.Sprintf("%v(%v) %sreturns(%v)", identity, strings.Join(inputNames, ", "), state, strings.Join(outputNames, ", ")) + + return Method{ + Name: name, + RawName: rawName, + Type: funType, + StateMutability: mutability, + Constant: isConst, + Payable: isPayable, + Inputs: inputs, + Outputs: outputs, + str: str, + Sig: sig, + ID: id, } - return fmt.Sprintf("%v(%v) %sreturns(%v)", identity, strings.Join(inputs, ", "), state, strings.Join(outputs, ", ")) } -// ID returns the canonical representation of the method's signature used by the -// abi definition to identify method names and types. -func (method Method) ID() []byte { - return crypto.Keccak256([]byte(method.Sig()))[:4] +func (method Method) String() string { + return method.str } // IsConstant returns the indicator whether the method is read-only. diff --git a/accounts/abi/method_test.go b/accounts/abi/method_test.go index ea176bf4e..395a52896 100644 --- a/accounts/abi/method_test.go +++ b/accounts/abi/method_test.go @@ -137,7 +137,7 @@ func TestMethodSig(t *testing.T) { } for _, test := range cases { - got := abi.Methods[test.method].Sig() + got := abi.Methods[test.method].Sig if got != test.expect { t.Errorf("expected string to be %s, got %s", test.expect, got) } diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index cf649b480..69f739a12 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -634,7 +634,7 @@ func TestMethodPack(t *testing.T) { t.Fatal(err) } - sig := abi.Methods["slice"].ID() + sig := abi.Methods["slice"].ID sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -648,7 +648,7 @@ func TestMethodPack(t *testing.T) { } var addrA, addrB = common.Address{1}, common.Address{2} - sig = abi.Methods["sliceAddress"].ID() + sig = abi.Methods["sliceAddress"].ID sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) @@ -663,7 +663,7 @@ func TestMethodPack(t *testing.T) { } var addrC, addrD = common.Address{3}, common.Address{4} - sig = abi.Methods["sliceMultiAddress"].ID() + sig = abi.Methods["sliceMultiAddress"].ID sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -681,7 +681,7 @@ func TestMethodPack(t *testing.T) { t.Errorf("expected %x got %x", sig, packed) } - sig = abi.Methods["slice256"].ID() + sig = abi.Methods["slice256"].ID sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -695,7 +695,7 @@ func TestMethodPack(t *testing.T) { } a := [2][2]*big.Int{{big.NewInt(1), big.NewInt(1)}, {big.NewInt(2), big.NewInt(0)}} - sig = abi.Methods["nestedArray"].ID() + sig = abi.Methods["nestedArray"].ID sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -712,7 +712,7 @@ func TestMethodPack(t *testing.T) { t.Errorf("expected %x got %x", sig, packed) } - sig = abi.Methods["nestedArray2"].ID() + sig = abi.Methods["nestedArray2"].ID sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x80}, 32)...) @@ -728,7 +728,7 @@ func TestMethodPack(t *testing.T) { t.Errorf("expected %x got %x", sig, packed) } - sig = abi.Methods["nestedSlice"].ID() + sig = abi.Methods["nestedSlice"].ID sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x02}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...) diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index 73ca8fa2b..1b2246e6f 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -118,18 +118,16 @@ func requireAssignable(dst, src reflect.Value) error { } // requireUnpackKind verifies preconditions for unpacking `args` into `kind` -func requireUnpackKind(v reflect.Value, t reflect.Type, k reflect.Kind, - args Arguments) error { - - switch k { +func requireUnpackKind(v reflect.Value, minLength int, args Arguments) error { + switch v.Kind() { case reflect.Struct: case reflect.Slice, reflect.Array: - if minLen := args.LengthNonIndexed(); v.Len() < minLen { + if v.Len() < minLength { return fmt.Errorf("abi: insufficient number of elements in the list/array for unpack, want %d, got %d", - minLen, v.Len()) + minLength, v.Len()) } default: - return fmt.Errorf("abi: cannot unmarshal tuple into %v", t) + return fmt.Errorf("abi: cannot unmarshal tuple into %v", v.Type()) } return nil } @@ -156,9 +154,8 @@ func mapArgNamesToStructFields(argNames []string, value reflect.Value) (map[stri continue } // skip fields that have no abi:"" tag. - var ok bool - var tagName string - if tagName, ok = typ.Field(i).Tag.Lookup("abi"); !ok { + tagName, ok := typ.Field(i).Tag.Lookup("abi") + if !ok { continue } // check if tag is empty. diff --git a/signer/fourbyte/abi.go b/signer/fourbyte/abi.go index 796086d41..007204c0f 100644 --- a/signer/fourbyte/abi.go +++ b/signer/fourbyte/abi.go @@ -140,7 +140,7 @@ func parseCallData(calldata []byte, abidata string) (*decodedCallData, error) { return nil, fmt.Errorf("signature %q matches, but arguments mismatch: %v", method.String(), err) } // Everything valid, assemble the call infos for the signer - decoded := decodedCallData{signature: method.Sig(), name: method.RawName} + decoded := decodedCallData{signature: method.Sig, name: method.RawName} for i := 0; i < len(method.Inputs); i++ { decoded.inputs = append(decoded.inputs, decodedArgument{ soltype: method.Inputs[i], @@ -158,7 +158,7 @@ func parseCallData(calldata []byte, abidata string) (*decodedCallData, error) { if !bytes.Equal(encoded, argdata) { was := common.Bytes2Hex(encoded) exp := common.Bytes2Hex(argdata) - return nil, fmt.Errorf("WARNING: Supplied data is stuffed with extra data. \nWant %s\nHave %s\nfor method %v", exp, was, method.Sig()) + return nil, fmt.Errorf("WARNING: Supplied data is stuffed with extra data. \nWant %s\nHave %s\nfor method %v", exp, was, method.Sig) } return &decoded, nil } diff --git a/signer/fourbyte/fourbyte_test.go b/signer/fourbyte/fourbyte_test.go index cdbd7ef73..cf54c9b9c 100644 --- a/signer/fourbyte/fourbyte_test.go +++ b/signer/fourbyte/fourbyte_test.go @@ -48,8 +48,8 @@ func TestEmbeddedDatabase(t *testing.T) { t.Errorf("Failed to get method by id (%s): %v", id, err) continue } - if m.Sig() != selector { - t.Errorf("Selector mismatch: have %v, want %v", m.Sig(), selector) + if m.Sig != selector { + t.Errorf("Selector mismatch: have %v, want %v", m.Sig, selector) } } }