subspace_farmer/cluster/controller/
stream_map.rs1use futures::stream::FusedStream;
4use futures::{FutureExt, Stream, StreamExt};
5use std::collections::hash_map::Entry;
6use std::collections::{HashMap, VecDeque};
7use std::future::Future;
8use std::hash::Hash;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio_stream::StreamMap as TokioStreamMap;
12
13type TaskFuture<'a, R> = Pin<Box<dyn Future<Output = R> + 'a>>;
14type TaskStream<'a, R> = Pin<Box<dyn Stream<Item = R> + Unpin + 'a>>;
15
16pub(super) struct StreamMap<'a, Index, R> {
18 in_progress: TokioStreamMap<Index, TaskStream<'a, R>>,
19 queue: HashMap<Index, VecDeque<TaskFuture<'a, R>>>,
20}
21
22impl<Index, R> Default for StreamMap<'_, Index, R> {
23 fn default() -> Self {
24 Self {
25 in_progress: TokioStreamMap::default(),
26 queue: HashMap::default(),
27 }
28 }
29}
30
31impl<'a, Index, R: 'a> StreamMap<'a, Index, R>
32where
33 Index: Eq + Hash + Copy + Unpin,
34{
35 pub(super) fn push(&mut self, index: Index, fut: TaskFuture<'a, R>) {
39 if self.in_progress.contains_key(&index) {
40 let queue = self.queue.entry(index).or_default();
41 queue.push_back(fut);
42 } else {
43 self.in_progress
44 .insert(index, Box::pin(fut.into_stream()) as _);
45 }
46 }
47
48 pub(super) fn add_if_not_in_progress(&mut self, index: Index, fut: TaskFuture<'a, R>) -> bool {
51 if self.in_progress.contains_key(&index) {
52 false
53 } else {
54 self.in_progress
55 .insert(index, Box::pin(fut.into_stream()) as _);
56 true
57 }
58 }
59
60 fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<(Index, R)>> {
63 if let Some((index, res)) = std::task::ready!(self.in_progress.poll_next_unpin(cx)) {
64 self.in_progress.remove(&index);
66 self.process_queue(index);
67 Poll::Ready(Some((index, res)))
68 } else {
69 assert!(self.queue.is_empty());
71 Poll::Ready(None)
72 }
73 }
74
75 fn process_queue(&mut self, index: Index) {
77 if let Entry::Occupied(mut next_entry) = self.queue.entry(index) {
78 let task_queue = next_entry.get_mut();
79 if let Some(fut) = task_queue.pop_front() {
80 self.in_progress
81 .insert(index, Box::pin(fut.into_stream()) as _);
82 }
83
84 if task_queue.is_empty() {
86 next_entry.remove();
87 }
88 }
89 }
90}
91
92impl<'a, Index, R: 'a> Stream for StreamMap<'a, Index, R>
93where
94 Index: Eq + Hash + Copy + Unpin,
95{
96 type Item = (Index, R);
97
98 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99 let this = self.get_mut();
100 this.poll_next_entry(cx)
101 }
102}
103
104impl<'a, Index, R: 'a> FusedStream for StreamMap<'a, Index, R>
105where
106 Index: Eq + Hash + Copy + Unpin,
107{
108 fn is_terminated(&self) -> bool {
109 self.in_progress.is_empty() && self.queue.is_empty()
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::cluster::controller::stream_map::StreamMap;
116 use futures::stream::FusedStream;
117 use futures::StreamExt;
118 use std::task::Context;
119
120 fn assert_is_terminated<'a, R: 'a>(stream_map: &StreamMap<'a, u16, R>) {
121 assert!(stream_map.in_progress.is_empty());
122 assert!(stream_map.queue.is_empty());
123 assert!(stream_map.is_terminated());
124 }
125
126 #[test]
127 fn test_stream_map_default() {
128 let stream_map = StreamMap::<u16, ()>::default();
129 assert_is_terminated(&stream_map);
130 }
131
132 #[test]
133 fn test_stream_map_push() {
134 let mut stream_map = StreamMap::default();
135
136 let index = 1;
137 let fut = Box::pin(async {});
138 stream_map.push(index, fut);
139 assert!(stream_map.queue.is_empty());
140 assert!(stream_map.in_progress.contains_key(&index));
141 assert!(!stream_map.is_terminated());
142 }
143
144 #[test]
145 fn test_stream_map_add_if_not_in_progress() {
146 let mut stream_map = StreamMap::default();
147
148 let index = 1;
149 let fut1 = Box::pin(async {});
150 let fut2 = Box::pin(async {});
151 assert!(stream_map.add_if_not_in_progress(index, fut1));
152 assert!(!stream_map.add_if_not_in_progress(index, fut2));
153 }
154
155 #[test]
156 fn test_stream_map_poll_next_entry() {
157 let mut stream_map = StreamMap::default();
158
159 let fut = Box::pin(async {});
160 stream_map.push(0, fut);
161
162 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
163 let poll_result = stream_map.poll_next_entry(&mut cx);
164 assert!(poll_result.is_ready());
165 assert_is_terminated(&stream_map);
166 }
167
168 #[tokio::test]
169 async fn test_stream_map_stream() {
170 let mut stream_map = StreamMap::default();
171
172 let fut00 = Box::pin(async { 0x00 });
173 stream_map.push(0, fut00);
174
175 let next_item = stream_map.next().await;
176 assert_eq!(next_item, Some((0, 0x00)));
177 assert_is_terminated(&stream_map);
178
179 let fut11 = Box::pin(async { 0x11 });
180 let fut12 = Box::pin(async { 0x12 });
181 let fut13 = Box::pin(async { 0x13 });
182 let fut21 = Box::pin(async {
183 for _ in 0..3 {
185 tokio::task::yield_now().await;
186 }
187 0x21
188 });
189 let fut22 = Box::pin(async { 0x22 });
190
191 stream_map.push(1, fut11);
194 stream_map.push(1, fut12);
195 assert!(!stream_map.is_terminated());
196 assert_eq!(stream_map.in_progress.len(), 1);
197 assert!(stream_map.in_progress.contains_key(&1));
198 assert_eq!(stream_map.queue.len(), 1);
199
200 stream_map.push(2, fut21);
202 assert_eq!(stream_map.in_progress.len(), 2);
203 assert!(stream_map.in_progress.contains_key(&2));
204 assert_eq!(stream_map.queue.len(), 1);
205
206 stream_map.push(2, fut22);
209 assert_eq!(stream_map.in_progress.len(), 2);
210 assert_eq!(stream_map.queue.len(), 2);
211 assert_eq!(stream_map.queue[&2].len(), 1);
212
213 stream_map.push(1, fut13);
215 assert!(!stream_map.is_terminated());
216 assert!(stream_map.in_progress.contains_key(&1));
217 assert_eq!(stream_map.in_progress.len(), 2);
218 assert_eq!(stream_map.queue[&1].len(), 2);
219
220 let next_item = stream_map.next().await;
223 assert!(!stream_map.is_terminated());
224 assert_eq!(next_item.unwrap(), (1, 0x11));
225 assert!(stream_map.in_progress.contains_key(&1));
226 assert!(stream_map.in_progress.contains_key(&2));
227 assert_eq!(stream_map.in_progress.len(), 2);
228 assert_eq!(stream_map.queue[&1].len(), 1);
229
230 let next_item = stream_map.next().await;
235 assert!(!stream_map.is_terminated());
236 assert_eq!(next_item.unwrap(), (1, 0x12));
237 assert_eq!(stream_map.in_progress.len(), 2);
238 assert!(stream_map.in_progress.contains_key(&1));
239 assert!(stream_map.in_progress.contains_key(&2));
240 assert!(!stream_map.queue.contains_key(&1));
241
242 let next_item = stream_map.next().await;
246 assert!(!stream_map.is_terminated());
247 assert_eq!(next_item.unwrap(), (1, 0x13));
248 assert_eq!(stream_map.in_progress.len(), 1);
249 assert!(!stream_map.in_progress.contains_key(&1));
250 assert!(stream_map.in_progress.contains_key(&2));
251 assert!(!stream_map.queue.contains_key(&1));
252 assert_eq!(stream_map.queue[&2].len(), 1);
253
254 let next_item = stream_map.next().await;
260 assert!(!stream_map.is_terminated());
261 assert_eq!(next_item.unwrap(), (2, 0x21));
262 assert_eq!(stream_map.in_progress.len(), 1);
263 assert!(!stream_map.in_progress.contains_key(&1));
264 assert!(stream_map.in_progress.contains_key(&2));
265 assert!(!stream_map.queue.contains_key(&1));
266 assert!(!stream_map.queue.contains_key(&2));
267
268 let next_item = stream_map.next().await;
273 assert_eq!(next_item, Some((2, 0x22)));
274 assert_is_terminated(&stream_map);
275 }
276}