subspace_proof_of_space_gpu/
cuda.rs1#[cfg(test)]
4mod tests;
5
6use rust_kzg_blst::types::fr::FsFr;
7use std::ops::DerefMut;
8use subspace_core_primitives::pieces::Record;
9use subspace_core_primitives::pos::{PosProof, PosSeed};
10use subspace_core_primitives::ScalarBytes;
11use subspace_kzg::Scalar;
12
13extern "C" {
14 fn gpu_count() -> usize;
17
18 fn generate_and_encode_pospace_dispatch(
39 k: u32,
40 seed: *const [u8; 32],
41 lg_record_size: u32,
42 challenge_index: *mut u32,
43 record: *const [u8; 32],
44 chunks_scratch: *mut [u8; 32],
45 proof_count: *mut u32,
46 parity_record_chunks: *mut FsFr,
47 gpu_id: i32,
48 ) -> sppark::Error;
49}
50
51pub fn cuda_devices() -> Vec<CudaDevice> {
53 let num_devices = unsafe { gpu_count() };
54
55 (0i32..)
56 .take(num_devices)
57 .map(|gpu_id| CudaDevice { gpu_id })
58 .collect()
59}
60
61#[derive(Debug)]
63pub struct CudaDevice {
64 gpu_id: i32,
65}
66
67impl CudaDevice {
68 pub fn id(&self) -> i32 {
70 self.gpu_id
71 }
72
73 pub fn generate_and_encode_pospace(
75 &self,
76 seed: &PosSeed,
77 record: &mut Record,
78 encoded_chunks_used_output: impl ExactSizeIterator<Item = impl DerefMut<Target = bool>>,
79 ) -> Result<(), String> {
80 let record_len = Record::NUM_CHUNKS;
81 let challenge_len = Record::NUM_S_BUCKETS;
82 let lg_record_size = record_len.ilog2();
83
84 if challenge_len > u32::MAX as usize {
85 return Err(String::from("challenge_len is too large to fit in u32"));
86 }
87
88 let mut proof_count = 0u32;
89 let mut chunks_scratch_gpu =
90 Vec::<[u8; ScalarBytes::FULL_BYTES]>::with_capacity(challenge_len);
91 let mut challenge_index_gpu = Vec::<u32>::with_capacity(challenge_len);
92 let mut parity_record_chunks = Vec::<Scalar>::with_capacity(Record::NUM_CHUNKS);
93
94 let error = unsafe {
95 generate_and_encode_pospace_dispatch(
96 u32::from(PosProof::K),
97 &**seed,
98 lg_record_size,
99 challenge_index_gpu.as_mut_ptr(),
100 record.as_ptr(),
101 chunks_scratch_gpu.as_mut_ptr(),
102 &mut proof_count,
103 Scalar::slice_mut_to_repr(&mut parity_record_chunks).as_mut_ptr(),
104 self.gpu_id,
105 )
106 };
107
108 if error.code != 0 {
109 let error = error.to_string();
110 if error.contains("the provided PTX was compiled with an unsupported toolchain.") {
111 return Err(format!(
112 "Nvidia driver is likely too old, make sure install version 550 or newer: \
113 {error}"
114 ));
115 }
116 return Err(error);
117 }
118
119 let proof_count = proof_count as usize;
120 unsafe {
121 chunks_scratch_gpu.set_len(proof_count);
122 challenge_index_gpu.set_len(proof_count);
123 parity_record_chunks.set_len(Record::NUM_CHUNKS);
124 }
125
126 let mut encoded_chunks_used = vec![false; challenge_len];
127 let source_record_chunks = record.to_vec();
128
129 let mut chunks_scratch = challenge_index_gpu
130 .into_iter()
131 .zip(chunks_scratch_gpu)
132 .collect::<Vec<_>>();
133
134 chunks_scratch
135 .sort_unstable_by(|(a_out_index, _), (b_out_index, _)| a_out_index.cmp(b_out_index));
136
137 chunks_scratch.truncate(proof_count.min(Record::NUM_CHUNKS));
139
140 for (out_index, _chunk) in &chunks_scratch {
141 encoded_chunks_used[*out_index as usize] = true;
142 }
143
144 encoded_chunks_used_output
145 .zip(&encoded_chunks_used)
146 .for_each(|(mut output, input)| *output = *input);
147
148 record
149 .iter_mut()
150 .zip(
151 chunks_scratch
152 .into_iter()
153 .map(|(_out_index, chunk)| chunk)
154 .chain(
155 source_record_chunks
156 .into_iter()
157 .zip(parity_record_chunks)
158 .flat_map(|(a, b)| [a, b.to_bytes()])
159 .zip(encoded_chunks_used.iter())
160 .filter_map(|(record_chunk, encoded_chunk_used)| {
162 if *encoded_chunk_used {
163 None
164 } else {
165 Some(record_chunk)
166 }
167 }),
168 ),
169 )
170 .for_each(|(output_chunk, input_chunk)| {
171 *output_chunk = input_chunk;
172 });
173
174 Ok(())
175 }
176}