// Copyright 2022 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. package main import ( "bytes" "fmt" "go/ast" "go/importer" "go/parser" "go/token" "go/types" "os" "path/filepath" "testing" ) // Package RLP is loaded only once and reused for all tests. var ( testFset = token.NewFileSet() testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom) testPackageRLP *types.Package ) func init() { cwd, err := os.Getwd() if err != nil { panic(err) } testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0) if err != nil { panic(fmt.Errorf("can't load package RLP: %v", err)) } } var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} func TestOutput(t *testing.T) { for _, test := range tests { test := test t.Run(test, func(t *testing.T) { inputFile := filepath.Join("testdata", test+".in.txt") outputFile := filepath.Join("testdata", test+".out.txt") bctx, typ, err := loadTestSource(inputFile, "Test") if err != nil { t.Fatal("error loading test source:", err) } output, err := bctx.generate(typ, true, true) if err != nil { t.Fatal("error in generate:", err) } // Set this environment variable to regenerate the test outputs. if os.Getenv("WRITE_TEST_FILES") != "" { os.WriteFile(outputFile, output, 0644) } // Check if output matches. wantOutput, err := os.ReadFile(outputFile) if err != nil { t.Fatal("error loading expected test output:", err) } if !bytes.Equal(output, wantOutput) { t.Fatalf("output mismatch, want: %v got %v", string(wantOutput), string(output)) } }) } } func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) { // Load the test input. content, err := os.ReadFile(file) if err != nil { return nil, nil, err } f, err := parser.ParseFile(testFset, file, content, 0) if err != nil { return nil, nil, err } conf := types.Config{Importer: testImporter} pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil) if err != nil { return nil, nil, err } // Find the test struct. bctx := newBuildContext(testPackageRLP) typ, err := lookupStructType(pkg.Scope(), typeName) if err != nil { return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err) } return bctx, typ, nil }