From 1db978ca6b28b28d5eda9cf04b519ee301b05345 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Sun, 19 Feb 2023 20:23:18 +0100 Subject: [PATCH] rpc: fix unmarshaling of null result in CallContext (#26723) The change fixes unmarshaling of JSON null results into json.RawMessage. --------- Co-authored-by: Jason Yuan Co-authored-by: Jason Yuan --- rpc/client.go | 5 ++++- rpc/client_test.go | 20 ++++++++++++++++++++ rpc/server_test.go | 2 +- rpc/testservice_test.go | 4 ++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/rpc/client.go b/rpc/client.go index a509cb2e0..69ff4851e 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -345,7 +345,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str case len(resp.Result) == 0: return ErrNoResult default: - return json.Unmarshal(resp.Result, &result) + if result == nil { + return nil + } + return json.Unmarshal(resp.Result, result) } } diff --git a/rpc/client_test.go b/rpc/client_test.go index 0a88ce40b..a94a54929 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -69,6 +69,26 @@ func TestClientResponseType(t *testing.T) { } } +// This test checks calling a method that returns 'null'. +func TestClientNullResponse(t *testing.T) { + server := newTestServer() + defer server.Stop() + + client := DialInProc(server) + defer client.Close() + + var result json.RawMessage + if err := client.Call(&result, "test_null"); err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("Expected non-nil result") + } + if !reflect.DeepEqual(result, json.RawMessage("null")) { + t.Errorf("Expected null, got %s", result) + } +} + // This test checks that server-returned errors with code and data come out of Client.Call. func TestClientErrorData(t *testing.T) { server := newTestServer() diff --git a/rpc/server_test.go b/rpc/server_test.go index c9abe53e5..f1a9b3d5c 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) { t.Fatalf("Expected service calc to be registered") } - wantCallbacks := 12 + wantCallbacks := 13 if len(svc.callbacks) != wantCallbacks { t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) } diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go index 8454a4019..eab67f1dd 100644 --- a/rpc/testservice_test.go +++ b/rpc/testservice_test.go @@ -78,6 +78,10 @@ func (o *MarshalErrObj) MarshalText() ([]byte, error) { func (s *testService) NoArgsRets() {} +func (s *testService) Null() any { + return nil +} + func (s *testService) Echo(str string, i int, args *echoArgs) echoResult { return echoResult{str, i, args} }