1pub mod ss58;
4#[cfg(test)]
5mod tests;
6
7use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
8use futures::channel::oneshot;
9use futures::channel::oneshot::Canceled;
10use futures::future::Either;
11use rayon::{
12 current_thread_index, ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder,
13};
14use std::future::Future;
15use std::num::NonZeroUsize;
16use std::ops::Deref;
17use std::pin::{pin, Pin};
18use std::process::exit;
19use std::task::{Context, Poll};
20use std::{fmt, io, iter, thread};
21use thread_priority::{set_current_thread_priority, ThreadPriority};
22use tokio::runtime::Handle;
23use tokio::task;
24use tracing::{debug, warn};
25
26const MAX_DEFAULT_FARMING_THREADS: usize = 32;
28
29#[derive(Debug)]
31pub struct AsyncJoinOnDrop<T> {
32 handle: Option<task::JoinHandle<T>>,
33 abort_on_drop: bool,
34}
35
36impl<T> Drop for AsyncJoinOnDrop<T> {
37 #[inline]
38 fn drop(&mut self) {
39 if let Some(handle) = self.handle.take() {
40 if self.abort_on_drop {
41 handle.abort();
42 }
43
44 if !handle.is_finished() {
45 task::block_in_place(move || {
46 let _ = Handle::current().block_on(handle);
47 });
48 }
49 }
50 }
51}
52
53impl<T> AsyncJoinOnDrop<T> {
54 #[inline]
56 pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
57 Self {
58 handle: Some(handle),
59 abort_on_drop,
60 }
61 }
62}
63
64impl<T> Future for AsyncJoinOnDrop<T> {
65 type Output = Result<T, task::JoinError>;
66
67 #[inline]
68 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69 Pin::new(
70 self.handle
71 .as_mut()
72 .expect("Only dropped in Drop impl; qed"),
73 )
74 .poll(cx)
75 }
76}
77
78pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
80
81impl Drop for JoinOnDrop {
82 #[inline]
83 fn drop(&mut self) {
84 self.0
85 .take()
86 .expect("Always called exactly once; qed")
87 .join()
88 .expect("Panic if background thread panicked");
89 }
90}
91
92impl JoinOnDrop {
93 #[inline]
95 pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
96 Self(Some(handle))
97 }
98}
99
100impl Deref for JoinOnDrop {
101 type Target = thread::JoinHandle<()>;
102
103 #[inline]
104 fn deref(&self) -> &Self::Target {
105 self.0.as_ref().expect("Only dropped in Drop impl; qed")
106 }
107}
108
109pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
112 create_future: CreateFut,
113 thread_name: String,
114) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
115where
116 CreateFut: (FnOnce() -> Fut) + Send + 'static,
117 Fut: Future<Output = T> + 'static,
118 T: Send + 'static,
119{
120 let (drop_tx, drop_rx) = oneshot::channel::<()>();
121 let (result_tx, result_rx) = oneshot::channel();
122 let handle = Handle::current();
123 let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
124 let _tokio_handle_guard = handle.enter();
125
126 let future = pin!(create_future());
127
128 let result = match handle.block_on(futures::future::select(future, drop_rx)) {
129 Either::Left((result, _)) => result,
130 Either::Right(_) => {
131 return;
133 }
134 };
135 if let Err(_error) = result_tx.send(result) {
136 debug!(
137 thread_name = ?thread::current().name(),
138 "Future finished, but receiver was already dropped",
139 );
140 }
141 })?;
142 let join_on_drop = JoinOnDrop::new(join_handle);
144
145 Ok(async move {
146 let result = result_rx.await;
147 drop(drop_tx);
148 drop(join_on_drop);
149 result
150 })
151}
152
153#[derive(Clone)]
155pub struct CpuCoreSet {
156 cores: Vec<usize>,
158 #[cfg(feature = "numa")]
159 topology: Option<std::sync::Arc<hwlocality::Topology>>,
160}
161
162impl fmt::Debug for CpuCoreSet {
163 #[inline]
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 let mut s = f.debug_struct("CpuCoreSet");
166 #[cfg(not(feature = "numa"))]
167 if self.cores.array_windows::<2>().all(|&[a, b]| a + 1 == b) {
168 s.field(
169 "cores",
170 &format!(
171 "{}-{}",
172 self.cores.first().expect("List of cores is not empty; qed"),
173 self.cores.last().expect("List of cores is not empty; qed")
174 ),
175 );
176 } else {
177 s.field(
178 "cores",
179 &self
180 .cores
181 .iter()
182 .map(usize::to_string)
183 .collect::<Vec<_>>()
184 .join(","),
185 );
186 }
187 #[cfg(feature = "numa")]
188 {
189 use hwlocality::cpu::cpuset::CpuSet;
190 use hwlocality::ffi::PositiveInt;
191
192 s.field(
193 "cores",
194 &CpuSet::from_iter(
195 self.cores.iter().map(|&core| {
196 PositiveInt::try_from(core).expect("Valid CPU core index; qed")
197 }),
198 ),
199 );
200 }
201 s.finish_non_exhaustive()
202 }
203}
204
205impl CpuCoreSet {
206 pub fn cpu_cores(&self) -> &[usize] {
208 &self.cores
209 }
210
211 pub fn truncate(&mut self, num_cores: usize) {
218 let num_cores = num_cores.clamp(1, self.cores.len());
219
220 #[cfg(feature = "numa")]
221 if let Some(topology) = &self.topology {
222 use hwlocality::object::attributes::ObjectAttributes;
223 use hwlocality::object::types::ObjectType;
224
225 let mut grouped_by_l2_cache_size_and_core_count =
226 std::collections::HashMap::<(usize, usize), Vec<usize>>::new();
227 topology
228 .objects_with_type(ObjectType::L2Cache)
229 .for_each(|object| {
230 let l2_cache_size =
231 if let Some(ObjectAttributes::Cache(cache)) = object.attributes() {
232 cache
233 .size()
234 .map(|size| size.get() as usize)
235 .unwrap_or_default()
236 } else {
237 0
238 };
239 if let Some(cpuset) = object.complete_cpuset() {
240 let cpuset = cpuset
241 .into_iter()
242 .map(usize::from)
243 .filter(|core| self.cores.contains(core))
244 .collect::<Vec<_>>();
245 let cpuset_len = cpuset.len();
246
247 if !cpuset.is_empty() {
248 grouped_by_l2_cache_size_and_core_count
249 .entry((l2_cache_size, cpuset_len))
250 .or_default()
251 .extend(cpuset);
252 }
253 }
254 });
255
256 if grouped_by_l2_cache_size_and_core_count
258 .values()
259 .flatten()
260 .count()
261 == self.cores.len()
262 {
263 self.cores = grouped_by_l2_cache_size_and_core_count
267 .into_values()
268 .flat_map(|cores| {
269 let limit = cores.len() * num_cores / self.cores.len();
270 cores.into_iter().take(limit.max(1))
272 })
273 .collect();
274
275 self.cores.sort();
276
277 return;
278 }
279 }
280 self.cores.truncate(num_cores);
281 }
282
283 pub fn pin_current_thread(&self) {
285 #[cfg(feature = "numa")]
286 if let Some(topology) = &self.topology {
287 use hwlocality::cpu::binding::CpuBindingFlags;
288 use hwlocality::cpu::cpuset::CpuSet;
289 use hwlocality::current_thread_id;
290 use hwlocality::ffi::PositiveInt;
291
292 let cpu_cores = CpuSet::from_iter(
294 self.cores
295 .iter()
296 .map(|&core| PositiveInt::try_from(core).expect("Valid CPU core index; qed")),
297 );
298
299 if let Err(error) =
300 topology.bind_thread_cpu(current_thread_id(), &cpu_cores, CpuBindingFlags::empty())
301 {
302 warn!(%error, ?cpu_cores, "Failed to pin thread to CPU cores")
303 }
304 }
305 }
306}
307
308pub fn recommended_number_of_farming_threads() -> usize {
311 #[cfg(feature = "numa")]
312 match hwlocality::Topology::new().map(std::sync::Arc::new) {
313 Ok(topology) => {
314 return topology
315 .objects_at_depth(hwlocality::object::depth::Depth::NUMANode)
317 .filter_map(|node| node.cpuset())
319 .map(|cpuset| cpuset.iter_set().count())
321 .find(|&count| count > 0)
322 .unwrap_or_else(num_cpus::get)
323 .min(MAX_DEFAULT_FARMING_THREADS);
324 }
325 Err(error) => {
326 warn!(%error, "Failed to get NUMA topology");
327 }
328 }
329 num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
330}
331
332pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
337 #[cfg(feature = "numa")]
338 match hwlocality::Topology::new().map(std::sync::Arc::new) {
339 Ok(topology) => {
340 let cpu_cores = topology
341 .objects_with_type(hwlocality::object::types::ObjectType::L3Cache)
343 .filter_map(|node| node.cpuset())
345 .map(|cpuset| cpuset.iter_set().map(usize::from).collect::<Vec<_>>())
347 .filter(|cores| !cores.is_empty())
348 .map(|cores| CpuCoreSet {
349 cores,
350 topology: Some(std::sync::Arc::clone(&topology)),
351 })
352 .collect::<Vec<_>>();
353
354 if !cpu_cores.is_empty() {
355 return cpu_cores;
356 }
357 }
358 Err(error) => {
359 warn!(%error, "Failed to get L3 cache topology");
360 }
361 }
362 vec![CpuCoreSet {
363 cores: (0..num_cpus::get()).collect(),
364 #[cfg(feature = "numa")]
365 topology: None,
366 }]
367}
368
369pub fn parse_cpu_cores_sets(
372 s: &str,
373) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
374 #[cfg(feature = "numa")]
375 let topology = hwlocality::Topology::new().map(std::sync::Arc::new).ok();
376
377 s.split(' ')
378 .map(|s| {
379 let mut cores = Vec::new();
380 for s in s.split(',') {
381 let mut parts = s.split('-');
382 let range_start = parts
383 .next()
384 .ok_or(
385 "Bad string format, must be comma separated list of CPU cores or ranges",
386 )?
387 .parse()?;
388
389 if let Some(range_end) = parts.next() {
390 let range_end = range_end.parse()?;
391
392 cores.extend(range_start..=range_end);
393 } else {
394 cores.push(range_start);
395 }
396 }
397
398 Ok(CpuCoreSet {
399 cores,
400 #[cfg(feature = "numa")]
401 topology: topology.clone(),
402 })
403 })
404 .collect()
405}
406
407pub fn thread_pool_core_indices(
409 thread_pool_size: Option<NonZeroUsize>,
410 thread_pools: Option<NonZeroUsize>,
411) -> Vec<CpuCoreSet> {
412 thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
413}
414
415fn thread_pool_core_indices_internal(
416 all_cpu_cores: Vec<CpuCoreSet>,
417 thread_pool_size: Option<NonZeroUsize>,
418 thread_pools: Option<NonZeroUsize>,
419) -> Vec<CpuCoreSet> {
420 #[cfg(feature = "numa")]
421 let topology = &all_cpu_cores
422 .first()
423 .expect("Not empty according to function description; qed")
424 .topology;
425
426 let thread_pools = thread_pools
429 .map(|thread_pools| thread_pools.get())
430 .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
431
432 if let Some(thread_pools) = thread_pools {
433 let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
434
435 let total_cpu_cores = all_cpu_cores.iter().flat_map(|set| set.cpu_cores()).count();
436
437 if let Some(thread_pool_size) = thread_pool_size {
438 let mut cpu_cores_iterator = iter::repeat(
441 all_cpu_cores
442 .iter()
443 .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
444 .copied(),
445 )
446 .flatten();
447
448 for _ in 0..thread_pools {
449 let cpu_cores = cpu_cores_iterator
450 .by_ref()
451 .take(thread_pool_size.get())
452 .map(|core_index| core_index % total_cpu_cores)
455 .collect();
456
457 thread_pool_core_indices.push(CpuCoreSet {
458 cores: cpu_cores,
459 #[cfg(feature = "numa")]
460 topology: topology.clone(),
461 });
462 }
463 } else {
464 let all_cpu_cores = all_cpu_cores
467 .iter()
468 .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
469 .copied()
470 .collect::<Vec<_>>();
471
472 thread_pool_core_indices = all_cpu_cores
473 .chunks(total_cpu_cores.div_ceil(thread_pools))
474 .map(|cpu_cores| CpuCoreSet {
475 cores: cpu_cores.to_vec(),
476 #[cfg(feature = "numa")]
477 topology: topology.clone(),
478 })
479 .collect();
480 }
481 thread_pool_core_indices
482 } else {
483 all_cpu_cores
485 }
486}
487
488fn create_plotting_thread_pool_manager_thread_pool_pair(
489 thread_prefix: &'static str,
490 thread_pool_index: usize,
491 cpu_core_set: CpuCoreSet,
492 thread_priority: Option<ThreadPriority>,
493) -> Result<ThreadPool, ThreadPoolBuildError> {
494 let thread_name =
495 move |thread_index| format!("{thread_prefix}-{thread_pool_index}.{thread_index}");
496 let panic_handler = move |panic_info| {
500 if let Some(index) = current_thread_index() {
501 eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
502 } else {
503 eprintln!(
505 "rayon panic handler called on non-rayon thread: {:?}",
506 panic_info
507 );
508 }
509 exit(1);
510 };
511
512 ThreadPoolBuilder::new()
513 .thread_name(thread_name)
514 .num_threads(cpu_core_set.cpu_cores().len())
515 .panic_handler(panic_handler)
516 .spawn_handler({
517 let handle = Handle::current();
518
519 rayon_custom_spawn_handler(move |thread| {
520 let cpu_core_set = cpu_core_set.clone();
521 let handle = handle.clone();
522
523 move || {
524 cpu_core_set.pin_current_thread();
525 if let Some(thread_priority) = thread_priority {
526 if let Err(error) = set_current_thread_priority(thread_priority) {
527 warn!(%error, "Failed to set thread priority");
528 }
529 }
530 drop(cpu_core_set);
531
532 let _guard = handle.enter();
533
534 task::block_in_place(|| thread.run())
535 }
536 })
537 })
538 .build()
539}
540
541pub fn create_plotting_thread_pool_manager<I>(
551 mut cpu_core_sets: I,
552 thread_priority: Option<ThreadPriority>,
553) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
554where
555 I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
556{
557 let total_thread_pools = cpu_core_sets.len();
558
559 PlottingThreadPoolManager::new(
560 |thread_pool_index| {
561 let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
562 .next()
563 .expect("Number of thread pools is the same as cpu core sets; qed");
564
565 Ok(PlottingThreadPoolPair {
566 plotting: create_plotting_thread_pool_manager_thread_pool_pair(
567 "plotting",
568 thread_pool_index,
569 plotting_cpu_core_set,
570 thread_priority,
571 )?,
572 replotting: create_plotting_thread_pool_manager_thread_pool_pair(
573 "replotting",
574 thread_pool_index,
575 replotting_cpu_core_set,
576 thread_priority,
577 )?,
578 })
579 },
580 NonZeroUsize::new(total_thread_pools)
581 .expect("Thread pool is guaranteed to be non-empty; qed"),
582 )
583}
584
585pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
591 mut spawn_handler_builder: SpawnHandlerBuilder,
592) -> impl FnMut(ThreadBuilder) -> io::Result<()>
593where
594 SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
595 SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
596 SpawnHandlerResult: Send + 'static,
597{
598 move |thread: ThreadBuilder| {
599 let mut b = thread::Builder::new();
600 if let Some(name) = thread.name() {
601 b = b.name(name.to_owned());
602 }
603 if let Some(stack_size) = thread.stack_size() {
604 b = b.stack_size(stack_size);
605 }
606
607 b.spawn(spawn_handler_builder(thread))?;
608 Ok(())
609 }
610}
611
612pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
615 let handle = Handle::current();
616
617 rayon_custom_spawn_handler(move |thread| {
618 let handle = handle.clone();
619
620 move || {
621 let _guard = handle.enter();
622
623 task::block_in_place(|| thread.run())
624 }
625 })
626}