subspace_proof_of_space_gpu/cuda.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#[cfg(test)]
mod tests;
use rust_kzg_blst::types::fr::FsFr;
use std::ops::DerefMut;
use subspace_core_primitives::pieces::Record;
use subspace_core_primitives::pos::{PosProof, PosSeed};
use subspace_core_primitives::ScalarBytes;
use subspace_kzg::Scalar;
extern "C" {
/// # Returns
/// * `usize` - The number of available GPUs.
fn gpu_count() -> usize;
/// # Parameters
/// * `k: The size parameter for the table.
/// * `seed: A pointer to the seed data.
/// * `lg_record_size: The logarithm of the record size.
/// * `challenge_index: A mutable pointer to store the index of the challenge.
/// * `record: A pointer to the record data.
/// * `chunks_scratch: A mutable pointer to a scratch space for chunk data.
/// * `proof_count: A mutable pointer to store the count of proofs.
/// * `parity_record_chunks: A mutable pointer to the parity record chunks.
/// * `gpu_id: The ID of the GPU to use.
///
/// # Returns
/// * `sppark::Error` - An error code indicating the result of the operation.
///
/// # Assumptions
/// * `seed` must be a valid pointer to a 32-byte.
/// * `record` must be a valid pointer to the record data (`*const Record`), with a length of `1 << lg_record_size`.
/// * `parity_record_chunks` must be valid mutable pointer to `Scalar` elements, each with a length of `1 << lg_record_size`.
/// * `chunks_scratch` must be a valid mutable pointer where up to `challenges_count` 32-byte chunks of GPU-calculated data will be written.
/// * `gpu_id` must be a valid identifier of an available GPU. The available GPUs can be determined by using the `gpu_count` function.
fn generate_and_encode_pospace_dispatch(
k: u32,
seed: *const [u8; 32],
lg_record_size: u32,
challenge_index: *mut u32,
record: *const [u8; 32],
chunks_scratch: *mut [u8; 32],
proof_count: *mut u32,
parity_record_chunks: *mut FsFr,
gpu_id: i32,
) -> sppark::Error;
}
/// Returns [`CudaDevice`] for each available device
pub fn cuda_devices() -> Vec<CudaDevice> {
let num_devices = unsafe { gpu_count() };
(0i32..)
.take(num_devices)
.map(|gpu_id| CudaDevice { gpu_id })
.collect()
}
/// Wrapper data structure encapsulating a single CUDA-capable device
#[derive(Debug)]
pub struct CudaDevice {
gpu_id: i32,
}
impl CudaDevice {
/// Cuda device ID
pub fn id(&self) -> i32 {
self.gpu_id
}
/// Generates and encodes PoSpace on the GPU
pub fn generate_and_encode_pospace(
&self,
seed: &PosSeed,
record: &mut Record,
encoded_chunks_used_output: impl ExactSizeIterator<Item = impl DerefMut<Target = bool>>,
) -> Result<(), String> {
let record_len = Record::NUM_CHUNKS;
let challenge_len = Record::NUM_S_BUCKETS;
let lg_record_size = record_len.ilog2();
if challenge_len > u32::MAX as usize {
return Err(String::from("challenge_len is too large to fit in u32"));
}
let mut proof_count = 0u32;
let mut chunks_scratch_gpu =
Vec::<[u8; ScalarBytes::FULL_BYTES]>::with_capacity(challenge_len);
let mut challenge_index_gpu = Vec::<u32>::with_capacity(challenge_len);
let mut parity_record_chunks = Vec::<Scalar>::with_capacity(Record::NUM_CHUNKS);
let error = unsafe {
generate_and_encode_pospace_dispatch(
u32::from(PosProof::K),
&**seed,
lg_record_size,
challenge_index_gpu.as_mut_ptr(),
record.as_ptr(),
chunks_scratch_gpu.as_mut_ptr(),
&mut proof_count,
Scalar::slice_mut_to_repr(&mut parity_record_chunks).as_mut_ptr(),
self.gpu_id,
)
};
if error.code != 0 {
let error = error.to_string();
if error.contains("the provided PTX was compiled with an unsupported toolchain.") {
return Err(format!(
"Nvidia driver is likely too old, make sure install version 550 or newer: \
{error}"
));
}
return Err(error);
}
let proof_count = proof_count as usize;
unsafe {
chunks_scratch_gpu.set_len(proof_count);
challenge_index_gpu.set_len(proof_count);
parity_record_chunks.set_len(Record::NUM_CHUNKS);
}
let mut encoded_chunks_used = vec![false; challenge_len];
let source_record_chunks = record.to_vec();
let mut chunks_scratch = challenge_index_gpu
.into_iter()
.zip(chunks_scratch_gpu)
.collect::<Vec<_>>();
chunks_scratch
.sort_unstable_by(|(a_out_index, _), (b_out_index, _)| a_out_index.cmp(b_out_index));
// We don't need all the proofs
chunks_scratch.truncate(proof_count.min(Record::NUM_CHUNKS));
for (out_index, _chunk) in &chunks_scratch {
encoded_chunks_used[*out_index as usize] = true;
}
encoded_chunks_used_output
.zip(&encoded_chunks_used)
.for_each(|(mut output, input)| *output = *input);
record
.iter_mut()
.zip(
chunks_scratch
.into_iter()
.map(|(_out_index, chunk)| chunk)
.chain(
source_record_chunks
.into_iter()
.zip(parity_record_chunks)
.flat_map(|(a, b)| [a, b.to_bytes()])
.zip(encoded_chunks_used.iter())
// Skip chunks that were used previously
.filter_map(|(record_chunk, encoded_chunk_used)| {
if *encoded_chunk_used {
None
} else {
Some(record_chunk)
}
}),
),
)
.for_each(|(output_chunk, input_chunk)| {
*output_chunk = input_chunk;
});
Ok(())
}
}