Minor refactor

This commit is contained in:
Thomas E Lackey 2022-09-08 10:59:47 -05:00
parent dceb81866f
commit 765480b12c

View File

@ -19,12 +19,14 @@ type Root common.Root
type Eth1Data common.Eth1Data
type SignedBeaconBlock struct {
spec *common.Spec
bellatrix *bellatrix.SignedBeaconBlock
altair *altair.SignedBeaconBlock
phase0 *phase0.SignedBeaconBlock
}
type BeaconBlock struct {
spec *common.Spec
bellatrix *bellatrix.BeaconBlock
altair *altair.BeaconBlock
phase0 *phase0.BeaconBlock
@ -37,15 +39,17 @@ type BeaconBlockBody struct {
}
type BeaconState struct {
spec *common.Spec
bellatrix *bellatrix.BeaconState
altair *altair.BeaconState
phase0 *phase0.BeaconState
}
func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error {
spec := chooseSpec(s.spec)
var bellatrix bellatrix.SignedBeaconBlock
decodingReader := codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
err := bellatrix.Deserialize(configs.Mainnet, decodingReader)
err := bellatrix.Deserialize(spec, makeDecodingReader(ssz))
if nil == err {
s.bellatrix = &bellatrix
s.altair = nil
@ -55,8 +59,7 @@ func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error {
}
var altair altair.SignedBeaconBlock
decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
err = altair.Deserialize(configs.Mainnet, decodingReader)
err = altair.Deserialize(spec, makeDecodingReader(ssz))
if nil == err {
s.bellatrix = nil
s.altair = &altair
@ -66,8 +69,7 @@ func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error {
}
var phase0 phase0.SignedBeaconBlock
decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
err = phase0.Deserialize(configs.Mainnet, decodingReader)
err = phase0.Deserialize(spec, makeDecodingReader(ssz))
if nil == err {
s.bellatrix = nil
s.altair = nil
@ -85,18 +87,19 @@ func (s *SignedBeaconBlock) UnmarshalSSZ(ssz []byte) error {
}
func (s *SignedBeaconBlock) MarshalSSZ() ([]byte, error) {
spec := chooseSpec(s.spec)
var err error
var buf bytes.Buffer
encodingWriter := codec.NewEncodingWriter(&buf)
if s.IsBellatrix() {
err = s.bellatrix.Serialize(configs.Mainnet, encodingWriter)
err = s.bellatrix.Serialize(spec, encodingWriter)
}
if s.IsAltair() {
err = s.altair.Serialize(configs.Mainnet, encodingWriter)
err = s.altair.Serialize(spec, encodingWriter)
}
if s.IsPhase0() {
err = s.phase0.Serialize(configs.Mainnet, encodingWriter)
err = s.phase0.Serialize(spec, encodingWriter)
}
if err != nil {
@ -148,15 +151,15 @@ func (s *SignedBeaconBlock) Signature() [96]byte {
func (s *SignedBeaconBlock) Block() *BeaconBlock {
if s.IsBellatrix() {
return &BeaconBlock{bellatrix: &s.bellatrix.Message}
return &BeaconBlock{bellatrix: &s.bellatrix.Message, spec: s.spec}
}
if s.IsAltair() {
return &BeaconBlock{altair: &s.altair.Message}
return &BeaconBlock{altair: &s.altair.Message, spec: s.spec}
}
if s.IsPhase0() {
return &BeaconBlock{phase0: &s.phase0.Message}
return &BeaconBlock{phase0: &s.phase0.Message, spec: s.spec}
}
return nil
@ -263,25 +266,28 @@ func (b *BeaconBlockBody) Eth1Data() Eth1Data {
}
func (b *BeaconBlock) HashTreeRoot() Root {
spec := chooseSpec(b.spec)
if b.IsBellatrix() {
return Root(b.bellatrix.HashTreeRoot(configs.Mainnet, tree.Hash))
return Root(b.bellatrix.HashTreeRoot(spec, tree.Hash))
}
if b.IsAltair() {
return Root(b.altair.HashTreeRoot(configs.Mainnet, tree.Hash))
return Root(b.altair.HashTreeRoot(spec, tree.Hash))
}
if b.IsPhase0() {
return Root(b.phase0.HashTreeRoot(configs.Mainnet, tree.Hash))
return Root(b.phase0.HashTreeRoot(spec, tree.Hash))
}
return Root{}
}
func (s *BeaconState) UnmarshalSSZ(ssz []byte) error {
spec := chooseSpec(s.spec)
var bellatrix bellatrix.BeaconState
decodingReader := codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
err := bellatrix.Deserialize(configs.Mainnet, decodingReader)
err := bellatrix.Deserialize(spec, makeDecodingReader(ssz))
if nil == err {
s.bellatrix = &bellatrix
s.altair = nil
@ -291,8 +297,7 @@ func (s *BeaconState) UnmarshalSSZ(ssz []byte) error {
}
var altair altair.BeaconState
decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
err = altair.Deserialize(configs.Mainnet, decodingReader)
err = altair.Deserialize(spec, makeDecodingReader(ssz))
if nil == err {
s.bellatrix = nil
s.altair = &altair
@ -302,8 +307,7 @@ func (s *BeaconState) UnmarshalSSZ(ssz []byte) error {
}
var phase0 phase0.BeaconState
decodingReader = codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
err = phase0.Deserialize(configs.Mainnet, decodingReader)
err = phase0.Deserialize(spec, makeDecodingReader(ssz))
if nil == err {
s.bellatrix = nil
s.altair = nil
@ -321,16 +325,17 @@ func (s *BeaconState) UnmarshalSSZ(ssz []byte) error {
}
func (s *BeaconState) MarshalSSZ() ([]byte, error) {
spec := chooseSpec(s.spec)
var err error
var buf bytes.Buffer
encodingWriter := codec.NewEncodingWriter(&buf)
if s.IsBellatrix() {
err = s.bellatrix.Serialize(configs.Mainnet, encodingWriter)
err = s.bellatrix.Serialize(spec, encodingWriter)
} else if s.IsAltair() {
err = s.altair.Serialize(configs.Mainnet, encodingWriter)
err = s.altair.Serialize(spec, encodingWriter)
} else if s.IsPhase0() {
err = s.phase0.Serialize(configs.Mainnet, encodingWriter)
err = s.phase0.Serialize(spec, encodingWriter)
} else {
err = errors.New("BeaconState not set")
}
@ -371,17 +376,19 @@ func (s *BeaconState) Slot() Slot {
return 0
}
func (b *BeaconState) HashTreeRoot() Root {
if b.IsBellatrix() {
return Root(b.bellatrix.HashTreeRoot(configs.Mainnet, tree.Hash))
func (s *BeaconState) HashTreeRoot() Root {
spec := chooseSpec(s.spec)
if s.IsBellatrix() {
return Root(s.bellatrix.HashTreeRoot(spec, tree.Hash))
}
if b.IsAltair() {
return Root(b.altair.HashTreeRoot(configs.Mainnet, tree.Hash))
if s.IsAltair() {
return Root(s.altair.HashTreeRoot(spec, tree.Hash))
}
if b.IsPhase0() {
return Root(b.phase0.HashTreeRoot(configs.Mainnet, tree.Hash))
if s.IsPhase0() {
return Root(s.phase0.HashTreeRoot(spec, tree.Hash))
}
return Root{}
@ -398,3 +405,14 @@ func (s *BeaconState) GetAltair() *altair.BeaconState {
func (s *BeaconState) GetPhase0() *phase0.BeaconState {
return s.phase0
}
func chooseSpec(spec *common.Spec) *common.Spec {
if nil == spec {
return configs.Mainnet
}
return spec
}
func makeDecodingReader(ssz []byte) *codec.DecodingReader {
return codec.NewDecodingReader(bytes.NewReader(ssz), uint64(len(ssz)))
}