subspace_farmer/plotter/gpu/
cuda.rs

1//! CUDA GPU records encoder
2
3use crate::plotter::gpu::GpuRecordsEncoder;
4use async_lock::Mutex as AsyncMutex;
5use parking_lot::Mutex;
6use rayon::{current_thread_index, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
7use std::process::exit;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use subspace_core_primitives::pieces::{PieceOffset, Record};
11use subspace_core_primitives::sectors::SectorId;
12use subspace_farmer_components::plotting::RecordsEncoder;
13use subspace_farmer_components::sector::SectorContentsMap;
14use subspace_proof_of_space_gpu::cuda::CudaDevice;
15
16/// CUDA implementation of [`GpuRecordsEncoder`]
17#[derive(Debug)]
18pub struct CudaRecordsEncoder {
19    cuda_device: CudaDevice,
20    thread_pool: ThreadPool,
21    global_mutex: Arc<AsyncMutex<()>>,
22}
23
24impl GpuRecordsEncoder for CudaRecordsEncoder {
25    const TYPE: &'static str = "cuda";
26}
27
28impl RecordsEncoder for CudaRecordsEncoder {
29    fn encode_records(
30        &mut self,
31        sector_id: &SectorId,
32        records: &mut [Record],
33        abort_early: &AtomicBool,
34    ) -> anyhow::Result<SectorContentsMap> {
35        let pieces_in_sector = records
36            .len()
37            .try_into()
38            .map_err(|error| anyhow::anyhow!("Failed to convert pieces in sector: {error}"))?;
39        let mut sector_contents_map = SectorContentsMap::new(pieces_in_sector);
40
41        {
42            let iter = Mutex::new(
43                (PieceOffset::ZERO..)
44                    .zip(records.iter_mut())
45                    .zip(sector_contents_map.iter_record_bitfields_mut()),
46            );
47            let plotting_error = Mutex::new(None::<String>);
48
49            self.thread_pool.scope(|scope| {
50                scope.spawn_broadcast(|_scope, _ctx| loop {
51                    // Take mutex briefly to make sure encoding is allowed right now
52                    self.global_mutex.lock_blocking();
53
54                    // This instead of `while` above because otherwise mutex will be held for the
55                    // duration of the loop and will limit concurrency to 1 record
56                    let Some(((piece_offset, record), mut encoded_chunks_used)) =
57                        iter.lock().next()
58                    else {
59                        return;
60                    };
61                    let pos_seed = sector_id.derive_evaluation_seed(piece_offset);
62
63                    if let Err(error) = self.cuda_device.generate_and_encode_pospace(
64                        &pos_seed,
65                        record,
66                        encoded_chunks_used.iter_mut(),
67                    ) {
68                        plotting_error.lock().replace(error);
69                        return;
70                    }
71
72                    if abort_early.load(Ordering::Relaxed) {
73                        return;
74                    }
75                });
76            });
77
78            let plotting_error = plotting_error.lock().take();
79            if let Some(error) = plotting_error {
80                return Err(anyhow::Error::msg(error));
81            }
82        }
83
84        Ok(sector_contents_map)
85    }
86}
87
88impl CudaRecordsEncoder {
89    /// Create new instance
90    pub fn new(
91        cuda_device: CudaDevice,
92        global_mutex: Arc<AsyncMutex<()>>,
93    ) -> Result<Self, ThreadPoolBuildError> {
94        let id = cuda_device.id();
95        let thread_name = move |thread_index| format!("cuda-{id}.{thread_index}");
96        // TODO: remove this panic handler when rayon logs panic_info
97        // https://github.com/rayon-rs/rayon/issues/1208
98        let panic_handler = move |panic_info| {
99            if let Some(index) = current_thread_index() {
100                eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
101            } else {
102                // We want to guarantee exit, rather than panicking in a panic handler.
103                eprintln!(
104                    "rayon panic handler called on non-rayon thread: {:?}",
105                    panic_info
106                );
107            }
108            exit(1);
109        };
110
111        let thread_pool = ThreadPoolBuilder::new()
112            .thread_name(thread_name)
113            .panic_handler(panic_handler)
114            // Make sure there is overlap between records, so GPU is almost always busy
115            .num_threads(2)
116            .build()?;
117
118        Ok(Self {
119            cuda_device,
120            thread_pool,
121            global_mutex,
122        })
123    }
124}