use libp2p::{
core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
swarm::NegotiatedSubstream,
};
use futures::{future::BoxFuture, prelude::*};
use std::{io, iter};
#[derive(Clone, Debug)]
pub struct TellProtocol {
pub message: Vec<u8>,
pub max_message_size: u64,
}
impl UpgradeInfo for TellProtocol {
type Info = &'static [u8];
type InfoIter = iter::Once<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
iter::once(b"/taple/tell/1.0.0")
}
}
impl OutboundUpgrade<NegotiatedSubstream> for TellProtocol {
type Output = ();
type Error = io::Error;
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_outbound(self, mut io: NegotiatedSubstream, _: Self::Info) -> Self::Future {
async move {
{
let mut buffer = unsigned_varint::encode::usize_buffer();
io.write_all(unsigned_varint::encode::usize(
self.message.len(),
&mut buffer,
))
.await?;
}
io.write_all(&self.message).await?;
io.close().await?;
Ok(())
}
.boxed()
}
}
impl InboundUpgrade<NegotiatedSubstream> for TellProtocol {
type Output = Vec<u8>;
type Error = io::Error;
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_inbound(self, mut io: NegotiatedSubstream, _: Self::Info) -> Self::Future {
async move {
let length = unsigned_varint::aio::read_usize(&mut io)
.await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
if length > usize::try_from(self.max_message_size).unwrap_or(usize::MAX) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Message size exceeds limit: {} > {}",
length, self.max_message_size
),
));
}
let mut buffer = vec![0; length];
io.read_exact(&mut buffer).await?;
Ok(buffer)
}
.boxed()
}
}