rpc: fix connection tracking set in Server (#26180)

rpc: fix connection tracking in Server

When upgrading to mapset/v2 with generics, the set element type used in
rpc.Server had to be changed to *ServerCodec because ServerCodec is not
'comparable'. While the distinction is technically correct, we know all
possible ServerCodec types, and all of them are comparable. So just use
a map instead.
This commit is contained in:
Felix Lange 2022-11-15 14:05:16 +01:00 committed by GitHub
parent 9afc6816d2
commit ae42148093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,9 +19,9 @@ package rpc
import ( import (
"context" "context"
"io" "io"
"sync"
"sync/atomic" "sync/atomic"
mapset "github.com/deckarep/golang-set/v2"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
) )
@ -45,13 +45,19 @@ const (
type Server struct { type Server struct {
services serviceRegistry services serviceRegistry
idgen func() ID idgen func() ID
run int32
codecs mapset.Set[*ServerCodec] mutex sync.Mutex
codecs map[ServerCodec]struct{}
run int32
} }
// NewServer creates a new server instance with no registered handlers. // NewServer creates a new server instance with no registered handlers.
func NewServer() *Server { func NewServer() *Server {
server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet[*ServerCodec](), run: 1} server := &Server{
idgen: randomIDGenerator(),
codecs: make(map[ServerCodec]struct{}),
run: 1,
}
// Register the default service providing meta information about the RPC service such // Register the default service providing meta information about the RPC service such
// as the services and methods it offers. // as the services and methods it offers.
rpcService := &RPCService{server} rpcService := &RPCService{server}
@ -75,20 +81,34 @@ func (s *Server) RegisterName(name string, receiver interface{}) error {
func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
defer codec.close() defer codec.close()
// Don't serve if server is stopped. if !s.trackCodec(codec) {
if atomic.LoadInt32(&s.run) == 0 {
return return
} }
defer s.untrackCodec(codec)
// Add the codec to the set so it can be closed by Stop.
s.codecs.Add(&codec)
defer s.codecs.Remove(&codec)
c := initClient(codec, s.idgen, &s.services) c := initClient(codec, s.idgen, &s.services)
<-codec.closed() <-codec.closed()
c.Close() c.Close()
} }
func (s *Server) trackCodec(codec ServerCodec) bool {
s.mutex.Lock()
defer s.mutex.Unlock()
if atomic.LoadInt32(&s.run) == 0 {
return false // Don't serve if server is stopped.
}
s.codecs[codec] = struct{}{}
return true
}
func (s *Server) untrackCodec(codec ServerCodec) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.codecs, codec)
}
// serveSingleRequest reads and processes a single RPC request from the given codec. This // serveSingleRequest reads and processes a single RPC request from the given codec. This
// is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in // is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in
// this mode. // this mode.
@ -120,12 +140,14 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
// requests to finish, then closes all codecs which will cancel pending requests and // requests to finish, then closes all codecs which will cancel pending requests and
// subscriptions. // subscriptions.
func (s *Server) Stop() { func (s *Server) Stop() {
s.mutex.Lock()
defer s.mutex.Unlock()
if atomic.CompareAndSwapInt32(&s.run, 1, 0) { if atomic.CompareAndSwapInt32(&s.run, 1, 0) {
log.Debug("RPC server shutting down") log.Debug("RPC server shutting down")
s.codecs.Each(func(c *ServerCodec) bool { for codec := range s.codecs {
(*c).close() codec.close()
return true }
})
} }
} }