cosmos-sdk/client/v2/autocli/flag/builder.go

512 lines
16 KiB
Go

package flag
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
cosmos_proto "github.com/cosmos/cosmos-proto"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
autocliv1 "cosmossdk.io/api/cosmos/autocli/v1"
msgv1 "cosmossdk.io/api/cosmos/msg/v1"
"cosmossdk.io/client/v2/internal/flags"
"cosmossdk.io/client/v2/internal/util"
"cosmossdk.io/core/address"
)
const (
AddressStringScalarType = "cosmos.AddressString"
ValidatorAddressStringScalarType = "cosmos.ValidatorAddressString"
ConsensusAddressStringScalarType = "cosmos.ConsensusAddressString"
PubkeyScalarType = "cosmos.Pubkey"
DecScalarType = "cosmos.Dec"
)
// Builder manages options for building pflag flags for protobuf messages.
type Builder struct {
// TypeResolver specifies how protobuf types will be resolved. If it is
// nil protoregistry.GlobalTypes will be used.
TypeResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}
// FileResolver specifies how protobuf file descriptors will be resolved. If it is
// nil protoregistry.GlobalFiles will be used.
FileResolver interface {
protodesc.Resolver
RangeFiles(func(protoreflect.FileDescriptor) bool)
}
messageFlagTypes map[protoreflect.FullName]Type
scalarFlagTypes map[string]Type
// Address Codecs are the address codecs to use for client/v2.
AddressCodec address.Codec
ValidatorAddressCodec address.ValidatorAddressCodec
ConsensusAddressCodec address.ConsensusAddressCodec
}
func (b *Builder) init() {
if b.messageFlagTypes == nil {
b.messageFlagTypes = map[protoreflect.FullName]Type{}
b.messageFlagTypes["google.protobuf.Timestamp"] = timestampType{}
b.messageFlagTypes["google.protobuf.Duration"] = durationType{}
b.messageFlagTypes["cosmos.base.v1beta1.Coin"] = coinType{}
b.messageFlagTypes["cosmos.base.v1beta1.DecCoin"] = decCoinType{}
}
if b.scalarFlagTypes == nil {
b.scalarFlagTypes = map[string]Type{}
b.scalarFlagTypes[AddressStringScalarType] = addressStringType{}
b.scalarFlagTypes[ValidatorAddressStringScalarType] = validatorAddressStringType{}
b.scalarFlagTypes[ConsensusAddressStringScalarType] = consensusAddressStringType{}
b.scalarFlagTypes[PubkeyScalarType] = pubkeyType{}
b.scalarFlagTypes[DecScalarType] = decType{}
}
}
// ValidateAndComplete the flag builder fields.
// It returns an error if any of the required fields are missing.
// If the keyring is nil, it will be set to a no keyring.
func (b *Builder) ValidateAndComplete() error {
if b.AddressCodec == nil {
return errors.New("address codec is required in flag builder")
}
if b.ValidatorAddressCodec == nil {
return errors.New("validator address codec is required in flag builder")
}
if b.ConsensusAddressCodec == nil {
return errors.New("consensus address codec is required in flag builder")
}
if b.TypeResolver == nil {
return errors.New("type resolver is required in flag builder")
}
if b.FileResolver == nil {
return errors.New("file resolver is required in flag builder")
}
return nil
}
// DefineMessageFlagType allows to extend custom protobuf message type handling for flags (and positional arguments).
func (b *Builder) DefineMessageFlagType(messageName protoreflect.FullName, flagType Type) {
b.init()
b.messageFlagTypes[messageName] = flagType
}
// DefineScalarFlagType allows to extend custom scalar type handling for flags (and positional arguments).
func (b *Builder) DefineScalarFlagType(scalarName string, flagType Type) {
b.init()
b.scalarFlagTypes[scalarName] = flagType
}
// AddMessageFlags adds flags for each field in the message to the flag set.
func (b *Builder) AddMessageFlags(ctx *context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions) (*MessageBinder, error) {
return b.addMessageFlags(ctx, flagSet, messageType, commandOptions, namingOptions{})
}
// addMessageFlags adds flags for each field in the message to the flag set.
func (b *Builder) addMessageFlags(ctx *context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions, options namingOptions) (*MessageBinder, error) {
messageBinder := &MessageBinder{
messageType: messageType,
// positional args are also parsed using a FlagSet so that we can reuse all the same parsers
positionalFlagSet: pflag.NewFlagSet("positional", pflag.ContinueOnError),
}
fields := messageType.Descriptor().Fields()
signerFieldName := GetSignerFieldName(messageType.Descriptor())
isPositional := map[string]bool{}
positionalArgsLen := len(commandOptions.PositionalArgs)
for i, arg := range commandOptions.PositionalArgs {
isPositional[arg.ProtoField] = true
// verify if a positional field is a signer field
if arg.ProtoField == signerFieldName {
messageBinder.SignerInfo = SignerInfo{
PositionalArgIndex: i,
FieldName: arg.ProtoField,
}
}
if arg.Optional && arg.Varargs {
return nil, fmt.Errorf("positional argument %s can't be both optional and varargs", arg.ProtoField)
}
if arg.Varargs {
if i != positionalArgsLen-1 {
return nil, fmt.Errorf("varargs positional argument %s must be the last argument", arg.ProtoField)
}
messageBinder.hasVarargs = true
}
if arg.Optional {
if i != positionalArgsLen-1 {
return nil, fmt.Errorf("optional positional argument %s must be the last argument", arg.ProtoField)
}
messageBinder.hasOptional = true
}
s := strings.Split(arg.ProtoField, ".")
if len(s) == 1 {
f, err := b.addFieldBindingToArgs(ctx, messageBinder, protoreflect.Name(arg.ProtoField), fields)
if err != nil {
return nil, err
}
messageBinder.positionalArgs = append(messageBinder.positionalArgs, f)
} else {
err := b.addFlattenFieldBindingToArgs(ctx, arg.ProtoField, s, messageType, messageBinder)
if err != nil {
return nil, err
}
}
}
totalArgs := len(messageBinder.positionalArgs)
switch {
case messageBinder.hasVarargs:
messageBinder.CobraArgs = cobra.MinimumNArgs(totalArgs - 1)
messageBinder.mandatoryArgUntil = totalArgs - 1
case messageBinder.hasOptional:
messageBinder.CobraArgs = cobra.RangeArgs(totalArgs-1, totalArgs)
messageBinder.mandatoryArgUntil = totalArgs - 1
default:
messageBinder.CobraArgs = cobra.ExactArgs(totalArgs)
messageBinder.mandatoryArgUntil = totalArgs
}
// validate flag options
for name, opts := range commandOptions.FlagOptions {
if fields.ByName(protoreflect.Name(name)) == nil {
return nil, fmt.Errorf("can't find field %s on %s specified as a flag", name, messageType.Descriptor().FullName())
}
// verify if a flag is a signer field
if name == signerFieldName {
messageBinder.SignerInfo = SignerInfo{
FieldName: name,
IsFlag: true,
FlagName: opts.Name,
}
}
}
// if signer has not been specified as positional arguments,
// add it as `--from` flag (instead of --field-name flags)
if signerFieldName != "" && messageBinder.SignerInfo == (SignerInfo{}) {
if commandOptions.FlagOptions == nil {
commandOptions.FlagOptions = make(map[string]*autocliv1.FlagOptions)
}
commandOptions.FlagOptions[signerFieldName] = &autocliv1.FlagOptions{
Name: flags.FlagFrom,
Usage: "Name or address with which to sign the message",
Shorthand: "f",
}
messageBinder.SignerInfo = SignerInfo{
FieldName: signerFieldName,
IsFlag: true,
FlagName: flags.FlagFrom,
}
}
// define all other fields as flags
flagOptsByFlagName := map[string]*autocliv1.FlagOptions{}
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
fieldName := string(field.Name())
// skips positional args and signer field if already set
if isPositional[fieldName] ||
(fieldName == signerFieldName && messageBinder.SignerInfo.FlagName == flags.FlagFrom) {
continue
}
flagOpts := commandOptions.FlagOptions[fieldName]
name, hasValue, err := b.addFieldFlag(ctx, flagSet, field, flagOpts, options)
if err != nil {
return nil, err
}
flagOptsByFlagName[name] = flagOpts
messageBinder.flagBindings = append(messageBinder.flagBindings, fieldBinding{
hasValue: hasValue,
field: field,
})
}
flagSet.VisitAll(func(flag *pflag.Flag) {
opts := flagOptsByFlagName[flag.Name]
if opts != nil {
// This is a bit of hacking around the pflag API, but
// we need to set these options here using Flag.VisitAll because the flag
// constructors that pflag gives us (StringP, Int32P, etc.) do not
// actually return the *Flag instance
flag.Deprecated = opts.Deprecated
flag.ShorthandDeprecated = opts.ShorthandDeprecated
flag.Hidden = opts.Hidden
}
})
return messageBinder, nil
}
// addFlattenFieldBindingToArgs recursively adds field bindings for nested message fields to the message binder.
// It takes a slice of field names representing the path to the target field, where each element is a field name
// in the nested message structure. For example, ["foo", "bar", "baz"] would bind the "baz" field inside the "bar"
// message which is inside the "foo" message.
func (b *Builder) addFlattenFieldBindingToArgs(ctx *context.Context, path string, s []string, msg protoreflect.MessageType, messageBinder *MessageBinder) error {
fields := msg.Descriptor().Fields()
if len(s) == 1 {
f, err := b.addFieldBindingToArgs(ctx, messageBinder, protoreflect.Name(s[0]), fields)
if err != nil {
return err
}
f.path = path
messageBinder.positionalArgs = append(messageBinder.positionalArgs, f)
return nil
}
fd := fields.ByName(protoreflect.Name(s[0]))
var innerMsg protoreflect.MessageType
if fd.IsList() {
innerMsg = msg.New().Get(fd).List().NewElement().Message().Type()
} else {
innerMsg = msg.New().Get(fd).Message().Type()
}
return b.addFlattenFieldBindingToArgs(ctx, path, s[1:], innerMsg, messageBinder)
}
// addFieldBindingToArgs adds a fieldBinding for a positional argument to the message binder.
// The fieldBinding is appended to the positional arguments list in the message binder.
func (b *Builder) addFieldBindingToArgs(ctx *context.Context, messageBinder *MessageBinder, name protoreflect.Name, fields protoreflect.FieldDescriptors) (fieldBinding, error) {
field := fields.ByName(name)
if field == nil {
return fieldBinding{}, fmt.Errorf("can't find field %s", name) // TODO: it will improve error if msg.FullName() was included.`
}
_, hasValue, err := b.addFieldFlag(
ctx,
messageBinder.positionalFlagSet,
field,
&autocliv1.FlagOptions{Name: fmt.Sprintf("%d", len(messageBinder.positionalArgs))},
namingOptions{},
)
if err != nil {
return fieldBinding{}, err
}
return fieldBinding{
field: field,
hasValue: hasValue,
}, nil
}
// bindPageRequest create a flag for pagination
func (b *Builder) bindPageRequest(ctx *context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) {
return b.addMessageFlags(
ctx,
flagSet,
util.ResolveMessageType(b.TypeResolver, field.Message()),
&autocliv1.RpcCommandOptions{},
namingOptions{Prefix: "page-"},
)
}
// namingOptions specifies internal naming options for flags.
type namingOptions struct {
// Prefix is a prefix to prepend to all flags.
Prefix string
}
// addFieldFlag adds a flag for the provided field to the flag set.
func (b *Builder) addFieldFlag(ctx *context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor, opts *autocliv1.FlagOptions, options namingOptions) (name string, hasValue HasValue, err error) {
if opts == nil {
opts = &autocliv1.FlagOptions{}
}
if field.Kind() == protoreflect.MessageKind && field.Message().FullName() == "cosmos.base.query.v1beta1.PageRequest" {
hasValue, err := b.bindPageRequest(ctx, flagSet, field)
return "", hasValue, err
}
name = opts.Name
if name == "" {
name = options.Prefix + util.DescriptorKebabName(field)
}
usage := opts.Usage
shorthand := opts.Shorthand
defaultValue := opts.DefaultValue
if typ := b.resolveFlagType(field); typ != nil {
if defaultValue == "" {
defaultValue = typ.DefaultValue()
}
val := typ.NewValue(ctx, b)
flagSet.AddFlag(&pflag.Flag{
Name: name,
Shorthand: shorthand,
Usage: usage,
DefValue: defaultValue,
Value: val,
})
return name, val, nil
}
// use the built-in pflag StringP, Int32P, etc. functions
var val HasValue
if field.IsList() {
val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage)
} else if field.IsMap() {
keyKind := field.MapKey().Kind()
valKind := field.MapValue().Kind()
val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage)
} else {
val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage)
}
// This is a bit of hacking around the pflag API, but the
// defaultValue is set in this way because this is much easier than trying
// to parse the string into the types that StringSliceP, Int32P, etc.
if defaultValue != "" {
err = flagSet.Set(name, defaultValue)
}
return name, val, err
}
func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type {
typ := b.resolveFlagTypeBasic(field)
if field.IsList() {
if typ != nil {
return compositeListType{simpleType: typ}
}
return nil
}
if field.IsMap() {
keyKind := field.MapKey().Kind()
valType := b.resolveFlagType(field.MapValue())
if valType != nil {
switch keyKind {
case protoreflect.StringKind:
ct := new(compositeMapType[string])
ct.keyValueResolver = func(s string) (string, error) { return s, nil }
ct.valueType = valType
ct.keyType = "string"
return ct
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
ct := new(compositeMapType[int32])
ct.keyValueResolver = func(s string) (int32, error) {
i, err := strconv.ParseInt(s, 10, 32)
return int32(i), err
}
ct.valueType = valType
ct.keyType = "int32"
return ct
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
ct := new(compositeMapType[int64])
ct.keyValueResolver = func(s string) (int64, error) {
i, err := strconv.ParseInt(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "int64"
return ct
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
ct := new(compositeMapType[uint32])
ct.keyValueResolver = func(s string) (uint32, error) {
i, err := strconv.ParseUint(s, 10, 32)
return uint32(i), err
}
ct.valueType = valType
ct.keyType = "uint32"
return ct
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
ct := new(compositeMapType[uint64])
ct.keyValueResolver = func(s string) (uint64, error) {
i, err := strconv.ParseUint(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "uint64"
return ct
case protoreflect.BoolKind:
ct := new(compositeMapType[bool])
ct.keyValueResolver = strconv.ParseBool
ct.valueType = valType
ct.keyType = "bool"
return ct
}
return nil
}
return nil
}
return typ
}
func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type {
scalar, ok := GetScalarType(field)
if ok {
b.init()
if typ, ok := b.scalarFlagTypes[scalar]; ok {
return typ
}
}
switch field.Kind() {
case protoreflect.BytesKind:
return binaryType{}
case protoreflect.EnumKind:
return enumType{enum: field.Enum()}
case protoreflect.MessageKind:
b.init()
if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok {
return flagType
}
return jsonMessageFlagType{
messageDesc: field.Message(),
}
default:
return nil
}
}
// GetScalarType gets scalar type of a field.
func GetScalarType(field protoreflect.FieldDescriptor) (string, bool) {
scalar := proto.GetExtension(field.Options(), cosmos_proto.E_Scalar)
scalarStr, ok := scalar.(string)
return scalarStr, ok
}
// GetSignerFieldName gets signer field name of a message.
// AutoCLI supports only one signer field per message.
func GetSignerFieldName(descriptor protoreflect.MessageDescriptor) string {
signersFields := proto.GetExtension(descriptor.Options(), msgv1.E_Signer).([]string)
if len(signersFields) == 0 {
return ""
}
return signersFields[0]
}