diff --git a/chain/vm/runtime.go b/chain/vm/runtime.go index 879608277..3fb5879f4 100644 --- a/chain/vm/runtime.go +++ b/chain/vm/runtime.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "fmt" + "io" "github.com/filecoin-project/go-address" "github.com/filecoin-project/specs-actors/actors/abi" @@ -55,6 +56,32 @@ type Runtime struct { allowInternal bool } +type safeCBORMarshaler struct { + m cbg.CBORMarshaler +} + +func (s *safeCBORMarshaler) MarshalCBOR(w io.Writer) error { + if err := s.m.MarshalCBOR(w); err != nil { + panic(aerrors.Newf(exitcode.ErrSerialization,"failed to marshal cbor object %s", err)) + } + + return nil +} + +type safeCBORUnmarshaler struct { + m cbg.CBORUnmarshaler +} + +func (s *safeCBORUnmarshaler) UnmarshalCBOR(r io.Reader) error { + if err := s.m.UnmarshalCBOR(r); err != nil { + panic(aerrors.Newf(exitcode.ErrSerialization,"failed to unmarshal cbor object %s", err)) + } + + return nil +} + +var _ cbg.CBORUnmarshaler = &safeCBORUnmarshaler{} + func (rt *Runtime) TotalFilCircSupply() abi.TokenAmount { total := types.FromFil(build.TotalFilecoin) @@ -106,7 +133,7 @@ type notFoundErr interface { } func (rs *Runtime) Get(c cid.Cid, o vmr.CBORUnmarshaler) bool { - if err := rs.cst.Get(context.TODO(), c, o); err != nil { + if err := rs.cst.Get(context.TODO(), c, &safeCBORUnmarshaler{o}); err != nil { var nfe notFoundErr if xerrors.As(err, &nfe) && nfe.IsNotFound() { return false @@ -118,7 +145,7 @@ func (rs *Runtime) Get(c cid.Cid, o vmr.CBORUnmarshaler) bool { } func (rs *Runtime) Put(x vmr.CBORMarshaler) cid.Cid { - c, err := rs.cst.Put(context.TODO(), x) + c, err := rs.cst.Put(context.TODO(), &safeCBORMarshaler{x}) if err != nil { panic(aerrors.Fatalf("failed to put cbor object: %s", err)) }