Skip to content
Snippets Groups Projects
Commit 57c40919 authored by Jim Newsome's avatar Jim Newsome
Browse files

WIP: MpscChannelMux

parent 9b168789
No related branches found
Tags tor-0.0.8pre2
No related merge requests found
Pipeline #179279 failed
//! Types and code for mapping StreamIDs to streams on a circuit.
mod counted_map;
mod mpsc_channel_mux;
use crate::circuit::halfstream::HalfStream;
use crate::circuit::sendme;
......
use std::collections::{HashMap, VecDeque};
use std::hash::{BuildHasher, Hash, RandomState};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use futures::{SinkExt, StreamExt};
/// Internal state shared between the receiver and senders, used to track how
/// many items are available from each receiver, and do scheduling.
struct Notifications<K, S = RandomState> {
/// Number of items available for each key.
counts: HashMap<K, usize, S>,
/// Keys that have non-zero items ready, in priority order.
ready_keys: VecDeque<K>,
/// Wakers to be notified when a notification arrives.
wakers: Vec<std::task::Waker>,
}
impl<K, S> Notifications<K, S>
where
K: Eq + Hash + Copy,
S: BuildHasher,
{
/// Increment number of items available for `key`, waking any previously
/// stored wakers.
fn track_push(&mut self, key: K) {
match self.counts.entry(key) {
std::collections::hash_map::Entry::Occupied(mut o) => {
*o.get_mut() += 1;
}
std::collections::hash_map::Entry::Vacant(v) => {
v.insert(1);
self.ready_keys.push_back(key);
}
};
for waker in self.wakers.drain(..) {
waker.wake()
}
}
fn track_pop(&mut self) -> Option<K> {
let key = self.ready_keys.pop_front()?;
let std::collections::hash_map::Entry::Occupied(mut o) = self.counts.entry(key) else {
unreachable!("Unexpected missing key");
};
if *o.get() > 0 {
// Decrement number of ready items and push to back of ready queue.
*o.get_mut() -= 1;
self.ready_keys.push_back(key);
} else {
// Remove entry altogether and don't push to back of ready queue.
o.remove();
}
Some(key)
}
}
pub struct ChannelMux<K, V, S = RandomState> {
/// The "data" receiver for each key.
data_recvs: HashMap<K, futures::channel::mpsc::Receiver<V>, S>,
/// Tracks which data channels have data available.
notifications: Arc<Mutex<Notifications<K, S>>>,
}
impl<K, V, S> ChannelMux<K, V, S>
where
S: Default + std::hash::BuildHasher,
K: std::hash::Hash + std::cmp::Eq,
{
pub fn new() -> Self {
Self::with_hashers(Default::default(), Default::default())
}
}
impl<K, V, S> ChannelMux<K, V, S>
where
S: std::hash::BuildHasher,
K: std::hash::Hash + std::cmp::Eq,
{
// XXX: really expose that we need two builders here? Or just require and use `Default`? (or `Clone`?)
pub fn with_hashers(hash_builder1: S, hash_builder2: S) -> Self {
Self {
data_recvs: HashMap::with_hasher(hash_builder1),
notifications: Arc::new(Mutex::new(Notifications {
wakers: Vec::new(),
counts: HashMap::with_hasher(hash_builder2),
ready_keys: VecDeque::new(),
})),
}
}
}
impl<K, V, S> ChannelMux<K, V, S>
where
S: std::hash::BuildHasher + Unpin,
K: std::hash::Hash + std::cmp::Eq + Copy + Unpin,
{
pub fn add_channel(&mut self, key: K, channel: (futures::channel::mpsc::Sender<V>, futures::channel::mpsc::Receiver<V>)) -> ChannelMuxSender<K, V, S> {
let (send, recv) = channel;
let res = self.data_recvs.insert(key, recv);
// XXX. Maybe return an Option here instead?
assert!(res.is_none());
ChannelMuxSender {
key,
data_send: send,
notifications: self.notifications.clone(),
}
}
}
impl<K, V, S> futures::Stream for ChannelMux<K, V, S>
where
S: std::hash::BuildHasher + Unpin,
K: std::hash::Hash + std::cmp::Eq + Copy + Unpin,
{
type Item = (K, Option<V>);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<(K, Option<V>)>> {
let mut_self = self.get_mut();
let mut notifications = mut_self.notifications.lock().expect("Poisoned lock");
let Some(key) = notifications.ready_keys.front() else {
// No ready data receivers.
notifications.wakers.push(cx.waker().clone());
return Poll::Pending;
};
let key = *key;
let data_recv = mut_self
.data_recvs
.get_mut(&key)
.expect("Missing data receiver for key")
.poll_next_unpin(cx);
let opt_value = match data_recv {
Poll::Ready(v) => {
notifications.track_pop();
v
}
Poll::Pending => {
// We can get here because we don't have explicit synchronization guaranteeing that data is pushed
// into the channel and ready before the notification arrives.
//
// For now, just return Pending. The current task will be notified when the data is actually ready.
//
// Alternatively we could move on to the next receiver, but that
// may introduce some surprising behavior.
return Poll::Pending;
}
};
if opt_value.is_none() {
let receiver = mut_self.data_recvs.remove(&key);
debug_assert!(receiver.is_some());
}
Poll::Ready(Some((key, opt_value)))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let notifications = self.notifications.lock().expect("lock");
let n = notifications.counts.values().sum();
(n, Some(n))
}
}
pub struct ChannelMuxSender<K, V, S = RandomState>
where
K: Unpin,
{
/// Our key
key: K,
/// The data channel.
data_send: futures::channel::mpsc::Sender<V>,
/// Shared state of available data.
notifications: Arc<Mutex<Notifications<K, S>>>,
}
impl<K, V> futures::Sink<V> for ChannelMuxSender<K, V>
where
K: Unpin + Copy + Hash + Eq,
{
fn start_send(mut self: Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
let k = self.as_ref().key.clone();
self.as_mut().data_send.start_send(item)?;
let mut notifications = self.notifications.lock().expect("Poisoned lock");
notifications.track_push(k);
Ok(())
}
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.data_send.poll_ready(cx)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.data_send.poll_flush_unpin(cx)
}
// XXX should wrap this to avoid leaking impl
type Error = futures::channel::mpsc::SendError;
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let res = self.data_send.poll_close_unpin(cx);
let mut notifications = self.notifications.lock().expect("Poisoned lock");
notifications.track_push(self.key);
res
}
}
#[cfg(test)]
mod tests {
use futures::{Sink as _, Stream as _};
use super::*;
#[test]
fn it_works() {
let mut mux = Box::pin(ChannelMux::<usize, String>::new());
let mut sender1 = Box::pin(mux.add_channel(1, futures::channel::mpsc::channel(10)));
let mut sender2 = Box::pin(mux.add_channel(2, futures::channel::mpsc::channel(10)));
futures::executor::block_on(futures::future::poll_fn(|cx| {
assert!(matches!(mux.as_mut().poll_next(cx), Poll::Pending));
assert!(matches!(
sender1.as_mut().poll_ready(cx),
Poll::Ready(Ok(()))
));
sender1.as_mut().start_send("Hi".to_owned()).unwrap();
assert!(matches!(
sender1.as_mut().poll_flush(cx),
Poll::Ready(Ok(()))
));
// XXX
assert!(matches!(mux.as_mut().poll_next(cx), Poll::Ready(_)));
assert!(matches!(mux.as_mut().poll_next(cx), Poll::Pending));
Poll::Ready(())
}));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment