diff --git a/pkg/beaconclient/consensus.go b/pkg/beaconclient/consensus.go index 209dbbd..9c0bcb6 100644 --- a/pkg/beaconclient/consensus.go +++ b/pkg/beaconclient/consensus.go @@ -19,12 +19,14 @@ type Root common.Root type Eth1Data common.Eth1Data type SignedBeaconBlock struct { + spec *common.Spec bellatrix *bellatrix.SignedBeaconBlock altair *altair.SignedBeaconBlock phase0 *phase0.SignedBeaconBlock } type BeaconBlock struct { + spec *common.Spec bellatrix *bellatrix.BeaconBlock altair *altair.BeaconBlock phase0 *phase0.BeaconBlock @@ -37,15 +39,17 @@ type BeaconBlockBody struct { } type BeaconState struct { + spec *common.Spec bellatrix *bellatrix.BeaconState altair *altair.BeaconState phase0 *phase0.BeaconState } func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error { + spec := chooseSpec(s.spec) + var bellatrix bellatrix.SignedBeaconBlock - decodingReader := codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) - err := bellatrix.Deserialize(configs.Mainnet, decodingReader) + err := bellatrix.Deserialize(spec, makeDecodingReader(ssz)) if nil == err { s.bellatrix = &bellatrix s.altair = nil @@ -55,8 +59,7 @@ func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error { } var altair altair.SignedBeaconBlock - decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) - err = altair.Deserialize(configs.Mainnet, decodingReader) + err = altair.Deserialize(spec, makeDecodingReader(ssz)) if nil == err { s.bellatrix = nil s.altair = &altair @@ -66,8 +69,7 @@ func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error { } var phase0 phase0.SignedBeaconBlock - decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) - err = phase0.Deserialize(configs.Mainnet, decodingReader) + err = phase0.Deserialize(spec, makeDecodingReader(ssz)) if nil == err { s.bellatrix = nil s.altair = nil @@ -85,18 +87,19 @@ func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error { } func (s *SignedBeaconBlock) MarshalSSZ() ([]byte, error) { + spec := chooseSpec(s.spec) var err error var buf bytes.Buffer encodingWriter := codec.NewEncodingWriter(&buf) if s.IsBellatrix() { - err = s.bellatrix.Serialize(configs.Mainnet, encodingWriter) + err = s.bellatrix.Serialize(spec, encodingWriter) } if s.IsAltair() { - err = s.altair.Serialize(configs.Mainnet, encodingWriter) + err = s.altair.Serialize(spec, encodingWriter) } if s.IsPhase0() { - err = s.phase0.Serialize(configs.Mainnet, encodingWriter) + err = s.phase0.Serialize(spec, encodingWriter) } if err != nil { @@ -148,15 +151,15 @@ func (s *SignedBeaconBlock) Signature() [96]byte { func (s *SignedBeaconBlock) Block() *BeaconBlock { if s.IsBellatrix() { - return &BeaconBlock{bellatrix: &s.bellatrix.Message} + return &BeaconBlock{bellatrix: &s.bellatrix.Message, spec: s.spec} } if s.IsAltair() { - return &BeaconBlock{altair: &s.altair.Message} + return &BeaconBlock{altair: &s.altair.Message, spec: s.spec} } if s.IsPhase0() { - return &BeaconBlock{phase0: &s.phase0.Message} + return &BeaconBlock{phase0: &s.phase0.Message, spec: s.spec} } return nil @@ -263,25 +266,28 @@ func (b *BeaconBlockBody) Eth1Data() Eth1Data { } func (b *BeaconBlock) HashTreeRoot() Root { + spec := chooseSpec(b.spec) + if b.IsBellatrix() { - return Root(b.bellatrix.HashTreeRoot(configs.Mainnet, tree.Hash)) + return Root(b.bellatrix.HashTreeRoot(spec, tree.Hash)) } if b.IsAltair() { - return Root(b.altair.HashTreeRoot(configs.Mainnet, tree.Hash)) + return Root(b.altair.HashTreeRoot(spec, tree.Hash)) } if b.IsPhase0() { - return Root(b.phase0.HashTreeRoot(configs.Mainnet, tree.Hash)) + return Root(b.phase0.HashTreeRoot(spec, tree.Hash)) } return Root{} } func (s *BeaconState) UnmarshalSSZ(ssz []byte) error { + spec := chooseSpec(s.spec) + var bellatrix bellatrix.BeaconState - decodingReader := codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) - err := bellatrix.Deserialize(configs.Mainnet, decodingReader) + err := bellatrix.Deserialize(spec, makeDecodingReader(ssz)) if nil == err { s.bellatrix = &bellatrix s.altair = nil @@ -291,8 +297,7 @@ func (s *BeaconState) UnmarshalSSZ(ssz []byte) error { } var altair altair.BeaconState - decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) - err = altair.Deserialize(configs.Mainnet, decodingReader) + err = altair.Deserialize(spec, makeDecodingReader(ssz)) if nil == err { s.bellatrix = nil s.altair = &altair @@ -302,8 +307,7 @@ func (s *BeaconState) UnmarshalSSZ(ssz []byte) error { } var phase0 phase0.BeaconState - decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) - err = phase0.Deserialize(configs.Mainnet, decodingReader) + err = phase0.Deserialize(spec, makeDecodingReader(ssz)) if nil == err { s.bellatrix = nil s.altair = nil @@ -321,16 +325,17 @@ func (s *BeaconState) UnmarshalSSZ(ssz []byte) error { } func (s *BeaconState) MarshalSSZ() ([]byte, error) { + spec := chooseSpec(s.spec) var err error var buf bytes.Buffer encodingWriter := codec.NewEncodingWriter(&buf) if s.IsBellatrix() { - err = s.bellatrix.Serialize(configs.Mainnet, encodingWriter) + err = s.bellatrix.Serialize(spec, encodingWriter) } else if s.IsAltair() { - err = s.altair.Serialize(configs.Mainnet, encodingWriter) + err = s.altair.Serialize(spec, encodingWriter) } else if s.IsPhase0() { - err = s.phase0.Serialize(configs.Mainnet, encodingWriter) + err = s.phase0.Serialize(spec, encodingWriter) } else { err = errors.New("BeaconState not set") } @@ -371,17 +376,19 @@ func (s *BeaconState) Slot() Slot { return 0 } -func (b *BeaconState) HashTreeRoot() Root { - if b.IsBellatrix() { - return Root(b.bellatrix.HashTreeRoot(configs.Mainnet, tree.Hash)) +func (s *BeaconState) HashTreeRoot() Root { + spec := chooseSpec(s.spec) + + if s.IsBellatrix() { + return Root(s.bellatrix.HashTreeRoot(spec, tree.Hash)) } - if b.IsAltair() { - return Root(b.altair.HashTreeRoot(configs.Mainnet, tree.Hash)) + if s.IsAltair() { + return Root(s.altair.HashTreeRoot(spec, tree.Hash)) } - if b.IsPhase0() { - return Root(b.phase0.HashTreeRoot(configs.Mainnet, tree.Hash)) + if s.IsPhase0() { + return Root(s.phase0.HashTreeRoot(spec, tree.Hash)) } return Root{} @@ -398,3 +405,14 @@ func (s *BeaconState) GetAltair() *altair.BeaconState { func (s *BeaconState) GetPhase0() *phase0.BeaconState { return s.phase0 } + +func chooseSpec(spec *common.Spec) *common.Spec { + if nil == spec { + return configs.Mainnet + } + return spec +} + +func makeDecodingReader(ssz []byte) *codec.DecodingReader { + return codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz))) +}