subspace_farmer/cluster/controller/
stream_map.rs

1//! A stream map that keeps track of futures that are currently being processed for each `Index`.
2
3use futures::stream::FusedStream;
4use futures::{FutureExt, Stream, StreamExt};
5use std::collections::hash_map::Entry;
6use std::collections::{HashMap, VecDeque};
7use std::future::Future;
8use std::hash::Hash;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio_stream::StreamMap as TokioStreamMap;
12
13type TaskFuture<'a, R> = Pin<Box<dyn Future<Output = R> + 'a>>;
14type TaskStream<'a, R> = Pin<Box<dyn Stream<Item = R> + Unpin + 'a>>;
15
16/// A StreamMap that keeps track of futures that are currently being processed for each `index`.
17pub(super) struct StreamMap<'a, Index, R> {
18    in_progress: TokioStreamMap<Index, TaskStream<'a, R>>,
19    queue: HashMap<Index, VecDeque<TaskFuture<'a, R>>>,
20}
21
22impl<Index, R> Default for StreamMap<'_, Index, R> {
23    fn default() -> Self {
24        Self {
25            in_progress: TokioStreamMap::default(),
26            queue: HashMap::default(),
27        }
28    }
29}
30
31impl<'a, Index, R: 'a> StreamMap<'a, Index, R>
32where
33    Index: Eq + Hash + Copy + Unpin,
34{
35    /// When pushing a new task, it first checks if there is already a future for the given `index` in `in_progress`.
36    ///   - If there is, the task is added to `queue`.
37    ///   - If not, the task is directly added to `in_progress`.
38    pub(super) fn push(&mut self, index: Index, fut: TaskFuture<'a, R>) {
39        if self.in_progress.contains_key(&index) {
40            let queue = self.queue.entry(index).or_default();
41            queue.push_back(fut);
42        } else {
43            self.in_progress
44                .insert(index, Box::pin(fut.into_stream()) as _);
45        }
46    }
47
48    /// Skip the task if there is already a future for the given `index` in `in_progress`.
49    /// Returns `true` if the task is added to `in_progress`, `false` otherwise.
50    pub(super) fn add_if_not_in_progress(&mut self, index: Index, fut: TaskFuture<'a, R>) -> bool {
51        if self.in_progress.contains_key(&index) {
52            false
53        } else {
54            self.in_progress
55                .insert(index, Box::pin(fut.into_stream()) as _);
56            true
57        }
58    }
59
60    /// Polls the next entry in `in_progress` and moves the next task from `queue` to `in_progress` if there is any.
61    /// If there are no more tasks to execute, returns `None`.
62    fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<(Index, R)>> {
63        if let Some((index, res)) = std::task::ready!(self.in_progress.poll_next_unpin(cx)) {
64            // Current task completed, remove from in_progress queue and check for more tasks
65            self.in_progress.remove(&index);
66            self.process_queue(index);
67            Poll::Ready(Some((index, res)))
68        } else {
69            // No more tasks to execute
70            assert!(self.queue.is_empty());
71            Poll::Ready(None)
72        }
73    }
74
75    /// Process the next task from the tasks queue for the given `index`
76    fn process_queue(&mut self, index: Index) {
77        if let Entry::Occupied(mut next_entry) = self.queue.entry(index) {
78            let task_queue = next_entry.get_mut();
79            if let Some(fut) = task_queue.pop_front() {
80                self.in_progress
81                    .insert(index, Box::pin(fut.into_stream()) as _);
82            }
83
84            // Remove the index from the map if there are no more tasks
85            if task_queue.is_empty() {
86                next_entry.remove();
87            }
88        }
89    }
90}
91
92impl<'a, Index, R: 'a> Stream for StreamMap<'a, Index, R>
93where
94    Index: Eq + Hash + Copy + Unpin,
95{
96    type Item = (Index, R);
97
98    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99        let this = self.get_mut();
100        this.poll_next_entry(cx)
101    }
102}
103
104impl<'a, Index, R: 'a> FusedStream for StreamMap<'a, Index, R>
105where
106    Index: Eq + Hash + Copy + Unpin,
107{
108    fn is_terminated(&self) -> bool {
109        self.in_progress.is_empty() && self.queue.is_empty()
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use crate::cluster::controller::stream_map::StreamMap;
116    use futures::stream::FusedStream;
117    use futures::StreamExt;
118    use std::task::Context;
119
120    fn assert_is_terminated<'a, R: 'a>(stream_map: &StreamMap<'a, u16, R>) {
121        assert!(stream_map.in_progress.is_empty());
122        assert!(stream_map.queue.is_empty());
123        assert!(stream_map.is_terminated());
124    }
125
126    #[test]
127    fn test_stream_map_default() {
128        let stream_map = StreamMap::<u16, ()>::default();
129        assert_is_terminated(&stream_map);
130    }
131
132    #[test]
133    fn test_stream_map_push() {
134        let mut stream_map = StreamMap::default();
135
136        let index = 1;
137        let fut = Box::pin(async {});
138        stream_map.push(index, fut);
139        assert!(stream_map.queue.is_empty());
140        assert!(stream_map.in_progress.contains_key(&index));
141        assert!(!stream_map.is_terminated());
142    }
143
144    #[test]
145    fn test_stream_map_add_if_not_in_progress() {
146        let mut stream_map = StreamMap::default();
147
148        let index = 1;
149        let fut1 = Box::pin(async {});
150        let fut2 = Box::pin(async {});
151        assert!(stream_map.add_if_not_in_progress(index, fut1));
152        assert!(!stream_map.add_if_not_in_progress(index, fut2));
153    }
154
155    #[test]
156    fn test_stream_map_poll_next_entry() {
157        let mut stream_map = StreamMap::default();
158
159        let fut = Box::pin(async {});
160        stream_map.push(0, fut);
161
162        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
163        let poll_result = stream_map.poll_next_entry(&mut cx);
164        assert!(poll_result.is_ready());
165        assert_is_terminated(&stream_map);
166    }
167
168    #[tokio::test]
169    async fn test_stream_map_stream() {
170        let mut stream_map = StreamMap::default();
171
172        let fut00 = Box::pin(async { 0x00 });
173        stream_map.push(0, fut00);
174
175        let next_item = stream_map.next().await;
176        assert_eq!(next_item, Some((0, 0x00)));
177        assert_is_terminated(&stream_map);
178
179        let fut11 = Box::pin(async { 0x11 });
180        let fut12 = Box::pin(async { 0x12 });
181        let fut13 = Box::pin(async { 0x13 });
182        let fut21 = Box::pin(async {
183            // Yield the current task three times to ensure that fut22 is polled last.
184            for _ in 0..3 {
185                tokio::task::yield_now().await;
186            }
187            0x21
188        });
189        let fut22 = Box::pin(async { 0x22 });
190
191        // Push 2 futs into the same farm index 1, expect fut11 to be polled first,
192        // fut12 should push into the in_progress queue and wait for fut11 to finish
193        stream_map.push(1, fut11);
194        stream_map.push(1, fut12);
195        assert!(!stream_map.is_terminated());
196        assert_eq!(stream_map.in_progress.len(), 1);
197        assert!(stream_map.in_progress.contains_key(&1));
198        assert_eq!(stream_map.queue.len(), 1);
199
200        // Push fut22 into farm index 2, we have 2 in progress futures now
201        stream_map.push(2, fut21);
202        assert_eq!(stream_map.in_progress.len(), 2);
203        assert!(stream_map.in_progress.contains_key(&2));
204        assert_eq!(stream_map.queue.len(), 1);
205
206        // Push fut22 into farm index 2, in-progress queue length should not change,
207        // but the queue should have 2 entries now
208        stream_map.push(2, fut22);
209        assert_eq!(stream_map.in_progress.len(), 2);
210        assert_eq!(stream_map.queue.len(), 2);
211        assert_eq!(stream_map.queue[&2].len(), 1);
212
213        // Push fut13 into farm index 1, fut13 should be polled after fut11 and fut12
214        stream_map.push(1, fut13);
215        assert!(!stream_map.is_terminated());
216        assert!(stream_map.in_progress.contains_key(&1));
217        assert_eq!(stream_map.in_progress.len(), 2);
218        assert_eq!(stream_map.queue[&1].len(), 2);
219
220        // Poll the next item in the stream, fut11 should be polled first,
221        // fut12 should be pushed into the in-progress queue
222        let next_item = stream_map.next().await;
223        assert!(!stream_map.is_terminated());
224        assert_eq!(next_item.unwrap(), (1, 0x11));
225        assert!(stream_map.in_progress.contains_key(&1));
226        assert!(stream_map.in_progress.contains_key(&2));
227        assert_eq!(stream_map.in_progress.len(), 2);
228        assert_eq!(stream_map.queue[&1].len(), 1);
229
230        // Here, fut12 and fut 13 should be polled before fut21 because fut21 has a yield point.
231        // Fut13 should be pushed into the in_progress queue.
232        // There are no more futures waiting to be polled in farm index 1, so the farm index 1
233        // should be removed from the queue map.
234        let next_item = stream_map.next().await;
235        assert!(!stream_map.is_terminated());
236        assert_eq!(next_item.unwrap(), (1, 0x12));
237        assert_eq!(stream_map.in_progress.len(), 2);
238        assert!(stream_map.in_progress.contains_key(&1));
239        assert!(stream_map.in_progress.contains_key(&2));
240        assert!(!stream_map.queue.contains_key(&1));
241
242        // Poll the next item in the stream, fut13 should be polled next.
243        // For now, all futures in farm index 1 have been polled, so farm index 1 should be removed
244        // from the in-progress queue.
245        let next_item = stream_map.next().await;
246        assert!(!stream_map.is_terminated());
247        assert_eq!(next_item.unwrap(), (1, 0x13));
248        assert_eq!(stream_map.in_progress.len(), 1);
249        assert!(!stream_map.in_progress.contains_key(&1));
250        assert!(stream_map.in_progress.contains_key(&2));
251        assert!(!stream_map.queue.contains_key(&1));
252        assert_eq!(stream_map.queue[&2].len(), 1);
253
254        // We hope futures with the same index are polled in the order they are pushed,
255        // so fut21 should be polled next.
256        // fut22 should be pushed into the in-progress queue.
257        // There are no more futures waiting to be polled in farm index 2, so the farm index 2
258        // should be removed from the queue map.
259        let next_item = stream_map.next().await;
260        assert!(!stream_map.is_terminated());
261        assert_eq!(next_item.unwrap(), (2, 0x21));
262        assert_eq!(stream_map.in_progress.len(), 1);
263        assert!(!stream_map.in_progress.contains_key(&1));
264        assert!(stream_map.in_progress.contains_key(&2));
265        assert!(!stream_map.queue.contains_key(&1));
266        assert!(!stream_map.queue.contains_key(&2));
267
268        // Poll the next item in the stream, fut22 should be polled next.
269        // For now, all futures in farm index 2 have been polled, so farm index 2 should be removed
270        // from the in-progress queue.
271        // Finally, the stream should be terminated.
272        let next_item = stream_map.next().await;
273        assert_eq!(next_item, Some((2, 0x22)));
274        assert_is_terminated(&stream_map);
275    }
276}