subspace_farmer/plotter/gpu/
cuda.rs1use crate::plotter::gpu::GpuRecordsEncoder;
4use async_lock::Mutex as AsyncMutex;
5use parking_lot::Mutex;
6use rayon::{current_thread_index, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
7use std::process::exit;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use subspace_core_primitives::pieces::{PieceOffset, Record};
11use subspace_core_primitives::sectors::SectorId;
12use subspace_farmer_components::plotting::RecordsEncoder;
13use subspace_farmer_components::sector::SectorContentsMap;
14use subspace_proof_of_space_gpu::cuda::CudaDevice;
15
16#[derive(Debug)]
18pub struct CudaRecordsEncoder {
19 cuda_device: CudaDevice,
20 thread_pool: ThreadPool,
21 global_mutex: Arc<AsyncMutex<()>>,
22}
23
24impl GpuRecordsEncoder for CudaRecordsEncoder {
25 const TYPE: &'static str = "cuda";
26}
27
28impl RecordsEncoder for CudaRecordsEncoder {
29 fn encode_records(
30 &mut self,
31 sector_id: &SectorId,
32 records: &mut [Record],
33 abort_early: &AtomicBool,
34 ) -> anyhow::Result<SectorContentsMap> {
35 let pieces_in_sector = records
36 .len()
37 .try_into()
38 .map_err(|error| anyhow::anyhow!("Failed to convert pieces in sector: {error}"))?;
39 let mut sector_contents_map = SectorContentsMap::new(pieces_in_sector);
40
41 {
42 let iter = Mutex::new(
43 (PieceOffset::ZERO..)
44 .zip(records.iter_mut())
45 .zip(sector_contents_map.iter_record_bitfields_mut()),
46 );
47 let plotting_error = Mutex::new(None::<String>);
48
49 self.thread_pool.scope(|scope| {
50 scope.spawn_broadcast(|_scope, _ctx| loop {
51 self.global_mutex.lock_blocking();
53
54 let Some(((piece_offset, record), mut encoded_chunks_used)) =
57 iter.lock().next()
58 else {
59 return;
60 };
61 let pos_seed = sector_id.derive_evaluation_seed(piece_offset);
62
63 if let Err(error) = self.cuda_device.generate_and_encode_pospace(
64 &pos_seed,
65 record,
66 encoded_chunks_used.iter_mut(),
67 ) {
68 plotting_error.lock().replace(error);
69 return;
70 }
71
72 if abort_early.load(Ordering::Relaxed) {
73 return;
74 }
75 });
76 });
77
78 let plotting_error = plotting_error.lock().take();
79 if let Some(error) = plotting_error {
80 return Err(anyhow::Error::msg(error));
81 }
82 }
83
84 Ok(sector_contents_map)
85 }
86}
87
88impl CudaRecordsEncoder {
89 pub fn new(
91 cuda_device: CudaDevice,
92 global_mutex: Arc<AsyncMutex<()>>,
93 ) -> Result<Self, ThreadPoolBuildError> {
94 let id = cuda_device.id();
95 let thread_name = move |thread_index| format!("cuda-{id}.{thread_index}");
96 let panic_handler = move |panic_info| {
99 if let Some(index) = current_thread_index() {
100 eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
101 } else {
102 eprintln!(
104 "rayon panic handler called on non-rayon thread: {:?}",
105 panic_info
106 );
107 }
108 exit(1);
109 };
110
111 let thread_pool = ThreadPoolBuilder::new()
112 .thread_name(thread_name)
113 .panic_handler(panic_handler)
114 .num_threads(2)
116 .build()?;
117
118 Ok(Self {
119 cuda_device,
120 thread_pool,
121 global_mutex,
122 })
123 }
124}