common/hexutil: add UnmarshalFixedUnprefixedText
This commit is contained in:
		
							parent
							
								
									04fa6a3744
								
							
						
					
					
						commit
						b4547a560b
					
				| @ -51,7 +51,7 @@ func (b *Bytes) UnmarshalJSON(input []byte) error { | ||||
| 
 | ||||
| // UnmarshalText implements encoding.TextUnmarshaler.
 | ||||
| func (b *Bytes) UnmarshalText(input []byte) error { | ||||
| 	raw, err := checkText(input) | ||||
| 	raw, err := checkText(input, true) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -73,7 +73,28 @@ func (b Bytes) String() string { | ||||
| // determines the required input length. This function is commonly used to implement the
 | ||||
| // UnmarshalText method for fixed-size types.
 | ||||
| func UnmarshalFixedText(typname string, input, out []byte) error { | ||||
| 	raw, err := checkText(input) | ||||
| 	raw, err := checkText(input, true) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if len(raw)/2 != len(out) { | ||||
| 		return fmt.Errorf("hex string has length %d, want %d for %s", len(raw), len(out)*2, typname) | ||||
| 	} | ||||
| 	// Pre-verify syntax before modifying out.
 | ||||
| 	for _, b := range raw { | ||||
| 		if decodeNibble(b) == badNibble { | ||||
| 			return ErrSyntax | ||||
| 		} | ||||
| 	} | ||||
| 	hex.Decode(out, raw) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // UnmarshalFixedUnprefixedText decodes the input as a string with optional 0x prefix. The
 | ||||
| // length of out determines the required input length. This function is commonly used to
 | ||||
| // implement the UnmarshalText method for fixed-size types.
 | ||||
| func UnmarshalFixedUnprefixedText(typname string, input, out []byte) error { | ||||
| 	raw, err := checkText(input, false) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -243,14 +264,15 @@ func bytesHave0xPrefix(input []byte) bool { | ||||
| 	return len(input) >= 2 && input[0] == '0' && (input[1] == 'x' || input[1] == 'X') | ||||
| } | ||||
| 
 | ||||
| func checkText(input []byte) ([]byte, error) { | ||||
| func checkText(input []byte, wantPrefix bool) ([]byte, error) { | ||||
| 	if len(input) == 0 { | ||||
| 		return nil, nil // empty strings are allowed
 | ||||
| 	} | ||||
| 	if !bytesHave0xPrefix(input) { | ||||
| 	if bytesHave0xPrefix(input) { | ||||
| 		input = input[2:] | ||||
| 	} else if wantPrefix { | ||||
| 		return nil, ErrMissingPrefix | ||||
| 	} | ||||
| 	input = input[2:] | ||||
| 	if len(input)%2 != 0 { | ||||
| 		return nil, ErrOddLength | ||||
| 	} | ||||
|  | ||||
| @ -337,3 +337,38 @@ func TestUnmarshalUint(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUnmarshalFixedUnprefixedText(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input   string | ||||
| 		want    []byte | ||||
| 		wantErr error | ||||
| 	}{ | ||||
| 		{input: "0x2", wantErr: ErrOddLength}, | ||||
| 		{input: "2", wantErr: ErrOddLength}, | ||||
| 		{input: "4444", wantErr: errors.New("hex string has length 4, want 8 for x")}, | ||||
| 		{input: "4444", wantErr: errors.New("hex string has length 4, want 8 for x")}, | ||||
| 		// check that output is not modified for partially correct input
 | ||||
| 		{input: "444444gg", wantErr: ErrSyntax, want: []byte{0, 0, 0, 0}}, | ||||
| 		{input: "0x444444gg", wantErr: ErrSyntax, want: []byte{0, 0, 0, 0}}, | ||||
| 		// valid inputs
 | ||||
| 		{input: "44444444", want: []byte{0x44, 0x44, 0x44, 0x44}}, | ||||
| 		{input: "0x44444444", want: []byte{0x44, 0x44, 0x44, 0x44}}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		out := make([]byte, 4) | ||||
| 		err := UnmarshalFixedUnprefixedText("x", []byte(test.input), out) | ||||
| 		switch { | ||||
| 		case err == nil && test.wantErr != nil: | ||||
| 			t.Errorf("%q: got no error, expected %q", test.input, test.wantErr) | ||||
| 		case err != nil && test.wantErr == nil: | ||||
| 			t.Errorf("%q: unexpected error %q", test.input, err) | ||||
| 		case err != nil && err.Error() != test.wantErr.Error(): | ||||
| 			t.Errorf("%q: error mismatch: got %q, want %q", test.input, err, test.wantErr) | ||||
| 		} | ||||
| 		if test.want != nil && !bytes.Equal(out, test.want) { | ||||
| 			t.Errorf("%q: output mismatch: got %x, want %x", test.input, out, test.want) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user