rpc: minor cleanups to RPC PR

This commit is contained in:
Péter Szilágyi 2017-11-17 14:18:46 +02:00
parent c5b8569707
commit 3c6b9c5d72
No known key found for this signature in database
GPG Key ID: E9AE538CEDF8293D
2 changed files with 38 additions and 28 deletions

View File

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -151,41 +152,36 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && r.ContentLength == 0 && r.URL.RawQuery == "" { if r.Method == "GET" && r.ContentLength == 0 && r.URL.RawQuery == "" {
return return
} }
if responseCode, errorMessage := httpErrorResponse(r); responseCode != 0 { if code, err := validateRequest(r); err != nil {
http.Error(w, errorMessage, responseCode) http.Error(w, err.Error(), code)
return return
} }
// All checks passed, create a codec that reads direct from the request body // All checks passed, create a codec that reads direct from the request body
// untilEOF and writes the response to w and order the server to process a // untilEOF and writes the response to w and order the server to process a
// single request. // single request.
codec := NewJSONCodec(&httpReadWriteNopCloser{r.Body, w}) codec := NewJSONCodec(&httpReadWriteNopCloser{r.Body, w})
defer codec.Close() defer codec.Close()
w.Header().Set("content-type", "application/json") w.Header().Set("content-type", contentType)
srv.ServeSingleRequest(codec, OptionMethodInvocation) srv.ServeSingleRequest(codec, OptionMethodInvocation)
} }
// Returns a non-zero response code and error message if the request is invalid. // validateRequest returns a non-zero response code and error message if the
func httpErrorResponse(r *http.Request) (int, string) { // request is invalid.
func validateRequest(r *http.Request) (int, error) {
if r.Method == "PUT" || r.Method == "DELETE" { if r.Method == "PUT" || r.Method == "DELETE" {
errorMessage := "method not allowed" return http.StatusMethodNotAllowed, errors.New("method not allowed")
return http.StatusMethodNotAllowed, errorMessage
} }
if r.ContentLength > maxHTTPRequestContentLength { if r.ContentLength > maxHTTPRequestContentLength {
errorMessage := fmt.Sprintf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength) err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength)
return http.StatusRequestEntityTooLarge, errorMessage return http.StatusRequestEntityTooLarge, err
} }
mt, _, err := mime.ParseMediaType(r.Header.Get("content-type"))
ct := r.Header.Get("content-type")
mt, _, err := mime.ParseMediaType(ct)
if err != nil || mt != contentType { if err != nil || mt != contentType {
errorMessage := fmt.Sprintf("invalid content type, only %s is supported", contentType) err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, errorMessage return http.StatusUnsupportedMediaType, err
} }
return 0, nil
return 0, ""
} }
func newCorsHandler(srv *Server, allowedOrigins []string) http.Handler { func newCorsHandler(srv *Server, allowedOrigins []string) http.Handler {

View File

@ -1,3 +1,19 @@
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc package rpc
import ( import (
@ -8,33 +24,31 @@ import (
) )
func TestHTTPErrorResponseWithDelete(t *testing.T) { func TestHTTPErrorResponseWithDelete(t *testing.T) {
httpErrorResponseTest(t, "DELETE", contentType, "", http.StatusMethodNotAllowed) testHTTPErrorResponse(t, "DELETE", contentType, "", http.StatusMethodNotAllowed)
} }
func TestHTTPErrorResponseWithPut(t *testing.T) { func TestHTTPErrorResponseWithPut(t *testing.T) {
httpErrorResponseTest(t, "PUT", contentType, "", http.StatusMethodNotAllowed) testHTTPErrorResponse(t, "PUT", contentType, "", http.StatusMethodNotAllowed)
} }
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
body := make([]rune, maxHTTPRequestContentLength+1, maxHTTPRequestContentLength+1) body := make([]rune, maxHTTPRequestContentLength+1, maxHTTPRequestContentLength+1)
httpErrorResponseTest(t, testHTTPErrorResponse(t,
"POST", contentType, string(body), http.StatusRequestEntityTooLarge) "POST", contentType, string(body), http.StatusRequestEntityTooLarge)
} }
func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) { func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) {
httpErrorResponseTest(t, "POST", "", "", http.StatusUnsupportedMediaType) testHTTPErrorResponse(t, "POST", "", "", http.StatusUnsupportedMediaType)
} }
func TestHTTPErrorResponseWithValidRequest(t *testing.T) { func TestHTTPErrorResponseWithValidRequest(t *testing.T) {
httpErrorResponseTest(t, "POST", contentType, "", 0) testHTTPErrorResponse(t, "POST", contentType, "", 0)
} }
func httpErrorResponseTest(t *testing.T, func testHTTPErrorResponse(t *testing.T, method, contentType, body string, expected int) {
method, contentType, body string, expectedResponse int) {
request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body)) request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body))
request.Header.Set("content-type", contentType) request.Header.Set("content-type", contentType)
if response, _ := httpErrorResponse(request); response != expectedResponse { if code, _ := validateRequest(request); code != expected {
t.Fatalf("response code should be %d not %d", expectedResponse, response) t.Fatalf("response code should be %d not %d", expected, code)
} }
} }