subspace_proof_of_space_gpu/
cuda.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// Originally written by Supranational LLC

#[cfg(test)]
mod tests;

use rust_kzg_blst::types::fr::FsFr;
use std::ops::DerefMut;
use subspace_core_primitives::pieces::Record;
use subspace_core_primitives::pos::{PosProof, PosSeed};
use subspace_core_primitives::ScalarBytes;
use subspace_kzg::Scalar;

extern "C" {
    /// # Returns
    /// * `usize` - The number of available GPUs.
    fn gpu_count() -> usize;

    /// # Parameters
    /// * `k: The size parameter for the table.
    /// * `seed: A pointer to the seed data.
    /// * `lg_record_size: The logarithm of the record size.
    /// * `challenge_index: A mutable pointer to store the index of the challenge.
    /// * `record: A pointer to the record data.
    /// * `chunks_scratch: A mutable pointer to a scratch space for chunk data.
    /// * `proof_count: A mutable pointer to store the count of proofs.
    /// * `parity_record_chunks: A mutable pointer to the parity record chunks.
    /// * `gpu_id: The ID of the GPU to use.
    ///
    /// # Returns
    /// * `sppark::Error` - An error code indicating the result of the operation.
    ///
    /// # Assumptions
    /// * `seed` must be a valid pointer to a 32-byte.
    /// * `record` must be a valid pointer to the record data (`*const Record`), with a length of `1 << lg_record_size`.
    /// * `parity_record_chunks` must be valid mutable pointer to `Scalar` elements, each with a length of `1 << lg_record_size`.
    /// * `chunks_scratch` must be a valid mutable pointer where up to `challenges_count` 32-byte chunks of GPU-calculated data will be written.
    /// * `gpu_id` must be a valid identifier of an available GPU. The available GPUs can be determined by using the `gpu_count` function.
    fn generate_and_encode_pospace_dispatch(
        k: u32,
        seed: *const [u8; 32],
        lg_record_size: u32,
        challenge_index: *mut u32,
        record: *const [u8; 32],
        chunks_scratch: *mut [u8; 32],
        proof_count: *mut u32,
        parity_record_chunks: *mut FsFr,
        gpu_id: i32,
    ) -> sppark::Error;
}

/// Returns [`CudaDevice`] for each available device
pub fn cuda_devices() -> Vec<CudaDevice> {
    let num_devices = unsafe { gpu_count() };

    (0i32..)
        .take(num_devices)
        .map(|gpu_id| CudaDevice { gpu_id })
        .collect()
}

/// Wrapper data structure encapsulating a single CUDA-capable device
#[derive(Debug)]
pub struct CudaDevice {
    gpu_id: i32,
}

impl CudaDevice {
    /// Cuda device ID
    pub fn id(&self) -> i32 {
        self.gpu_id
    }

    /// Generates and encodes PoSpace on the GPU
    pub fn generate_and_encode_pospace(
        &self,
        seed: &PosSeed,
        record: &mut Record,
        encoded_chunks_used_output: impl ExactSizeIterator<Item = impl DerefMut<Target = bool>>,
    ) -> Result<(), String> {
        let record_len = Record::NUM_CHUNKS;
        let challenge_len = Record::NUM_S_BUCKETS;
        let lg_record_size = record_len.ilog2();

        if challenge_len > u32::MAX as usize {
            return Err(String::from("challenge_len is too large to fit in u32"));
        }

        let mut proof_count = 0u32;
        let mut chunks_scratch_gpu =
            Vec::<[u8; ScalarBytes::FULL_BYTES]>::with_capacity(challenge_len);
        let mut challenge_index_gpu = Vec::<u32>::with_capacity(challenge_len);
        let mut parity_record_chunks = Vec::<Scalar>::with_capacity(Record::NUM_CHUNKS);

        let error = unsafe {
            generate_and_encode_pospace_dispatch(
                u32::from(PosProof::K),
                &**seed,
                lg_record_size,
                challenge_index_gpu.as_mut_ptr(),
                record.as_ptr(),
                chunks_scratch_gpu.as_mut_ptr(),
                &mut proof_count,
                Scalar::slice_mut_to_repr(&mut parity_record_chunks).as_mut_ptr(),
                self.gpu_id,
            )
        };

        if error.code != 0 {
            let error = error.to_string();
            if error.contains("the provided PTX was compiled with an unsupported toolchain.") {
                return Err(format!(
                    "Nvidia driver is likely too old, make sure install version 550 or newer: \
                    {error}"
                ));
            }
            return Err(error);
        }

        let proof_count = proof_count as usize;
        unsafe {
            chunks_scratch_gpu.set_len(proof_count);
            challenge_index_gpu.set_len(proof_count);
            parity_record_chunks.set_len(Record::NUM_CHUNKS);
        }

        let mut encoded_chunks_used = vec![false; challenge_len];
        let source_record_chunks = record.to_vec();

        let mut chunks_scratch = challenge_index_gpu
            .into_iter()
            .zip(chunks_scratch_gpu)
            .collect::<Vec<_>>();

        chunks_scratch
            .sort_unstable_by(|(a_out_index, _), (b_out_index, _)| a_out_index.cmp(b_out_index));

        // We don't need all the proofs
        chunks_scratch.truncate(proof_count.min(Record::NUM_CHUNKS));

        for (out_index, _chunk) in &chunks_scratch {
            encoded_chunks_used[*out_index as usize] = true;
        }

        encoded_chunks_used_output
            .zip(&encoded_chunks_used)
            .for_each(|(mut output, input)| *output = *input);

        record
            .iter_mut()
            .zip(
                chunks_scratch
                    .into_iter()
                    .map(|(_out_index, chunk)| chunk)
                    .chain(
                        source_record_chunks
                            .into_iter()
                            .zip(parity_record_chunks)
                            .flat_map(|(a, b)| [a, b.to_bytes()])
                            .zip(encoded_chunks_used.iter())
                            // Skip chunks that were used previously
                            .filter_map(|(record_chunk, encoded_chunk_used)| {
                                if *encoded_chunk_used {
                                    None
                                } else {
                                    Some(record_chunk)
                                }
                            }),
                    ),
            )
            .for_each(|(output_chunk, input_chunk)| {
                *output_chunk = input_chunk;
            });

        Ok(())
    }
}