1pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
10use async_trait::async_trait;
11use bytes::Bytes;
12use event_listener_primitives::{Bag, HandlerId};
13use futures::channel::mpsc;
14use futures::stream::FuturesUnordered;
15use futures::{select, stream, FutureExt, Sink, SinkExt, StreamExt};
16use prometheus_client::registry::Registry;
17use std::any::type_name;
18use std::error::Error;
19use std::fmt;
20use std::future::pending;
21use std::marker::PhantomData;
22use std::num::NonZeroUsize;
23use std::pin::pin;
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::sync::Arc;
26use std::task::Poll;
27use std::time::Instant;
28use subspace_core_primitives::sectors::SectorIndex;
29use subspace_core_primitives::PublicKey;
30use subspace_data_retrieval::piece_getter::PieceGetter;
31use subspace_erasure_coding::ErasureCoding;
32use subspace_farmer_components::plotting::{
33 download_sector, encode_sector, write_sector, CpuRecordsEncoder, DownloadSectorOptions,
34 EncodeSectorOptions, PlottingError,
35};
36use subspace_farmer_components::FarmerProtocolInfo;
37use subspace_kzg::Kzg;
38use subspace_proof_of_space::Table;
39use tokio::task::yield_now;
40use tracing::{warn, Instrument};
41
42pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
44type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
45
46#[derive(Default, Debug)]
47struct Handlers {
48 plotting_progress: Handler3<PublicKey, SectorIndex, SectorPlottingProgress>,
49}
50
51pub struct CpuPlotter<PG, PosTable> {
53 piece_getter: PG,
54 downloading_semaphore: Arc<Semaphore>,
55 plotting_thread_pool_manager: PlottingThreadPoolManager,
56 record_encoding_concurrency: NonZeroUsize,
57 global_mutex: Arc<AsyncMutex<()>>,
58 kzg: Kzg,
59 erasure_coding: ErasureCoding,
60 handlers: Arc<Handlers>,
61 tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
62 _background_tasks: AsyncJoinOnDrop<()>,
63 abort_early: Arc<AtomicBool>,
64 metrics: Option<Arc<CpuPlotterMetrics>>,
65 _phantom: PhantomData<PosTable>,
66}
67
68impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
69 #[inline]
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("CpuPlotter").finish_non_exhaustive()
72 }
73}
74
75impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
76 #[inline]
77 fn drop(&mut self) {
78 self.abort_early.store(true, Ordering::Release);
79 self.tasks_sender.close_channel();
80 }
81}
82
83#[async_trait]
84impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
85where
86 PG: PieceGetter + Clone + Send + Sync + 'static,
87 PosTable: Table,
88{
89 async fn has_free_capacity(&self) -> Result<bool, String> {
90 Ok(self.downloading_semaphore.try_acquire().is_some())
91 }
92
93 async fn plot_sector(
94 &self,
95 public_key: PublicKey,
96 sector_index: SectorIndex,
97 farmer_protocol_info: FarmerProtocolInfo,
98 pieces_in_sector: u16,
99 replotting: bool,
100 progress_sender: mpsc::Sender<SectorPlottingProgress>,
101 ) {
102 let start = Instant::now();
103
104 let downloading_permit = self.downloading_semaphore.acquire_arc().await;
107
108 self.plot_sector_internal(
109 start,
110 downloading_permit,
111 public_key,
112 sector_index,
113 farmer_protocol_info,
114 pieces_in_sector,
115 replotting,
116 progress_sender,
117 )
118 .await
119 }
120
121 async fn try_plot_sector(
122 &self,
123 public_key: PublicKey,
124 sector_index: SectorIndex,
125 farmer_protocol_info: FarmerProtocolInfo,
126 pieces_in_sector: u16,
127 replotting: bool,
128 progress_sender: mpsc::Sender<SectorPlottingProgress>,
129 ) -> bool {
130 let start = Instant::now();
131
132 let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
133 return false;
134 };
135
136 self.plot_sector_internal(
137 start,
138 downloading_permit,
139 public_key,
140 sector_index,
141 farmer_protocol_info,
142 pieces_in_sector,
143 replotting,
144 progress_sender,
145 )
146 .await;
147
148 true
149 }
150}
151
152impl<PG, PosTable> CpuPlotter<PG, PosTable>
153where
154 PG: PieceGetter + Clone + Send + Sync + 'static,
155 PosTable: Table,
156{
157 #[allow(clippy::too_many_arguments)]
159 pub fn new(
160 piece_getter: PG,
161 downloading_semaphore: Arc<Semaphore>,
162 plotting_thread_pool_manager: PlottingThreadPoolManager,
163 record_encoding_concurrency: NonZeroUsize,
164 global_mutex: Arc<AsyncMutex<()>>,
165 kzg: Kzg,
166 erasure_coding: ErasureCoding,
167 registry: Option<&mut Registry>,
168 ) -> Self {
169 let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
170
171 let background_tasks = AsyncJoinOnDrop::new(
173 tokio::spawn(async move {
174 let background_tasks = FuturesUnordered::new();
175 let mut background_tasks = pin!(background_tasks);
176 background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
178
179 loop {
180 select! {
181 maybe_background_task = tasks_receiver.next().fuse() => {
182 let Some(background_task) = maybe_background_task else {
183 break;
184 };
185
186 background_tasks.push(background_task);
187 },
188 _ = background_tasks.select_next_some() => {
189 }
191 }
192 }
193 }),
194 true,
195 );
196
197 let abort_early = Arc::new(AtomicBool::new(false));
198 let metrics = registry.map(|registry| {
199 Arc::new(CpuPlotterMetrics::new(
200 registry,
201 type_name::<PosTable>(),
202 plotting_thread_pool_manager.thread_pool_pairs(),
203 ))
204 });
205
206 Self {
207 piece_getter,
208 downloading_semaphore,
209 plotting_thread_pool_manager,
210 record_encoding_concurrency,
211 global_mutex,
212 kzg,
213 erasure_coding,
214 handlers: Arc::default(),
215 tasks_sender,
216 _background_tasks: background_tasks,
217 abort_early,
218 metrics,
219 _phantom: PhantomData,
220 }
221 }
222
223 pub fn on_plotting_progress(
225 &self,
226 callback: HandlerFn3<PublicKey, SectorIndex, SectorPlottingProgress>,
227 ) -> HandlerId {
228 self.handlers.plotting_progress.add(callback)
229 }
230
231 #[allow(clippy::too_many_arguments)]
232 async fn plot_sector_internal<PS>(
233 &self,
234 start: Instant,
235 downloading_permit: SemaphoreGuardArc,
236 public_key: PublicKey,
237 sector_index: SectorIndex,
238 farmer_protocol_info: FarmerProtocolInfo,
239 pieces_in_sector: u16,
240 replotting: bool,
241 mut progress_sender: PS,
242 ) where
243 PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
244 PS::Error: Error,
245 {
246 if let Some(metrics) = &self.metrics {
247 metrics.sector_plotting.inc();
248 }
249
250 let progress_updater = ProgressUpdater {
251 public_key,
252 sector_index,
253 handlers: Arc::clone(&self.handlers),
254 metrics: self.metrics.clone(),
255 };
256
257 let plotting_fut = {
258 let piece_getter = self.piece_getter.clone();
259 let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
260 let record_encoding_concurrency = self.record_encoding_concurrency;
261 let global_mutex = Arc::clone(&self.global_mutex);
262 let kzg = self.kzg.clone();
263 let erasure_coding = self.erasure_coding.clone();
264 let abort_early = Arc::clone(&self.abort_early);
265 let metrics = self.metrics.clone();
266
267 async move {
268 let downloaded_sector = {
270 if !progress_updater
271 .update_progress_and_events(
272 &mut progress_sender,
273 SectorPlottingProgress::Downloading,
274 )
275 .await
276 {
277 return;
278 }
279
280 global_mutex.lock().await;
282
283 let downloading_start = Instant::now();
284
285 let downloaded_sector_fut = download_sector(DownloadSectorOptions {
286 public_key: &public_key,
287 sector_index,
288 piece_getter: &piece_getter,
289 farmer_protocol_info,
290 kzg: &kzg,
291 erasure_coding: &erasure_coding,
292 pieces_in_sector,
293 });
294
295 let downloaded_sector = match downloaded_sector_fut.await {
296 Ok(downloaded_sector) => downloaded_sector,
297 Err(error) => {
298 warn!(%error, "Failed to download sector");
299
300 progress_updater
301 .update_progress_and_events(
302 &mut progress_sender,
303 SectorPlottingProgress::Error {
304 error: format!("Failed to download sector: {error}"),
305 },
306 )
307 .await;
308
309 return;
310 }
311 };
312
313 if !progress_updater
314 .update_progress_and_events(
315 &mut progress_sender,
316 SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
317 )
318 .await
319 {
320 return;
321 }
322
323 downloaded_sector
324 };
325
326 let (sector, plotted_sector) = {
328 let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
329 if let Some(metrics) = &metrics {
330 metrics.plotting_capacity_used.inc();
331 }
332
333 yield_now().await;
335
336 if !progress_updater
337 .update_progress_and_events(
338 &mut progress_sender,
339 SectorPlottingProgress::Encoding,
340 )
341 .await
342 {
343 if let Some(metrics) = &metrics {
344 metrics.plotting_capacity_used.dec();
345 }
346 return;
347 }
348
349 let encoding_start = Instant::now();
350
351 let plotting_result = tokio::task::block_in_place(move || {
352 let thread_pool = if replotting {
353 &thread_pools.replotting
354 } else {
355 &thread_pools.plotting
356 };
357
358 let encoded_sector = thread_pool.install(|| {
359 let mut generators = (0..record_encoding_concurrency.get())
360 .map(|_| PosTable::generator())
361 .collect::<Vec<_>>();
362 let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
363 &mut generators,
364 &erasure_coding,
365 &global_mutex,
366 );
367
368 encode_sector(
369 downloaded_sector,
370 EncodeSectorOptions {
371 sector_index,
372 records_encoder: &mut records_encoder,
373 abort_early: &abort_early,
374 },
375 )
376 })?;
377
378 if abort_early.load(Ordering::Acquire) {
379 return Err(PlottingError::AbortEarly);
380 }
381
382 drop(thread_pools);
383
384 let mut sector = Vec::new();
385
386 write_sector(&encoded_sector, &mut sector)?;
387
388 Ok((sector, encoded_sector.plotted_sector))
389 });
390
391 if let Some(metrics) = &metrics {
392 metrics.plotting_capacity_used.dec();
393 }
394
395 match plotting_result {
396 Ok(plotting_result) => {
397 if !progress_updater
398 .update_progress_and_events(
399 &mut progress_sender,
400 SectorPlottingProgress::Encoded(encoding_start.elapsed()),
401 )
402 .await
403 {
404 return;
405 }
406
407 plotting_result
408 }
409 Err(PlottingError::AbortEarly) => {
410 return;
411 }
412 Err(error) => {
413 progress_updater
414 .update_progress_and_events(
415 &mut progress_sender,
416 SectorPlottingProgress::Error {
417 error: format!("Failed to encode sector: {error}"),
418 },
419 )
420 .await;
421
422 return;
423 }
424 }
425 };
426
427 progress_updater
428 .update_progress_and_events(
429 &mut progress_sender,
430 SectorPlottingProgress::Finished {
431 plotted_sector,
432 time: start.elapsed(),
433 sector: Box::pin({
434 let mut sector = Some(Ok(Bytes::from(sector)));
435
436 stream::poll_fn(move |_cx| {
437 let _downloading_permit = &downloading_permit;
439
440 Poll::Ready(sector.take())
441 })
442 }),
443 },
444 )
445 .await;
446 }
447 };
448
449 let plotting_task =
451 AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
452 if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
453 warn!(%error, "Failed to send plotting task");
454
455 let progress = SectorPlottingProgress::Error {
456 error: format!("Failed to send plotting task: {error}"),
457 };
458
459 self.handlers
460 .plotting_progress
461 .call_simple(&public_key, §or_index, &progress);
462 }
463 }
464}
465
466struct ProgressUpdater {
467 public_key: PublicKey,
468 sector_index: SectorIndex,
469 handlers: Arc<Handlers>,
470 metrics: Option<Arc<CpuPlotterMetrics>>,
471}
472
473impl ProgressUpdater {
474 async fn update_progress_and_events<PS>(
476 &self,
477 progress_sender: &mut PS,
478 progress: SectorPlottingProgress,
479 ) -> bool
480 where
481 PS: Sink<SectorPlottingProgress> + Unpin,
482 PS::Error: Error,
483 {
484 if let Some(metrics) = &self.metrics {
485 match &progress {
486 SectorPlottingProgress::Downloading => {
487 metrics.sector_downloading.inc();
488 }
489 SectorPlottingProgress::Downloaded(time) => {
490 metrics.sector_downloading_time.observe(time.as_secs_f64());
491 metrics.sector_downloaded.inc();
492 }
493 SectorPlottingProgress::Encoding => {
494 metrics.sector_encoding.inc();
495 }
496 SectorPlottingProgress::Encoded(time) => {
497 metrics.sector_encoding_time.observe(time.as_secs_f64());
498 metrics.sector_encoded.inc();
499 }
500 SectorPlottingProgress::Finished { time, .. } => {
501 metrics.sector_plotting_time.observe(time.as_secs_f64());
502 metrics.sector_plotted.inc();
503 }
504 SectorPlottingProgress::Error { .. } => {
505 metrics.sector_plotting_error.inc();
506 }
507 }
508 }
509 self.handlers.plotting_progress.call_simple(
510 &self.public_key,
511 &self.sector_index,
512 &progress,
513 );
514
515 if let Err(error) = progress_sender.send(progress).await {
516 warn!(%error, "Failed to send progress update");
517
518 false
519 } else {
520 true
521 }
522 }
523}