rpc: fix unmarshaling of null result in CallContext (#26723)

The change fixes unmarshaling of JSON null results into json.RawMessage.

---------

Co-authored-by: Jason Yuan <jason.yuan@curvegrid.com>
Co-authored-by: Jason Yuan <jason.yuan869@gmail.com>
This commit is contained in:
Felix Lange 2023-02-19 20:23:18 +01:00 committed by GitHub
parent 7c749c947a
commit 1db978ca6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 2 deletions

View File

@ -345,7 +345,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
case len(resp.Result) == 0: case len(resp.Result) == 0:
return ErrNoResult return ErrNoResult
default: default:
return json.Unmarshal(resp.Result, &result) if result == nil {
return nil
}
return json.Unmarshal(resp.Result, result)
} }
} }

View File

@ -69,6 +69,26 @@ func TestClientResponseType(t *testing.T) {
} }
} }
// This test checks calling a method that returns 'null'.
func TestClientNullResponse(t *testing.T) {
server := newTestServer()
defer server.Stop()
client := DialInProc(server)
defer client.Close()
var result json.RawMessage
if err := client.Call(&result, "test_null"); err != nil {
t.Fatal(err)
}
if result == nil {
t.Fatal("Expected non-nil result")
}
if !reflect.DeepEqual(result, json.RawMessage("null")) {
t.Errorf("Expected null, got %s", result)
}
}
// This test checks that server-returned errors with code and data come out of Client.Call. // This test checks that server-returned errors with code and data come out of Client.Call.
func TestClientErrorData(t *testing.T) { func TestClientErrorData(t *testing.T) {
server := newTestServer() server := newTestServer()

View File

@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) {
t.Fatalf("Expected service calc to be registered") t.Fatalf("Expected service calc to be registered")
} }
wantCallbacks := 12 wantCallbacks := 13
if len(svc.callbacks) != wantCallbacks { if len(svc.callbacks) != wantCallbacks {
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
} }

View File

@ -78,6 +78,10 @@ func (o *MarshalErrObj) MarshalText() ([]byte, error) {
func (s *testService) NoArgsRets() {} func (s *testService) NoArgsRets() {}
func (s *testService) Null() any {
return nil
}
func (s *testService) Echo(str string, i int, args *echoArgs) echoResult { func (s *testService) Echo(str string, i int, args *echoArgs) echoResult {
return echoResult{str, i, args} return echoResult{str, i, args}
} }