diff --git a/chain/invoker.go b/chain/invoker.go index e9acb6a0c..0eaba5948 100644 --- a/chain/invoker.go +++ b/chain/invoker.go @@ -8,6 +8,7 @@ import ( actors "github.com/filecoin-project/go-lotus/chain/actors" "github.com/filecoin-project/go-lotus/chain/types" "github.com/ipfs/go-cid" + cbor "github.com/ipfs/go-ipld-cbor" ) type invoker struct { @@ -50,15 +51,10 @@ func (inv *invoker) register(c cid.Cid, instance Invokee) { inv.builtInCode[c] = code } -type unmarshalCBOR interface { - UnmarshalCBOR([]byte) (int, error) -} - type Invokee interface { Exports() []interface{} } -var tUnmarhsalCBOR = reflect.TypeOf((*unmarshalCBOR)(nil)).Elem() var tVMContext = reflect.TypeOf((*types.VMContext)(nil)).Elem() var tError = reflect.TypeOf((*error)(nil)).Elem() @@ -89,12 +85,8 @@ func (*invoker) transform(instance Invokee) (nativeCode, error) { return nil, newErr("second argument should be types.VMContext") } - if !t.In(2).Implements(tUnmarhsalCBOR) { - return nil, newErr("parameter doesn't implement UnmarshalCBOR") - } - if t.In(2).Kind() != reflect.Ptr { - return nil, newErr("parameter has to be a pointer") + return nil, newErr("parameter has to be a pointer to parameter") } if t.NumOut() != 2 { @@ -118,7 +110,7 @@ func (*invoker) transform(instance Invokee) (nativeCode, error) { param := reflect.New(paramT) inBytes := in[2].Interface().([]byte) - _, err := param.Interface().(unmarshalCBOR).UnmarshalCBOR(inBytes) + err := cbor.DecodeInto(inBytes, param.Interface()) if err != nil { return []reflect.Value{ reflect.ValueOf(types.InvokeRet{}), @@ -127,6 +119,7 @@ func (*invoker) transform(instance Invokee) (nativeCode, error) { reflect.ValueOf(&err).Elem(), } } + return meth.Call([]reflect.Value{ in[0], in[1], param, }) diff --git a/chain/invoker_test.go b/chain/invoker_test.go index d2514b55c..e6f64aed7 100644 --- a/chain/invoker_test.go +++ b/chain/invoker_test.go @@ -1,9 +1,9 @@ package chain import ( - "errors" "testing" + cbor "github.com/ipfs/go-ipld-cbor" "github.com/stretchr/testify/assert" "github.com/filecoin-project/go-lotus/chain/types" @@ -11,19 +11,11 @@ import ( type basicContract struct{} type basicParams struct { - b byte + B byte } -func (b *basicParams) UnmarshalCBOR(in []byte) (int, error) { - b.b = in[0] - return 1, nil -} - -type badParam struct { -} - -func (b *badParam) UnmarshalCBOR(in []byte) (int, error) { - return -1, errors.New("some error") +func init() { + cbor.RegisterCborType(basicParams{}) } func (b basicContract) Exports() []interface{} { @@ -45,18 +37,20 @@ func (b basicContract) Exports() []interface{} { func (basicContract) InvokeSomething0(act *types.Actor, vmctx types.VMContext, params *basicParams) (types.InvokeRet, error) { return types.InvokeRet{ - ReturnCode: params.b, + ReturnCode: params.B, }, nil } func (basicContract) BadParam(act *types.Actor, vmctx types.VMContext, - params *badParam) (types.InvokeRet, error) { - panic("should not execute") + params *basicParams) (types.InvokeRet, error) { + return types.InvokeRet{ + ReturnCode: 255, + }, nil } func (basicContract) InvokeSomething10(act *types.Actor, vmctx types.VMContext, params *basicParams) (types.InvokeRet, error) { return types.InvokeRet{ - ReturnCode: params.b + 10, + ReturnCode: params.B + 10, }, nil } @@ -64,14 +58,26 @@ func TestInvokerBasic(t *testing.T) { inv := invoker{} code, err := inv.transform(basicContract{}) assert.NoError(t, err) - ret, err := code[0](nil, nil, []byte{1}) - assert.NoError(t, err) - assert.Equal(t, byte(1), ret.ReturnCode, "return code should be 1") - ret, err = code[10](nil, &VMContext{}, []byte{2}) - assert.NoError(t, err) - assert.Equal(t, byte(12), ret.ReturnCode, "return code should be 1") + { + bParam, err := cbor.DumpObject(basicParams{B: 1}) + assert.NoError(t, err) - ret, err = code[1](nil, &VMContext{}, []byte{2}) + ret, err := code[0](nil, &VMContext{}, bParam) + assert.NoError(t, err) + assert.Equal(t, byte(1), ret.ReturnCode, "return code should be 1") + } + + { + bParam, err := cbor.DumpObject(basicParams{B: 2}) + assert.NoError(t, err) + + ret, err := code[10](nil, &VMContext{}, bParam) + assert.NoError(t, err) + assert.Equal(t, byte(12), ret.ReturnCode, "return code should be 12") + } + + _, err = code[1](nil, &VMContext{}, []byte{0}) assert.Error(t, err) + }