fix(x/tx): recursively traverse nested messages in GetSigners (#18740)
This commit is contained in:
parent
1a496057ac
commit
0c0589813b
@ -29,7 +29,10 @@ Ref: https://keepachangelog.com/en/1.0.0/
|
||||
|
||||
# Changelog
|
||||
|
||||
## [Unreleased]
|
||||
## v0.13.0
|
||||
|
||||
### Improvements
|
||||
* [#18740](https://github.com/cosmos/cosmos-sdk/pull/18740) Support nested messages when fetching signers up to a default depth of 32.
|
||||
|
||||
## v0.12.0
|
||||
|
||||
|
||||
@ -55,6 +55,42 @@ message RepeatedNestedRepeatedSigner {
|
||||
}
|
||||
}
|
||||
|
||||
message DeeplyNestedSigner {
|
||||
option (cosmos.msg.v1.signer) = "inner_one";
|
||||
InnerOne inner_one = 1;
|
||||
|
||||
message InnerOne {
|
||||
option (cosmos.msg.v1.signer) = "inner_two";
|
||||
InnerTwo inner_two = 1;
|
||||
|
||||
message InnerTwo {
|
||||
option (cosmos.msg.v1.signer) = "signer";
|
||||
string signer = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message DeeplyNestedRepeatedSigner {
|
||||
option (cosmos.msg.v1.signer) = "inner";
|
||||
repeated Inner inner = 1;
|
||||
|
||||
message Inner {
|
||||
option (cosmos.msg.v1.signer) = "inner";
|
||||
repeated Inner inner = 1;
|
||||
|
||||
message Inner {
|
||||
option (cosmos.msg.v1.signer) = "inner";
|
||||
repeated Bottom inner = 1;
|
||||
|
||||
message Bottom {
|
||||
option (cosmos.msg.v1.signer) = "signer";
|
||||
repeated string signer = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
message BadSigner {
|
||||
option (cosmos.msg.v1.signer) = "signer";
|
||||
bytes signer = 1;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -30,6 +30,7 @@ type Context struct {
|
||||
validatorAddressCodec address.Codec
|
||||
getSignersFuncs map[protoreflect.FullName]GetSignersFunc
|
||||
customGetSignerFuncs map[protoreflect.FullName]GetSignersFunc
|
||||
maxRecursionDepth int
|
||||
}
|
||||
|
||||
// Options are options for creating Context which will be used for signing operations.
|
||||
@ -47,7 +48,11 @@ type Options struct {
|
||||
// ValidatorAddressCodec is the codec for converting validator addresses between strings and bytes.
|
||||
ValidatorAddressCodec address.Codec
|
||||
|
||||
// CustomGetSigners is a map of message types to custom GetSignersFuncs.
|
||||
CustomGetSigners map[protoreflect.FullName]GetSignersFunc
|
||||
|
||||
// MaxRecursionDepth is the maximum depth of nested messages that will be traversed
|
||||
MaxRecursionDepth int
|
||||
}
|
||||
|
||||
// DefineCustomGetSigners defines a custom GetSigners function for a given
|
||||
@ -90,6 +95,10 @@ func NewContext(options Options) (*Context, error) {
|
||||
return nil, errors.New("validator address codec is required")
|
||||
}
|
||||
|
||||
if options.MaxRecursionDepth <= 0 {
|
||||
options.MaxRecursionDepth = 32
|
||||
}
|
||||
|
||||
customGetSignerFuncs := map[protoreflect.FullName]GetSignersFunc{}
|
||||
for k := range options.CustomGetSigners {
|
||||
customGetSignerFuncs[k] = options.CustomGetSigners[k]
|
||||
@ -102,6 +111,7 @@ func NewContext(options Options) (*Context, error) {
|
||||
validatorAddressCodec: options.ValidatorAddressCodec,
|
||||
getSignersFuncs: map[protoreflect.FullName]GetSignersFunc{},
|
||||
customGetSignerFuncs: customGetSignerFuncs,
|
||||
maxRecursionDepth: options.MaxRecursionDepth,
|
||||
}
|
||||
|
||||
return c, nil
|
||||
@ -208,92 +218,87 @@ func (c *Context) makeGetSignersFunc(descriptor protoreflect.MessageDescriptor)
|
||||
}
|
||||
}
|
||||
case protoreflect.MessageKind:
|
||||
isList := field.IsList()
|
||||
nestedMessage := field.Message()
|
||||
nestedSignersFields, err := getSignersFieldNames(nestedMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(nestedSignersFields) != 1 {
|
||||
return nil, fmt.Errorf("nested cosmos.msg.v1.signer option in message %s must contain only one value", nestedMessage.FullName())
|
||||
}
|
||||
|
||||
nestedFieldName := nestedSignersFields[0]
|
||||
nestedField := nestedMessage.Fields().ByName(protoreflect.Name(nestedFieldName))
|
||||
nestedIsList := nestedField.IsList()
|
||||
if nestedField == nil {
|
||||
return nil, fmt.Errorf("field %s not found in message %s", nestedFieldName, nestedMessage.FullName())
|
||||
}
|
||||
|
||||
if nestedField.Kind() != protoreflect.StringKind || nestedField.IsMap() || nestedField.HasOptionalKeyword() {
|
||||
return nil, fmt.Errorf("nested signer field %s in message %s must be a simple string", nestedFieldName, nestedMessage.FullName())
|
||||
}
|
||||
|
||||
addrCdc := c.getAddressCodec(nestedField)
|
||||
|
||||
if isList {
|
||||
if nestedIsList {
|
||||
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
|
||||
msgs := msg.ProtoReflect().Get(field).List()
|
||||
m := msgs.Len()
|
||||
for i := 0; i < m; i++ {
|
||||
signers := msgs.Get(i).Message().Get(nestedField).List()
|
||||
n := signers.Len()
|
||||
for j := 0; j < n; j++ {
|
||||
addrStr := signers.Get(j).String()
|
||||
addrBz, err := addrCdc.StringToBytes(addrStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, addrBz)
|
||||
}
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
} else {
|
||||
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
|
||||
msgs := msg.ProtoReflect().Get(field).List()
|
||||
m := msgs.Len()
|
||||
for i := 0; i < m; i++ {
|
||||
addrStr := msgs.Get(i).Message().Get(nestedField).String()
|
||||
addrBz, err := addrCdc.StringToBytes(addrStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, addrBz)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
var fieldGetter func(protoreflect.Message, int) ([][]byte, error)
|
||||
fieldGetter = func(msg protoreflect.Message, depth int) ([][]byte, error) {
|
||||
if depth > c.maxRecursionDepth {
|
||||
return nil, fmt.Errorf("maximum recursion depth exceeded")
|
||||
}
|
||||
} else {
|
||||
if nestedIsList {
|
||||
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
|
||||
nestedMsg := msg.ProtoReflect().Get(field).Message()
|
||||
signers := nestedMsg.Get(nestedField).List()
|
||||
n := signers.Len()
|
||||
for j := 0; j < n; j++ {
|
||||
addrStr := signers.Get(j).String()
|
||||
desc := msg.Descriptor()
|
||||
signerFields, err := getSignersFieldNames(desc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(signerFields) != 1 {
|
||||
return nil, fmt.Errorf("nested cosmos.msg.v1.signer option in message %s must contain only one value", desc.FullName())
|
||||
}
|
||||
signerFieldName := signerFields[0]
|
||||
childField := desc.Fields().ByName(protoreflect.Name(signerFieldName))
|
||||
switch {
|
||||
case childField.Kind() == protoreflect.MessageKind:
|
||||
if childField.IsList() {
|
||||
childMsgs := msg.Get(childField).List()
|
||||
var arr [][]byte
|
||||
for i := 0; i < childMsgs.Len(); i++ {
|
||||
res, err := fieldGetter(childMsgs.Get(i).Message(), depth+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, res...)
|
||||
}
|
||||
return arr, nil
|
||||
} else {
|
||||
return fieldGetter(msg.Get(childField).Message(), depth+1)
|
||||
}
|
||||
case childField.IsMap() || childField.HasOptionalKeyword():
|
||||
return nil, fmt.Errorf("cosmos.msg.v1.signer field %s in message %s must not be a map or optional", signerFieldName, desc.FullName())
|
||||
case childField.Kind() == protoreflect.StringKind:
|
||||
addrCdc := c.getAddressCodec(childField)
|
||||
if childField.IsList() {
|
||||
childMsgs := msg.Get(childField).List()
|
||||
n := childMsgs.Len()
|
||||
var res [][]byte
|
||||
for i := 0; i < n; i++ {
|
||||
addrStr := childMsgs.Get(i).String()
|
||||
addrBz, err := addrCdc.StringToBytes(addrStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, addrBz)
|
||||
res = append(res, addrBz)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
} else {
|
||||
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
|
||||
addrStr := msg.ProtoReflect().Get(field).Message().Get(nestedField).String()
|
||||
return res, nil
|
||||
} else {
|
||||
addrStr := msg.Get(childField).String()
|
||||
addrBz, err := addrCdc.StringToBytes(addrStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(arr, addrBz), nil
|
||||
return [][]byte{addrBz}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected field type %s for field %s in message %s, only string and message type are supported",
|
||||
childField.Kind(), signerFieldName, desc.FullName())
|
||||
}
|
||||
|
||||
fieldGetters[i] = func(msg proto.Message, arr [][]byte) ([][]byte, error) {
|
||||
if field.IsList() {
|
||||
signers := msg.ProtoReflect().Get(field).List()
|
||||
n := signers.Len()
|
||||
for i := 0; i < n; i++ {
|
||||
res, err := fieldGetter(signers.Get(i).Message(), 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, res...)
|
||||
}
|
||||
} else {
|
||||
res, err := fieldGetter(msg.ProtoReflect().Get(field).Message(), 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, res...)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected field type %s for field %s in message %s", field.Kind(), fieldName, descriptor.FullName())
|
||||
}
|
||||
|
||||
@ -14,6 +14,43 @@ import (
|
||||
"cosmossdk.io/x/tx/internal/testpb"
|
||||
)
|
||||
|
||||
var deeplyNestedRepeatedSigner = &testpb.DeeplyNestedRepeatedSigner{
|
||||
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner{
|
||||
{
|
||||
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner{
|
||||
{
|
||||
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner_Bottom{
|
||||
{
|
||||
Signer: []string{hex.EncodeToString([]byte("foo")), hex.EncodeToString([]byte("bar"))},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner{
|
||||
{
|
||||
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner_Bottom{
|
||||
{
|
||||
Signer: []string{hex.EncodeToString([]byte("baz"))},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Inner: []*testpb.DeeplyNestedRepeatedSigner_Inner_Inner_Bottom{
|
||||
{
|
||||
Signer: []string{hex.EncodeToString([]byte("qux")), hex.EncodeToString([]byte("fuz"))},
|
||||
},
|
||||
{
|
||||
Signer: []string{hex.EncodeToString([]byte("bing")), hex.EncodeToString([]byte("bap"))},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestGetSigners(t *testing.T) {
|
||||
ctx, err := NewContext(Options{
|
||||
AddressCodec: dummyAddressCodec{},
|
||||
@ -88,7 +125,18 @@ func TestGetSigners(t *testing.T) {
|
||||
want: [][]byte{[]byte("foo"), []byte("bar")},
|
||||
},
|
||||
{
|
||||
name: "nested repeated",
|
||||
name: "deeply nested",
|
||||
msg: &testpb.DeeplyNestedSigner{
|
||||
InnerOne: &testpb.DeeplyNestedSigner_InnerOne{
|
||||
InnerTwo: &testpb.DeeplyNestedSigner_InnerOne_InnerTwo{
|
||||
Signer: hex.EncodeToString([]byte("foo")),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: [][]byte{[]byte("foo")},
|
||||
},
|
||||
{
|
||||
name: "nested repeated #1",
|
||||
msg: &testpb.NestedRepeatedSigner{Inner: &testpb.NestedRepeatedSigner_Inner{
|
||||
Signer: []string{
|
||||
hex.EncodeToString([]byte("foo")),
|
||||
@ -97,6 +145,11 @@ func TestGetSigners(t *testing.T) {
|
||||
}},
|
||||
want: [][]byte{[]byte("foo"), []byte("bar")},
|
||||
},
|
||||
{
|
||||
name: "nested repeated #2",
|
||||
msg: deeplyNestedRepeatedSigner,
|
||||
want: [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("qux"), []byte("fuz"), []byte("bing"), []byte("bap")},
|
||||
},
|
||||
{
|
||||
name: "repeated nested repeated",
|
||||
msg: &testpb.RepeatedNestedRepeatedSigner{Inner: []*testpb.RepeatedNestedRepeatedSigner_Inner{
|
||||
@ -145,6 +198,27 @@ func TestGetSigners(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxRecursionDepth(t *testing.T) {
|
||||
ctx, err := NewContext(Options{
|
||||
AddressCodec: dummyAddressCodec{},
|
||||
ValidatorAddressCodec: dummyValidatorAddressCodec{},
|
||||
MaxRecursionDepth: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ctx.GetSigners(deeplyNestedRepeatedSigner)
|
||||
require.ErrorContains(t, err, "maximum recursion depth exceeded")
|
||||
|
||||
ctx, err = NewContext(Options{
|
||||
AddressCodec: dummyAddressCodec{},
|
||||
ValidatorAddressCodec: dummyValidatorAddressCodec{},
|
||||
MaxRecursionDepth: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ctx.GetSigners(deeplyNestedRepeatedSigner)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDefineCustomGetSigners(t *testing.T) {
|
||||
customMsg := &testpb.Ballot{}
|
||||
signers := [][]byte{[]byte("foo")}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user