Improved RPC handling. WIP

This commit is contained in:
Age Manning 2019-07-09 15:44:23 +10:00
parent bb0e28b8e3
commit 4a84b2f7cc
No known key found for this signature in database
GPG Key ID: 05EED64B79E06A93
7 changed files with 720 additions and 339 deletions

View File

@ -22,3 +22,7 @@ futures = "0.1.25"
error-chain = "0.12.0" error-chain = "0.12.0"
tokio-timer = "0.2.10" tokio-timer = "0.2.10"
dirs = "2.0.1" dirs = "2.0.1"
tokio-io = "0.1.12"
smallvec = "0.6.10"
fnv = "1.0.6"
unsigned-varint = "0.2.2"

View File

@ -1,5 +1,5 @@
use crate::discovery::Discovery; use crate::discovery::Discovery;
use crate::rpc::{RPCEvent, RPCMessage, Rpc}; use crate::rpc::{RPCEvent, RPCMessage, RPC};
use crate::{error, NetworkConfig}; use crate::{error, NetworkConfig};
use crate::{Topic, TopicHash}; use crate::{Topic, TopicHash};
use futures::prelude::*; use futures::prelude::*;
@ -29,7 +29,7 @@ pub struct Behaviour<TSubstream: AsyncRead + AsyncWrite> {
/// The routing pub-sub mechanism for eth2. /// The routing pub-sub mechanism for eth2.
gossipsub: Gossipsub<TSubstream>, gossipsub: Gossipsub<TSubstream>,
/// The serenity RPC specified in the wire-0 protocol. /// The serenity RPC specified in the wire-0 protocol.
serenity_rpc: Rpc<TSubstream>, serenity_rpc: RPC<TSubstream>,
/// Keep regular connection to peers and disconnect if absent. /// Keep regular connection to peers and disconnect if absent.
ping: Ping<TSubstream>, ping: Ping<TSubstream>,
/// Kademlia for peer discovery. /// Kademlia for peer discovery.
@ -57,7 +57,7 @@ impl<TSubstream: AsyncRead + AsyncWrite> Behaviour<TSubstream> {
.with_keep_alive(false); .with_keep_alive(false);
Ok(Behaviour { Ok(Behaviour {
serenity_rpc: Rpc::new(log), serenity_rpc: RPC::new(log),
gossipsub: Gossipsub::new(local_peer_id.clone(), net_conf.gs_config.clone()), gossipsub: Gossipsub::new(local_peer_id.clone(), net_conf.gs_config.clone()),
discovery: Discovery::new(local_key, net_conf, log)?, discovery: Discovery::new(local_key, net_conf, log)?,
ping: Ping::new(ping_config), ping: Ping::new(ping_config),

View File

@ -1,37 +1,37 @@
use libp2p::core::protocols_handler::{ use super::protocol::{ProtocolId, RPCError, RPCProtocol, RPCRequest};
KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, use super::RPCEvent;
SubstreamProtocol use fnv::FnvHashMap;
};
use libp2p::core::upgrade::{InboundUpgrade, OutboundUpgrade};
use futures::prelude::*; use futures::prelude::*;
use libp2p::core::protocols_handler::{
KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol,
};
use libp2p::core::upgrade::{self, InboundUpgrade, OutboundUpgrade, WriteOne};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::{error, marker::PhantomData, time::Duration}; use std::time::{Duration, Instant};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use wasm_timer::Instant;
/// The time (in seconds) before a substream that is awaiting a response times out. /// The time (in seconds) before a substream that is awaiting a response times out.
pub const RESPONSE_TIMEOUT: u64 = 9; pub const RESPONSE_TIMEOUT: u64 = 9;
/// Implementation of `ProtocolsHandler` for the RPC protocol. /// Implementation of `ProtocolsHandler` for the RPC protocol.
pub struct RPCHandler<TSubstream> { pub struct RPCHandler<TSubstream> {
/// The upgrade for inbound substreams. /// The upgrade for inbound substreams.
listen_protocol: SubstreamProtocol<RPCProtocol>, listen_protocol: SubstreamProtocol<RPCProtocol>,
/// If `Some`, something bad happened and we should shut down the handler with an error. /// If `Some`, something bad happened and we should shut down the handler with an error.
pending_error: Option<ProtocolsHandlerUpgrErr<RPCRequest::Error>>, pending_error: Option<ProtocolsHandlerUpgrErr<RPCError>>,
/// Queue of events to produce in `poll()`. /// Queue of events to produce in `poll()`.
events_out: SmallVec<[TOutEvent; 4]>, events_out: SmallVec<[RPCEvent; 4]>,
/// Queue of outbound substreams to open. /// Queue of outbound substreams to open.
dial_queue: SmallVec<[(usize,TOutProto); 4]>, dial_queue: SmallVec<[(usize, RPCRequest); 4]>,
/// Current number of concurrent outbound substreams being opened. /// Current number of concurrent outbound substreams being opened.
dial_negotiated: u32, dial_negotiated: u32,
/// Map of current substreams awaiting a response to an RPC request. /// Map of current substreams awaiting a response to an RPC request.
waiting_substreams: FnvHashMap<u64, SubstreamState<TSubstream> waiting_substreams: FnvHashMap<usize, SubstreamState<TSubstream>>,
/// Sequential Id for waiting substreams. /// Sequential Id for waiting substreams.
current_substream_id: usize, current_substream_id: usize,
@ -50,19 +50,21 @@ pub struct RPCHandler<TSubstream> {
pub enum SubstreamState<TSubstream> { pub enum SubstreamState<TSubstream> {
/// An outbound substream is waiting a response from the user. /// An outbound substream is waiting a response from the user.
WaitingResponse { WaitingResponse {
stream: <TSubstream>, /// The negotiated substream.
timeout: Duration, substream: upgrade::Negotiated<TSubstream>,
} /// The protocol that was negotiated.
negotiated_protocol: ProtocolId,
/// The time until we close the substream.
timeout: Instant,
},
/// A response has been sent and we are waiting for the stream to close. /// A response has been sent and we are waiting for the stream to close.
ResponseSent(WriteOne<TSubstream, Vec<u8>) PendingWrite(WriteOne<upgrade::Negotiated<TSubstream>, Vec<u8>>),
} }
impl<TSubstream> impl<TSubstream> RPCHandler<TSubstream> {
RPCHandler<TSubstream>
{
pub fn new( pub fn new(
listen_protocol: SubstreamProtocol<RPCProtocol>, listen_protocol: SubstreamProtocol<RPCProtocol>,
inactive_timeout: Duration inactive_timeout: Duration,
) -> Self { ) -> Self {
RPCHandler { RPCHandler {
listen_protocol, listen_protocol,
@ -71,7 +73,7 @@ impl<TSubstream>
dial_queue: SmallVec::new(), dial_queue: SmallVec::new(),
dial_negotiated: 0, dial_negotiated: 0,
waiting_substreams: FnvHashMap::default(), waiting_substreams: FnvHashMap::default(),
curent_substream_id: 0, current_substream_id: 0,
max_dial_negotiated: 8, max_dial_negotiated: 8,
keep_alive: KeepAlive::Yes, keep_alive: KeepAlive::Yes,
inactive_timeout, inactive_timeout,
@ -87,7 +89,7 @@ impl<TSubstream>
/// ///
/// > **Note**: If you modify the protocol, modifications will only applies to future inbound /// > **Note**: If you modify the protocol, modifications will only applies to future inbound
/// > substreams, not the ones already being negotiated. /// > substreams, not the ones already being negotiated.
pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInProto> { pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<RPCProtocol> {
&self.listen_protocol &self.listen_protocol
} }
@ -95,36 +97,35 @@ impl<TSubstream>
/// ///
/// > **Note**: If you modify the protocol, modifications will only applies to future inbound /// > **Note**: If you modify the protocol, modifications will only applies to future inbound
/// > substreams, not the ones already being negotiated. /// > substreams, not the ones already being negotiated.
pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInProto> { pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<RPCProtocol> {
&mut self.listen_protocol &mut self.listen_protocol
} }
/// Opens an outbound substream with `upgrade`. /// Opens an outbound substream with `upgrade`.
#[inline] #[inline]
pub fn send_request(&mut self, request_id, u64, upgrade: RPCRequest) { pub fn send_request(&mut self, request_id: usize, upgrade: RPCRequest) {
self.keep_alive = KeepAlive::Yes; self.keep_alive = KeepAlive::Yes;
self.dial_queue.push((request_id, upgrade)); self.dial_queue.push((request_id, upgrade));
} }
} }
impl<TSubstream> Default impl<TSubstream> Default for RPCHandler<TSubstream> {
for RPCHandler<TSubstream>
{
fn default() -> Self { fn default() -> Self {
RPCHandler::new(SubstreamProtocol::new(RPCProtocol), Duration::from_secs(30)) RPCHandler::new(SubstreamProtocol::new(RPCProtocol), Duration::from_secs(30))
} }
} }
impl<TSubstream> ProtocolsHandler impl<TSubstream> ProtocolsHandler for RPCHandler<TSubstream>
for RPCHandler<TSubstream> where
TSubstream: AsyncRead + AsyncWrite,
{ {
type InEvent = RPCEvent; type InEvent = RPCEvent;
type OutEvent = RPCEvent; type OutEvent = RPCEvent;
type Error = ProtocolsHandlerUpgrErr<RPCRequest::Error>; type Error = ProtocolsHandlerUpgrErr<RPCError>;
type Substream = TSubstream; type Substream = TSubstream;
type InboundProtocol = RPCProtocol; type InboundProtocol = RPCProtocol;
type OutboundProtocol = RPCRequest; type OutboundProtocol = RPCRequest;
type OutboundOpenInfo = u64; // request_id type OutboundOpenInfo = usize; // request_id
#[inline] #[inline]
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> { fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
@ -134,35 +135,43 @@ impl<TSubstream> ProtocolsHandler
#[inline] #[inline]
fn inject_fully_negotiated_inbound( fn inject_fully_negotiated_inbound(
&mut self, &mut self,
out: RPCProtocol::Output, out: <RPCProtocol as InboundUpgrade<TSubstream>>::Output,
) { ) {
let (stream, req) = out; let (substream, req, negotiated_protocol) = out;
// drop the stream and return a 0 id for goodbye "requests" // drop the stream and return a 0 id for goodbye "requests"
if let req @ RPCRequest::Goodbye(_) = req { if let r @ RPCRequest::Goodbye(_) = req {
self.events_out.push(RPCEvent::Request(0, req)); self.events_out.push(RPCEvent::Request(0, r));
return; return;
} }
// New inbound request. Store the stream and tag the output. // New inbound request. Store the stream and tag the output.
let awaiting_stream = SubstreamState::WaitingResponse { stream, timeout: Instant::now() + Duration::from_secs(RESPONSE_TIMEOUT) }; let awaiting_stream = SubstreamState::WaitingResponse {
self.waiting_substreams.insert(self.current_substream_id, awaiting_stream); substream,
negotiated_protocol,
timeout: Instant::now() + Duration::from_secs(RESPONSE_TIMEOUT),
};
self.waiting_substreams
.insert(self.current_substream_id, awaiting_stream);
self.events_out.push(RPCEvent::Request(self.current_substream_id, req)); self.events_out
.push(RPCEvent::Request(self.current_substream_id, req));
self.current_substream_id += 1; self.current_substream_id += 1;
} }
#[inline] #[inline]
fn inject_fully_negotiated_outbound( fn inject_fully_negotiated_outbound(
&mut self, &mut self,
out: RPCResponse, out: <RPCRequest as OutboundUpgrade<TSubstream>>::Output,
request_id: Self::OutboundOpenInfo, request_id: Self::OutboundOpenInfo,
) { ) {
self.dial_negotiated -= 1; self.dial_negotiated -= 1;
if self.dial_negotiated == 0 && self.dial_queue.is_empty() && self.waiting_substreams.is_empty() { if self.dial_negotiated == 0
&& self.dial_queue.is_empty()
&& self.waiting_substreams.is_empty()
{
self.keep_alive = KeepAlive::Until(Instant::now() + self.inactive_timeout); self.keep_alive = KeepAlive::Until(Instant::now() + self.inactive_timeout);
} } else {
else {
self.keep_alive = KeepAlive::Yes; self.keep_alive = KeepAlive::Yes;
} }
@ -177,10 +186,19 @@ impl<TSubstream> ProtocolsHandler
RPCEvent::Request(rpc_id, req) => self.send_request(rpc_id, req), RPCEvent::Request(rpc_id, req) => self.send_request(rpc_id, req),
RPCEvent::Response(rpc_id, res) => { RPCEvent::Response(rpc_id, res) => {
// check if the stream matching the response still exists // check if the stream matching the response still exists
if let Some(mut waiting_stream) = self.waiting_substreams.get_mut(&rpc_id) { if let Some(waiting_stream) = self.waiting_substreams.get_mut(&rpc_id) {
// only send one response per stream. This must be in the waiting state. // only send one response per stream. This must be in the waiting state.
if let SubstreamState::WaitingResponse {substream, .. } = waiting_stream { if let SubstreamState::WaitingResponse {
waiting_stream = SubstreamState::PendingWrite(upgrade::write_one(substream, res)); substream,
negotiated_protocol,
..
} = *waiting_stream
{
*waiting_stream = SubstreamState::PendingWrite(upgrade::write_one(
substream,
res.encode(negotiated_protocol)
.expect("Response should always be encodeable"),
));
} }
} }
} }
@ -195,6 +213,7 @@ impl<TSubstream> ProtocolsHandler
<Self::OutboundProtocol as OutboundUpgrade<Self::Substream>>::Error, <Self::OutboundProtocol as OutboundUpgrade<Self::Substream>>::Error,
>, >,
) { ) {
dbg!(error);
if self.pending_error.is_none() { if self.pending_error.is_none() {
self.pending_error = Some(error); self.pending_error = Some(error);
} }
@ -217,20 +236,24 @@ impl<TSubstream> ProtocolsHandler
// prioritise sending responses for waiting substreams // prioritise sending responses for waiting substreams
self.waiting_substreams.retain(|_k, mut waiting_stream| { self.waiting_substreams.retain(|_k, mut waiting_stream| {
match waiting_stream => { match waiting_stream {
SubstreamState::PendingWrite(write_one) => { SubstreamState::PendingWrite(write_one) => {
match write_one.poll() => { match write_one.poll() {
Ok(Async::Ready(_socket)) => false, Ok(Async::Ready(_socket)) => false,
Ok(Async::NotReady()) => true, Ok(Async::NotReady) => true,
Err(_e) => { Err(_e) => {
//TODO: Add logging //TODO: Add logging
// throw away streams that error // throw away streams that error
false false
} }
} }
}, }
SubstreamState::WaitingResponse { timeout, .. } => { SubstreamState::WaitingResponse { timeout, .. } => {
if Instant::now() > timeout { false} else { true } if Instant::now() > *timeout {
false
} else {
true
}
} }
} }
}); });

View File

@ -4,23 +4,6 @@ use ssz::{impl_decode_via_from, impl_encode_via_from};
use ssz_derive::{Decode, Encode}; use ssz_derive::{Decode, Encode};
use types::{Epoch, Hash256, Slot}; use types::{Epoch, Hash256, Slot};
#[derive(Debug, Clone)]
pub enum RPCResponse {
Hello(HelloMessage),
Goodbye, // empty value - required for protocol handler
BeaconBlockRoots(BeaconBlockRootsResponse),
BeaconBlockHeaders(BeaconBlockHeadersResponse),
BeaconBlockBodies(BeaconBlockBodiesResponse),
BeaconChainState(BeaconChainStateResponse),
}
pub enum ResponseCode {
Success = 0,
EncodingError = 1,
InvalidRequest = 2,
ServerError = 3,
}
/* Request/Response data structures for RPC methods */ /* Request/Response data structures for RPC methods */
/* Requests */ /* Requests */
@ -78,7 +61,6 @@ impl From<u64> for Goodbye {
} }
} }
impl_encode_via_from!(Goodbye, u64);
impl_decode_via_from!(Goodbye, u64); impl_decode_via_from!(Goodbye, u64);
/// Request a number of beacon block roots from a peer. /// Request a number of beacon block roots from a peer.
@ -108,7 +90,7 @@ pub struct BlockRootSlot {
pub slot: Slot, pub slot: Slot,
} }
/// The response of a beacl block roots request. /// The response of a beacon block roots request.
impl BeaconBlockRootsResponse { impl BeaconBlockRootsResponse {
/// Returns `true` if each `self.roots.slot[i]` is higher than the preceding `i`. /// Returns `true` if each `self.roots.slot[i]` is higher than the preceding `i`.
pub fn slots_are_ascending(&self) -> bool { pub fn slots_are_ascending(&self) -> bool {

View File

@ -1,39 +1,41 @@
/// The Ethereum 2.0 Wire Protocol ///! The Ethereum 2.0 Wire Protocol
/// ///!
/// This protocol is a purpose built Ethereum 2.0 libp2p protocol. It's role is to facilitate ///! This protocol is a purpose built Ethereum 2.0 libp2p protocol. It's role is to facilitate
/// direct peer-to-peer communication primarily for sending/receiving chain information for ///! direct peer-to-peer communication primarily for sending/receiving chain information for
/// syncing. ///! syncing.
///
pub mod methods;
mod protocol;
use futures::prelude::*; use futures::prelude::*;
use libp2p::core::protocols_handler::{OneShotHandler, ProtocolsHandler}; use handler::RPCHandler;
use libp2p::core::protocols_handler::ProtocolsHandler;
use libp2p::core::swarm::{ use libp2p::core::swarm::{
ConnectedPoint, NetworkBehaviour, NetworkBehaviourAction, PollParameters, ConnectedPoint, NetworkBehaviour, NetworkBehaviourAction, PollParameters,
}; };
use libp2p::{Multiaddr, PeerId}; use libp2p::{Multiaddr, PeerId};
pub use methods::{HelloMessage, RPCResponse}; pub use methods::HelloMessage;
pub use protocol::{RPCProtocol, RPCRequest}; pub use protocol::{RPCProtocol, RPCRequest, RPCResponse};
use slog::o; use slog::o;
use std::marker::PhantomData; use std::marker::PhantomData;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
mod handler;
pub mod methods;
mod protocol;
mod request_response;
/// The return type used in the behaviour and the resultant event from the protocols handler. /// The return type used in the behaviour and the resultant event from the protocols handler.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum RPCEvent { pub enum RPCEvent {
/// A request that was received from the RPC protocol. The first parameter is a sequential /// A request that was received from the RPC protocol. The first parameter is a sequential
/// id which tracks an awaiting substream for the response. /// id which tracks an awaiting substream for the response.
Request(u64, RPCRequest), Request(usize, RPCRequest),
/// A response that has been received from the RPC protocol. The first parameter returns /// A response that has been received from the RPC protocol. The first parameter returns
/// that which was sent with the corresponding request. /// that which was sent with the corresponding request.
Response(u64, RPCResponse), Response(usize, RPCResponse),
} }
/// Rpc implements the libp2p `NetworkBehaviour` trait and therefore manages network-level /// Implements the libp2p `NetworkBehaviour` trait and therefore manages network-level
/// logic. /// logic.
pub struct Rpc<TSubstream> { pub struct RPC<TSubstream> {
/// Queue of events to processed. /// Queue of events to processed.
events: Vec<NetworkBehaviourAction<RPCEvent, RPCMessage>>, events: Vec<NetworkBehaviourAction<RPCEvent, RPCMessage>>,
/// Pins the generic substream. /// Pins the generic substream.
@ -42,10 +44,10 @@ pub struct Rpc<TSubstream> {
_log: slog::Logger, _log: slog::Logger,
} }
impl<TSubstream> Rpc<TSubstream> { impl<TSubstream> RPC<TSubstream> {
pub fn new(log: &slog::Logger) -> Self { pub fn new(log: &slog::Logger) -> Self {
let log = log.new(o!("Service" => "Libp2p-RPC")); let log = log.new(o!("Service" => "Libp2p-RPC"));
Rpc { RPC {
events: Vec::new(), events: Vec::new(),
marker: PhantomData, marker: PhantomData,
_log: log, _log: log,
@ -63,7 +65,7 @@ impl<TSubstream> Rpc<TSubstream> {
} }
} }
impl<TSubstream> NetworkBehaviour for Rpc<TSubstream> impl<TSubstream> NetworkBehaviour for RPC<TSubstream>
where where
TSubstream: AsyncRead + AsyncWrite, TSubstream: AsyncRead + AsyncWrite,
{ {
@ -95,12 +97,6 @@ where
source: PeerId, source: PeerId,
event: <Self::ProtocolsHandler as ProtocolsHandler>::OutEvent, event: <Self::ProtocolsHandler as ProtocolsHandler>::OutEvent,
) { ) {
// ignore successful send events
let event = match event {
HandlerEvent::Rx(event) => event,
HandlerEvent::Sent => return,
};
// send the event to the user // send the event to the user
self.events self.events
.push(NetworkBehaviourAction::GenerateEvent(RPCMessage::RPC( .push(NetworkBehaviourAction::GenerateEvent(RPCMessage::RPC(
@ -129,26 +125,3 @@ pub enum RPCMessage {
RPC(PeerId, RPCEvent), RPC(PeerId, RPCEvent),
PeerDialed(PeerId), PeerDialed(PeerId),
} }
/// The output type received from the `OneShotHandler`.
#[derive(Debug)]
pub enum HandlerEvent {
/// An RPC was received from a remote.
Rx(RPCEvent),
/// An RPC was sent.
Sent,
}
impl From<RPCEvent> for HandlerEvent {
#[inline]
fn from(rpc: RPCEvent) -> HandlerEvent {
HandlerEvent::Rx(rpc)
}
}
impl From<()> for HandlerEvent {
#[inline]
fn from(_: ()) -> HandlerEvent {
HandlerEvent::Sent
}
}

View File

@ -1,10 +1,13 @@
use super::methods::*; use super::methods::*;
use super::request_response::{rpc_request_response, RPCRequestResponse};
use futures::future::Future;
use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use ssz::{Decode, Encode}; use ssz::{Decode, Encode};
use std::hash::Hasher;
use std::io; use std::io;
use std::iter; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::prelude::future::MapErr;
use tokio::util::FutureExt;
/// The maximum bytes that can be sent across the RPC. /// The maximum bytes that can be sent across the RPC.
const MAX_RPC_SIZE: usize = 4_194_304; // 4M const MAX_RPC_SIZE: usize = 4_194_304; // 4M
@ -33,42 +36,6 @@ impl UpgradeInfo for RPCProtocol {
} }
} }
/* Inbound upgrade */
// The inbound protocol reads the request, decodes it and returns the stream to the protocol
// handler to respond to once ready.
type FnDecodeRPCEvent<TSocket> = fn(
upgrade::Negotiated<TSocket>,
Vec<u8>,
(),
) -> Result<(upgrade::Negotiated<TSocket>, RPCEvent), RPCError>;
impl<TSocket> InboundUpgrade<TSocket> for RPCProtocol
where
TSocket: AsyncRead + AsyncWrite,
{
type Output = (upgrade::Negotiated<TSocket>, RPCEvent);
type Error = RPCError;
type Future = upgrade::ReadRespond<upgrade::Negotiated<TSocket>, (), FnDecodeRPCEvent<TSocket>>;
fn upgrade_inbound(
self,
socket: upgrade::Negotiated<TSocket>,
protocol: Self::Info,
) -> Self::Future {
upgrade::read_respond(socket, MAX_RPC_SIZE, (), |socket, packet, ()| {
Ok((socket, decode_request(packet, protocol)?))
})
.timeout(Duration::from_secs(RESPONSE_TIMEOUT))
}
}
/* Outbound request */
// Combines all the RPC requests into a single enum to implement `UpgradeInfo` and
// `OutboundUpgrade`
/// The raw protocol id sent over the wire. /// The raw protocol id sent over the wire.
type RawProtocolId = Vec<u8>; type RawProtocolId = Vec<u8>;
@ -86,38 +53,100 @@ pub struct ProtocolId {
/// An RPC protocol ID. /// An RPC protocol ID.
impl ProtocolId { impl ProtocolId {
pub fn new(message_name: String, version: usize, encoding: String) -> Self { pub fn new(message_name: &str, version: usize, encoding: &str) -> Self {
ProtocolId { ProtocolId {
message_name, message_name: message_name.into(),
version, version,
encoding, encoding: encoding.into(),
} }
} }
/// Converts a raw RPC protocol id string into an `RPCProtocolId` /// Converts a raw RPC protocol id string into an `RPCProtocolId`
pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, RPCError> { pub fn from_bytes(bytes: &[u8]) -> Result<Self, RPCError> {
let protocol_string = String::from_utf8(bytes.as_vec()) let protocol_string = String::from_utf8(bytes.to_vec())
.map_err(|_| RPCError::InvalidProtocol("Invalid protocol Id"))?; .map_err(|_| RPCError::InvalidProtocol("Invalid protocol Id"))?;
let protocol_string = protocol_string.as_str().split('/'); let protocol_list: Vec<&str> = protocol_string.as_str().split('/').take(5).collect();
if protocol_list.len() != 5 {
return Err(RPCError::InvalidProtocol("Not enough '/'"));
}
Ok(ProtocolId { Ok(ProtocolId {
message_name: protocol_string[3], message_name: protocol_list[3].into(),
version: protocol_string[4], version: protocol_list[4]
encoding: protocol_string[5], .parse()
.map_err(|_| RPCError::InvalidProtocol("Invalid version"))?,
encoding: protocol_list[5].into(),
}) })
} }
} }
impl Into<RawProtocolId> for ProtocolId { impl Into<RawProtocolId> for ProtocolId {
fn into(&self) -> [u8] { fn into(self) -> RawProtocolId {
&format!( format!(
"{}/{}/{}/{}", "{}/{}/{}/{}",
PROTOCOL_PREFIX, self.message_name, self.version, self.encoding PROTOCOL_PREFIX, self.message_name, self.version, self.encoding
) )
.as_bytes() .as_bytes()
.to_vec()
} }
} }
/* Inbound upgrade */
// The inbound protocol reads the request, decodes it and returns the stream to the protocol
// handler to respond to once ready.
type FnDecodeRPCEvent<TSocket> =
fn(
upgrade::Negotiated<TSocket>,
Vec<u8>,
&'static [u8], // protocol id
) -> Result<(upgrade::Negotiated<TSocket>, RPCRequest, ProtocolId), RPCError>;
impl<TSocket> InboundUpgrade<TSocket> for RPCProtocol
where
TSocket: AsyncRead + AsyncWrite,
{
type Output = (upgrade::Negotiated<TSocket>, RPCRequest, ProtocolId);
type Error = RPCError;
type Future = MapErr<
tokio_timer::Timeout<
upgrade::ReadRespond<
upgrade::Negotiated<TSocket>,
Self::Info,
FnDecodeRPCEvent<TSocket>,
>,
>,
fn(tokio::timer::timeout::Error<RPCError>) -> RPCError,
>;
fn upgrade_inbound(
self,
socket: upgrade::Negotiated<TSocket>,
protocol: &'static [u8],
) -> Self::Future {
upgrade::read_respond(socket, MAX_RPC_SIZE, protocol, {
|socket, packet, protocol| {
let protocol_id = ProtocolId::from_bytes(protocol)?;
Ok((
socket,
RPCRequest::decode(packet, protocol_id)?,
protocol_id,
))
}
}
as FnDecodeRPCEvent<TSocket>)
.timeout(Duration::from_secs(RESPONSE_TIMEOUT))
.map_err(RPCError::from)
}
}
/* Outbound request */
// Combines all the RPC requests into a single enum to implement `UpgradeInfo` and
// `OutboundUpgrade`
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum RPCRequest { pub enum RPCRequest {
Hello(HelloMessage), Hello(HelloMessage),
@ -154,18 +183,16 @@ impl RPCRequest {
RPCRequest::BeaconBlockBodies(_) => { RPCRequest::BeaconBlockBodies(_) => {
vec![ProtocolId::new("beacon_block_bodies", 1, "ssz").into()] vec![ProtocolId::new("beacon_block_bodies", 1, "ssz").into()]
} }
RPCRequest::BeaconBlockState(_) => { RPCRequest::BeaconChainState(_) => {
vec![ProtocolId::new("beacon_block_state", 1, "ssz").into()] vec![ProtocolId::new("beacon_block_state", 1, "ssz").into()]
} }
} }
} }
/// Encodes the Request object based on the negotiated protocol. /// Encodes the Request object based on the negotiated protocol.
pub fn encode(&self, protocol: RawProtocolId) -> Result<Vec<u8>, io::Error> { pub fn encode(&self, protocol: ProtocolId) -> Result<Vec<u8>, RPCError> {
// Assume select has given a supported protocol.
let protocol = ProtocolId::from_bytes(protocol)?;
// Match on the encoding and in the future, the version // Match on the encoding and in the future, the version
match protocol.encoding { match protocol.encoding.as_str() {
"ssz" => Ok(self.ssz_encode()), "ssz" => Ok(self.ssz_encode()),
_ => { _ => {
return Err(RPCError::Custom(format!( return Err(RPCError::Custom(format!(
@ -176,7 +203,7 @@ impl RPCRequest {
} }
} }
fn ssz_encode(&self) { fn ssz_encode(&self) -> Vec<u8> {
match self { match self {
RPCRequest::Hello(req) => req.as_ssz_bytes(), RPCRequest::Hello(req) => req.as_ssz_bytes(),
RPCRequest::Goodbye(req) => req.as_ssz_bytes(), RPCRequest::Goodbye(req) => req.as_ssz_bytes(),
@ -186,61 +213,34 @@ impl RPCRequest {
RPCRequest::BeaconChainState(req) => req.as_ssz_bytes(), RPCRequest::BeaconChainState(req) => req.as_ssz_bytes(),
} }
} }
}
/* Outbound upgrades */
impl<TSocket> OutboundUpgrade<TSocket> for RPCRequest
where
TSocket: AsyncWrite,
{
type Output = RPCResponse;
type Error = RPCResponse;
type Future = upgrade::RequestResponse<upgrade::Negotiated<TSocket>>;
fn upgrade_outbound(
self,
socket: upgrade::Negotiated<TSocket>,
protocol: Self::Info,
) -> Self::Future {
let bytes = self.encode(protocol);
wait_for_response = if let RPCRequest::Goodbye(_) = self {
false
} else {
true
};
// TODO: Reimplement request_response
upgrade::request_response(socket, bytes, MAX_RPC_SIZE, protocol, |packet, protocol| {
Ok(decode_response(packet, protocol)?)
})
.timeout(Duration::from_secs(RESPONSE_TIMEOUT))
}
}
/* Decoding for Requests/Responses */
// This function can be extended to provide further logic for supporting various protocol versions/encoding // This function can be extended to provide further logic for supporting various protocol versions/encoding
fn decode_request(packet: Vec<u8>, protocol: ProtocolId) -> Result<RPCRequest, io::Error> { /// Decodes a request received from our peer.
let protocol_id = ProtocolId::from_bytes(protocol); pub fn decode(packet: Vec<u8>, protocol: ProtocolId, response_code: ResponseCode) -> Result<Self, RPCError> {
match protocol_id.message_name { match response_code {
"hello" => match protocol_id.version { ResponseCode::
"1" => match protocol_id.encoding {
match protocol.message_name.as_str() {
"hello" => match protocol.version {
1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCRequest::Hello(HelloMessage::from_ssz_bytes(&packet)?)), "ssz" => Ok(RPCRequest::Hello(HelloMessage::from_ssz_bytes(&packet)?)),
_ => Err(RPCError::InvalidProtocol("Unknown HELLO encoding")), _ => Err(RPCError::InvalidProtocol("Unknown HELLO encoding")),
}, },
_ => Err(RPCError::InvalidProtocol("Unknown HELLO version")), _ => Err(RPCError::InvalidProtocol("Unknown HELLO version")),
}, },
"goodbye" => match protocol_id.version { "goodbye" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCRequest::Goodbye(Goodbye::from_ssz_bytes(&packet)?)), "ssz" => Ok(RPCRequest::Goodbye(Goodbye::from_ssz_bytes(&packet)?)),
_ => Err(RPCError::InvalidProtocol("Unknown GOODBYE encoding")), _ => Err(RPCError::InvalidProtocol("Unknown GOODBYE encoding")),
}, },
_ => Err(RPCError::InvalidProtocol("Unknown GOODBYE version")), _ => Err(RPCError::InvalidProtocol("Unknown GOODBYE version")),
}, },
"beacon_block_roots" => match protocol_id.version { "beacon_block_roots" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCRequest::BeaconBlockRooots( "ssz" => Ok(RPCRequest::BeaconBlockRoots(
BeaconBlockRootsRequest::from_ssz_bytes(&packet)?, BeaconBlockRootsRequest::from_ssz_bytes(&packet)?,
)), )),
_ => Err(RPCError::InvalidProtocol( _ => Err(RPCError::InvalidProtocol(
@ -251,10 +251,10 @@ fn decode_request(packet: Vec<u8>, protocol: ProtocolId) -> Result<RPCRequest, i
"Unknown BEACON_BLOCK_ROOTS version", "Unknown BEACON_BLOCK_ROOTS version",
)), )),
}, },
"beacon_block_headers" => match protocol_id.version { "beacon_block_headers" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCRequest::BeaconBlockHeaders( "ssz" => Ok(RPCRequest::BeaconBlockHeaders(
BeaconBlockHeadersRequest::from_ssz_bytes(&packet), BeaconBlockHeadersRequest::from_ssz_bytes(&packet)?,
)), )),
_ => Err(RPCError::InvalidProtocol( _ => Err(RPCError::InvalidProtocol(
"Unknown BEACON_BLOCK_HEADERS encoding", "Unknown BEACON_BLOCK_HEADERS encoding",
@ -264,8 +264,8 @@ fn decode_request(packet: Vec<u8>, protocol: ProtocolId) -> Result<RPCRequest, i
"Unknown BEACON_BLOCK_HEADERS version", "Unknown BEACON_BLOCK_HEADERS version",
)), )),
}, },
"beacon_block_bodies" => match protocol_id.version { "beacon_block_bodies" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCRequest::BeaconBlockBodies( "ssz" => Ok(RPCRequest::BeaconBlockBodies(
BeaconBlockBodiesRequest::from_ssz_bytes(&packet)?, BeaconBlockBodiesRequest::from_ssz_bytes(&packet)?,
)), )),
@ -277,8 +277,8 @@ fn decode_request(packet: Vec<u8>, protocol: ProtocolId) -> Result<RPCRequest, i
"Unknown BEACON_BLOCK_BODIES version", "Unknown BEACON_BLOCK_BODIES version",
)), )),
}, },
"beacon_chain_state" => match protocol_id.version { "beacon_chain_state" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCRequest::BeaconChainState( "ssz" => Ok(RPCRequest::BeaconChainState(
BeaconChainStateRequest::from_ssz_bytes(&packet)?, BeaconChainStateRequest::from_ssz_bytes(&packet)?,
)), )),
@ -292,23 +292,66 @@ fn decode_request(packet: Vec<u8>, protocol: ProtocolId) -> Result<RPCRequest, i
}, },
} }
} }
}
/* Response Type */
#[derive(Debug, Clone)]
pub enum RPCResponse {
/// A HELLO message.
Hello(HelloMessage),
/// An empty field returned from sending a GOODBYE request.
Goodbye, // empty value - required for protocol handler
/// A response to a get BEACON_BLOCK_ROOTS request.
BeaconBlockRoots(BeaconBlockRootsResponse),
/// A response to a get BEACON_BLOCK_HEADERS request.
BeaconBlockHeaders(BeaconBlockHeadersResponse),
/// A response to a get BEACON_BLOCK_BODIES request.
BeaconBlockBodies(BeaconBlockBodiesResponse),
/// A response to a get BEACON_CHAIN_STATE request.
BeaconChainState(BeaconChainStateResponse),
/// The Error returned from the peer during a request.
Error(String),
}
pub enum ResponseCode {
Success = 0,
EncodingError = 1,
InvalidRequest = 2,
ServerError = 3,
Unknown,
}
impl From<u64> for ResponseCode {
fn from(val: u64) -> ResponseCode {
match val {
0 => ResponseCode::Success,
1 => ResponseCode::EncodingError,
2 => ResponseCode::InvalidRequest,
3 => ResponseCode::ServerError,
_ => ResponseCode::Unknown,
}
}
}
impl RPCResponse {
/// Decodes a response that was received on the same stream as a request. The response type should /// Decodes a response that was received on the same stream as a request. The response type should
/// therefore match the request protocol type. /// therefore match the request protocol type.
fn decode_response(packet: Vec<u8>, protocol: RawProtocolId) -> Result<RPCResponse, RPCError> { fn decode(packet: Vec<u8>, protocol: ProtocolId) -> Result<Self, RPCError> {
let protocol_id = ProtocolId::from_bytes(protocol)?; match protocol.message_name.as_str() {
"hello" => match protocol.version {
match protocol_id.message_name { 1 => match protocol.encoding.as_str() {
"hello" => match protocol_id.version {
"1" => match protocol_id.encoding {
"ssz" => Ok(RPCResponse::Hello(HelloMessage::from_ssz_bytes(&packet)?)), "ssz" => Ok(RPCResponse::Hello(HelloMessage::from_ssz_bytes(&packet)?)),
_ => Err(RPCError::InvalidProtocol("Unknown HELLO encoding")), _ => Err(RPCError::InvalidProtocol("Unknown HELLO encoding")),
}, },
_ => Err(RPCError::InvalidProtocol("Unknown HELLO version")), _ => Err(RPCError::InvalidProtocol("Unknown HELLO version")),
}, },
"goodbye" => Err(RPCError::Custom("GOODBYE should not have a response")), "goodbye" => Err(RPCError::Custom(
"beacon_block_roots" => match protocol_id.version { "GOODBYE should not have a response".into(),
"1" => match protocol_id.encoding { )),
"beacon_block_roots" => match protocol.version {
1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCResponse::BeaconBlockRoots( "ssz" => Ok(RPCResponse::BeaconBlockRoots(
BeaconBlockRootsResponse::from_ssz_bytes(&packet)?, BeaconBlockRootsResponse::from_ssz_bytes(&packet)?,
)), )),
@ -320,8 +363,8 @@ fn decode_response(packet: Vec<u8>, protocol: RawProtocolId) -> Result<RPCRespon
"Unknown BEACON_BLOCK_ROOTS version", "Unknown BEACON_BLOCK_ROOTS version",
)), )),
}, },
"beacon_block_headers" => match protocol_id.version { "beacon_block_headers" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCResponse::BeaconBlockHeaders( "ssz" => Ok(RPCResponse::BeaconBlockHeaders(
BeaconBlockHeadersResponse { headers: packet }, BeaconBlockHeadersResponse { headers: packet },
)), )),
@ -333,8 +376,8 @@ fn decode_response(packet: Vec<u8>, protocol: RawProtocolId) -> Result<RPCRespon
"Unknown BEACON_BLOCK_HEADERS version", "Unknown BEACON_BLOCK_HEADERS version",
)), )),
}, },
"beacon_block_bodies" => match protocol_id.version { "beacon_block_bodies" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(RPCResponse::BeaconBlockBodies(BeaconBlockBodiesResponse { "ssz" => Ok(RPCResponse::BeaconBlockBodies(BeaconBlockBodiesResponse {
block_bodies: packet, block_bodies: packet,
})), })),
@ -346,9 +389,11 @@ fn decode_response(packet: Vec<u8>, protocol: RawProtocolId) -> Result<RPCRespon
"Unknown BEACON_BLOCK_BODIES version", "Unknown BEACON_BLOCK_BODIES version",
)), )),
}, },
"beacon_chain_state" => match protocol_id.version { "beacon_chain_state" => match protocol.version {
"1" => match protocol_id.encoding { 1 => match protocol.encoding.as_str() {
"ssz" => Ok(BeaconChainStateRequest::from_ssz_bytes(&packet)?), "ssz" => Ok(RPCResponse::BeaconChainState(
BeaconChainStateResponse::from_ssz_bytes(&packet)?,
)),
_ => Err(RPCError::InvalidProtocol( _ => Err(RPCError::InvalidProtocol(
"Unknown BEACON_CHAIN_STATE encoding", "Unknown BEACON_CHAIN_STATE encoding",
)), )),
@ -360,6 +405,74 @@ fn decode_response(packet: Vec<u8>, protocol: RawProtocolId) -> Result<RPCRespon
} }
} }
/// Encodes the Response object based on the negotiated protocol.
pub fn encode(&self, protocol: ProtocolId) -> Result<Vec<u8>, RPCError> {
// Match on the encoding and in the future, the version
match protocol.encoding.as_str() {
"ssz" => Ok(self.ssz_encode()),
_ => {
return Err(RPCError::Custom(format!(
"Unknown Encoding: {}",
protocol.encoding
)))
}
}
}
fn ssz_encode(&self) -> Vec<u8> {
match self {
RPCResponse::Hello(res) => res.as_ssz_bytes(),
RPCResponse::Goodbye => unreachable!(),
RPCResponse::BeaconBlockRoots(res) => res.as_ssz_bytes(),
RPCResponse::BeaconBlockHeaders(res) => res.headers, // already raw bytes
RPCResponse::BeaconBlockBodies(res) => res.block_bodies, // already raw bytes
RPCResponse::BeaconChainState(res) => res.as_ssz_bytes(),
}
}
}
/* Outbound upgrades */
impl<TSocket> OutboundUpgrade<TSocket> for RPCRequest
where
TSocket: AsyncRead + AsyncWrite,
{
type Output = RPCResponse;
type Error = RPCError;
type Future = MapErr<
tokio_timer::Timeout<RPCRequestResponse<upgrade::Negotiated<TSocket>, Vec<u8>>>,
fn(tokio::timer::timeout::Error<RPCError>) -> RPCError,
>;
fn upgrade_outbound(
self,
socket: upgrade::Negotiated<TSocket>,
protocol: Self::Info,
) -> Self::Future {
let protocol_id = ProtocolId::from_bytes(&protocol)
.expect("Protocol ID must be valid for outbound requests");
let request_bytes = self
.encode(protocol_id)
.expect("Should be able to encode a supported protocol");
// if sending a goodbye, drop the stream and return an empty GOODBYE response
let short_circuit_return = if let RPCRequest::Goodbye(_) = self {
Some(RPCResponse::Goodbye)
} else {
None
};
rpc_request_response(
socket,
request_bytes,
MAX_RPC_SIZE,
short_circuit_return,
protocol_id,
)
.timeout(Duration::from_secs(RESPONSE_TIMEOUT))
.map_err(RPCError::from)
}
}
/// Error in RPC Encoding/Decoding. /// Error in RPC Encoding/Decoding.
#[derive(Debug)] #[derive(Debug)]
pub enum RPCError { pub enum RPCError {
@ -367,8 +480,12 @@ pub enum RPCError {
ReadError(upgrade::ReadOneError), ReadError(upgrade::ReadOneError),
/// Error when decoding the raw buffer from ssz. /// Error when decoding the raw buffer from ssz.
SSZDecodeError(ssz::DecodeError), SSZDecodeError(ssz::DecodeError),
/// Invalid Protocol ID /// Invalid Protocol ID.
InvalidProtocol(&'static str), InvalidProtocol(&'static str),
/// IO Error.
IoError(io::Error),
/// Waiting for a request/response timed out, or timer error'd.
StreamTimeout,
/// Custom message. /// Custom message.
Custom(String), Custom(String),
} }
@ -386,3 +503,46 @@ impl From<ssz::DecodeError> for RPCError {
RPCError::SSZDecodeError(err) RPCError::SSZDecodeError(err)
} }
} }
impl<T> From<tokio::timer::timeout::Error<T>> for RPCError {
fn from(err: tokio::timer::timeout::Error<T>) -> Self {
if err.is_elapsed() {
RPCError::StreamTimeout
} else {
RPCError::Custom("Stream timer failed".into())
}
}
}
impl From<io::Error> for RPCError {
fn from(err: io::Error) -> Self {
RPCError::IoError(err)
}
}
// Error trait is required for `ProtocolsHandler`
impl std::fmt::Display for RPCError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
RPCError::ReadError(ref err) => write!(f, "Error while reading from socket: {}", err),
RPCError::SSZDecodeError(ref err) => write!(f, "Error while decoding ssz: {:?}", err),
RPCError::InvalidProtocol(ref err) => write!(f, "Invalid Protocol: {}", err),
RPCError::IoError(ref err) => write!(f, "IO Error: {}", err),
RPCError::StreamTimeout => write!(f, "Stream Timeout"),
RPCError::Custom(ref err) => write!(f, "{}", err),
}
}
}
impl std::error::Error for RPCError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match *self {
RPCError::ReadError(ref err) => Some(err),
RPCError::SSZDecodeError(ref err) => None,
RPCError::InvalidProtocol(ref err) => None,
RPCError::IoError(ref err) => Some(err),
RPCError::StreamTimeout => None,
RPCError::Custom(ref err) => None,
}
}
}

View File

@ -0,0 +1,239 @@
use super::protocol::{ProtocolId, RPCError, RPCResponse, ResponseCode};
use futures::prelude::*;
use futures::try_ready;
use libp2p::core::upgrade::{read_one, ReadOne, ReadOneError};
use std::mem;
use tokio_io::{io, AsyncRead, AsyncWrite};
/// Sends a message over a socket, waits for a response code, then optionally waits for a response.
///
/// The response code is a 1-byte code which determines whether the request succeeded or not.
/// Depending on the response-code, an error may be returned. On success, a response is then
/// retrieved if required.
/// This function also gives an option to terminate the socket and return a default value, allowing for
/// one-shot requests.
///
/// The `short_circuit_return` parameter, if specified, returns the value without awaiting for a
/// response to a request and performing the logic in `then`.
#[inline]
pub fn rpc_request_response<TSocket, TData>(
socket: TSocket,
data: TData, // data sent as a request
max_size: usize, // maximum bytes to read in a response
short_circuit_return: Option<RPCResponse>, // default value to return right after a request, do not wait for a response
protocol: ProtocolId, // the protocol being negotiated
) -> RPCRequestResponse<TSocket, TData>
where
TSocket: AsyncRead + AsyncWrite,
TData: AsRef<[u8]>,
{
RPCRequestResponse {
protocol,
inner: RPCRequestResponseInner::Write(
write_one(socket, data).inner,
max_size,
short_circuit_return,
),
}
}
/// Future that makes `rpc_request_response` work.
pub struct RPCRequestResponse<TSocket, TData = Vec<u8>> {
protocol: ProtocolId,
inner: RPCRequestResponseInner<TSocket, TData>,
}
enum RPCRequestResponseInner<TSocket, TData> {
// We need to write data to the socket.
Write(WriteOneInner<TSocket, TData>, usize, Option<RPCResponse>),
// We need to read the response code.
ReadResponseCode(io::ReadExact<TSocket, io::Window<Vec<u8>>>, usize),
// We need to read a final data packet. The second parameter is the response code
Read(ReadOne<TSocket>, ResponseCode),
// An error happened during the processing.
Poisoned,
}
impl<TSocket, TData> Future for RPCRequestResponse<TSocket, TData>
where
TSocket: AsyncRead + AsyncWrite,
TData: AsRef<[u8]>,
{
type Item = RPCResponse;
type Error = RPCError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.inner, RPCRequestResponseInner::Poisoned) {
RPCRequestResponseInner::Write(mut inner, max_size, sc_return) => {
match inner.poll().map_err(ReadOneError::Io)? {
Async::Ready(socket) => {
// short-circuit the future if `short_circuit_return` is specified
if let Some(return_val) = sc_return {
return Ok(Async::Ready(return_val));
}
// begin reading the 1-byte response code
let mut data_buf = vec![0; 1];
let mut data_buf = io::Window::new(data_buf);
self.inner = RPCRequestResponseInner::ReadResponseCode(
io::read_exact(socket, data_buf),
max_size,
);
}
Async::NotReady => {
self.inner = RPCRequestResponseInner::Write(inner, max_size, sc_return);
return Ok(Async::NotReady);
}
}
}
RPCRequestResponseInner::ReadResponseCode(mut inner, max_size) => {
match inner.poll()? {
Async::Ready((socket, data)) => {
let response_code =
ResponseCode::from(u64::from_be_bytes(data.into_inner()));
// known response codes
match response_code {
ResponseCode::Success
| ResponseCode::InvalidRequest
| ResponseCode::ServerError => {
// need to read another packet
self.inner = RPCRequestResponseInner::Read(
read_one(socket, max_size),
response_code,
)
}
ResponseCode::EncodingError => {
// invalid encoding
let response = RPCResponse::Error("Invalid Encoding".into());
return Ok(Async::Ready(response));
}
ResponseCode::Unknown => {
// unknown response code
let response = RPCResponse::Error(format!(
"Unknown response code: {}",
response_code
));
return Ok(Async::Ready(response));
}
}
}
Async::NotReady => {
self.inner = RPCRequestResponseInner::ReadResponseCode(inner, max_size);
return Ok(Async::NotReady);
}
}
}
RPCRequestResponseInner::Read(mut inner, response_code) => match inner.poll()? {
Async::Ready(packet) => {
return Ok(Async::Ready(RPCResponse::decode(
packet,
self.protocol,
response_code,
)?))
}
Async::NotReady => {
self.inner = RPCRequestResponseInner::Read(inner, response_code);
return Ok(Async::NotReady);
}
},
RPCRequestResponseInner::Poisoned => panic!(),
};
}
}
}
/* Copied from rust-libp2p (https://github.com/libp2p/rust-libp2p) to access private members */
/// Send a message to the given socket, then shuts down the writing side.
///
/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
/// > compatible with what `read_one` expects.
#[inline]
pub fn write_one<TSocket, TData>(socket: TSocket, data: TData) -> WriteOne<TSocket, TData>
where
TSocket: AsyncWrite,
TData: AsRef<[u8]>,
{
let len_data = build_int_buffer(data.as_ref().len());
WriteOne {
inner: WriteOneInner::WriteLen(io::write_all(socket, len_data), data),
}
}
enum WriteOneInner<TSocket, TData> {
/// We need to write the data length to the socket.
WriteLen(io::WriteAll<TSocket, io::Window<[u8; 10]>>, TData),
/// We need to write the actual data to the socket.
Write(io::WriteAll<TSocket, TData>),
/// We need to shut down the socket.
Shutdown(io::Shutdown<TSocket>),
/// A problem happened during the processing.
Poisoned,
}
impl<TSocket, TData> Future for WriteOneInner<TSocket, TData>
where
TSocket: AsyncWrite,
TData: AsRef<[u8]>,
{
type Item = TSocket;
type Error = std::io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(self, WriteOneInner::Poisoned) {
WriteOneInner::WriteLen(mut inner, data) => match inner.poll()? {
Async::Ready((socket, _)) => {
*self = WriteOneInner::Write(io::write_all(socket, data));
}
Async::NotReady => {
*self = WriteOneInner::WriteLen(inner, data);
}
},
WriteOneInner::Write(mut inner) => match inner.poll()? {
Async::Ready((socket, _)) => {
*self = WriteOneInner::Shutdown(tokio_io::io::shutdown(socket));
}
Async::NotReady => {
*self = WriteOneInner::Write(inner);
}
},
WriteOneInner::Shutdown(ref mut inner) => {
let socket = try_ready!(inner.poll());
return Ok(Async::Ready(socket));
}
WriteOneInner::Poisoned => panic!(),
}
}
}
}
/// Builds a buffer that contains the given integer encoded as variable-length.
fn build_int_buffer(num: usize) -> io::Window<[u8; 10]> {
let mut len_data = unsigned_varint::encode::u64_buffer();
let encoded_len = unsigned_varint::encode::u64(num as u64, &mut len_data).len();
let mut len_data = io::Window::new(len_data);
len_data.set_end(encoded_len);
len_data
}
/// Future that makes `write_one` work.
struct WriteOne<TSocket, TData = Vec<u8>> {
inner: WriteOneInner<TSocket, TData>,
}
impl<TSocket, TData> Future for WriteOne<TSocket, TData>
where
TSocket: AsyncWrite,
TData: AsRef<[u8]>,
{
type Item = ();
type Error = std::io::Error;
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Ok(self.inner.poll()?.map(|_socket| ()))
}
}