subspace_proof_of_space_gpu/
cuda.rs1#[cfg(test)]
4mod tests;
5
6use rust_kzg_blst::types::fr::FsFr;
7use std::ops::DerefMut;
8use subspace_core_primitives::ScalarBytes;
9use subspace_core_primitives::pieces::Record;
10use subspace_core_primitives::pos::{PosProof, PosSeed};
11use subspace_kzg::Scalar;
12
13unsafe extern "C" {
14 fn gpu_count() -> usize;
17
18 fn generate_and_encode_pospace_dispatch(
39 k: u32,
40 seed: *const [u8; 32],
41 lg_record_size: u32,
42 challenge_index: *mut u32,
43 record: *const [u8; 32],
44 chunks_scratch: *mut [u8; 32],
45 proof_count: *mut u32,
46 parity_record_chunks: *mut FsFr,
47 gpu_id: i32,
48 ) -> sppark::Error;
49}
50
51pub fn cuda_devices() -> Vec<CudaDevice> {
53 let num_devices = unsafe { gpu_count() };
54
55 (0i32..)
56 .take(num_devices)
57 .map(|gpu_id| CudaDevice { gpu_id })
58 .collect()
59}
60
61#[derive(Debug)]
63pub struct CudaDevice {
64 gpu_id: i32,
65}
66
67impl CudaDevice {
68 pub fn id(&self) -> i32 {
70 self.gpu_id
71 }
72
73 pub fn generate_and_encode_pospace(
75 &self,
76 seed: &PosSeed,
77 record: &mut Record,
78 encoded_chunks_used_output: impl ExactSizeIterator<Item = impl DerefMut<Target = bool>>,
79 ) -> Result<(), String> {
80 let record_len = Record::NUM_CHUNKS;
81 let challenge_len = Record::NUM_S_BUCKETS;
82 let lg_record_size = record_len.ilog2();
83
84 if challenge_len > u32::MAX as usize {
85 return Err(String::from("challenge_len is too large to fit in u32"));
86 }
87
88 let mut proof_count = 0u32;
89 let mut chunks_scratch_gpu =
90 Vec::<[u8; ScalarBytes::FULL_BYTES]>::with_capacity(challenge_len);
91 let mut challenge_index_gpu = Vec::<u32>::with_capacity(challenge_len);
92 let mut parity_record_chunks = Vec::<Scalar>::with_capacity(Record::NUM_CHUNKS);
93
94 let error = unsafe {
95 generate_and_encode_pospace_dispatch(
96 u32::from(PosProof::K),
97 &**seed,
98 lg_record_size,
99 challenge_index_gpu.as_mut_ptr(),
100 record.as_ptr(),
101 chunks_scratch_gpu.as_mut_ptr(),
102 &mut proof_count,
103 Scalar::slice_mut_to_repr(&mut parity_record_chunks).as_mut_ptr(),
104 self.gpu_id,
105 )
106 };
107
108 if error.code != 0 {
109 let error = error.to_string();
110 if error.contains("the provided PTX was compiled with an unsupported toolchain.") {
111 return Err(format!(
112 "Nvidia driver is likely too old, make sure install version 550 or newer: \
113 {error}"
114 ));
115 }
116 return Err(error);
117 }
118
119 let proof_count = proof_count as usize;
120 unsafe {
121 chunks_scratch_gpu.set_len(proof_count);
122 challenge_index_gpu.set_len(proof_count);
123 parity_record_chunks.set_len(Record::NUM_CHUNKS);
124 }
125
126 let mut encoded_chunks_used = vec![false; challenge_len];
127 let source_record_chunks = record.to_vec();
128
129 let mut chunks_scratch = challenge_index_gpu
130 .into_iter()
131 .zip(chunks_scratch_gpu)
132 .collect::<Vec<_>>();
133
134 chunks_scratch.sort_unstable_by_key(|(a_out_index, _)| *a_out_index);
135
136 chunks_scratch.truncate(proof_count.min(Record::NUM_CHUNKS));
138
139 for (out_index, _chunk) in &chunks_scratch {
140 encoded_chunks_used[*out_index as usize] = true;
141 }
142
143 encoded_chunks_used_output
144 .zip(&encoded_chunks_used)
145 .for_each(|(mut output, input)| *output = *input);
146
147 record
148 .iter_mut()
149 .zip(
150 chunks_scratch
151 .into_iter()
152 .map(|(_out_index, chunk)| chunk)
153 .chain(
154 source_record_chunks
155 .into_iter()
156 .zip(parity_record_chunks)
157 .flat_map(|(a, b)| [a, b.to_bytes()])
158 .zip(encoded_chunks_used.iter())
159 .filter_map(|(record_chunk, encoded_chunk_used)| {
161 if *encoded_chunk_used {
162 None
163 } else {
164 Some(record_chunk)
165 }
166 }),
167 ),
168 )
169 .for_each(|(output_chunk, input_chunk)| {
170 *output_chunk = input_chunk;
171 });
172
173 Ok(())
174 }
175}