diff --git a/server/v2/api/grpcgateway/interceptor.go b/server/v2/api/grpcgateway/interceptor.go index 81d1c3e32f..ee8fc598ca 100644 --- a/server/v2/api/grpcgateway/interceptor.go +++ b/server/v2/api/grpcgateway/interceptor.go @@ -1,6 +1,7 @@ package grpcgateway import ( + "bytes" "errors" "io" "net/http" @@ -89,7 +90,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h msgType := gogoproto.MessageType(match.QueryInputName) msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message) if !ok { - runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName)) + runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName)) return } @@ -102,12 +103,12 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h case http.MethodPost: inputMsg, err = g.createMessageFromPostRequest(in, request, msg) default: - runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET")) + runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET")) return } if err != nil { // the errors returned from the message creation methods return status errors. no need to make one here. - runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err) + runtime.HTTPError(request.Context(), g.gateway, out, writer, request, err) return } @@ -118,7 +119,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h if heightStr != "" && heightStr != "latest" { height, err = strconv.ParseUint(heightStr, 10, 64) if err != nil { - runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr)) + runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr)) return } } @@ -130,7 +131,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h g.gateway.ServeHTTP(writer, request) } else { // for all other errors, we just return the error. - runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err) + runtime.HTTPError(request.Context(), g.gateway, out, writer, request, err) } return } @@ -143,12 +144,17 @@ func (g *gatewayInterceptor[T]) createMessageFromPostRequest(marshaler runtime.M if req.ContentLength > MaxBodySize { return nil, status.Errorf(codes.InvalidArgument, "request body too large: %d bytes, max=%d", req.ContentLength, MaxBodySize) } - newReader, err := utilities.IOReaderFactory(req.Body) + + // this block of code ensures that the body can be re-read. this is needed as if the query fails in the + // app's query handler, we need to pass the request back to the canonical gateway, which needs to be able to + // read the body again. + bodyBytes, err := io.ReadAll(req.Body) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "%v", err) } + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - if err = marshaler.NewDecoder(newReader()).Decode(input); err != nil && !errors.Is(err, io.EOF) { + if err = marshaler.NewDecoder(bytes.NewReader(bodyBytes)).Decode(input); err != nil && !errors.Is(err, io.EOF) { return nil, status.Errorf(codes.InvalidArgument, "%v", err) } @@ -217,12 +223,12 @@ func getHTTPGetAnnotationMapping() (map[string]string, error) { continue } queryInputName := string(methodDesc.Input().FullName()) - annotations := append(httpRule.GetAdditionalBindings(), httpRule) - for _, a := range annotations { - if httpAnnotation := a.GetGet(); httpAnnotation != "" { + httpRules := append(httpRule.GetAdditionalBindings(), httpRule) + for _, rule := range httpRules { + if httpAnnotation := rule.GetGet(); httpAnnotation != "" { annotationToQueryInputName[httpAnnotation] = queryInputName } - if httpAnnotation := a.GetPost(); httpAnnotation != "" { + if httpAnnotation := rule.GetPost(); httpAnnotation != "" { annotationToQueryInputName[httpAnnotation] = queryInputName } } diff --git a/tests/systemtests/distribution_test.go b/tests/systemtests/distribution_test.go index 9c41597fe5..73114ee831 100644 --- a/tests/systemtests/distribution_test.go +++ b/tests/systemtests/distribution_test.go @@ -179,7 +179,7 @@ func TestDistrValidatorGRPCQueries(t *testing.T) { // test validator slashes grpc endpoint slashURL := baseurl + `/cosmos/distribution/v1beta1/validators/%s/slashes` - invalidHeightOutput := `{"code":"NUMBER", "details":[]interface {}{}, "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}` + invalidHeightOutput := `{"code":"NUMBER", "details":[], "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}` slashTestCases := []systest.RestTestCase{ {