cosmos-sdk/client/v2/autocli/prompt/message.go
2024-12-11 10:28:23 +00:00

260 lines
8.8 KiB
Go

package prompt
import (
"fmt"
"io"
"strconv"
"strings"
"github.com/manifoldco/promptui"
"google.golang.org/protobuf/reflect/protoreflect"
"cosmossdk.io/client/v2/autocli/flag"
addresscodec "cosmossdk.io/core/address"
)
// PromptMessage prompts the user for values to populate a protobuf message interactively.
// It returns the populated message and any error encountered during prompting.
func PromptMessage(
addressCodec, validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, msg protoreflect.Message,
) (protoreflect.Message, error) {
return promptMessage(addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, nil, msg)
}
// promptMessage prompts the user for values to populate a protobuf message interactively.
// stdIn is provided to make the function easier to unit test by allowing injection of predefined inputs.
func promptMessage(
addressCodec, validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) (protoreflect.Message, error) {
fields := msg.Descriptor().Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
fieldName := string(field.Name())
promptUi := promptui.Prompt{
Validate: ValidatePromptNotEmpty,
Stdin: stdIn,
}
// If this signer field has already a valid default value set,
// use that value as the default prompt value. This is useful for
// commands that have an authority such as gov.
if strings.EqualFold(fieldName, flag.GetSignerFieldName(msg.Descriptor())) {
if defaultValue := msg.Get(field); defaultValue.IsValid() {
promptUi.Default = defaultValue.String()
}
}
// validate address fields
scalarField, ok := flag.GetScalarType(field)
if ok {
switch scalarField {
case flag.AddressStringScalarType:
promptUi.Validate = ValidateAddress(addressCodec)
case flag.ValidatorAddressStringScalarType:
promptUi.Validate = ValidateAddress(validatorAddressCodec)
case flag.ConsensusAddressStringScalarType:
promptUi.Validate = ValidateAddress(consensusAddressCodec)
default:
// prompt.Validate = ValidatePromptNotEmpty (we possibly don't want to force all fields to be non-empty)
promptUi.Validate = nil
}
}
// handle nested message fields recursively
if field.Kind() == protoreflect.MessageKind {
err := promptInnerMessageKind(field, addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, stdIn, msg)
if err != nil {
return nil, err
}
continue
}
// handle repeated fields by prompting for a comma-separated list of values
if field.IsList() {
list, err := promptList(field, msg, promptUi, promptPrefix)
if err != nil {
return nil, err
}
msg.Set(field, protoreflect.ValueOfList(list))
continue
}
promptUi.Label = fmt.Sprintf("Enter %s %s", promptPrefix, fieldName)
result, err := promptUi.Run()
if err != nil {
return msg, fmt.Errorf("failed to prompt for %s: %w", fieldName, err)
}
v, err := valueOf(field, result)
if err != nil {
return msg, err
}
msg.Set(field, v)
}
return msg, nil
}
// valueOf converts a string input value to a protoreflect.Value based on the field's type.
// It handles string, numeric, bool, bytes and enum field types.
// Returns the converted value and any error that occurred during conversion.
func valueOf(field protoreflect.FieldDescriptor, result string) (protoreflect.Value, error) {
switch field.Kind() {
case protoreflect.StringKind:
return protoreflect.ValueOfString(result), nil
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
resultUint, err := strconv.ParseUint(result, 10, 0)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid value for int: %w", err)
}
return protoreflect.ValueOfUint64(resultUint), nil
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
resultInt, err := strconv.ParseInt(result, 10, 0)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid value for int: %w", err)
}
// If a value was successfully parsed the ranges of:
// [minInt, maxInt]
// are within the ranges of:
// [minInt64, maxInt64]
// of which on 64-bit machines, which are most common,
// int==int64
return protoreflect.ValueOfInt64(resultInt), nil
case protoreflect.BoolKind:
resultBool, err := strconv.ParseBool(result)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid value for bool: %w", err)
}
return protoreflect.ValueOfBool(resultBool), nil
case protoreflect.BytesKind:
resultBytes := []byte(result)
return protoreflect.ValueOfBytes(resultBytes), nil
case protoreflect.EnumKind:
enumValue := field.Enum().Values().ByName(protoreflect.Name(result))
if enumValue == nil {
return protoreflect.Value{}, fmt.Errorf("invalid enum value %q", result)
}
return protoreflect.ValueOfEnum(enumValue.Number()), nil
default:
// TODO: add more kinds
// skip any other types
return protoreflect.Value{}, nil
}
}
// promptList prompts the user for a comma-separated list of values for a repeated field.
// The user will be prompted to enter values separated by commas which will be parsed
// according to the field's type using valueOf.
func promptList(field protoreflect.FieldDescriptor, msg protoreflect.Message, promptUi promptui.Prompt, promptPrefix string) (protoreflect.List, error) {
promptUi.Label = fmt.Sprintf("Enter %s %s list (separate values with ',')", promptPrefix, string(field.Name()))
result, err := promptUi.Run()
if err != nil {
return nil, fmt.Errorf("failed to prompt for %s: %w", string(field.Name()), err)
}
list := msg.Mutable(field).List()
for _, item := range strings.Split(result, ",") {
v, err := valueOf(field, item)
if err != nil {
return nil, err
}
list.Append(v)
}
return list, nil
}
// promptInnerMessageKind handles prompting for fields that are of message kind.
// It handles both single messages and repeated message fields by delegating to
// promptInnerMessage and promptMessageList respectively.
func promptInnerMessageKind(
f protoreflect.FieldDescriptor, addressCodec addresscodec.Codec,
validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) error {
if f.IsList() {
return promptMessageList(f, addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, stdIn, msg)
}
return promptInnerMessage(f, addressCodec, validatorAddressCodec, consensusAddressCodec, promptPrefix, stdIn, msg)
}
// promptInnerMessage prompts for a single nested message field. It creates a new message instance,
// recursively prompts for its fields, and sets the populated message on the parent message.
func promptInnerMessage(
f protoreflect.FieldDescriptor, addressCodec addresscodec.Codec,
validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) error {
fieldName := promptPrefix + "." + string(f.Name())
nestedMsg := msg.Get(f).Message()
nestedMsg = nestedMsg.New()
// Recursively prompt for nested message fields
updatedMsg, err := promptMessage(
addressCodec,
validatorAddressCodec,
consensusAddressCodec,
fieldName,
stdIn,
nestedMsg,
)
if err != nil {
return fmt.Errorf("failed to prompt for nested message %s: %w", fieldName, err)
}
msg.Set(f, protoreflect.ValueOfMessage(updatedMsg))
return nil
}
// promptMessageList prompts for a repeated message field by repeatedly creating new message instances,
// prompting for their fields, and appending them to the list until the user chooses to stop.
func promptMessageList(
f protoreflect.FieldDescriptor, addressCodec addresscodec.Codec,
validatorAddressCodec, consensusAddressCodec addresscodec.Codec,
promptPrefix string, stdIn io.ReadCloser, msg protoreflect.Message,
) error {
list := msg.Mutable(f).List()
for {
fieldName := promptPrefix + "." + string(f.Name())
// Create and populate a new message for the list
nestedMsg := list.NewElement().Message()
updatedMsg, err := promptMessage(
addressCodec,
validatorAddressCodec,
consensusAddressCodec,
fieldName,
stdIn,
nestedMsg,
)
if err != nil {
return fmt.Errorf("failed to prompt for list item in %s: %w", fieldName, err)
}
list.Append(protoreflect.ValueOfMessage(updatedMsg))
// Prompt whether to continue
// TODO: may be better yes/no rather than interactive?
continuePrompt := promptui.Select{
Label: "Add another item?",
Items: []string{"No", "Yes"},
Stdin: stdIn,
}
_, result, err := continuePrompt.Run()
if err != nil {
return fmt.Errorf("failed to prompt for continuation: %w", err)
}
if result == "No" {
break
}
}
return nil
}