refactor(server/v2): auto-gateway improvements (#23262)

Co-authored-by: Alex | Interchain Labs <alex@skip.money>
This commit is contained in:
Tyler 2025-01-13 13:13:13 -08:00 committed by GitHub
parent 9f048ebd8a
commit b461a3142a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 599 additions and 335 deletions

View File

@ -0,0 +1,11 @@
// Package grpcgateway provides a custom http mux that utilizes the global gogoproto registry to match
// grpc gateway requests to query handlers. POST requests with JSON bodies and GET requests with query params are supported.
// Wildcard endpoints (i.e. foo/bar/{baz}), as well as catch-all endpoints (i.e. foo/bar/{baz=**} are supported. Using
// header `x-cosmos-block-height` allows you to specify a height for the query.
//
// The URL matching logic is achieved by building regular expressions from the gateway HTTP annotations. These regular expressions
// are then used to match against incoming requests to the HTTP server.
//
// In cases where the custom http mux is unable to handle the query (i.e. no match found), the request will fall back to the
// ServeMux from github.com/grpc-ecosystem/grpc-gateway/runtime.
package grpcgateway

View File

@ -1,12 +1,18 @@
package grpcgateway
import (
"errors"
"io"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"github.com/mitchellh/mapstructure"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -18,18 +24,27 @@ import (
"cosmossdk.io/server/v2/appmanager"
)
const MaxBodySize = 1 << 20 // 1 MB
var _ http.Handler = &gatewayInterceptor[transaction.Tx]{}
// queryMetadata holds information related to handling gateway queries.
type queryMetadata struct {
// queryInputProtoName is the proto name of the query's input type.
queryInputProtoName string
// wildcardKeyNames are the wildcard key names from the query's HTTP annotation.
// for example /foo/bar/{baz}/{qux} would produce []string{"baz", "qux"}
// this is used for building the query's parameter map.
wildcardKeyNames []string
}
// gatewayInterceptor handles routing grpc-gateway queries to the app manager's query router.
type gatewayInterceptor[T transaction.Tx] struct {
logger log.Logger
// gateway is the fallback grpc gateway mux handler.
gateway *runtime.ServeMux
// customEndpointMapping is a mapping of custom GET options on proto RPC handlers, to the fully qualified method name.
//
// example: /cosmos/bank/v1beta1/denoms_metadata -> cosmos.bank.v1beta1.Query.DenomsMetadata
customEndpointMapping map[string]string
matcher uriMatcher
// appManager is used to route queries to the application.
appManager appmanager.AppManager[T]
@ -41,57 +56,74 @@ func newGatewayInterceptor[T transaction.Tx](logger log.Logger, gateway *runtime
if err != nil {
return nil, err
}
// convert the mapping to regular expressions for URL matching.
wildcardMatchers, simpleMatchers := createRegexMapping(logger, getMapping)
matcher := uriMatcher{
wildcardURIMatchers: wildcardMatchers,
simpleMatchers: simpleMatchers,
}
return &gatewayInterceptor[T]{
logger: logger,
gateway: gateway,
customEndpointMapping: getMapping,
appManager: am,
logger: logger,
gateway: gateway,
matcher: matcher,
appManager: am,
}, nil
}
// ServeHTTP implements the http.Handler interface. This function will attempt to match http requests to the
// interceptors internal mapping of http annotations to query request type names.
// If no match can be made, it falls back to the runtime gateway server mux.
// ServeHTTP implements the http.Handler interface. This method will attempt to match request URIs to its internal mapping
// of gateway HTTP annotations. If no match can be made, it falls back to the runtime gateway server mux.
func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
g.logger.Debug("received grpc-gateway request", "request_uri", request.RequestURI)
match := matchURL(request.URL, g.customEndpointMapping)
match := g.matcher.matchURL(request.URL)
if match == nil {
// no match cases fall back to gateway mux.
g.gateway.ServeHTTP(writer, request)
return
}
g.logger.Debug("matched request", "query_input", match.QueryInputName)
_, out := runtime.MarshalerForRequest(g.gateway, request)
var msg gogoproto.Message
var err error
g.logger.Debug("matched request", "query_input", match.QueryInputName)
in, out := runtime.MarshalerForRequest(g.gateway, request)
// extract the proto message type.
msgType := gogoproto.MessageType(match.QueryInputName)
msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message)
if !ok {
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName))
return
}
// msg population based on http method.
var inputMsg gogoproto.Message
var err error
switch request.Method {
case http.MethodPost:
msg, err = createMessageFromJSON(match, request)
case http.MethodGet:
msg, err = createMessage(match)
inputMsg, err = g.createMessageFromGetRequest(request, msg, match.Params)
case http.MethodPost:
inputMsg, err = g.createMessageFromPostRequest(in, request, msg)
default:
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.Unimplemented, "HTTP method must be POST or GET"))
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET"))
return
}
if err != nil {
// the errors returned from the message creation methods return status errors. no need to make one here.
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
return
}
// extract block height header
// get the height from the header.
var height uint64
heightStr := request.Header.Get(GRPCBlockHeightHeader)
if heightStr != "" {
heightStr = strings.Trim(heightStr, `\"`)
if heightStr != "" && heightStr != "latest" {
height, err = strconv.ParseUint(heightStr, 10, 64)
if err != nil {
err = status.Errorf(codes.InvalidArgument, "invalid height: %s", heightStr)
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr))
return
}
}
query, err := g.appManager.Query(request.Context(), height, msg)
responseMsg, err := g.appManager.Query(request.Context(), height, inputMsg)
if err != nil {
// if we couldn't find a handler for this request, just fall back to the gateway mux.
if strings.Contains(err.Error(), "no handler") {
@ -102,8 +134,62 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
}
return
}
// for no errors, we forward the response.
runtime.ForwardResponseMessage(request.Context(), g.gateway, out, writer, request, query)
runtime.ForwardResponseMessage(request.Context(), g.gateway, out, writer, request, responseMsg)
}
func (g *gatewayInterceptor[T]) createMessageFromPostRequest(marshaler runtime.Marshaler, req *http.Request, input gogoproto.Message) (gogoproto.Message, error) {
if req.ContentLength > MaxBodySize {
return nil, status.Errorf(codes.InvalidArgument, "request body too large: %d bytes, max=%d", req.ContentLength, MaxBodySize)
}
newReader, err := utilities.IOReaderFactory(req.Body)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
if err = marshaler.NewDecoder(newReader()).Decode(input); err != nil && !errors.Is(err, io.EOF) {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
return input, nil
}
func (g *gatewayInterceptor[T]) createMessageFromGetRequest(req *http.Request, input gogoproto.Message, wildcardValues map[string]string) (gogoproto.Message, error) {
// decode the path wildcards into the message.
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: input,
TagName: "json",
WeaklyTypedInput: true,
})
if err != nil {
return nil, status.Error(codes.Internal, "failed to create message decoder")
}
if err := decoder.Decode(wildcardValues); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if err = req.ParseForm(); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
filter := filterFromPathParams(wildcardValues)
err = runtime.PopulateQueryParameters(input, req.Form, filter)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
return input, err
}
func filterFromPathParams(pathParams map[string]string) *utilities.DoubleArray {
var prefixPaths [][]string
for k := range pathParams {
prefixPaths = append(prefixPaths, []string{k})
}
return utilities.NewDoubleArray(prefixPaths)
}
// getHTTPGetAnnotationMapping returns a mapping of RPC Method HTTP GET annotation to the RPC Handler's Request Input type full name.
@ -115,31 +201,74 @@ func getHTTPGetAnnotationMapping() (map[string]string, error) {
return nil, err
}
httpGets := make(map[string]string)
annotationToQueryInputName := make(map[string]string)
protoFiles.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
for i := 0; i < fd.Services().Len(); i++ {
serviceDesc := fd.Services().Get(i)
for j := 0; j < serviceDesc.Methods().Len(); j++ {
methodDesc := serviceDesc.Methods().Get(j)
httpAnnotation := proto.GetExtension(methodDesc.Options(), annotations.E_Http)
if httpAnnotation == nil {
httpExtension := proto.GetExtension(methodDesc.Options(), annotations.E_Http)
if httpExtension == nil {
continue
}
httpRule, ok := httpAnnotation.(*annotations.HttpRule)
httpRule, ok := httpExtension.(*annotations.HttpRule)
if !ok || httpRule == nil {
continue
}
if httpRule.GetGet() == "" {
continue
queryInputName := string(methodDesc.Input().FullName())
annotations := append(httpRule.GetAdditionalBindings(), httpRule)
for _, a := range annotations {
if httpAnnotation := a.GetGet(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
}
if httpAnnotation := a.GetPost(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
}
}
httpGets[httpRule.GetGet()] = string(methodDesc.Input().FullName())
}
}
return true
})
return httpGets, nil
return annotationToQueryInputName, nil
}
// createRegexMapping converts the annotationMapping (HTTP annotation -> query input type name) to a
// map of regular expressions for that HTTP annotation pattern, to queryMetadata.
func createRegexMapping(logger log.Logger, annotationMapping map[string]string) (map[*regexp.Regexp]queryMetadata, map[string]queryMetadata) {
wildcardMatchers := make(map[*regexp.Regexp]queryMetadata)
// seen patterns is a map of URI patterns to annotations. for simple queries (no wildcards) the annotation is used
// for the key.
seenPatterns := make(map[string]string)
simpleMatchers := make(map[string]queryMetadata)
for annotation, queryInputName := range annotationMapping {
pattern, wildcardNames := patternToRegex(annotation)
if len(wildcardNames) == 0 {
if otherAnnotation, ok := seenPatterns[annotation]; ok {
// TODO: eventually we want this to error, but there is currently a duplicate in the protobuf.
// see: https://github.com/cosmos/cosmos-sdk/issues/23281
logger.Warn("duplicate HTTP annotation found", "annotation1", annotation, "annotation2", otherAnnotation, "query_input_name", queryInputName)
}
simpleMatchers[annotation] = queryMetadata{
queryInputProtoName: queryInputName,
wildcardKeyNames: nil,
}
seenPatterns[annotation] = annotation
} else {
reg := regexp.MustCompile(pattern)
if otherAnnotation, ok := seenPatterns[pattern]; ok {
// TODO: eventually we want this to error, but there is currently a duplicate in the protobuf.
// see: https://github.com/cosmos/cosmos-sdk/issues/23281
logger.Warn("duplicate HTTP annotation found", "annotation1", annotation, "annotation2", otherAnnotation, "query_input_name", queryInputName)
}
wildcardMatchers[reg] = queryMetadata{
queryInputProtoName: queryInputName,
wildcardKeyNames: wildcardNames,
}
seenPatterns[pattern] = annotation
}
}
return wildcardMatchers, simpleMatchers
}

View File

@ -0,0 +1,313 @@
package grpcgateway
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"cosmossdk.io/core/transaction"
"cosmossdk.io/log"
)
func Test_createRegexMapping(t *testing.T) {
tests := []struct {
name string
annotations map[string]string
expectedRegex int
expectedSimple int
wantWarn bool
}{
{
name: "no annotations should not warn",
},
{
name: "expected correct amount of regex and simple matchers",
annotations: map[string]string{
"/foo/bar/baz": "",
"/foo/{bar}/baz": "",
"/foo/bar/bell": "",
},
expectedRegex: 1,
expectedSimple: 2,
},
{
name: "different annotations should not warn",
annotations: map[string]string{
"/foo/bar/{baz}": "",
"/crypto/{currency}": "",
},
expectedRegex: 2,
},
{
name: "duplicate annotations should warn",
annotations: map[string]string{
"/hello/{world}": "",
"/hello/{developers}": "",
},
expectedRegex: 2,
wantWarn: true,
},
}
buf := bytes.NewBuffer(nil)
logger := log.NewLogger(buf)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
regex, simple := createRegexMapping(logger, tt.annotations)
if tt.wantWarn {
require.NotEmpty(t, buf.String())
} else {
require.Empty(t, buf.String())
}
require.Equal(t, tt.expectedRegex, len(regex))
require.Equal(t, tt.expectedSimple, len(simple))
})
}
}
func TestCreateMessageFromGetRequest(t *testing.T) {
gogoproto.RegisterType(&DummyProto{}, dummyProtoName)
testCases := []struct {
name string
request func() *http.Request
wildcardValues map[string]string
expected *DummyProto
wantErr bool
errCode codes.Code
}{
{
name: "simple wildcard + query params",
request: func() *http.Request {
// GET with query params: ?bar=true&baz=42&denoms=atom&denoms=osmo
// Also nested pagination params: page.limit=100, page.nest.foo=999
req := httptest.NewRequest(
http.MethodGet,
"/dummy?bar=true&baz=42&denoms=atom&denoms=osmo&page.limit=100&page.nest.foo=999",
nil,
)
return req
},
wildcardValues: map[string]string{
"foo": "wildFooValue", // from path wildcard e.g. /dummy/{foo}
},
expected: &DummyProto{
Foo: "wildFooValue",
Bar: true,
Baz: 42,
Denoms: []string{"atom", "osmo"},
Page: &Pagination{
Limit: 100,
Nest: &Nested{
Foo: 999,
},
},
},
wantErr: false,
},
{
name: "invalid integer in query param",
request: func() *http.Request {
req := httptest.NewRequest(
http.MethodGet,
"/dummy?baz=notanint",
nil,
)
return req
},
wildcardValues: map[string]string{},
expected: &DummyProto{}, // won't get populated
wantErr: true,
errCode: codes.InvalidArgument,
},
{
name: "no query params, but wildcard set",
request: func() *http.Request {
// No query params. Only the wildcard.
req := httptest.NewRequest(
http.MethodGet,
"/dummy",
nil,
)
return req
},
wildcardValues: map[string]string{
"foo": "barFromWildcard",
},
expected: &DummyProto{
Foo: "barFromWildcard",
},
wantErr: false,
},
}
// We only need a minimal gatewayInterceptor instance to call createMessageFromGetRequest,
// so it's fine to leave most fields nil for this unit test.
g := &gatewayInterceptor[transaction.Tx]{}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := tc.request()
inputMsg := &DummyProto{}
gotMsg, err := g.createMessageFromGetRequest(
req,
inputMsg,
tc.wildcardValues,
)
if tc.wantErr {
require.Error(t, err, "expected error but got none")
st, ok := status.FromError(err)
if ok && tc.errCode != codes.OK {
require.Equal(t, tc.errCode, st.Code())
}
} else {
require.NoError(t, err, "unexpected error")
require.Equal(t, tc.expected, gotMsg, "message contents do not match expected")
}
})
}
}
func TestCreateMessageFromPostRequest(t *testing.T) {
gogoproto.RegisterType(&DummyProto{}, dummyProtoName)
gogoproto.RegisterType(&Pagination{}, "pagination")
gogoproto.RegisterType(&Nested{}, "nested")
testCases := []struct {
name string
body any
wantErr bool
errCode codes.Code
expected *DummyProto
}{
{
name: "valid JSON body with nested fields",
body: map[string]any{
"foo": "postFoo",
"bar": true,
"baz": 42,
"denoms": []string{"atom", "osmo"},
"page": map[string]any{
"limit": 100,
"nest": map[string]any{
"foo": 999,
},
},
},
wantErr: false,
expected: &DummyProto{
Foo: "postFoo",
Bar: true,
Baz: 42,
Denoms: []string{"atom", "osmo"},
Page: &Pagination{
Limit: 100,
Nest: &Nested{
Foo: 999,
},
},
},
},
{
name: "invalid JSON structure",
// Provide a broken JSON string:
body: `{"foo": "bad json", "extra": "not closed"`,
wantErr: true,
errCode: codes.InvalidArgument,
},
{
name: "empty JSON object",
body: map[string]any{},
wantErr: false,
expected: &DummyProto{}, // all fields remain zeroed
},
}
g := &gatewayInterceptor[transaction.Tx]{}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var reqBody []byte
switch typedBody := tc.body.(type) {
case string:
// This might be invalid JSON we intentionally want to test
reqBody = []byte(typedBody)
default:
// Marshal the given any into JSON
b, err := json.Marshal(typedBody)
require.NoError(t, err, "failed to marshal test body to JSON")
reqBody = b
}
req := httptest.NewRequest(http.MethodPost, "/dummy", bytes.NewReader(reqBody))
inputMsg := &DummyProto{}
gotMsg, err := g.createMessageFromPostRequest(
&runtime.JSONPb{}, // JSONPb marshaler
req,
inputMsg,
)
if tc.wantErr {
require.Error(t, err, "expected an error but got none")
// Optionally verify the gRPC status code
st, ok := status.FromError(err)
if ok && tc.errCode != codes.OK {
require.Equal(t, tc.errCode, st.Code())
}
} else {
require.NoError(t, err, "did not expect an error")
require.Equal(t, tc.expected, gotMsg)
}
})
}
}
/*
--- Testing Types ---
*/
type Nested struct {
Foo int32 `protobuf:"varint,1,opt,name=foo,proto3" json:"foo,omitempty"`
}
func (n Nested) Reset() {}
func (n Nested) String() string { return "" }
func (n Nested) ProtoMessage() {}
type Pagination struct {
Limit int32 `protobuf:"varint,1,opt,name=limit,proto3" json:"limit,omitempty"`
Nest *Nested `protobuf:"bytes,2,opt,name=nest,proto3" json:"nest,omitempty"`
}
func (p Pagination) Reset() {}
func (p Pagination) String() string { return "" }
func (p Pagination) ProtoMessage() {}
const dummyProtoName = "dummy"
type DummyProto struct {
Foo string `protobuf:"bytes,1,opt,name=foo,proto3" json:"foo,omitempty"`
Bar bool `protobuf:"varint,2,opt,name=bar,proto3" json:"bar,omitempty"`
Baz int32 `protobuf:"varint,3,opt,name=baz,proto3" json:"baz,omitempty"`
Denoms []string `protobuf:"bytes,4,rep,name=denoms,proto3" json:"denoms,omitempty"`
Page *Pagination `protobuf:"bytes,5,opt,name=page,proto3" json:"page,omitempty"`
}
func (d DummyProto) Reset() {}
func (d DummyProto) String() string { return dummyProtoName }
func (d DummyProto) ProtoMessage() {}

View File

@ -59,7 +59,7 @@ func New[T transaction.Tx](
// marshaled in unary requests.
runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler),
// Custom header matcher for mapping request headers to
// Custom header uriMatcher for mapping request headers to
// GRPC metadata
runtime.WithIncomingHeaderMatcher(CustomGRPCHeaderMatcher),
),

View File

@ -1,86 +1,65 @@
package grpcgateway
import (
"io"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
"github.com/cosmos/gogoproto/jsonpb"
gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const maxBodySize = 1 << 20 // 1 MB
// uriMatcher provides functionality to match HTTP request URIs.
type uriMatcher struct {
// wildcardURIMatchers are used for complex URIs that involve wildcards (i.e. /foo/{bar}/baz)
wildcardURIMatchers map[*regexp.Regexp]queryMetadata
// simpleMatchers are used for simple URI's that have no wildcards (i.e. /foo/bar/baz).
simpleMatchers map[string]queryMetadata
}
// uriMatch contains information related to a URI match.
type uriMatch struct {
// QueryInputName is the fully qualified name of the proto input type of the query rpc method.
QueryInputName string
// Params are any wildcard/query params found in the request.
// Params are any wildcard params found in the request.
//
// example:
// - foo/bar/{baz} - foo/bar/qux -> {baz: qux}
// - foo/bar?baz=qux - foo/bar -> {baz: qux}
// example: /foo/bar/{baz} -> /foo/bar/hello = {"baz": "hello"}
Params map[string]string
}
// HasParams reports whether the uriMatch has any params.
func (uri uriMatch) HasParams() bool {
return len(uri.Params) > 0
}
// matchURL attempts to find a match for the given URL.
// NOTE: if no match is found, nil is returned.
func matchURL(u *url.URL, getPatternToQueryInputName map[string]string) *uriMatch {
func (m uriMatcher) matchURL(u *url.URL) *uriMatch {
uriPath := strings.TrimRight(u.Path, "/")
queryParams := u.Query()
params := make(map[string]string)
for key, vals := range queryParams {
if len(vals) > 0 {
// url.Values contains a slice for the values as you are able to specify a key multiple times in URL.
// example: https://localhost:9090/do/something?color=red&color=blue&color=green
// We will just take the first value in the slice.
params[key] = vals[0]
}
}
// for simple cases where there are no wildcards, we can just do a map lookup.
if inputName, ok := getPatternToQueryInputName[uriPath]; ok {
// see if we can get a simple match first.
if qmd, ok := m.simpleMatchers[uriPath]; ok {
return &uriMatch{
QueryInputName: inputName,
QueryInputName: qmd.queryInputProtoName,
Params: params,
}
}
// attempt to find a match in the pattern map.
for getPattern, queryInputName := range getPatternToQueryInputName {
getPattern = strings.TrimRight(getPattern, "/")
regexPattern, wildcardNames := patternToRegex(getPattern)
regex := regexp.MustCompile(regexPattern)
matches := regex.FindStringSubmatch(uriPath)
if len(matches) > 1 {
// first match is the full string, subsequent matches are capture groups
for i, name := range wildcardNames {
// try the complex matchers.
for reg, qmd := range m.wildcardURIMatchers {
matches := reg.FindStringSubmatch(uriPath)
switch {
case len(matches) == 1:
return &uriMatch{
QueryInputName: qmd.queryInputProtoName,
Params: params,
}
case len(matches) > 1:
// first match is the URI, subsequent matches are the wild card values.
for i, name := range qmd.wildcardKeyNames {
params[name] = matches[i+1]
}
return &uriMatch{
QueryInputName: queryInputName,
QueryInputName: qmd.queryInputProtoName,
Params: params,
}
}
}
return nil
}
@ -110,78 +89,3 @@ func patternToRegex(pattern string) (string, []string) {
return "^" + escaped + "$", wildcardNames
}
// createMessageFromJSON creates a message from the uriMatch given the JSON body in the http request.
func createMessageFromJSON(match *uriMatch, r *http.Request) (gogoproto.Message, error) {
requestType := gogoproto.MessageType(match.QueryInputName)
if requestType == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request type")
}
msg, ok := reflect.New(requestType.Elem()).Interface().(gogoproto.Message)
if !ok {
return nil, status.Error(codes.Internal, "failed to cast to proto message")
}
defer r.Body.Close()
limitedReader := io.LimitReader(r.Body, maxBodySize)
err := jsonpb.Unmarshal(limitedReader, msg)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
return msg, nil
}
// createMessage creates a message from the given uriMatch. If the match has params, the message will be populated
// with the value of those params. Otherwise, an empty message is returned.
func createMessage(match *uriMatch) (gogoproto.Message, error) {
requestType := gogoproto.MessageType(match.QueryInputName)
if requestType == nil {
return nil, status.Error(codes.InvalidArgument, "unknown request type")
}
msg, ok := reflect.New(requestType.Elem()).Interface().(gogoproto.Message)
if !ok {
return nil, status.Error(codes.Internal, "failed to create message instance")
}
// if the uri match has params, we need to populate the message with the values of those params.
if match.HasParams() {
// convert flat params map to nested structure
nestedParams := make(map[string]any)
for key, value := range match.Params {
parts := strings.Split(key, ".")
current := nestedParams
// step through nested levels
for i, part := range parts {
if i == len(parts)-1 {
// Last part - set the value
current[part] = value
} else {
// continue nestedness
if _, exists := current[part]; !exists {
current[part] = make(map[string]any)
}
current = current[part].(map[string]any)
}
}
}
// Configure decoder to handle the nested structure
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: msg,
TagName: "json", // Use json tags as they're simpler
WeaklyTypedInput: true,
})
if err != nil {
return nil, status.Error(codes.Internal, "failed to create message instance")
}
if err := decoder.Decode(nestedParams); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
return msg, nil
}

View File

@ -1,15 +1,14 @@
package grpcgateway
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
"regexp"
"testing"
gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/stretchr/testify/require"
"cosmossdk.io/log"
)
func TestMatchURI(t *testing.T) {
@ -26,16 +25,22 @@ func TestMatchURI(t *testing.T) {
expected: &uriMatch{QueryInputName: "query.Bank", Params: map[string]string{}},
},
{
name: "match with query parameters",
uri: "https://localhost:8080/foo/bar?baz=qux",
mapping: map[string]string{"/foo/bar": "query.Bank"},
expected: &uriMatch{QueryInputName: "query.Bank", Params: map[string]string{"baz": "qux"}},
name: "match with wildcard similar to simple match - simple",
uri: "https://localhost:8080/bank/supply/latest",
mapping: map[string]string{
"/bank/supply/{height}": "queryBankHeight",
"/bank/supply/latest": "queryBankLatest",
},
expected: &uriMatch{QueryInputName: "queryBankLatest", Params: map[string]string{}},
},
{
name: "match with multiple query parameters",
uri: "https://localhost:8080/foo/bar?baz=qux&foo=/msg.type.bank.send",
mapping: map[string]string{"/foo/bar": "query.Bank"},
expected: &uriMatch{QueryInputName: "query.Bank", Params: map[string]string{"baz": "qux", "foo": "/msg.type.bank.send"}},
name: "match with wildcard similar to simple match - wildcard",
uri: "https://localhost:8080/bank/supply/52",
mapping: map[string]string{
"/bank/supply/{height}": "queryBankHeight",
"/bank/supply/latest": "queryBankLatest",
},
expected: &uriMatch{QueryInputName: "queryBankHeight", Params: map[string]string{"height": "52"}},
},
{
name: "wildcard match at the end",
@ -81,183 +86,86 @@ func TestMatchURI(t *testing.T) {
},
}
logger := log.NewLogger(os.Stdout)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
u, err := url.Parse(tc.uri)
require.NoError(t, err)
actual := matchURL(u, tc.mapping)
regexpMatchers, simpleMatchers := createRegexMapping(logger, tc.mapping)
matcher := uriMatcher{
wildcardURIMatchers: regexpMatchers,
simpleMatchers: simpleMatchers,
}
actual := matcher.matchURL(u)
require.Equal(t, tc.expected, actual)
})
}
}
func TestURIMatch_HasParams(t *testing.T) {
u := uriMatch{Params: map[string]string{"foo": "bar"}}
require.True(t, u.HasParams())
u = uriMatch{}
require.False(t, u.HasParams())
}
type Nested struct {
Foo int `protobuf:"varint,1,opt,name=foo,proto3" json:"foo,omitempty"`
}
type Pagination struct {
Limit int `protobuf:"varint,1,opt,name=limit,proto3" json:"limit,omitempty"`
Nest *Nested `protobuf:"bytes,2,opt,name=nest,proto3" json:"nest,omitempty"`
}
const dummyProtoName = "dummy"
type DummyProto struct {
Foo string `protobuf:"bytes,1,opt,name=foo,proto3" json:"foo,omitempty"`
Bar bool `protobuf:"varint,2,opt,name=bar,proto3" json:"bar,omitempty"`
Baz int `protobuf:"varint,3,opt,name=baz,proto3" json:"baz,omitempty"`
Page *Pagination `protobuf:"bytes,4,opt,name=page,proto3" json:"page,omitempty"`
}
func (d DummyProto) Reset() {}
func (d DummyProto) String() string { return dummyProtoName }
func (d DummyProto) ProtoMessage() {}
func TestCreateMessage(t *testing.T) {
gogoproto.RegisterType(&DummyProto{}, dummyProtoName)
testCases := []struct {
name string
uri uriMatch
expected gogoproto.Message
expErr bool
func Test_patternToRegex(t *testing.T) {
tests := []struct {
name string
pattern string
wildcards []string
wildcardValues []string
shouldMatch string
shouldNotMatch []string
}{
{
name: "simple, empty message",
uri: uriMatch{QueryInputName: dummyProtoName},
expected: &DummyProto{},
name: "simple match, no wildcards",
pattern: "/foo/bar/baz",
shouldMatch: "/foo/bar/baz",
shouldNotMatch: []string{"/foo/bar", "/foo", "/foo/bar/baz/boo"},
},
{
name: "message with params",
uri: uriMatch{
QueryInputName: dummyProtoName,
Params: map[string]string{"foo": "blah", "bar": "true", "baz": "1352"},
},
expected: &DummyProto{
Foo: "blah",
Bar: true,
Baz: 1352,
},
name: "match with wildcard",
pattern: "/foo/bar/{baz}",
wildcards: []string{"baz"},
shouldMatch: "/foo/bar/hello",
wildcardValues: []string{"hello"},
shouldNotMatch: []string{"/foo/bar", "/foo/bar/baz/boo"},
},
{
name: "message with nested params",
uri: uriMatch{
QueryInputName: dummyProtoName,
Params: map[string]string{"foo": "blah", "bar": "true", "baz": "1352", "page.limit": "3"},
},
expected: &DummyProto{
Foo: "blah",
Bar: true,
Baz: 1352,
Page: &Pagination{Limit: 3},
},
name: "match with multiple wildcards",
pattern: "/foo/{bar}/{baz}/meow",
wildcards: []string{"bar", "baz"},
shouldMatch: "/foo/hello/world/meow",
wildcardValues: []string{"hello", "world"},
shouldNotMatch: []string{"/foo/bar/baz/boo", "/foo/bar/baz"},
},
{
name: "message with multi nested params",
uri: uriMatch{
QueryInputName: dummyProtoName,
Params: map[string]string{"foo": "blah", "bar": "true", "baz": "1352", "page.limit": "3", "page.nest.foo": "5"},
},
expected: &DummyProto{
Foo: "blah",
Bar: true,
Baz: 1352,
Page: &Pagination{Limit: 3, Nest: &Nested{Foo: 5}},
},
},
{
name: "invalid params should error out",
uri: uriMatch{
QueryInputName: dummyProtoName,
Params: map[string]string{"foo": "blah", "bar": "235235", "baz": "true"},
},
expErr: true,
},
{
name: "unknown input type",
uri: uriMatch{
QueryInputName: "foobar",
},
expErr: true,
name: "match catch-all wildcard",
pattern: `/foo/bar/{baz=**}`,
wildcards: []string{"baz"},
shouldMatch: `/foo/bar/this/is/a/long/wildcard`,
wildcardValues: []string{"this/is/a/long/wildcard"},
shouldNotMatch: []string{"/foo/bar", "/foo", "/foo/baz/bar/long/wild/card"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
regString, wildcards := patternToRegex(tt.pattern)
// should produce the same wildcard keys
require.Equal(t, tt.wildcards, wildcards)
reg := regexp.MustCompile(regString)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, err := createMessage(&tc.uri)
if tc.expErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.expected, actual)
}
})
}
}
func TestCreateMessageFromJson(t *testing.T) {
gogoproto.RegisterType(&DummyProto{}, dummyProtoName)
testCases := []struct {
name string
uri uriMatch
request func() *http.Request
expected gogoproto.Message
expErr bool
}{
{
name: "simple, empty message",
uri: uriMatch{QueryInputName: dummyProtoName},
request: func() *http.Request {
return &http.Request{Body: io.NopCloser(bytes.NewReader([]byte("{}")))}
},
expected: &DummyProto{},
},
{
name: "message with json input",
uri: uriMatch{QueryInputName: dummyProtoName},
request: func() *http.Request {
d := DummyProto{
Foo: "hello",
Bar: true,
Baz: 320,
}
bz, err := json.Marshal(d)
require.NoError(t, err)
return &http.Request{Body: io.NopCloser(bytes.NewReader(bz))}
},
expected: &DummyProto{
Foo: "hello",
Bar: true,
Baz: 320,
},
},
{
name: "message with invalid json",
uri: uriMatch{QueryInputName: dummyProtoName},
request: func() *http.Request {
return &http.Request{Body: io.NopCloser(bytes.NewReader([]byte(`{"foo":12,dfi3}"`)))}
},
expErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, err := createMessageFromJSON(&tc.uri, tc.request())
if tc.expErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.expected, actual)
// handle the "should match" case.
matches := reg.FindStringSubmatch(tt.shouldMatch)
require.True(t, len(matches) > 0) // there should always be a match.
// when matches > 1, this means we got wildcard values to handle. the test should have wildcard values.
if len(matches) > 1 {
require.Greater(t, len(tt.wildcardValues), 0)
}
// matches[0] is the URL, everything else should be those wildcard values.
if len(tt.wildcardValues) > 0 {
require.Equal(t, matches[1:], tt.wildcardValues)
}
// should never match these.
for _, notMatch := range tt.shouldNotMatch {
require.Len(t, reg.FindStringSubmatch(notMatch), 0)
}
})
}

View File

@ -263,7 +263,7 @@ func TestBankGRPCQueries(t *testing.T) {
"error when querying supply with height greater than block height",
supplyUrl,
map[string]string{
blockHeightHeader: fmt.Sprintf("%d", blockHeight+5),
blockHeightHeader: fmt.Sprintf("%d", blockHeight+5000),
},
http.StatusBadRequest,
"invalid height",

View File

@ -179,21 +179,20 @@ func TestDistrValidatorGRPCQueries(t *testing.T) {
// test validator slashes grpc endpoint
slashURL := baseurl + `/cosmos/distribution/v1beta1/validators/%s/slashes`
invalidStartingHeightOutput := `{"code":3, "message":"1 error(s) decoding:\n\n* cannot parse 'starting_height' as uint: strconv.ParseUint: parsing \"-3\": invalid syntax", "details":[]}`
invalidEndingHeightOutput := `{"code":3, "message":"1 error(s) decoding:\n\n* cannot parse 'ending_height' as uint: strconv.ParseUint: parsing \"-3\": invalid syntax", "details":[]}`
invalidHeightOutput := `{"code":"NUMBER", "details":[]interface {}{}, "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}`
slashTestCases := []systest.RestTestCase{
{
Name: "invalid start height",
Url: fmt.Sprintf(slashURL+`?starting_height=%s&ending_height=%s`, valOperAddr, "-3", "3"),
ExpCode: http.StatusBadRequest,
ExpOut: invalidStartingHeightOutput,
ExpOut: invalidHeightOutput,
},
{
Name: "invalid end height",
Url: fmt.Sprintf(slashURL+`?starting_height=%s&ending_height=%s`, valOperAddr, "1", "-3"),
ExpCode: http.StatusBadRequest,
ExpOut: invalidEndingHeightOutput,
ExpOut: invalidHeightOutput,
},
{
Name: "valid request get slashes",