subspace_proof_of_space_gpu/
cuda.rs

1// Originally written by Supranational LLC
2
3#[cfg(test)]
4mod tests;
5
6use rust_kzg_blst::types::fr::FsFr;
7use std::ops::DerefMut;
8use subspace_core_primitives::pieces::Record;
9use subspace_core_primitives::pos::{PosProof, PosSeed};
10use subspace_core_primitives::ScalarBytes;
11use subspace_kzg::Scalar;
12
13extern "C" {
14    /// # Returns
15    /// * `usize` - The number of available GPUs.
16    fn gpu_count() -> usize;
17
18    /// # Parameters
19    /// * `k: The size parameter for the table.
20    /// * `seed: A pointer to the seed data.
21    /// * `lg_record_size: The logarithm of the record size.
22    /// * `challenge_index: A mutable pointer to store the index of the challenge.
23    /// * `record: A pointer to the record data.
24    /// * `chunks_scratch: A mutable pointer to a scratch space for chunk data.
25    /// * `proof_count: A mutable pointer to store the count of proofs.
26    /// * `parity_record_chunks: A mutable pointer to the parity record chunks.
27    /// * `gpu_id: The ID of the GPU to use.
28    ///
29    /// # Returns
30    /// * `sppark::Error` - An error code indicating the result of the operation.
31    ///
32    /// # Assumptions
33    /// * `seed` must be a valid pointer to a 32-byte.
34    /// * `record` must be a valid pointer to the record data (`*const Record`), with a length of `1 << lg_record_size`.
35    /// * `parity_record_chunks` must be valid mutable pointer to `Scalar` elements, each with a length of `1 << lg_record_size`.
36    /// * `chunks_scratch` must be a valid mutable pointer where up to `challenges_count` 32-byte chunks of GPU-calculated data will be written.
37    /// * `gpu_id` must be a valid identifier of an available GPU. The available GPUs can be determined by using the `gpu_count` function.
38    fn generate_and_encode_pospace_dispatch(
39        k: u32,
40        seed: *const [u8; 32],
41        lg_record_size: u32,
42        challenge_index: *mut u32,
43        record: *const [u8; 32],
44        chunks_scratch: *mut [u8; 32],
45        proof_count: *mut u32,
46        parity_record_chunks: *mut FsFr,
47        gpu_id: i32,
48    ) -> sppark::Error;
49}
50
51/// Returns [`CudaDevice`] for each available device
52pub fn cuda_devices() -> Vec<CudaDevice> {
53    let num_devices = unsafe { gpu_count() };
54
55    (0i32..)
56        .take(num_devices)
57        .map(|gpu_id| CudaDevice { gpu_id })
58        .collect()
59}
60
61/// Wrapper data structure encapsulating a single CUDA-capable device
62#[derive(Debug)]
63pub struct CudaDevice {
64    gpu_id: i32,
65}
66
67impl CudaDevice {
68    /// Cuda device ID
69    pub fn id(&self) -> i32 {
70        self.gpu_id
71    }
72
73    /// Generates and encodes PoSpace on the GPU
74    pub fn generate_and_encode_pospace(
75        &self,
76        seed: &PosSeed,
77        record: &mut Record,
78        encoded_chunks_used_output: impl ExactSizeIterator<Item = impl DerefMut<Target = bool>>,
79    ) -> Result<(), String> {
80        let record_len = Record::NUM_CHUNKS;
81        let challenge_len = Record::NUM_S_BUCKETS;
82        let lg_record_size = record_len.ilog2();
83
84        if challenge_len > u32::MAX as usize {
85            return Err(String::from("challenge_len is too large to fit in u32"));
86        }
87
88        let mut proof_count = 0u32;
89        let mut chunks_scratch_gpu =
90            Vec::<[u8; ScalarBytes::FULL_BYTES]>::with_capacity(challenge_len);
91        let mut challenge_index_gpu = Vec::<u32>::with_capacity(challenge_len);
92        let mut parity_record_chunks = Vec::<Scalar>::with_capacity(Record::NUM_CHUNKS);
93
94        let error = unsafe {
95            generate_and_encode_pospace_dispatch(
96                u32::from(PosProof::K),
97                &**seed,
98                lg_record_size,
99                challenge_index_gpu.as_mut_ptr(),
100                record.as_ptr(),
101                chunks_scratch_gpu.as_mut_ptr(),
102                &mut proof_count,
103                Scalar::slice_mut_to_repr(&mut parity_record_chunks).as_mut_ptr(),
104                self.gpu_id,
105            )
106        };
107
108        if error.code != 0 {
109            let error = error.to_string();
110            if error.contains("the provided PTX was compiled with an unsupported toolchain.") {
111                return Err(format!(
112                    "Nvidia driver is likely too old, make sure install version 550 or newer: \
113                    {error}"
114                ));
115            }
116            return Err(error);
117        }
118
119        let proof_count = proof_count as usize;
120        unsafe {
121            chunks_scratch_gpu.set_len(proof_count);
122            challenge_index_gpu.set_len(proof_count);
123            parity_record_chunks.set_len(Record::NUM_CHUNKS);
124        }
125
126        let mut encoded_chunks_used = vec![false; challenge_len];
127        let source_record_chunks = record.to_vec();
128
129        let mut chunks_scratch = challenge_index_gpu
130            .into_iter()
131            .zip(chunks_scratch_gpu)
132            .collect::<Vec<_>>();
133
134        chunks_scratch
135            .sort_unstable_by(|(a_out_index, _), (b_out_index, _)| a_out_index.cmp(b_out_index));
136
137        // We don't need all the proofs
138        chunks_scratch.truncate(proof_count.min(Record::NUM_CHUNKS));
139
140        for (out_index, _chunk) in &chunks_scratch {
141            encoded_chunks_used[*out_index as usize] = true;
142        }
143
144        encoded_chunks_used_output
145            .zip(&encoded_chunks_used)
146            .for_each(|(mut output, input)| *output = *input);
147
148        record
149            .iter_mut()
150            .zip(
151                chunks_scratch
152                    .into_iter()
153                    .map(|(_out_index, chunk)| chunk)
154                    .chain(
155                        source_record_chunks
156                            .into_iter()
157                            .zip(parity_record_chunks)
158                            .flat_map(|(a, b)| [a, b.to_bytes()])
159                            .zip(encoded_chunks_used.iter())
160                            // Skip chunks that were used previously
161                            .filter_map(|(record_chunk, encoded_chunk_used)| {
162                                if *encoded_chunk_used {
163                                    None
164                                } else {
165                                    Some(record_chunk)
166                                }
167                            }),
168                    ),
169            )
170            .for_each(|(output_chunk, input_chunk)| {
171                *output_chunk = input_chunk;
172            });
173
174        Ok(())
175    }
176}