1#[cfg(feature = "cuda")]
4pub mod cuda;
5mod gpu_encoders_manager;
6pub mod metrics;
7#[cfg(feature = "rocm")]
8pub mod rocm;
9
10use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
11use crate::plotter::gpu::metrics::GpuPlotterMetrics;
12use crate::plotter::{Plotter, SectorPlottingProgress};
13use crate::utils::AsyncJoinOnDrop;
14use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
15use async_trait::async_trait;
16use bytes::Bytes;
17use event_listener_primitives::{Bag, HandlerId};
18use futures::channel::mpsc;
19use futures::stream::FuturesUnordered;
20use futures::{select, stream, FutureExt, Sink, SinkExt, StreamExt};
21use prometheus_client::registry::Registry;
22use std::error::Error;
23use std::fmt;
24use std::future::pending;
25use std::num::TryFromIntError;
26use std::pin::pin;
27use std::sync::atomic::{AtomicBool, Ordering};
28use std::sync::Arc;
29use std::task::Poll;
30use std::time::Instant;
31use subspace_core_primitives::sectors::SectorIndex;
32use subspace_core_primitives::PublicKey;
33use subspace_data_retrieval::piece_getter::PieceGetter;
34use subspace_erasure_coding::ErasureCoding;
35use subspace_farmer_components::plotting::{
36 download_sector, encode_sector, write_sector, DownloadSectorOptions, EncodeSectorOptions,
37 PlottingError, RecordsEncoder,
38};
39use subspace_farmer_components::FarmerProtocolInfo;
40use subspace_kzg::Kzg;
41use tokio::task::yield_now;
42use tracing::{warn, Instrument};
43
44pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
46type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
47
48#[derive(Default, Debug)]
49struct Handlers {
50 plotting_progress: Handler3<PublicKey, SectorIndex, SectorPlottingProgress>,
51}
52
53pub trait GpuRecordsEncoder: RecordsEncoder + fmt::Debug + Send {
55 const TYPE: &'static str;
57}
58
59pub struct GpuPlotter<PG, GRE> {
61 piece_getter: PG,
62 downloading_semaphore: Arc<Semaphore>,
63 gpu_records_encoders_manager: GpuRecordsEncoderManager<GRE>,
64 global_mutex: Arc<AsyncMutex<()>>,
65 kzg: Kzg,
66 erasure_coding: ErasureCoding,
67 handlers: Arc<Handlers>,
68 tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
69 _background_tasks: AsyncJoinOnDrop<()>,
70 abort_early: Arc<AtomicBool>,
71 metrics: Option<Arc<GpuPlotterMetrics>>,
72}
73
74impl<PG, GRE> fmt::Debug for GpuPlotter<PG, GRE>
75where
76 GRE: GpuRecordsEncoder + 'static,
77{
78 #[inline]
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct(&format!("GpuPlotter[type = {}]", GRE::TYPE))
81 .finish_non_exhaustive()
82 }
83}
84
85impl<PG, RE> Drop for GpuPlotter<PG, RE> {
86 #[inline]
87 fn drop(&mut self) {
88 self.abort_early.store(true, Ordering::Release);
89 self.tasks_sender.close_channel();
90 }
91}
92
93#[async_trait]
94impl<PG, GRE> Plotter for GpuPlotter<PG, GRE>
95where
96 PG: PieceGetter + Clone + Send + Sync + 'static,
97 GRE: GpuRecordsEncoder + 'static,
98{
99 async fn has_free_capacity(&self) -> Result<bool, String> {
100 Ok(self.downloading_semaphore.try_acquire().is_some())
101 }
102
103 async fn plot_sector(
104 &self,
105 public_key: PublicKey,
106 sector_index: SectorIndex,
107 farmer_protocol_info: FarmerProtocolInfo,
108 pieces_in_sector: u16,
109 _replotting: bool,
110 progress_sender: mpsc::Sender<SectorPlottingProgress>,
111 ) {
112 let start = Instant::now();
113
114 let downloading_permit = self.downloading_semaphore.acquire_arc().await;
117
118 self.plot_sector_internal(
119 start,
120 downloading_permit,
121 public_key,
122 sector_index,
123 farmer_protocol_info,
124 pieces_in_sector,
125 progress_sender,
126 )
127 .await
128 }
129
130 async fn try_plot_sector(
131 &self,
132 public_key: PublicKey,
133 sector_index: SectorIndex,
134 farmer_protocol_info: FarmerProtocolInfo,
135 pieces_in_sector: u16,
136 _replotting: bool,
137 progress_sender: mpsc::Sender<SectorPlottingProgress>,
138 ) -> bool {
139 let start = Instant::now();
140
141 let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
142 return false;
143 };
144
145 self.plot_sector_internal(
146 start,
147 downloading_permit,
148 public_key,
149 sector_index,
150 farmer_protocol_info,
151 pieces_in_sector,
152 progress_sender,
153 )
154 .await;
155
156 true
157 }
158}
159
160impl<PG, GRE> GpuPlotter<PG, GRE>
161where
162 PG: PieceGetter + Clone + Send + Sync + 'static,
163 GRE: GpuRecordsEncoder + 'static,
164{
165 pub fn new(
169 piece_getter: PG,
170 downloading_semaphore: Arc<Semaphore>,
171 gpu_records_encoders: Vec<GRE>,
172 global_mutex: Arc<AsyncMutex<()>>,
173 kzg: Kzg,
174 erasure_coding: ErasureCoding,
175 registry: Option<&mut Registry>,
176 ) -> Result<Self, TryFromIntError> {
177 let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
178
179 let background_tasks = AsyncJoinOnDrop::new(
181 tokio::spawn(async move {
182 let background_tasks = FuturesUnordered::new();
183 let mut background_tasks = pin!(background_tasks);
184 background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
186
187 loop {
188 select! {
189 maybe_background_task = tasks_receiver.next().fuse() => {
190 let Some(background_task) = maybe_background_task else {
191 break;
192 };
193
194 background_tasks.push(background_task);
195 },
196 _ = background_tasks.select_next_some() => {
197 }
199 }
200 }
201 }),
202 true,
203 );
204
205 let abort_early = Arc::new(AtomicBool::new(false));
206 let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
207 let metrics = registry.map(|registry| {
208 Arc::new(GpuPlotterMetrics::new(
209 registry,
210 GRE::TYPE,
211 gpu_records_encoders_manager.gpu_records_encoders(),
212 ))
213 });
214
215 Ok(Self {
216 piece_getter,
217 downloading_semaphore,
218 gpu_records_encoders_manager,
219 global_mutex,
220 kzg,
221 erasure_coding,
222 handlers: Arc::default(),
223 tasks_sender,
224 _background_tasks: background_tasks,
225 abort_early,
226 metrics,
227 })
228 }
229
230 pub fn on_plotting_progress(
232 &self,
233 callback: HandlerFn3<PublicKey, SectorIndex, SectorPlottingProgress>,
234 ) -> HandlerId {
235 self.handlers.plotting_progress.add(callback)
236 }
237
238 #[allow(clippy::too_many_arguments)]
239 async fn plot_sector_internal<PS>(
240 &self,
241 start: Instant,
242 downloading_permit: SemaphoreGuardArc,
243 public_key: PublicKey,
244 sector_index: SectorIndex,
245 farmer_protocol_info: FarmerProtocolInfo,
246 pieces_in_sector: u16,
247 mut progress_sender: PS,
248 ) where
249 PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
250 PS::Error: Error,
251 {
252 if let Some(metrics) = &self.metrics {
253 metrics.sector_plotting.inc();
254 }
255
256 let progress_updater = ProgressUpdater {
257 public_key,
258 sector_index,
259 handlers: Arc::clone(&self.handlers),
260 metrics: self.metrics.clone(),
261 };
262
263 let plotting_fut = {
264 let piece_getter = self.piece_getter.clone();
265 let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
266 let global_mutex = Arc::clone(&self.global_mutex);
267 let kzg = self.kzg.clone();
268 let erasure_coding = self.erasure_coding.clone();
269 let abort_early = Arc::clone(&self.abort_early);
270 let metrics = self.metrics.clone();
271
272 async move {
273 let downloaded_sector = {
275 if !progress_updater
276 .update_progress_and_events(
277 &mut progress_sender,
278 SectorPlottingProgress::Downloading,
279 )
280 .await
281 {
282 return;
283 }
284
285 global_mutex.lock().await;
287
288 let downloading_start = Instant::now();
289
290 let downloaded_sector_fut = download_sector(DownloadSectorOptions {
291 public_key: &public_key,
292 sector_index,
293 piece_getter: &piece_getter,
294 farmer_protocol_info,
295 kzg: &kzg,
296 erasure_coding: &erasure_coding,
297 pieces_in_sector,
298 });
299
300 let downloaded_sector = match downloaded_sector_fut.await {
301 Ok(downloaded_sector) => downloaded_sector,
302 Err(error) => {
303 warn!(%error, "Failed to download sector");
304
305 progress_updater
306 .update_progress_and_events(
307 &mut progress_sender,
308 SectorPlottingProgress::Error {
309 error: format!("Failed to download sector: {error}"),
310 },
311 )
312 .await;
313
314 return;
315 }
316 };
317
318 if !progress_updater
319 .update_progress_and_events(
320 &mut progress_sender,
321 SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
322 )
323 .await
324 {
325 return;
326 }
327
328 downloaded_sector
329 };
330
331 let (sector, plotted_sector) = {
333 let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
334 if let Some(metrics) = &metrics {
335 metrics.plotting_capacity_used.inc();
336 }
337
338 yield_now().await;
340
341 if !progress_updater
342 .update_progress_and_events(
343 &mut progress_sender,
344 SectorPlottingProgress::Encoding,
345 )
346 .await
347 {
348 if let Some(metrics) = &metrics {
349 metrics.plotting_capacity_used.dec();
350 }
351 return;
352 }
353
354 let encoding_start = Instant::now();
355
356 let plotting_result = tokio::task::block_in_place(move || {
357 let encoded_sector = encode_sector(
358 downloaded_sector,
359 EncodeSectorOptions {
360 sector_index,
361 records_encoder: &mut *records_encoder,
362 abort_early: &abort_early,
363 },
364 )?;
365
366 if abort_early.load(Ordering::Acquire) {
367 return Err(PlottingError::AbortEarly);
368 }
369
370 drop(records_encoder);
371
372 let mut sector = Vec::new();
373
374 write_sector(&encoded_sector, &mut sector)?;
375
376 Ok((sector, encoded_sector.plotted_sector))
377 });
378
379 if let Some(metrics) = &metrics {
380 metrics.plotting_capacity_used.dec();
381 }
382
383 match plotting_result {
384 Ok(plotting_result) => {
385 if !progress_updater
386 .update_progress_and_events(
387 &mut progress_sender,
388 SectorPlottingProgress::Encoded(encoding_start.elapsed()),
389 )
390 .await
391 {
392 return;
393 }
394
395 plotting_result
396 }
397 Err(PlottingError::AbortEarly) => {
398 return;
399 }
400 Err(error) => {
401 progress_updater
402 .update_progress_and_events(
403 &mut progress_sender,
404 SectorPlottingProgress::Error {
405 error: format!("Failed to encode sector: {error}"),
406 },
407 )
408 .await;
409
410 return;
411 }
412 }
413 };
414
415 progress_updater
416 .update_progress_and_events(
417 &mut progress_sender,
418 SectorPlottingProgress::Finished {
419 plotted_sector,
420 time: start.elapsed(),
421 sector: Box::pin({
422 let mut sector = Some(Ok(Bytes::from(sector)));
423
424 stream::poll_fn(move |_cx| {
425 let _downloading_permit = &downloading_permit;
427
428 Poll::Ready(sector.take())
429 })
430 }),
431 },
432 )
433 .await;
434 }
435 };
436
437 let plotting_task =
439 AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
440 if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
441 warn!(%error, "Failed to send plotting task");
442
443 let progress = SectorPlottingProgress::Error {
444 error: format!("Failed to send plotting task: {error}"),
445 };
446
447 self.handlers
448 .plotting_progress
449 .call_simple(&public_key, §or_index, &progress);
450 }
451 }
452}
453
454struct ProgressUpdater {
455 public_key: PublicKey,
456 sector_index: SectorIndex,
457 handlers: Arc<Handlers>,
458 metrics: Option<Arc<GpuPlotterMetrics>>,
459}
460
461impl ProgressUpdater {
462 async fn update_progress_and_events<PS>(
464 &self,
465 progress_sender: &mut PS,
466 progress: SectorPlottingProgress,
467 ) -> bool
468 where
469 PS: Sink<SectorPlottingProgress> + Unpin,
470 PS::Error: Error,
471 {
472 if let Some(metrics) = &self.metrics {
473 match &progress {
474 SectorPlottingProgress::Downloading => {
475 metrics.sector_downloading.inc();
476 }
477 SectorPlottingProgress::Downloaded(time) => {
478 metrics.sector_downloading_time.observe(time.as_secs_f64());
479 metrics.sector_downloaded.inc();
480 }
481 SectorPlottingProgress::Encoding => {
482 metrics.sector_encoding.inc();
483 }
484 SectorPlottingProgress::Encoded(time) => {
485 metrics.sector_encoding_time.observe(time.as_secs_f64());
486 metrics.sector_encoded.inc();
487 }
488 SectorPlottingProgress::Finished { time, .. } => {
489 metrics.sector_plotting_time.observe(time.as_secs_f64());
490 metrics.sector_plotted.inc();
491 }
492 SectorPlottingProgress::Error { .. } => {
493 metrics.sector_plotting_error.inc();
494 }
495 }
496 }
497 self.handlers.plotting_progress.call_simple(
498 &self.public_key,
499 &self.sector_index,
500 &progress,
501 );
502
503 if let Err(error) = progress_sender.send(progress).await {
504 warn!(%error, "Failed to send progress update");
505
506 false
507 } else {
508 true
509 }
510 }
511}