cmd/abigen: change --exc to exclude by type name (#22620)

The abigen exclusion pattern, previously on the form "path:type", now supports wildcards. Examples "*:type" to exclude a named type in all files, or "/path/to/foo.sol:*" all types in foo.sol.
This commit is contained in:
Sebastian Stammler 2022-09-23 19:04:02 +02:00 committed by GitHub
parent 65f3c1b46f
commit e87806727d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 7 deletions

View File

@ -154,9 +154,12 @@ func abigen(c *cli.Context) error {
types = append(types, kind) types = append(types, kind)
} else { } else {
// Generate the list of types to exclude from binding // Generate the list of types to exclude from binding
exclude := make(map[string]bool) var exclude *nameFilter
for _, kind := range strings.Split(c.String(excFlag.Name), ",") { if c.IsSet(excFlag.Name) {
exclude[strings.ToLower(kind)] = true var err error
if exclude, err = newNameFilter(strings.Split(c.String(excFlag.Name), ",")...); err != nil {
utils.Fatalf("Failed to parse excludes: %v", err)
}
} }
var contracts map[string]*compiler.Contract var contracts map[string]*compiler.Contract
@ -181,7 +184,11 @@ func abigen(c *cli.Context) error {
} }
// Gather all non-excluded contract for binding // Gather all non-excluded contract for binding
for name, contract := range contracts { for name, contract := range contracts {
if exclude[strings.ToLower(name)] { // fully qualified name is of the form <solFilePath>:<type>
nameParts := strings.Split(name, ":")
typeName := nameParts[len(nameParts)-1]
if exclude != nil && exclude.Matches(name) {
fmt.Fprintf(os.Stderr, "excluding: %v\n", name)
continue continue
} }
abi, err := json.Marshal(contract.Info.AbiDefinition) // Flatten the compiler parse abi, err := json.Marshal(contract.Info.AbiDefinition) // Flatten the compiler parse
@ -191,15 +198,14 @@ func abigen(c *cli.Context) error {
abis = append(abis, string(abi)) abis = append(abis, string(abi))
bins = append(bins, contract.Code) bins = append(bins, contract.Code)
sigs = append(sigs, contract.Hashes) sigs = append(sigs, contract.Hashes)
nameParts := strings.Split(name, ":") types = append(types, typeName)
types = append(types, nameParts[len(nameParts)-1])
// Derive the library placeholder which is a 34 character prefix of the // Derive the library placeholder which is a 34 character prefix of the
// hex encoding of the keccak256 hash of the fully qualified library name. // hex encoding of the keccak256 hash of the fully qualified library name.
// Note that the fully qualified library name is the path of its source // Note that the fully qualified library name is the path of its source
// file and the library name separated by ":". // file and the library name separated by ":".
libPattern := crypto.Keccak256Hash([]byte(name)).String()[2:36] // the first 2 chars are 0x libPattern := crypto.Keccak256Hash([]byte(name)).String()[2:36] // the first 2 chars are 0x
libs[libPattern] = nameParts[len(nameParts)-1] libs[libPattern] = typeName
} }
} }
// Extract all aliases from the flags // Extract all aliases from the flags

58
cmd/abigen/namefilter.go Normal file
View File

@ -0,0 +1,58 @@
package main
import (
"fmt"
"strings"
)
type nameFilter struct {
fulls map[string]bool // path/to/contract.sol:Type
files map[string]bool // path/to/contract.sol:*
types map[string]bool // *:Type
}
func newNameFilter(patterns ...string) (*nameFilter, error) {
f := &nameFilter{
fulls: make(map[string]bool),
files: make(map[string]bool),
types: make(map[string]bool),
}
for _, pattern := range patterns {
if err := f.add(pattern); err != nil {
return nil, err
}
}
return f, nil
}
func (f *nameFilter) add(pattern string) error {
ft := strings.Split(pattern, ":")
if len(ft) != 2 {
// filenames and types must not include ':' symbol
return fmt.Errorf("invalid pattern: %s", pattern)
}
file, typ := ft[0], ft[1]
if file == "*" {
f.types[typ] = true
return nil
} else if typ == "*" {
f.files[file] = true
return nil
}
f.fulls[pattern] = true
return nil
}
func (f *nameFilter) Matches(name string) bool {
ft := strings.Split(name, ":")
if len(ft) != 2 {
// If contract names are always of the fully-qualified form
// <filePath>:<type>, then this case will never happen.
return false
}
file, typ := ft[0], ft[1]
// full paths > file paths > types
return f.fulls[name] || f.files[file] || f.types[typ]
}

View File

@ -0,0 +1,38 @@
package main
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNameFilter(t *testing.T) {
_, err := newNameFilter("Foo")
require.Error(t, err)
_, err = newNameFilter("too/many:colons:Foo")
require.Error(t, err)
f, err := newNameFilter("a/path:A", "*:B", "c/path:*")
require.NoError(t, err)
for _, tt := range []struct {
name string
match bool
}{
{"a/path:A", true},
{"unknown/path:A", false},
{"a/path:X", false},
{"unknown/path:X", false},
{"any/path:B", true},
{"c/path:X", true},
{"c/path:foo:B", false},
} {
match := f.Matches(tt.name)
if tt.match {
assert.True(t, match, "expected match")
} else {
assert.False(t, match, "expected no match")
}
}
}