package flag import ( "context" "errors" "fmt" "strconv" "strings" cosmos_proto "github.com/cosmos/cosmos-proto" "github.com/spf13/cobra" "github.com/spf13/pflag" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" autocliv1 "cosmossdk.io/api/cosmos/autocli/v1" msgv1 "cosmossdk.io/api/cosmos/msg/v1" "cosmossdk.io/client/v2/internal/flags" "cosmossdk.io/client/v2/internal/util" "cosmossdk.io/core/address" ) const ( AddressStringScalarType = "cosmos.AddressString" ValidatorAddressStringScalarType = "cosmos.ValidatorAddressString" ConsensusAddressStringScalarType = "cosmos.ConsensusAddressString" PubkeyScalarType = "cosmos.Pubkey" DecScalarType = "cosmos.Dec" ) // Builder manages options for building pflag flags for protobuf messages. type Builder struct { // TypeResolver specifies how protobuf types will be resolved. If it is // nil protoregistry.GlobalTypes will be used. TypeResolver interface { protoregistry.MessageTypeResolver protoregistry.ExtensionTypeResolver } // FileResolver specifies how protobuf file descriptors will be resolved. If it is // nil protoregistry.GlobalFiles will be used. FileResolver interface { protodesc.Resolver RangeFiles(func(protoreflect.FileDescriptor) bool) } messageFlagTypes map[protoreflect.FullName]Type scalarFlagTypes map[string]Type // Address Codecs are the address codecs to use for client/v2. AddressCodec address.Codec ValidatorAddressCodec address.ValidatorAddressCodec ConsensusAddressCodec address.ConsensusAddressCodec } func (b *Builder) init() { if b.messageFlagTypes == nil { b.messageFlagTypes = map[protoreflect.FullName]Type{} b.messageFlagTypes["google.protobuf.Timestamp"] = timestampType{} b.messageFlagTypes["google.protobuf.Duration"] = durationType{} b.messageFlagTypes["cosmos.base.v1beta1.Coin"] = coinType{} b.messageFlagTypes["cosmos.base.v1beta1.DecCoin"] = decCoinType{} } if b.scalarFlagTypes == nil { b.scalarFlagTypes = map[string]Type{} b.scalarFlagTypes[AddressStringScalarType] = addressStringType{} b.scalarFlagTypes[ValidatorAddressStringScalarType] = validatorAddressStringType{} b.scalarFlagTypes[ConsensusAddressStringScalarType] = consensusAddressStringType{} b.scalarFlagTypes[PubkeyScalarType] = pubkeyType{} b.scalarFlagTypes[DecScalarType] = decType{} } } // ValidateAndComplete the flag builder fields. // It returns an error if any of the required fields are missing. // If the keyring is nil, it will be set to a no keyring. func (b *Builder) ValidateAndComplete() error { if b.AddressCodec == nil { return errors.New("address codec is required in flag builder") } if b.ValidatorAddressCodec == nil { return errors.New("validator address codec is required in flag builder") } if b.ConsensusAddressCodec == nil { return errors.New("consensus address codec is required in flag builder") } if b.TypeResolver == nil { return errors.New("type resolver is required in flag builder") } if b.FileResolver == nil { return errors.New("file resolver is required in flag builder") } return nil } // DefineMessageFlagType allows to extend custom protobuf message type handling for flags (and positional arguments). func (b *Builder) DefineMessageFlagType(messageName protoreflect.FullName, flagType Type) { b.init() b.messageFlagTypes[messageName] = flagType } // DefineScalarFlagType allows to extend custom scalar type handling for flags (and positional arguments). func (b *Builder) DefineScalarFlagType(scalarName string, flagType Type) { b.init() b.scalarFlagTypes[scalarName] = flagType } // AddMessageFlags adds flags for each field in the message to the flag set. func (b *Builder) AddMessageFlags(ctx *context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions) (*MessageBinder, error) { return b.addMessageFlags(ctx, flagSet, messageType, commandOptions, namingOptions{}) } // addMessageFlags adds flags for each field in the message to the flag set. func (b *Builder) addMessageFlags(ctx *context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions, options namingOptions) (*MessageBinder, error) { messageBinder := &MessageBinder{ messageType: messageType, // positional args are also parsed using a FlagSet so that we can reuse all the same parsers positionalFlagSet: pflag.NewFlagSet("positional", pflag.ContinueOnError), } fields := messageType.Descriptor().Fields() signerFieldName := GetSignerFieldName(messageType.Descriptor()) isPositional := map[string]bool{} positionalArgsLen := len(commandOptions.PositionalArgs) for i, arg := range commandOptions.PositionalArgs { isPositional[arg.ProtoField] = true // verify if a positional field is a signer field if arg.ProtoField == signerFieldName { messageBinder.SignerInfo = SignerInfo{ PositionalArgIndex: i, FieldName: arg.ProtoField, } } if arg.Optional && arg.Varargs { return nil, fmt.Errorf("positional argument %s can't be both optional and varargs", arg.ProtoField) } if arg.Varargs { if i != positionalArgsLen-1 { return nil, fmt.Errorf("varargs positional argument %s must be the last argument", arg.ProtoField) } messageBinder.hasVarargs = true } if arg.Optional { if i != positionalArgsLen-1 { return nil, fmt.Errorf("optional positional argument %s must be the last argument", arg.ProtoField) } messageBinder.hasOptional = true } s := strings.Split(arg.ProtoField, ".") if len(s) == 1 { f, err := b.addFieldBindingToArgs(ctx, messageBinder, protoreflect.Name(arg.ProtoField), fields) if err != nil { return nil, err } messageBinder.positionalArgs = append(messageBinder.positionalArgs, f) } else { err := b.addFlattenFieldBindingToArgs(ctx, arg.ProtoField, s, messageType, messageBinder) if err != nil { return nil, err } } } totalArgs := len(messageBinder.positionalArgs) switch { case messageBinder.hasVarargs: messageBinder.CobraArgs = cobra.MinimumNArgs(totalArgs - 1) messageBinder.mandatoryArgUntil = totalArgs - 1 case messageBinder.hasOptional: messageBinder.CobraArgs = cobra.RangeArgs(totalArgs-1, totalArgs) messageBinder.mandatoryArgUntil = totalArgs - 1 default: messageBinder.CobraArgs = cobra.ExactArgs(totalArgs) messageBinder.mandatoryArgUntil = totalArgs } // validate flag options for name, opts := range commandOptions.FlagOptions { if fields.ByName(protoreflect.Name(name)) == nil { return nil, fmt.Errorf("can't find field %s on %s specified as a flag", name, messageType.Descriptor().FullName()) } // verify if a flag is a signer field if name == signerFieldName { messageBinder.SignerInfo = SignerInfo{ FieldName: name, IsFlag: true, FlagName: opts.Name, } } } // if signer has not been specified as positional arguments, // add it as `--from` flag (instead of --field-name flags) if signerFieldName != "" && messageBinder.SignerInfo == (SignerInfo{}) { if commandOptions.FlagOptions == nil { commandOptions.FlagOptions = make(map[string]*autocliv1.FlagOptions) } commandOptions.FlagOptions[signerFieldName] = &autocliv1.FlagOptions{ Name: flags.FlagFrom, Usage: "Name or address with which to sign the message", Shorthand: "f", } messageBinder.SignerInfo = SignerInfo{ FieldName: signerFieldName, IsFlag: true, FlagName: flags.FlagFrom, } } // define all other fields as flags flagOptsByFlagName := map[string]*autocliv1.FlagOptions{} for i := 0; i < fields.Len(); i++ { field := fields.Get(i) fieldName := string(field.Name()) // skips positional args and signer field if already set if isPositional[fieldName] || (fieldName == signerFieldName && messageBinder.SignerInfo.FlagName == flags.FlagFrom) { continue } flagOpts := commandOptions.FlagOptions[fieldName] name, hasValue, err := b.addFieldFlag(ctx, flagSet, field, flagOpts, options) if err != nil { return nil, err } flagOptsByFlagName[name] = flagOpts messageBinder.flagBindings = append(messageBinder.flagBindings, fieldBinding{ hasValue: hasValue, field: field, }) } flagSet.VisitAll(func(flag *pflag.Flag) { opts := flagOptsByFlagName[flag.Name] if opts != nil { // This is a bit of hacking around the pflag API, but // we need to set these options here using Flag.VisitAll because the flag // constructors that pflag gives us (StringP, Int32P, etc.) do not // actually return the *Flag instance flag.Deprecated = opts.Deprecated flag.ShorthandDeprecated = opts.ShorthandDeprecated flag.Hidden = opts.Hidden } }) return messageBinder, nil } // addFlattenFieldBindingToArgs recursively adds field bindings for nested message fields to the message binder. // It takes a slice of field names representing the path to the target field, where each element is a field name // in the nested message structure. For example, ["foo", "bar", "baz"] would bind the "baz" field inside the "bar" // message which is inside the "foo" message. func (b *Builder) addFlattenFieldBindingToArgs(ctx *context.Context, path string, s []string, msg protoreflect.MessageType, messageBinder *MessageBinder) error { fields := msg.Descriptor().Fields() if len(s) == 1 { f, err := b.addFieldBindingToArgs(ctx, messageBinder, protoreflect.Name(s[0]), fields) if err != nil { return err } f.path = path messageBinder.positionalArgs = append(messageBinder.positionalArgs, f) return nil } fd := fields.ByName(protoreflect.Name(s[0])) var innerMsg protoreflect.MessageType if fd.IsList() { innerMsg = msg.New().Get(fd).List().NewElement().Message().Type() } else { innerMsg = msg.New().Get(fd).Message().Type() } return b.addFlattenFieldBindingToArgs(ctx, path, s[1:], innerMsg, messageBinder) } // addFieldBindingToArgs adds a fieldBinding for a positional argument to the message binder. // The fieldBinding is appended to the positional arguments list in the message binder. func (b *Builder) addFieldBindingToArgs(ctx *context.Context, messageBinder *MessageBinder, name protoreflect.Name, fields protoreflect.FieldDescriptors) (fieldBinding, error) { field := fields.ByName(name) if field == nil { return fieldBinding{}, fmt.Errorf("can't find field %s", name) // TODO: it will improve error if msg.FullName() was included.` } _, hasValue, err := b.addFieldFlag( ctx, messageBinder.positionalFlagSet, field, &autocliv1.FlagOptions{Name: fmt.Sprintf("%d", len(messageBinder.positionalArgs))}, namingOptions{}, ) if err != nil { return fieldBinding{}, err } return fieldBinding{ field: field, hasValue: hasValue, }, nil } // bindPageRequest create a flag for pagination func (b *Builder) bindPageRequest(ctx *context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) { return b.addMessageFlags( ctx, flagSet, util.ResolveMessageType(b.TypeResolver, field.Message()), &autocliv1.RpcCommandOptions{}, namingOptions{Prefix: "page-"}, ) } // namingOptions specifies internal naming options for flags. type namingOptions struct { // Prefix is a prefix to prepend to all flags. Prefix string } // addFieldFlag adds a flag for the provided field to the flag set. func (b *Builder) addFieldFlag(ctx *context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor, opts *autocliv1.FlagOptions, options namingOptions) (name string, hasValue HasValue, err error) { if opts == nil { opts = &autocliv1.FlagOptions{} } if field.Kind() == protoreflect.MessageKind && field.Message().FullName() == "cosmos.base.query.v1beta1.PageRequest" { hasValue, err := b.bindPageRequest(ctx, flagSet, field) return "", hasValue, err } name = opts.Name if name == "" { name = options.Prefix + util.DescriptorKebabName(field) } usage := opts.Usage shorthand := opts.Shorthand defaultValue := opts.DefaultValue if typ := b.resolveFlagType(field); typ != nil { if defaultValue == "" { defaultValue = typ.DefaultValue() } val := typ.NewValue(ctx, b) flagSet.AddFlag(&pflag.Flag{ Name: name, Shorthand: shorthand, Usage: usage, DefValue: defaultValue, Value: val, }) return name, val, nil } // use the built-in pflag StringP, Int32P, etc. functions var val HasValue if field.IsList() { val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage) } else if field.IsMap() { keyKind := field.MapKey().Kind() valKind := field.MapValue().Kind() val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage) } else { val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage) } // This is a bit of hacking around the pflag API, but the // defaultValue is set in this way because this is much easier than trying // to parse the string into the types that StringSliceP, Int32P, etc. if defaultValue != "" { err = flagSet.Set(name, defaultValue) } return name, val, err } func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type { typ := b.resolveFlagTypeBasic(field) if field.IsList() { if typ != nil { return compositeListType{simpleType: typ} } return nil } if field.IsMap() { keyKind := field.MapKey().Kind() valType := b.resolveFlagType(field.MapValue()) if valType != nil { switch keyKind { case protoreflect.StringKind: ct := new(compositeMapType[string]) ct.keyValueResolver = func(s string) (string, error) { return s, nil } ct.valueType = valType ct.keyType = "string" return ct case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: ct := new(compositeMapType[int32]) ct.keyValueResolver = func(s string) (int32, error) { i, err := strconv.ParseInt(s, 10, 32) return int32(i), err } ct.valueType = valType ct.keyType = "int32" return ct case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: ct := new(compositeMapType[int64]) ct.keyValueResolver = func(s string) (int64, error) { i, err := strconv.ParseInt(s, 10, 64) return i, err } ct.valueType = valType ct.keyType = "int64" return ct case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: ct := new(compositeMapType[uint32]) ct.keyValueResolver = func(s string) (uint32, error) { i, err := strconv.ParseUint(s, 10, 32) return uint32(i), err } ct.valueType = valType ct.keyType = "uint32" return ct case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: ct := new(compositeMapType[uint64]) ct.keyValueResolver = func(s string) (uint64, error) { i, err := strconv.ParseUint(s, 10, 64) return i, err } ct.valueType = valType ct.keyType = "uint64" return ct case protoreflect.BoolKind: ct := new(compositeMapType[bool]) ct.keyValueResolver = strconv.ParseBool ct.valueType = valType ct.keyType = "bool" return ct } return nil } return nil } return typ } func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type { scalar, ok := GetScalarType(field) if ok { b.init() if typ, ok := b.scalarFlagTypes[scalar]; ok { return typ } } switch field.Kind() { case protoreflect.BytesKind: return binaryType{} case protoreflect.EnumKind: return enumType{enum: field.Enum()} case protoreflect.MessageKind: b.init() if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok { return flagType } return jsonMessageFlagType{ messageDesc: field.Message(), } default: return nil } } // GetScalarType gets scalar type of a field. func GetScalarType(field protoreflect.FieldDescriptor) (string, bool) { scalar := proto.GetExtension(field.Options(), cosmos_proto.E_Scalar) scalarStr, ok := scalar.(string) return scalarStr, ok } // GetSignerFieldName gets signer field name of a message. // AutoCLI supports only one signer field per message. func GetSignerFieldName(descriptor protoreflect.MessageDescriptor) string { signersFields := proto.GetExtension(descriptor.Options(), msgv1.E_Signer).([]string) if len(signersFields) == 0 { return "" } return signersFields[0] }