// Copyright 2022 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. package main import ( "bytes" "fmt" "go/format" "go/types" "sort" "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct" ) // buildContext keeps the data needed for make*Op. type buildContext struct { topType *types.Named // the type we're creating methods for encoderIface *types.Interface decoderIface *types.Interface rawValueType *types.Named typeToStructCache map[types.Type]*rlpstruct.Type } func newBuildContext(packageRLP *types.Package) *buildContext { enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying() dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying() rawv := packageRLP.Scope().Lookup("RawValue").Type() return &buildContext{ typeToStructCache: make(map[types.Type]*rlpstruct.Type), encoderIface: enc.(*types.Interface), decoderIface: dec.(*types.Interface), rawValueType: rawv.(*types.Named), } } func (bctx *buildContext) isEncoder(typ types.Type) bool { return types.Implements(typ, bctx.encoderIface) } func (bctx *buildContext) isDecoder(typ types.Type) bool { return types.Implements(typ, bctx.decoderIface) } // typeToStructType converts typ to rlpstruct.Type. func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { if prev := bctx.typeToStructCache[typ]; prev != nil { return prev // short-circuit for recursive types. } // Resolve named types to their underlying type, but keep the name. name := types.TypeString(typ, nil) for { utype := typ.Underlying() if utype == typ { break } typ = utype } // Create the type and store it in cache. t := &rlpstruct.Type{ Name: name, Kind: typeReflectKind(typ), IsEncoder: bctx.isEncoder(typ), IsDecoder: bctx.isDecoder(typ), } bctx.typeToStructCache[typ] = t // Assign element type. switch typ.(type) { case *types.Array, *types.Slice, *types.Pointer: etype := typ.(interface{ Elem() types.Type }).Elem() t.Elem = bctx.typeToStructType(etype) } return t } // genContext is passed to the gen* methods of op when generating // the output code. It tracks packages to be imported by the output // file and assigns unique names of temporary variables. type genContext struct { inPackage *types.Package imports map[string]struct{} tempCounter int } func newGenContext(inPackage *types.Package) *genContext { return &genContext{ inPackage: inPackage, imports: make(map[string]struct{}), } } func (ctx *genContext) temp() string { v := fmt.Sprintf("_tmp%d", ctx.tempCounter) ctx.tempCounter++ return v } func (ctx *genContext) resetTemp() { ctx.tempCounter = 0 } func (ctx *genContext) addImport(path string) { if path == ctx.inPackage.Path() { return // avoid importing the package that we're generating in. } // TODO: renaming? ctx.imports[path] = struct{}{} } // importsList returns all packages that need to be imported. func (ctx *genContext) importsList() []string { imp := make([]string, 0, len(ctx.imports)) for k := range ctx.imports { imp = append(imp, k) } sort.Strings(imp) return imp } // qualify is the types.Qualifier used for printing types. func (ctx *genContext) qualify(pkg *types.Package) string { if pkg.Path() == ctx.inPackage.Path() { return "" } ctx.addImport(pkg.Path()) // TODO: renaming? return pkg.Name() } type op interface { // genWrite creates the encoder. The generated code should write v, // which is any Go expression, to the rlp.EncoderBuffer 'w'. genWrite(ctx *genContext, v string) string // genDecode creates the decoder. The generated code should read // a value from the rlp.Stream 'dec' and store it to dst. genDecode(ctx *genContext) (string, string) } // basicOp handles basic types bool, uint*, string. type basicOp struct { typ types.Type writeMethod string // calle write the value writeArgType types.Type // parameter type of writeMethod decMethod string decResultType types.Type // return type of decMethod decUseBitSize bool // if true, result bit size is appended to decMethod } func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) { op := basicOp{typ: typ} kind := typ.Kind() switch { case kind == types.Bool: op.writeMethod = "WriteBool" op.writeArgType = types.Typ[types.Bool] op.decMethod = "Bool" op.decResultType = types.Typ[types.Bool] case kind >= types.Uint8 && kind <= types.Uint64: op.writeMethod = "WriteUint64" op.writeArgType = types.Typ[types.Uint64] op.decMethod = "Uint" op.decResultType = typ op.decUseBitSize = true case kind == types.String: op.writeMethod = "WriteString" op.writeArgType = types.Typ[types.String] op.decMethod = "String" op.decResultType = types.Typ[types.String] default: return nil, fmt.Errorf("unhandled basic type: %v", typ) } return op, nil } func (*buildContext) makeByteSliceOp(typ *types.Slice) op { if !isByte(typ.Elem()) { panic("non-byte slice type in makeByteSliceOp") } bslice := types.NewSlice(types.Typ[types.Uint8]) return basicOp{ typ: typ, writeMethod: "WriteBytes", writeArgType: bslice, decMethod: "Bytes", decResultType: bslice, } } func (bctx *buildContext) makeRawValueOp() op { bslice := types.NewSlice(types.Typ[types.Uint8]) return basicOp{ typ: bctx.rawValueType, writeMethod: "Write", writeArgType: bslice, decMethod: "Raw", decResultType: bslice, } } func (op basicOp) writeNeedsConversion() bool { return !types.AssignableTo(op.typ, op.writeArgType) } func (op basicOp) decodeNeedsConversion() bool { return !types.AssignableTo(op.decResultType, op.typ) } func (op basicOp) genWrite(ctx *genContext, v string) string { if op.writeNeedsConversion() { v = fmt.Sprintf("%s(%s)", op.writeArgType, v) } return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v) } func (op basicOp) genDecode(ctx *genContext) (string, string) { var ( resultV = ctx.temp() result = resultV method = op.decMethod ) if op.decUseBitSize { // Note: For now, this only works for platform-independent integer // sizes. makeBasicOp forbids the platform-dependent types. var sizes types.StdSizes method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8) } // Call the decoder method. var b bytes.Buffer fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method) fmt.Fprintf(&b, "if err != nil { return err }\n") if op.decodeNeedsConversion() { conv := ctx.temp() fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV) result = conv } return result, b.String() } // byteArrayOp handles [...]byte. type byteArrayOp struct { typ types.Type name types.Type // name != typ for named byte array types (e.g. common.Address) } func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp { nt := types.Type(name) if name == nil { nt = typ } return byteArrayOp{typ, nt} } func (op byteArrayOp) genWrite(ctx *genContext, v string) string { return fmt.Sprintf("w.WriteBytes(%s[:])\n", v) } func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { var resultV = ctx.temp() var b bytes.Buffer fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify)) fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV) return resultV, b.String() } // bigIntOp handles big.Int. // This exists because big.Int has it's own decoder operation on rlp.Stream, // but the decode method returns *big.Int, so it needs to be dereferenced. type bigIntOp struct { pointer bool } func (op bigIntOp) genWrite(ctx *genContext, v string) string { var b bytes.Buffer fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v) fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n") fmt.Fprintf(&b, "}\n") dst := v if !op.pointer { dst = "&" + v } fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst) // Wrap with nil check. if op.pointer { code := b.String() b.Reset() fmt.Fprintf(&b, "if %s == nil {\n", v) fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") fmt.Fprintf(&b, "} else {\n") fmt.Fprint(&b, code) fmt.Fprintf(&b, "}\n") } return b.String() } func (op bigIntOp) genDecode(ctx *genContext) (string, string) { var resultV = ctx.temp() var b bytes.Buffer fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV) fmt.Fprintf(&b, "if err != nil { return err }\n") result := resultV if !op.pointer { result = "(*" + resultV + ")" } return result, b.String() } // uint256Op handles "github.com/holiman/uint256".Int type uint256Op struct { pointer bool } func (op uint256Op) genWrite(ctx *genContext, v string) string { var b bytes.Buffer dst := v if !op.pointer { dst = "&" + v } fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst) // Wrap with nil check. if op.pointer { code := b.String() b.Reset() fmt.Fprintf(&b, "if %s == nil {\n", v) fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") fmt.Fprintf(&b, "} else {\n") fmt.Fprint(&b, code) fmt.Fprintf(&b, "}\n") } return b.String() } func (op uint256Op) genDecode(ctx *genContext) (string, string) { ctx.addImport("github.com/holiman/uint256") var b bytes.Buffer resultV := ctx.temp() fmt.Fprintf(&b, "var %s uint256.Int\n", resultV) fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV) result := resultV if op.pointer { result = "&" + resultV } return result, b.String() } // encoderDecoderOp handles rlp.Encoder and rlp.Decoder. // In order to be used with this, the type must implement both interfaces. // This restriction may be lifted in the future by creating separate ops for // encoding and decoding. type encoderDecoderOp struct { typ types.Type } func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string { return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v) } func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) { // DecodeRLP must have pointer receiver, and this is verified in makeOp. etyp := op.typ.(*types.Pointer).Elem() var resultV = ctx.temp() var b bytes.Buffer fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify)) fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV) return resultV, b.String() } // ptrOp handles pointer types. type ptrOp struct { elemTyp types.Type elem op nilOK bool nilValue rlpstruct.NilKind } func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) { elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{}) if err != nil { return nil, err } op := ptrOp{elemTyp: elemTyp, elem: elemOp} // Determine nil value. if tags.NilOK { op.nilOK = true op.nilValue = tags.NilKind } else { styp := bctx.typeToStructType(elemTyp) op.nilValue = styp.DefaultNilValue() } return op, nil } func (op ptrOp) genWrite(ctx *genContext, v string) string { // Note: in writer functions, accesses to v are read-only, i.e. v is any Go // expression. To make all accesses work through the pointer, we substitute // v with (*v). This is required for most accesses including `v`, `call(v)`, // and `v[index]` on slices. // // For `v.field` and `v[:]` on arrays, the dereference operation is not required. var vv string _, isStruct := op.elem.(structOp) _, isByteArray := op.elem.(byteArrayOp) if isStruct || isByteArray { vv = v } else { vv = fmt.Sprintf("(*%s)", v) } var b bytes.Buffer fmt.Fprintf(&b, "if %s == nil {\n", v) fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue) fmt.Fprintf(&b, "} else {\n") fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv)) fmt.Fprintf(&b, "}\n") return b.String() } func (op ptrOp) genDecode(ctx *genContext) (string, string) { result, code := op.elem.genDecode(ctx) if !op.nilOK { // If nil pointers are not allowed, we can just decode the element. return "&" + result, code } // nil is allowed, so check the kind and size first. // If size is zero and kind matches the nilKind of the type, // the value decodes as a nil pointer. var ( resultV = ctx.temp() kindV = ctx.temp() sizeV = ctx.temp() wantKind string ) if op.nilValue == rlpstruct.NilKindList { wantKind = "rlp.List" } else { wantKind = "rlp.String" } var b bytes.Buffer fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify)) fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV) fmt.Fprintf(&b, " return err\n") fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind) fmt.Fprint(&b, code) fmt.Fprintf(&b, " %s = &%s\n", resultV, result) fmt.Fprintf(&b, "}\n") return resultV, b.String() } // structOp handles struct types. type structOp struct { named *types.Named typ *types.Struct fields []*structField optionalFields []*structField } type structField struct { name string typ types.Type elem op } func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) { // Convert fields to []rlpstruct.Field. var allStructFields []rlpstruct.Field for i := 0; i < typ.NumFields(); i++ { f := typ.Field(i) allStructFields = append(allStructFields, rlpstruct.Field{ Name: f.Name(), Exported: f.Exported(), Index: i, Tag: typ.Tag(i), Type: *bctx.typeToStructType(f.Type()), }) } // Filter/validate fields. fields, tags, err := rlpstruct.ProcessFields(allStructFields) if err != nil { return nil, err } // Create field ops. var op = structOp{named: named, typ: typ} for i, field := range fields { // Advanced struct tags are not supported yet. tag := tags[i] if err := checkUnsupportedTags(field.Name, tag); err != nil { return nil, err } typ := typ.Field(field.Index).Type() elem, err := bctx.makeOp(nil, typ, tags[i]) if err != nil { return nil, fmt.Errorf("field %s: %v", field.Name, err) } f := &structField{name: field.Name, typ: typ, elem: elem} if tag.Optional { op.optionalFields = append(op.optionalFields, f) } else { op.fields = append(op.fields, f) } } return op, nil } func checkUnsupportedTags(field string, tag rlpstruct.Tags) error { if tag.Tail { return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field) } return nil } func (op structOp) genWrite(ctx *genContext, v string) string { var b bytes.Buffer var listMarker = ctx.temp() fmt.Fprintf(&b, "%s := w.List()\n", listMarker) for _, field := range op.fields { selector := v + "." + field.name fmt.Fprint(&b, field.elem.genWrite(ctx, selector)) } op.writeOptionalFields(&b, ctx, v) fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) return b.String() } func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) { if len(op.optionalFields) == 0 { return } // First check zero-ness of all optional fields. var zeroV = make([]string, len(op.optionalFields)) for i, field := range op.optionalFields { selector := v + "." + field.name zeroV[i] = ctx.temp() fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify)) } // Now write the fields. for i, field := range op.optionalFields { selector := v + "." + field.name cond := "" for j := i; j < len(op.optionalFields); j++ { if j > i { cond += " || " } cond += zeroV[j] } fmt.Fprintf(b, "if %s {\n", cond) fmt.Fprint(b, field.elem.genWrite(ctx, selector)) fmt.Fprintf(b, "}\n") } } func (op structOp) genDecode(ctx *genContext) (string, string) { // Get the string representation of the type. // Here, named types are handled separately because the output // would contain a copy of the struct definition otherwise. var typeName string if op.named != nil { typeName = types.TypeString(op.named, ctx.qualify) } else { typeName = types.TypeString(op.typ, ctx.qualify) } // Create struct object. var resultV = ctx.temp() var b bytes.Buffer fmt.Fprintf(&b, "var %s %s\n", resultV, typeName) // Decode fields. fmt.Fprintf(&b, "{\n") fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") for _, field := range op.fields { result, code := field.elem.genDecode(ctx) fmt.Fprintf(&b, "// %s:\n", field.name) fmt.Fprint(&b, code) fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result) } op.decodeOptionalFields(&b, ctx, resultV) fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") fmt.Fprintf(&b, "}\n") return resultV, b.String() } func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) { var suffix bytes.Buffer for _, field := range op.optionalFields { result, code := field.elem.genDecode(ctx) fmt.Fprintf(b, "// %s:\n", field.name) fmt.Fprintf(b, "if dec.MoreDataInList() {\n") fmt.Fprint(b, code) fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result) fmt.Fprintf(&suffix, "}\n") } suffix.WriteTo(b) } // sliceOp handles slice types. type sliceOp struct { typ *types.Slice elemOp op } func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) { elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{}) if err != nil { return nil, err } return sliceOp{typ: typ, elemOp: elemOp}, nil } func (op sliceOp) genWrite(ctx *genContext, v string) string { var ( listMarker = ctx.temp() // holds return value of w.List() iterElemV = ctx.temp() // iteration variable elemCode = op.elemOp.genWrite(ctx, iterElemV) ) var b bytes.Buffer fmt.Fprintf(&b, "%s := w.List()\n", listMarker) fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v) fmt.Fprint(&b, elemCode) fmt.Fprintf(&b, "}\n") fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) return b.String() } func (op sliceOp) genDecode(ctx *genContext) (string, string) { var sliceV = ctx.temp() // holds the output slice elemResult, elemCode := op.elemOp.genDecode(ctx) var b bytes.Buffer fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify)) fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") fmt.Fprintf(&b, "for dec.MoreDataInList() {\n") fmt.Fprintf(&b, " %s", elemCode) fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult) fmt.Fprintf(&b, "}\n") fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") return sliceV, b.String() } func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) { switch typ := typ.(type) { case *types.Named: if isBigInt(typ) { return bigIntOp{}, nil } if isUint256(typ) { return uint256Op{}, nil } if typ == bctx.rawValueType { return bctx.makeRawValueOp(), nil } if bctx.isDecoder(typ) { return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ) } // TODO: same check for encoder? return bctx.makeOp(typ, typ.Underlying(), tags) case *types.Pointer: if isBigInt(typ.Elem()) { return bigIntOp{pointer: true}, nil } if isUint256(typ.Elem()) { return uint256Op{pointer: true}, nil } // Encoder/Decoder interfaces. if bctx.isEncoder(typ) { if bctx.isDecoder(typ) { return encoderDecoderOp{typ}, nil } return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ) } if bctx.isDecoder(typ) { return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ) } // Default pointer handling. return bctx.makePtrOp(typ.Elem(), tags) case *types.Basic: return bctx.makeBasicOp(typ) case *types.Struct: return bctx.makeStructOp(name, typ) case *types.Slice: etyp := typ.Elem() if isByte(etyp) && !bctx.isEncoder(etyp) { return bctx.makeByteSliceOp(typ), nil } return bctx.makeSliceOp(typ) case *types.Array: etyp := typ.Elem() if isByte(etyp) && !bctx.isEncoder(etyp) { return bctx.makeByteArrayOp(name, typ), nil } return nil, fmt.Errorf("unhandled array type: %v", typ) default: return nil, fmt.Errorf("unhandled type: %v", typ) } } // generateDecoder generates the DecodeRLP method on 'typ'. func generateDecoder(ctx *genContext, typ string, op op) []byte { ctx.resetTemp() ctx.addImport(pathOfPackageRLP) result, code := op.genDecode(ctx) var b bytes.Buffer fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ) fmt.Fprint(&b, code) fmt.Fprintf(&b, " *obj = %s\n", result) fmt.Fprintf(&b, " return nil\n") fmt.Fprintf(&b, "}\n") return b.Bytes() } // generateEncoder generates the EncodeRLP method on 'typ'. func generateEncoder(ctx *genContext, typ string, op op) []byte { ctx.resetTemp() ctx.addImport("io") ctx.addImport(pathOfPackageRLP) var b bytes.Buffer fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n") fmt.Fprint(&b, op.genWrite(ctx, "obj")) fmt.Fprintf(&b, " return w.Flush()\n") fmt.Fprintf(&b, "}\n") return b.Bytes() } func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) { bctx.topType = typ pkg := typ.Obj().Pkg() op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{}) if err != nil { return nil, err } var ( ctx = newGenContext(pkg) encSource []byte decSource []byte ) if encoder { encSource = generateEncoder(ctx, typ.Obj().Name(), op) } if decoder { decSource = generateDecoder(ctx, typ.Obj().Name(), op) } var b bytes.Buffer fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) for _, imp := range ctx.importsList() { fmt.Fprintf(&b, "import %q\n", imp) } if encoder { fmt.Fprintln(&b) b.Write(encSource) } if decoder { fmt.Fprintln(&b) b.Write(decSource) } source := b.Bytes() // fmt.Println(string(source)) return format.Source(source) }