1pub mod ss58;
4#[cfg(test)]
5mod tests;
6
7use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
8use rayon::{
9    ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, current_thread_index,
10};
11use std::num::NonZeroUsize;
12use std::process::exit;
13use std::{fmt, io, iter, thread};
14use thread_priority::{ThreadPriority, set_current_thread_priority};
15use tokio::runtime::Handle;
16use tokio::task;
17use tracing::warn;
18
19const MAX_DEFAULT_FARMING_THREADS: usize = 32;
21
22#[derive(Clone)]
24pub struct CpuCoreSet {
25    cores: Vec<usize>,
27    #[cfg(feature = "numa")]
28    topology: Option<std::sync::Arc<hwlocality::Topology>>,
29}
30
31impl fmt::Debug for CpuCoreSet {
32    #[inline]
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        let mut s = f.debug_struct("CpuCoreSet");
35        #[cfg(not(feature = "numa"))]
36        if self.cores.array_windows::<2>().all(|&[a, b]| a + 1 == b) {
37            s.field(
38                "cores",
39                &format!(
40                    "{}-{}",
41                    self.cores.first().expect("List of cores is not empty; qed"),
42                    self.cores.last().expect("List of cores is not empty; qed")
43                ),
44            );
45        } else {
46            s.field(
47                "cores",
48                &self
49                    .cores
50                    .iter()
51                    .map(usize::to_string)
52                    .collect::<Vec<_>>()
53                    .join(","),
54            );
55        }
56        #[cfg(feature = "numa")]
57        {
58            use hwlocality::cpu::cpuset::CpuSet;
59            use hwlocality::ffi::PositiveInt;
60
61            s.field(
62                "cores",
63                &CpuSet::from_iter(
64                    self.cores.iter().map(|&core| {
65                        PositiveInt::try_from(core).expect("Valid CPU core index; qed")
66                    }),
67                ),
68            );
69        }
70        s.finish_non_exhaustive()
71    }
72}
73
74impl CpuCoreSet {
75    pub fn cpu_cores(&self) -> &[usize] {
77        &self.cores
78    }
79
80    pub fn truncate(&mut self, num_cores: usize) {
87        let num_cores = num_cores.clamp(1, self.cores.len());
88
89        #[cfg(feature = "numa")]
90        if let Some(topology) = &self.topology {
91            use hwlocality::object::attributes::ObjectAttributes;
92            use hwlocality::object::types::ObjectType;
93
94            let mut grouped_by_l2_cache_size_and_core_count =
95                std::collections::HashMap::<(usize, usize), Vec<usize>>::new();
96            topology
97                .objects_with_type(ObjectType::L2Cache)
98                .for_each(|object| {
99                    let l2_cache_size =
100                        if let Some(ObjectAttributes::Cache(cache)) = object.attributes() {
101                            cache
102                                .size()
103                                .map(|size| size.get() as usize)
104                                .unwrap_or_default()
105                        } else {
106                            0
107                        };
108                    if let Some(cpuset) = object.complete_cpuset() {
109                        let cpuset = cpuset
110                            .into_iter()
111                            .map(usize::from)
112                            .filter(|core| self.cores.contains(core))
113                            .collect::<Vec<_>>();
114                        let cpuset_len = cpuset.len();
115
116                        if !cpuset.is_empty() {
117                            grouped_by_l2_cache_size_and_core_count
118                                .entry((l2_cache_size, cpuset_len))
119                                .or_default()
120                                .extend(cpuset);
121                        }
122                    }
123                });
124
125            if grouped_by_l2_cache_size_and_core_count
127                .values()
128                .flatten()
129                .count()
130                == self.cores.len()
131            {
132                self.cores = grouped_by_l2_cache_size_and_core_count
136                    .into_values()
137                    .flat_map(|cores| {
138                        let limit = cores.len() * num_cores / self.cores.len();
139                        cores.into_iter().take(limit.max(1))
141                    })
142                    .collect();
143
144                self.cores.sort();
145
146                return;
147            }
148        }
149        self.cores.truncate(num_cores);
150    }
151
152    pub fn pin_current_thread(&self) {
154        #[cfg(feature = "numa")]
155        if let Some(topology) = &self.topology {
156            use hwlocality::cpu::binding::CpuBindingFlags;
157            use hwlocality::cpu::cpuset::CpuSet;
158            use hwlocality::current_thread_id;
159            use hwlocality::ffi::PositiveInt;
160
161            let cpu_cores = CpuSet::from_iter(
163                self.cores
164                    .iter()
165                    .map(|&core| PositiveInt::try_from(core).expect("Valid CPU core index; qed")),
166            );
167
168            if let Err(error) =
169                topology.bind_thread_cpu(current_thread_id(), &cpu_cores, CpuBindingFlags::empty())
170            {
171                warn!(%error, ?cpu_cores, "Failed to pin thread to CPU cores")
172            }
173        }
174    }
175}
176
177pub fn recommended_number_of_farming_threads() -> usize {
180    #[cfg(feature = "numa")]
181    match hwlocality::Topology::new().map(std::sync::Arc::new) {
182        Ok(topology) => {
183            return topology
184                .objects_at_depth(hwlocality::object::depth::Depth::NUMANode)
186                .filter_map(|node| node.cpuset())
188                .map(|cpuset| cpuset.iter_set().count())
190                .find(|&count| count > 0)
191                .unwrap_or_else(num_cpus::get)
192                .min(MAX_DEFAULT_FARMING_THREADS);
193        }
194        Err(error) => {
195            warn!(%error, "Failed to get NUMA topology");
196        }
197    }
198    num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
199}
200
201pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
206    #[cfg(feature = "numa")]
207    match hwlocality::Topology::new().map(std::sync::Arc::new) {
208        Ok(topology) => {
209            let cpu_cores = topology
210                .objects_with_type(hwlocality::object::types::ObjectType::L3Cache)
212                .filter_map(|node| node.cpuset())
214                .map(|cpuset| cpuset.iter_set().map(usize::from).collect::<Vec<_>>())
216                .filter(|cores| !cores.is_empty())
217                .map(|cores| CpuCoreSet {
218                    cores,
219                    topology: Some(std::sync::Arc::clone(&topology)),
220                })
221                .collect::<Vec<_>>();
222
223            if !cpu_cores.is_empty() {
224                return cpu_cores;
225            }
226        }
227        Err(error) => {
228            warn!(%error, "Failed to get L3 cache topology");
229        }
230    }
231    vec![CpuCoreSet {
232        cores: (0..num_cpus::get()).collect(),
233        #[cfg(feature = "numa")]
234        topology: None,
235    }]
236}
237
238pub fn parse_cpu_cores_sets(
241    s: &str,
242) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
243    #[cfg(feature = "numa")]
244    let topology = hwlocality::Topology::new().map(std::sync::Arc::new).ok();
245
246    s.split(' ')
247        .map(|s| {
248            let mut cores = Vec::new();
249            for s in s.split(',') {
250                let mut parts = s.split('-');
251                let range_start = parts
252                    .next()
253                    .ok_or(
254                        "Bad string format, must be comma separated list of CPU cores or ranges",
255                    )?
256                    .parse()?;
257
258                if let Some(range_end) = parts.next() {
259                    let range_end = range_end.parse()?;
260
261                    cores.extend(range_start..=range_end);
262                } else {
263                    cores.push(range_start);
264                }
265            }
266
267            Ok(CpuCoreSet {
268                cores,
269                #[cfg(feature = "numa")]
270                topology: topology.clone(),
271            })
272        })
273        .collect()
274}
275
276pub fn thread_pool_core_indices(
278    thread_pool_size: Option<NonZeroUsize>,
279    thread_pools: Option<NonZeroUsize>,
280) -> Vec<CpuCoreSet> {
281    thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
282}
283
284fn thread_pool_core_indices_internal(
285    all_cpu_cores: Vec<CpuCoreSet>,
286    thread_pool_size: Option<NonZeroUsize>,
287    thread_pools: Option<NonZeroUsize>,
288) -> Vec<CpuCoreSet> {
289    #[cfg(feature = "numa")]
290    let topology = &all_cpu_cores
291        .first()
292        .expect("Not empty according to function description; qed")
293        .topology;
294
295    let thread_pools = thread_pools
298        .map(|thread_pools| thread_pools.get())
299        .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
300
301    if let Some(thread_pools) = thread_pools {
302        let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
303
304        let total_cpu_cores = all_cpu_cores.iter().flat_map(|set| set.cpu_cores()).count();
305
306        if let Some(thread_pool_size) = thread_pool_size {
307            let mut cpu_cores_iterator = iter::repeat(
310                all_cpu_cores
311                    .iter()
312                    .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
313                    .copied(),
314            )
315            .flatten();
316
317            for _ in 0..thread_pools {
318                let cpu_cores = cpu_cores_iterator
319                    .by_ref()
320                    .take(thread_pool_size.get())
321                    .map(|core_index| core_index % total_cpu_cores)
324                    .collect();
325
326                thread_pool_core_indices.push(CpuCoreSet {
327                    cores: cpu_cores,
328                    #[cfg(feature = "numa")]
329                    topology: topology.clone(),
330                });
331            }
332        } else {
333            let all_cpu_cores = all_cpu_cores
336                .iter()
337                .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
338                .copied()
339                .collect::<Vec<_>>();
340
341            thread_pool_core_indices = all_cpu_cores
342                .chunks(total_cpu_cores.div_ceil(thread_pools))
343                .map(|cpu_cores| CpuCoreSet {
344                    cores: cpu_cores.to_vec(),
345                    #[cfg(feature = "numa")]
346                    topology: topology.clone(),
347                })
348                .collect();
349        }
350        thread_pool_core_indices
351    } else {
352        all_cpu_cores
354    }
355}
356
357fn create_plotting_thread_pool_manager_thread_pool_pair(
358    thread_prefix: &'static str,
359    thread_pool_index: usize,
360    cpu_core_set: CpuCoreSet,
361    thread_priority: Option<ThreadPriority>,
362) -> Result<ThreadPool, ThreadPoolBuildError> {
363    let thread_name =
365        move |thread_index| format!("{thread_prefix:9}-{thread_pool_index:02}.{thread_index:02}");
366    let panic_handler = move |panic_info| {
370        if let Some(index) = current_thread_index() {
371            eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
372        } else {
373            eprintln!("rayon panic handler called on non-rayon thread: {panic_info:?}");
375        }
376        exit(1);
377    };
378
379    ThreadPoolBuilder::new()
380        .thread_name(thread_name)
381        .num_threads(cpu_core_set.cpu_cores().len())
382        .panic_handler(panic_handler)
383        .spawn_handler({
384            let handle = Handle::current();
385
386            rayon_custom_spawn_handler(move |thread| {
387                let cpu_core_set = cpu_core_set.clone();
388                let handle = handle.clone();
389
390                move || {
391                    cpu_core_set.pin_current_thread();
392                    if let Some(thread_priority) = thread_priority
393                        && let Err(error) = set_current_thread_priority(thread_priority)
394                    {
395                        warn!(%error, "Failed to set thread priority");
396                    }
397                    drop(cpu_core_set);
398
399                    let _guard = handle.enter();
400
401                    task::block_in_place(|| thread.run())
402                }
403            })
404        })
405        .build()
406}
407
408pub fn create_plotting_thread_pool_manager<I>(
418    mut cpu_core_sets: I,
419    thread_priority: Option<ThreadPriority>,
420) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
421where
422    I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
423{
424    let total_thread_pools = cpu_core_sets.len();
425
426    PlottingThreadPoolManager::new(
427        |thread_pool_index| {
428            let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
429                .next()
430                .expect("Number of thread pools is the same as cpu core sets; qed");
431
432            Ok(PlottingThreadPoolPair {
433                plotting: create_plotting_thread_pool_manager_thread_pool_pair(
434                    "plotting",
435                    thread_pool_index,
436                    plotting_cpu_core_set,
437                    thread_priority,
438                )?,
439                replotting: create_plotting_thread_pool_manager_thread_pool_pair(
440                    "replotting",
441                    thread_pool_index,
442                    replotting_cpu_core_set,
443                    thread_priority,
444                )?,
445            })
446        },
447        NonZeroUsize::new(total_thread_pools)
448            .expect("Thread pool is guaranteed to be non-empty; qed"),
449    )
450}
451
452pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
458    mut spawn_handler_builder: SpawnHandlerBuilder,
459) -> impl FnMut(ThreadBuilder) -> io::Result<()>
460where
461    SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
462    SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
463    SpawnHandlerResult: Send + 'static,
464{
465    move |thread: ThreadBuilder| {
466        let mut b = thread::Builder::new();
467        if let Some(name) = thread.name() {
468            b = b.name(name.to_owned());
469        }
470        if let Some(stack_size) = thread.stack_size() {
471            b = b.stack_size(stack_size);
472        }
473
474        b.spawn(spawn_handler_builder(thread))?;
475        Ok(())
476    }
477}
478
479pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
482    let handle = Handle::current();
483
484    rayon_custom_spawn_handler(move |thread| {
485        let handle = handle.clone();
486
487        move || {
488            let _guard = handle.enter();
489
490            task::block_in_place(|| thread.run())
491        }
492    })
493}