fix(x/tx): recursively traverse nested messages in GetSigners (#18740)

This commit is contained in:
Matt Kocubinski 2023-12-18 11:43:54 -06:00 committed by GitHub
parent 1a496057ac
commit 0c0589813b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 3891 additions and 145 deletions

View File

@ -29,7 +29,10 @@ Ref: https://keepachangelog.com/en/1.0.0/
# Changelog
## [Unreleased]
## v0.13.0
### Improvements
* [#18740](https://github.com/cosmos/cosmos-sdk/pull/18740) Support nested messages when fetching signers up to a default depth of 32.
## v0.12.0

View File

@ -55,6 +55,42 @@ message RepeatedNestedRepeatedSigner {
}
}
message DeeplyNestedSigner {
option (cosmos.msg.v1.signer) = "inner_one";
InnerOne inner_one = 1;
message InnerOne {
option (cosmos.msg.v1.signer) = "inner_two";
InnerTwo inner_two = 1;
message InnerTwo {
option (cosmos.msg.v1.signer) = "signer";
string signer = 1;
}
}
}
message DeeplyNestedRepeatedSigner {
option (cosmos.msg.v1.signer) = "inner";
repeated Inner inner = 1;
message Inner {
option (cosmos.msg.v1.signer) = "inner";
repeated Inner inner = 1;
message Inner {
option (cosmos.msg.v1.signer) = "inner";
repeated Bottom inner = 1;
message Bottom {
option (cosmos.msg.v1.signer) = "signer";
repeated string signer = 1;
}
}
}
}
message BadSigner {
option (cosmos.msg.v1.signer) = "signer";
bytes signer = 1;

File diff suppressed because it is too large Load Diff

View File

@ -30,6 +30,7 @@ type Context struct {
validatorAddressCodec address.Codec
getSignersFuncs map[protoreflect.FullName]GetSignersFunc
customGetSignerFuncs map[protoreflect.FullName]GetSignersFunc
maxRecursionDepth int
}
// Options are options for creating Context which will be used for signing operations.
@ -47,7 +48,11 @@ type Options struct {
// ValidatorAddressCodec is the codec for converting validator addresses between strings and bytes.
ValidatorAddressCodec address.Codec
// CustomGetSigners is a map of message types to custom GetSignersFuncs.
CustomGetSigners map[protoreflect.FullName]GetSignersFunc
// MaxRecursionDepth is the maximum depth of nested messages that will be traversed
MaxRecursionDepth int
}
// DefineCustomGetSigners defines a custom GetSigners function for a given
@ -90,6 +95,10 @@ func NewContext(options Options) (*Context, error) {
return nil, errors.New("validator address codec is required")
}
if options.MaxRecursionDepth <= 0 {
options.MaxRecursionDepth = 32
}
customGetSignerFuncs := map[protoreflect.FullName]GetSignersFunc{}
for k := range options.CustomGetSigners {
customGetSignerFuncs[k] = options.CustomGetSigners[k]
@ -102,6 +111,7 @@ func NewContext(options Options) (*Context, error) {
validatorAddressCodec: options.ValidatorAddressCodec,
getSignersFuncs: map[protoreflect.FullName]GetSignersFunc{},
customGetSignerFuncs: customGetSignerFuncs,
maxRecursionDepth: options.MaxRecursionDepth,
}
return c, nil
@ -208,92 +218,87 @@ func (c *Context) makeGetSignersFunc(descriptor protoreflect.MessageDescriptor)
}
}
case protoreflect.MessageKind:
isList := field.IsList()
nestedMessage := field.Message()
nestedSignersFields, err := getSignersFieldNames(nestedMessage)
if err != nil {
return nil, err
}
if len(nestedSignersFields) != 1 {
return nil, fmt.Errorf("nested cosmos.msg.v1.signer option in message %s must contain only one value", nestedMessage.FullName())
}
nestedFieldName := nestedSignersFields[0]
nestedField := nestedMessage.Fields().ByName(protoreflect.Name(nestedFieldName))
nestedIsList := nestedField.IsList()
if nestedField == nil {
return nil, fmt.Errorf("field %s not found in message %s", nestedFieldName, nestedMessage.FullName())
}
if nestedField.Kind() != protoreflect.StringKind || nestedField.IsMap() || nestedField.HasOptionalKeyword() {
return nil, fmt.Errorf("nested signer field %s in message %s must be a simple string", nestedFieldName, nestedMessage.FullName())
}
addrCdc := c.getAddressCodec(nestedField)
if isList {
if nestedIsList {
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
msgs := msg.ProtoReflect().Get(field).List()
m := msgs.Len()
for i := 0; i < m; i++ {
signers := msgs.Get(i).Message().Get(nestedField).List()
n := signers.Len()
for j := 0; j < n; j++ {
addrStr := signers.Get(j).String()
addrBz, err := addrCdc.StringToBytes(addrStr)
if err != nil {
return nil, err
}
arr = append(arr, addrBz)
}
}
return arr, nil
}
} else {
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
msgs := msg.ProtoReflect().Get(field).List()
m := msgs.Len()
for i := 0; i < m; i++ {
addrStr := msgs.Get(i).Message().Get(nestedField).String()
addrBz, err := addrCdc.StringToBytes(addrStr)
if err != nil {
return nil, err
}
arr = append(arr, addrBz)
}
return arr, nil
}
var fieldGetter func(protoreflect.Message, int) ([][]byte, error)
fieldGetter = func(msg protoreflect.Message, depth int) ([][]byte, error) {
if depth > c.maxRecursionDepth {
return nil, fmt.Errorf("maximum recursion depth exceeded")
}
} else {
if nestedIsList {
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
nestedMsg := msg.ProtoReflect().Get(field).Message()
signers := nestedMsg.Get(nestedField).List()
n := signers.Len()
for j := 0; j < n; j++ {
addrStr := signers.Get(j).String()
desc := msg.Descriptor()
signerFields, err := getSignersFieldNames(desc)
if err != nil {
return nil, err
}
if len(signerFields) != 1 {
return nil, fmt.Errorf("nested cosmos.msg.v1.signer option in message %s must contain only one value", desc.FullName())
}
signerFieldName := signerFields[0]
childField := desc.Fields().ByName(protoreflect.Name(signerFieldName))
switch {
case childField.Kind() == protoreflect.MessageKind:
if childField.IsList() {
childMsgs := msg.Get(childField).List()
var arr [][]byte
for i := 0; i < childMsgs.Len(); i++ {
res, err := fieldGetter(childMsgs.Get(i).Message(), depth+1)
if err != nil {
return nil, err
}
arr = append(arr, res...)
}
return arr, nil
} else {
return fieldGetter(msg.Get(childField).Message(), depth+1)
}
case childField.IsMap() || childField.HasOptionalKeyword():
return nil, fmt.Errorf("cosmos.msg.v1.signer field %s in message %s must not be a map or optional", signerFieldName, desc.FullName())
case childField.Kind() == protoreflect.StringKind:
addrCdc := c.getAddressCodec(childField)
if childField.IsList() {
childMsgs := msg.Get(childField).List()
n := childMsgs.Len()
var res [][]byte
for i := 0; i < n; i++ {
addrStr := childMsgs.Get(i).String()
addrBz, err := addrCdc.StringToBytes(addrStr)
if err != nil {
return nil, err
}
arr = append(arr, addrBz)
res = append(res, addrBz)
}
return arr, nil
}
} else {
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
addrStr := msg.ProtoReflect().Get(field).Message().Get(nestedField).String()
return res, nil
} else {
addrStr := msg.Get(childField).String()
addrBz, err := addrCdc.StringToBytes(addrStr)
if err != nil {
return nil, err
}
return append(arr, addrBz), nil
return [][]byte{addrBz}, nil
}
}
return nil, fmt.Errorf("unexpected field type %s for field %s in message %s, only string and message type are supported",
childField.Kind(), signerFieldName, desc.FullName())
}
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
if field.IsList() {
signers := msg.ProtoReflect().Get(field).List()
n := signers.Len()
for i := 0; i < n; i++ {
res, err := fieldGetter(signers.Get(i).Message(), 0)
if err != nil {
return nil, err
}
arr = append(arr, res...)
}
} else {
res, err := fieldGetter(msg.ProtoReflect().Get(field).Message(), 0)
if err != nil {
return nil, err
}
arr = append(arr, res...)
}
return arr, nil
}
default:
return nil, fmt.Errorf("unexpected field type %s for field %s in message %s", field.Kind(), fieldName, descriptor.FullName())
}

View File

@ -14,6 +14,43 @@ import (
"cosmossdk.io/x/tx/internal/testpb"
)
var deeplyNestedRepeatedSigner = &testpb.DeeplyNestedRepeatedSigner{
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner{
{
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner{
{
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner_Bottom{
{
Signer: []string{hex.EncodeToString([]byte("foo")), hex.EncodeToString([]byte("bar"))},
},
},
},
},
},
{
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner{
{
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner_Bottom{
{
Signer: []string{hex.EncodeToString([]byte("baz"))},
},
},
},
{
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner_Bottom{
{
Signer: []string{hex.EncodeToString([]byte("qux")), hex.EncodeToString([]byte("fuz"))},
},
{
Signer: []string{hex.EncodeToString([]byte("bing")), hex.EncodeToString([]byte("bap"))},
},
},
},
},
},
},
}
func TestGetSigners(t *testing.T) {
ctx, err := NewContext(Options{
AddressCodec: dummyAddressCodec{},
@ -88,7 +125,18 @@ func TestGetSigners(t *testing.T) {
want: [][]byte{[]byte("foo"), []byte("bar")},
},
{
name: "nested repeated",
name: "deeply nested",
msg: &testpb.DeeplyNestedSigner{
InnerOne: &testpb.DeeplyNestedSigner_InnerOne{
InnerTwo: &testpb.DeeplyNestedSigner_InnerOne_InnerTwo{
Signer: hex.EncodeToString([]byte("foo")),
},
},
},
want: [][]byte{[]byte("foo")},
},
{
name: "nested repeated #1",
msg: &testpb.NestedRepeatedSigner{Inner: &testpb.NestedRepeatedSigner_Inner{
Signer: []string{
hex.EncodeToString([]byte("foo")),
@ -97,6 +145,11 @@ func TestGetSigners(t *testing.T) {
}},
want: [][]byte{[]byte("foo"), []byte("bar")},
},
{
name: "nested repeated #2",
msg: deeplyNestedRepeatedSigner,
want: [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux"), []byte("fuz"), []byte("bing"), []byte("bap")},
},
{
name: "repeated nested repeated",
msg: &testpb.RepeatedNestedRepeatedSigner{Inner: []*testpb.RepeatedNestedRepeatedSigner_Inner{
@ -145,6 +198,27 @@ func TestGetSigners(t *testing.T) {
}
}
func TestMaxRecursionDepth(t *testing.T) {
ctx, err := NewContext(Options{
AddressCodec: dummyAddressCodec{},
ValidatorAddressCodec: dummyValidatorAddressCodec{},
MaxRecursionDepth: 1,
})
require.NoError(t, err)
_, err = ctx.GetSigners(deeplyNestedRepeatedSigner)
require.ErrorContains(t, err, "maximum recursion depth exceeded")
ctx, err = NewContext(Options{
AddressCodec: dummyAddressCodec{},
ValidatorAddressCodec: dummyValidatorAddressCodec{},
MaxRecursionDepth: 5,
})
require.NoError(t, err)
_, err = ctx.GetSigners(deeplyNestedRepeatedSigner)
require.NoError(t, err)
}
func TestDefineCustomGetSigners(t *testing.T) {
customMsg := &testpb.Ballot{}
signers := [][]byte{[]byte("foo")}