subspace_farmer/plotter/
gpu.rs

1//! GPU plotter
2
3#[cfg(feature = "cuda")]
4pub mod cuda;
5mod gpu_encoders_manager;
6pub mod metrics;
7#[cfg(feature = "rocm")]
8pub mod rocm;
9
10use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
11use crate::plotter::gpu::metrics::GpuPlotterMetrics;
12use crate::plotter::{Plotter, SectorPlottingProgress};
13use crate::utils::AsyncJoinOnDrop;
14use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
15use async_trait::async_trait;
16use bytes::Bytes;
17use event_listener_primitives::{Bag, HandlerId};
18use futures::channel::mpsc;
19use futures::stream::FuturesUnordered;
20use futures::{select, stream, FutureExt, Sink, SinkExt, StreamExt};
21use prometheus_client::registry::Registry;
22use std::error::Error;
23use std::fmt;
24use std::future::pending;
25use std::num::TryFromIntError;
26use std::pin::pin;
27use std::sync::atomic::{AtomicBool, Ordering};
28use std::sync::Arc;
29use std::task::Poll;
30use std::time::Instant;
31use subspace_core_primitives::sectors::SectorIndex;
32use subspace_core_primitives::PublicKey;
33use subspace_data_retrieval::piece_getter::PieceGetter;
34use subspace_erasure_coding::ErasureCoding;
35use subspace_farmer_components::plotting::{
36    download_sector, encode_sector, write_sector, DownloadSectorOptions, EncodeSectorOptions,
37    PlottingError, RecordsEncoder,
38};
39use subspace_farmer_components::FarmerProtocolInfo;
40use subspace_kzg::Kzg;
41use tokio::task::yield_now;
42use tracing::{warn, Instrument};
43
44/// Type alias used for event handlers
45pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
46type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
47
48#[derive(Default, Debug)]
49struct Handlers {
50    plotting_progress: Handler3<PublicKey, SectorIndex, SectorPlottingProgress>,
51}
52
53/// GPU-specific [`RecordsEncoder`] with extra APIs
54pub trait GpuRecordsEncoder: RecordsEncoder + fmt::Debug + Send {
55    /// GPU encoder type, typically related to GPU vendor
56    const TYPE: &'static str;
57}
58
59/// GPU plotter
60pub struct GpuPlotter<PG, GRE> {
61    piece_getter: PG,
62    downloading_semaphore: Arc<Semaphore>,
63    gpu_records_encoders_manager: GpuRecordsEncoderManager<GRE>,
64    global_mutex: Arc<AsyncMutex<()>>,
65    kzg: Kzg,
66    erasure_coding: ErasureCoding,
67    handlers: Arc<Handlers>,
68    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
69    _background_tasks: AsyncJoinOnDrop<()>,
70    abort_early: Arc<AtomicBool>,
71    metrics: Option<Arc<GpuPlotterMetrics>>,
72}
73
74impl<PG, GRE> fmt::Debug for GpuPlotter<PG, GRE>
75where
76    GRE: GpuRecordsEncoder + 'static,
77{
78    #[inline]
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct(&format!("GpuPlotter[type = {}]", GRE::TYPE))
81            .finish_non_exhaustive()
82    }
83}
84
85impl<PG, RE> Drop for GpuPlotter<PG, RE> {
86    #[inline]
87    fn drop(&mut self) {
88        self.abort_early.store(true, Ordering::Release);
89        self.tasks_sender.close_channel();
90    }
91}
92
93#[async_trait]
94impl<PG, GRE> Plotter for GpuPlotter<PG, GRE>
95where
96    PG: PieceGetter + Clone + Send + Sync + 'static,
97    GRE: GpuRecordsEncoder + 'static,
98{
99    async fn has_free_capacity(&self) -> Result<bool, String> {
100        Ok(self.downloading_semaphore.try_acquire().is_some())
101    }
102
103    async fn plot_sector(
104        &self,
105        public_key: PublicKey,
106        sector_index: SectorIndex,
107        farmer_protocol_info: FarmerProtocolInfo,
108        pieces_in_sector: u16,
109        _replotting: bool,
110        progress_sender: mpsc::Sender<SectorPlottingProgress>,
111    ) {
112        let start = Instant::now();
113
114        // Done outside the future below as a backpressure, ensuring that it is not possible to
115        // schedule unbounded number of plotting tasks
116        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
117
118        self.plot_sector_internal(
119            start,
120            downloading_permit,
121            public_key,
122            sector_index,
123            farmer_protocol_info,
124            pieces_in_sector,
125            progress_sender,
126        )
127        .await
128    }
129
130    async fn try_plot_sector(
131        &self,
132        public_key: PublicKey,
133        sector_index: SectorIndex,
134        farmer_protocol_info: FarmerProtocolInfo,
135        pieces_in_sector: u16,
136        _replotting: bool,
137        progress_sender: mpsc::Sender<SectorPlottingProgress>,
138    ) -> bool {
139        let start = Instant::now();
140
141        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
142            return false;
143        };
144
145        self.plot_sector_internal(
146            start,
147            downloading_permit,
148            public_key,
149            sector_index,
150            farmer_protocol_info,
151            pieces_in_sector,
152            progress_sender,
153        )
154        .await;
155
156        true
157    }
158}
159
160impl<PG, GRE> GpuPlotter<PG, GRE>
161where
162    PG: PieceGetter + Clone + Send + Sync + 'static,
163    GRE: GpuRecordsEncoder + 'static,
164{
165    /// Create new instance.
166    ///
167    /// Returns an error if empty list of encoders is provided.
168    pub fn new(
169        piece_getter: PG,
170        downloading_semaphore: Arc<Semaphore>,
171        gpu_records_encoders: Vec<GRE>,
172        global_mutex: Arc<AsyncMutex<()>>,
173        kzg: Kzg,
174        erasure_coding: ErasureCoding,
175        registry: Option<&mut Registry>,
176    ) -> Result<Self, TryFromIntError> {
177        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
178
179        // Basically runs plotting tasks in the background and allows to abort on drop
180        let background_tasks = AsyncJoinOnDrop::new(
181            tokio::spawn(async move {
182                let background_tasks = FuturesUnordered::new();
183                let mut background_tasks = pin!(background_tasks);
184                // Just so that `FuturesUnordered` will never end
185                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
186
187                loop {
188                    select! {
189                        maybe_background_task = tasks_receiver.next().fuse() => {
190                            let Some(background_task) = maybe_background_task else {
191                                break;
192                            };
193
194                            background_tasks.push(background_task);
195                        },
196                        _ = background_tasks.select_next_some() => {
197                            // Nothing to do
198                        }
199                    }
200                }
201            }),
202            true,
203        );
204
205        let abort_early = Arc::new(AtomicBool::new(false));
206        let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
207        let metrics = registry.map(|registry| {
208            Arc::new(GpuPlotterMetrics::new(
209                registry,
210                GRE::TYPE,
211                gpu_records_encoders_manager.gpu_records_encoders(),
212            ))
213        });
214
215        Ok(Self {
216            piece_getter,
217            downloading_semaphore,
218            gpu_records_encoders_manager,
219            global_mutex,
220            kzg,
221            erasure_coding,
222            handlers: Arc::default(),
223            tasks_sender,
224            _background_tasks: background_tasks,
225            abort_early,
226            metrics,
227        })
228    }
229
230    /// Subscribe to plotting progress notifications
231    pub fn on_plotting_progress(
232        &self,
233        callback: HandlerFn3<PublicKey, SectorIndex, SectorPlottingProgress>,
234    ) -> HandlerId {
235        self.handlers.plotting_progress.add(callback)
236    }
237
238    #[allow(clippy::too_many_arguments)]
239    async fn plot_sector_internal<PS>(
240        &self,
241        start: Instant,
242        downloading_permit: SemaphoreGuardArc,
243        public_key: PublicKey,
244        sector_index: SectorIndex,
245        farmer_protocol_info: FarmerProtocolInfo,
246        pieces_in_sector: u16,
247        mut progress_sender: PS,
248    ) where
249        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
250        PS::Error: Error,
251    {
252        if let Some(metrics) = &self.metrics {
253            metrics.sector_plotting.inc();
254        }
255
256        let progress_updater = ProgressUpdater {
257            public_key,
258            sector_index,
259            handlers: Arc::clone(&self.handlers),
260            metrics: self.metrics.clone(),
261        };
262
263        let plotting_fut = {
264            let piece_getter = self.piece_getter.clone();
265            let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
266            let global_mutex = Arc::clone(&self.global_mutex);
267            let kzg = self.kzg.clone();
268            let erasure_coding = self.erasure_coding.clone();
269            let abort_early = Arc::clone(&self.abort_early);
270            let metrics = self.metrics.clone();
271
272            async move {
273                // Downloading
274                let downloaded_sector = {
275                    if !progress_updater
276                        .update_progress_and_events(
277                            &mut progress_sender,
278                            SectorPlottingProgress::Downloading,
279                        )
280                        .await
281                    {
282                        return;
283                    }
284
285                    // Take mutex briefly to make sure plotting is allowed right now
286                    global_mutex.lock().await;
287
288                    let downloading_start = Instant::now();
289
290                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
291                        public_key: &public_key,
292                        sector_index,
293                        piece_getter: &piece_getter,
294                        farmer_protocol_info,
295                        kzg: &kzg,
296                        erasure_coding: &erasure_coding,
297                        pieces_in_sector,
298                    });
299
300                    let downloaded_sector = match downloaded_sector_fut.await {
301                        Ok(downloaded_sector) => downloaded_sector,
302                        Err(error) => {
303                            warn!(%error, "Failed to download sector");
304
305                            progress_updater
306                                .update_progress_and_events(
307                                    &mut progress_sender,
308                                    SectorPlottingProgress::Error {
309                                        error: format!("Failed to download sector: {error}"),
310                                    },
311                                )
312                                .await;
313
314                            return;
315                        }
316                    };
317
318                    if !progress_updater
319                        .update_progress_and_events(
320                            &mut progress_sender,
321                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
322                        )
323                        .await
324                    {
325                        return;
326                    }
327
328                    downloaded_sector
329                };
330
331                // Plotting
332                let (sector, plotted_sector) = {
333                    let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
334                    if let Some(metrics) = &metrics {
335                        metrics.plotting_capacity_used.inc();
336                    }
337
338                    // Give a chance to interrupt plotting if necessary
339                    yield_now().await;
340
341                    if !progress_updater
342                        .update_progress_and_events(
343                            &mut progress_sender,
344                            SectorPlottingProgress::Encoding,
345                        )
346                        .await
347                    {
348                        if let Some(metrics) = &metrics {
349                            metrics.plotting_capacity_used.dec();
350                        }
351                        return;
352                    }
353
354                    let encoding_start = Instant::now();
355
356                    let plotting_result = tokio::task::block_in_place(move || {
357                        let encoded_sector = encode_sector(
358                            downloaded_sector,
359                            EncodeSectorOptions {
360                                sector_index,
361                                records_encoder: &mut *records_encoder,
362                                abort_early: &abort_early,
363                            },
364                        )?;
365
366                        if abort_early.load(Ordering::Acquire) {
367                            return Err(PlottingError::AbortEarly);
368                        }
369
370                        drop(records_encoder);
371
372                        let mut sector = Vec::new();
373
374                        write_sector(&encoded_sector, &mut sector)?;
375
376                        Ok((sector, encoded_sector.plotted_sector))
377                    });
378
379                    if let Some(metrics) = &metrics {
380                        metrics.plotting_capacity_used.dec();
381                    }
382
383                    match plotting_result {
384                        Ok(plotting_result) => {
385                            if !progress_updater
386                                .update_progress_and_events(
387                                    &mut progress_sender,
388                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
389                                )
390                                .await
391                            {
392                                return;
393                            }
394
395                            plotting_result
396                        }
397                        Err(PlottingError::AbortEarly) => {
398                            return;
399                        }
400                        Err(error) => {
401                            progress_updater
402                                .update_progress_and_events(
403                                    &mut progress_sender,
404                                    SectorPlottingProgress::Error {
405                                        error: format!("Failed to encode sector: {error}"),
406                                    },
407                                )
408                                .await;
409
410                            return;
411                        }
412                    }
413                };
414
415                progress_updater
416                    .update_progress_and_events(
417                        &mut progress_sender,
418                        SectorPlottingProgress::Finished {
419                            plotted_sector,
420                            time: start.elapsed(),
421                            sector: Box::pin({
422                                let mut sector = Some(Ok(Bytes::from(sector)));
423
424                                stream::poll_fn(move |_cx| {
425                                    // Just so that permit is dropped with stream itself
426                                    let _downloading_permit = &downloading_permit;
427
428                                    Poll::Ready(sector.take())
429                                })
430                            }),
431                        },
432                    )
433                    .await;
434            }
435        };
436
437        // Spawn a separate task such that `block_in_place` inside will not affect anything else
438        let plotting_task =
439            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
440        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
441            warn!(%error, "Failed to send plotting task");
442
443            let progress = SectorPlottingProgress::Error {
444                error: format!("Failed to send plotting task: {error}"),
445            };
446
447            self.handlers
448                .plotting_progress
449                .call_simple(&public_key, &sector_index, &progress);
450        }
451    }
452}
453
454struct ProgressUpdater {
455    public_key: PublicKey,
456    sector_index: SectorIndex,
457    handlers: Arc<Handlers>,
458    metrics: Option<Arc<GpuPlotterMetrics>>,
459}
460
461impl ProgressUpdater {
462    /// Returns `true` on success and `false` if progress receiver channel is gone
463    async fn update_progress_and_events<PS>(
464        &self,
465        progress_sender: &mut PS,
466        progress: SectorPlottingProgress,
467    ) -> bool
468    where
469        PS: Sink<SectorPlottingProgress> + Unpin,
470        PS::Error: Error,
471    {
472        if let Some(metrics) = &self.metrics {
473            match &progress {
474                SectorPlottingProgress::Downloading => {
475                    metrics.sector_downloading.inc();
476                }
477                SectorPlottingProgress::Downloaded(time) => {
478                    metrics.sector_downloading_time.observe(time.as_secs_f64());
479                    metrics.sector_downloaded.inc();
480                }
481                SectorPlottingProgress::Encoding => {
482                    metrics.sector_encoding.inc();
483                }
484                SectorPlottingProgress::Encoded(time) => {
485                    metrics.sector_encoding_time.observe(time.as_secs_f64());
486                    metrics.sector_encoded.inc();
487                }
488                SectorPlottingProgress::Finished { time, .. } => {
489                    metrics.sector_plotting_time.observe(time.as_secs_f64());
490                    metrics.sector_plotted.inc();
491                }
492                SectorPlottingProgress::Error { .. } => {
493                    metrics.sector_plotting_error.inc();
494                }
495            }
496        }
497        self.handlers.plotting_progress.call_simple(
498            &self.public_key,
499            &self.sector_index,
500            &progress,
501        );
502
503        if let Err(error) = progress_sender.send(progress).await {
504            warn!(%error, "Failed to send progress update");
505
506            false
507        } else {
508            true
509        }
510    }
511}