fix(server/v2): post request fallback (#23361)

This commit is contained in:
Tyler 2025-01-13 16:25:07 -08:00 committed by GitHub
parent b461a3142a
commit 265cb94e8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 12 deletions

View File

@ -1,6 +1,7 @@
package grpcgateway
import (
"bytes"
"errors"
"io"
"net/http"
@ -89,7 +90,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
msgType := gogoproto.MessageType(match.QueryInputName)
msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message)
if !ok {
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName))
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName))
return
}
@ -102,12 +103,12 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
case http.MethodPost:
inputMsg, err = g.createMessageFromPostRequest(in, request, msg)
default:
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET"))
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET"))
return
}
if err != nil {
// the errors returned from the message creation methods return status errors. no need to make one here.
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, err)
return
}
@ -118,7 +119,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
if heightStr != "" && heightStr != "latest" {
height, err = strconv.ParseUint(heightStr, 10, 64)
if err != nil {
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr))
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr))
return
}
}
@ -130,7 +131,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
g.gateway.ServeHTTP(writer, request)
} else {
// for all other errors, we just return the error.
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, err)
}
return
}
@ -143,12 +144,17 @@ func (g *gatewayInterceptor[T]) createMessageFromPostRequest(marshaler runtime.M
if req.ContentLength > MaxBodySize {
return nil, status.Errorf(codes.InvalidArgument, "request body too large: %d bytes, max=%d", req.ContentLength, MaxBodySize)
}
newReader, err := utilities.IOReaderFactory(req.Body)
// this block of code ensures that the body can be re-read. this is needed as if the query fails in the
// app's query handler, we need to pass the request back to the canonical gateway, which needs to be able to
// read the body again.
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if err = marshaler.NewDecoder(newReader()).Decode(input); err != nil && !errors.Is(err, io.EOF) {
if err = marshaler.NewDecoder(bytes.NewReader(bodyBytes)).Decode(input); err != nil && !errors.Is(err, io.EOF) {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
@ -217,12 +223,12 @@ func getHTTPGetAnnotationMapping() (map[string]string, error) {
continue
}
queryInputName := string(methodDesc.Input().FullName())
annotations := append(httpRule.GetAdditionalBindings(), httpRule)
for _, a := range annotations {
if httpAnnotation := a.GetGet(); httpAnnotation != "" {
httpRules := append(httpRule.GetAdditionalBindings(), httpRule)
for _, rule := range httpRules {
if httpAnnotation := rule.GetGet(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
}
if httpAnnotation := a.GetPost(); httpAnnotation != "" {
if httpAnnotation := rule.GetPost(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
}
}

View File

@ -179,7 +179,7 @@ func TestDistrValidatorGRPCQueries(t *testing.T) {
// test validator slashes grpc endpoint
slashURL := baseurl + `/cosmos/distribution/v1beta1/validators/%s/slashes`
invalidHeightOutput := `{"code":"NUMBER", "details":[]interface {}{}, "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}`
invalidHeightOutput := `{"code":"NUMBER", "details":[], "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}`
slashTestCases := []systest.RestTestCase{
{