subspace_farmer/plotter/
cpu.rs

1//! CPU plotter
2
3pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
10use async_trait::async_trait;
11use bytes::Bytes;
12use event_listener_primitives::{Bag, HandlerId};
13use futures::channel::mpsc;
14use futures::stream::FuturesUnordered;
15use futures::{select, stream, FutureExt, Sink, SinkExt, StreamExt};
16use prometheus_client::registry::Registry;
17use std::any::type_name;
18use std::error::Error;
19use std::fmt;
20use std::future::pending;
21use std::marker::PhantomData;
22use std::num::NonZeroUsize;
23use std::pin::pin;
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::sync::Arc;
26use std::task::Poll;
27use std::time::Instant;
28use subspace_core_primitives::sectors::SectorIndex;
29use subspace_core_primitives::PublicKey;
30use subspace_data_retrieval::piece_getter::PieceGetter;
31use subspace_erasure_coding::ErasureCoding;
32use subspace_farmer_components::plotting::{
33    download_sector, encode_sector, write_sector, CpuRecordsEncoder, DownloadSectorOptions,
34    EncodeSectorOptions, PlottingError,
35};
36use subspace_farmer_components::FarmerProtocolInfo;
37use subspace_kzg::Kzg;
38use subspace_proof_of_space::Table;
39use tokio::task::yield_now;
40use tracing::{warn, Instrument};
41
42/// Type alias used for event handlers
43pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
44type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
45
46#[derive(Default, Debug)]
47struct Handlers {
48    plotting_progress: Handler3<PublicKey, SectorIndex, SectorPlottingProgress>,
49}
50
51/// CPU plotter
52pub struct CpuPlotter<PG, PosTable> {
53    piece_getter: PG,
54    downloading_semaphore: Arc<Semaphore>,
55    plotting_thread_pool_manager: PlottingThreadPoolManager,
56    record_encoding_concurrency: NonZeroUsize,
57    global_mutex: Arc<AsyncMutex<()>>,
58    kzg: Kzg,
59    erasure_coding: ErasureCoding,
60    handlers: Arc<Handlers>,
61    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
62    _background_tasks: AsyncJoinOnDrop<()>,
63    abort_early: Arc<AtomicBool>,
64    metrics: Option<Arc<CpuPlotterMetrics>>,
65    _phantom: PhantomData<PosTable>,
66}
67
68impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
69    #[inline]
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("CpuPlotter").finish_non_exhaustive()
72    }
73}
74
75impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
76    #[inline]
77    fn drop(&mut self) {
78        self.abort_early.store(true, Ordering::Release);
79        self.tasks_sender.close_channel();
80    }
81}
82
83#[async_trait]
84impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
85where
86    PG: PieceGetter + Clone + Send + Sync + 'static,
87    PosTable: Table,
88{
89    async fn has_free_capacity(&self) -> Result<bool, String> {
90        Ok(self.downloading_semaphore.try_acquire().is_some())
91    }
92
93    async fn plot_sector(
94        &self,
95        public_key: PublicKey,
96        sector_index: SectorIndex,
97        farmer_protocol_info: FarmerProtocolInfo,
98        pieces_in_sector: u16,
99        replotting: bool,
100        progress_sender: mpsc::Sender<SectorPlottingProgress>,
101    ) {
102        let start = Instant::now();
103
104        // Done outside the future below as a backpressure, ensuring that it is not possible to
105        // schedule unbounded number of plotting tasks
106        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
107
108        self.plot_sector_internal(
109            start,
110            downloading_permit,
111            public_key,
112            sector_index,
113            farmer_protocol_info,
114            pieces_in_sector,
115            replotting,
116            progress_sender,
117        )
118        .await
119    }
120
121    async fn try_plot_sector(
122        &self,
123        public_key: PublicKey,
124        sector_index: SectorIndex,
125        farmer_protocol_info: FarmerProtocolInfo,
126        pieces_in_sector: u16,
127        replotting: bool,
128        progress_sender: mpsc::Sender<SectorPlottingProgress>,
129    ) -> bool {
130        let start = Instant::now();
131
132        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
133            return false;
134        };
135
136        self.plot_sector_internal(
137            start,
138            downloading_permit,
139            public_key,
140            sector_index,
141            farmer_protocol_info,
142            pieces_in_sector,
143            replotting,
144            progress_sender,
145        )
146        .await;
147
148        true
149    }
150}
151
152impl<PG, PosTable> CpuPlotter<PG, PosTable>
153where
154    PG: PieceGetter + Clone + Send + Sync + 'static,
155    PosTable: Table,
156{
157    /// Create new instance
158    #[allow(clippy::too_many_arguments)]
159    pub fn new(
160        piece_getter: PG,
161        downloading_semaphore: Arc<Semaphore>,
162        plotting_thread_pool_manager: PlottingThreadPoolManager,
163        record_encoding_concurrency: NonZeroUsize,
164        global_mutex: Arc<AsyncMutex<()>>,
165        kzg: Kzg,
166        erasure_coding: ErasureCoding,
167        registry: Option<&mut Registry>,
168    ) -> Self {
169        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
170
171        // Basically runs plotting tasks in the background and allows to abort on drop
172        let background_tasks = AsyncJoinOnDrop::new(
173            tokio::spawn(async move {
174                let background_tasks = FuturesUnordered::new();
175                let mut background_tasks = pin!(background_tasks);
176                // Just so that `FuturesUnordered` will never end
177                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
178
179                loop {
180                    select! {
181                        maybe_background_task = tasks_receiver.next().fuse() => {
182                            let Some(background_task) = maybe_background_task else {
183                                break;
184                            };
185
186                            background_tasks.push(background_task);
187                        },
188                        _ = background_tasks.select_next_some() => {
189                            // Nothing to do
190                        }
191                    }
192                }
193            }),
194            true,
195        );
196
197        let abort_early = Arc::new(AtomicBool::new(false));
198        let metrics = registry.map(|registry| {
199            Arc::new(CpuPlotterMetrics::new(
200                registry,
201                type_name::<PosTable>(),
202                plotting_thread_pool_manager.thread_pool_pairs(),
203            ))
204        });
205
206        Self {
207            piece_getter,
208            downloading_semaphore,
209            plotting_thread_pool_manager,
210            record_encoding_concurrency,
211            global_mutex,
212            kzg,
213            erasure_coding,
214            handlers: Arc::default(),
215            tasks_sender,
216            _background_tasks: background_tasks,
217            abort_early,
218            metrics,
219            _phantom: PhantomData,
220        }
221    }
222
223    /// Subscribe to plotting progress notifications
224    pub fn on_plotting_progress(
225        &self,
226        callback: HandlerFn3<PublicKey, SectorIndex, SectorPlottingProgress>,
227    ) -> HandlerId {
228        self.handlers.plotting_progress.add(callback)
229    }
230
231    #[allow(clippy::too_many_arguments)]
232    async fn plot_sector_internal<PS>(
233        &self,
234        start: Instant,
235        downloading_permit: SemaphoreGuardArc,
236        public_key: PublicKey,
237        sector_index: SectorIndex,
238        farmer_protocol_info: FarmerProtocolInfo,
239        pieces_in_sector: u16,
240        replotting: bool,
241        mut progress_sender: PS,
242    ) where
243        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
244        PS::Error: Error,
245    {
246        if let Some(metrics) = &self.metrics {
247            metrics.sector_plotting.inc();
248        }
249
250        let progress_updater = ProgressUpdater {
251            public_key,
252            sector_index,
253            handlers: Arc::clone(&self.handlers),
254            metrics: self.metrics.clone(),
255        };
256
257        let plotting_fut = {
258            let piece_getter = self.piece_getter.clone();
259            let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
260            let record_encoding_concurrency = self.record_encoding_concurrency;
261            let global_mutex = Arc::clone(&self.global_mutex);
262            let kzg = self.kzg.clone();
263            let erasure_coding = self.erasure_coding.clone();
264            let abort_early = Arc::clone(&self.abort_early);
265            let metrics = self.metrics.clone();
266
267            async move {
268                // Downloading
269                let downloaded_sector = {
270                    if !progress_updater
271                        .update_progress_and_events(
272                            &mut progress_sender,
273                            SectorPlottingProgress::Downloading,
274                        )
275                        .await
276                    {
277                        return;
278                    }
279
280                    // Take mutex briefly to make sure plotting is allowed right now
281                    global_mutex.lock().await;
282
283                    let downloading_start = Instant::now();
284
285                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
286                        public_key: &public_key,
287                        sector_index,
288                        piece_getter: &piece_getter,
289                        farmer_protocol_info,
290                        kzg: &kzg,
291                        erasure_coding: &erasure_coding,
292                        pieces_in_sector,
293                    });
294
295                    let downloaded_sector = match downloaded_sector_fut.await {
296                        Ok(downloaded_sector) => downloaded_sector,
297                        Err(error) => {
298                            warn!(%error, "Failed to download sector");
299
300                            progress_updater
301                                .update_progress_and_events(
302                                    &mut progress_sender,
303                                    SectorPlottingProgress::Error {
304                                        error: format!("Failed to download sector: {error}"),
305                                    },
306                                )
307                                .await;
308
309                            return;
310                        }
311                    };
312
313                    if !progress_updater
314                        .update_progress_and_events(
315                            &mut progress_sender,
316                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
317                        )
318                        .await
319                    {
320                        return;
321                    }
322
323                    downloaded_sector
324                };
325
326                // Plotting
327                let (sector, plotted_sector) = {
328                    let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
329                    if let Some(metrics) = &metrics {
330                        metrics.plotting_capacity_used.inc();
331                    }
332
333                    // Give a chance to interrupt plotting if necessary
334                    yield_now().await;
335
336                    if !progress_updater
337                        .update_progress_and_events(
338                            &mut progress_sender,
339                            SectorPlottingProgress::Encoding,
340                        )
341                        .await
342                    {
343                        if let Some(metrics) = &metrics {
344                            metrics.plotting_capacity_used.dec();
345                        }
346                        return;
347                    }
348
349                    let encoding_start = Instant::now();
350
351                    let plotting_result = tokio::task::block_in_place(move || {
352                        let thread_pool = if replotting {
353                            &thread_pools.replotting
354                        } else {
355                            &thread_pools.plotting
356                        };
357
358                        let encoded_sector = thread_pool.install(|| {
359                            let mut generators = (0..record_encoding_concurrency.get())
360                                .map(|_| PosTable::generator())
361                                .collect::<Vec<_>>();
362                            let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
363                                &mut generators,
364                                &erasure_coding,
365                                &global_mutex,
366                            );
367
368                            encode_sector(
369                                downloaded_sector,
370                                EncodeSectorOptions {
371                                    sector_index,
372                                    records_encoder: &mut records_encoder,
373                                    abort_early: &abort_early,
374                                },
375                            )
376                        })?;
377
378                        if abort_early.load(Ordering::Acquire) {
379                            return Err(PlottingError::AbortEarly);
380                        }
381
382                        drop(thread_pools);
383
384                        let mut sector = Vec::new();
385
386                        write_sector(&encoded_sector, &mut sector)?;
387
388                        Ok((sector, encoded_sector.plotted_sector))
389                    });
390
391                    if let Some(metrics) = &metrics {
392                        metrics.plotting_capacity_used.dec();
393                    }
394
395                    match plotting_result {
396                        Ok(plotting_result) => {
397                            if !progress_updater
398                                .update_progress_and_events(
399                                    &mut progress_sender,
400                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
401                                )
402                                .await
403                            {
404                                return;
405                            }
406
407                            plotting_result
408                        }
409                        Err(PlottingError::AbortEarly) => {
410                            return;
411                        }
412                        Err(error) => {
413                            progress_updater
414                                .update_progress_and_events(
415                                    &mut progress_sender,
416                                    SectorPlottingProgress::Error {
417                                        error: format!("Failed to encode sector: {error}"),
418                                    },
419                                )
420                                .await;
421
422                            return;
423                        }
424                    }
425                };
426
427                progress_updater
428                    .update_progress_and_events(
429                        &mut progress_sender,
430                        SectorPlottingProgress::Finished {
431                            plotted_sector,
432                            time: start.elapsed(),
433                            sector: Box::pin({
434                                let mut sector = Some(Ok(Bytes::from(sector)));
435
436                                stream::poll_fn(move |_cx| {
437                                    // Just so that permit is dropped with stream itself
438                                    let _downloading_permit = &downloading_permit;
439
440                                    Poll::Ready(sector.take())
441                                })
442                            }),
443                        },
444                    )
445                    .await;
446            }
447        };
448
449        // Spawn a separate task such that `block_in_place` inside will not affect anything else
450        let plotting_task =
451            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
452        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
453            warn!(%error, "Failed to send plotting task");
454
455            let progress = SectorPlottingProgress::Error {
456                error: format!("Failed to send plotting task: {error}"),
457            };
458
459            self.handlers
460                .plotting_progress
461                .call_simple(&public_key, &sector_index, &progress);
462        }
463    }
464}
465
466struct ProgressUpdater {
467    public_key: PublicKey,
468    sector_index: SectorIndex,
469    handlers: Arc<Handlers>,
470    metrics: Option<Arc<CpuPlotterMetrics>>,
471}
472
473impl ProgressUpdater {
474    /// Returns `true` on success and `false` if progress receiver channel is gone
475    async fn update_progress_and_events<PS>(
476        &self,
477        progress_sender: &mut PS,
478        progress: SectorPlottingProgress,
479    ) -> bool
480    where
481        PS: Sink<SectorPlottingProgress> + Unpin,
482        PS::Error: Error,
483    {
484        if let Some(metrics) = &self.metrics {
485            match &progress {
486                SectorPlottingProgress::Downloading => {
487                    metrics.sector_downloading.inc();
488                }
489                SectorPlottingProgress::Downloaded(time) => {
490                    metrics.sector_downloading_time.observe(time.as_secs_f64());
491                    metrics.sector_downloaded.inc();
492                }
493                SectorPlottingProgress::Encoding => {
494                    metrics.sector_encoding.inc();
495                }
496                SectorPlottingProgress::Encoded(time) => {
497                    metrics.sector_encoding_time.observe(time.as_secs_f64());
498                    metrics.sector_encoded.inc();
499                }
500                SectorPlottingProgress::Finished { time, .. } => {
501                    metrics.sector_plotting_time.observe(time.as_secs_f64());
502                    metrics.sector_plotted.inc();
503                }
504                SectorPlottingProgress::Error { .. } => {
505                    metrics.sector_plotting_error.inc();
506                }
507            }
508        }
509        self.handlers.plotting_progress.call_simple(
510            &self.public_key,
511            &self.sector_index,
512            &progress,
513        );
514
515        if let Err(error) = progress_sender.send(progress).await {
516            warn!(%error, "Failed to send progress update");
517
518            false
519        } else {
520            true
521        }
522    }
523}