1use futures::channel::oneshot;
2use futures::channel::oneshot::Canceled;
3use futures::future::{Either, FusedFuture};
4use std::fmt::Display;
5use std::future::Future;
6use std::ops::Deref;
7use std::pin::{Pin, pin};
8use std::process::exit;
9use std::task::{Context, Poll};
10use std::{io, panic, thread};
11use tokio::runtime::Handle;
12use tokio::{signal, task};
13use tracing::level_filters::LevelFilter;
14use tracing::{debug, info, warn};
15use tracing_subscriber::layer::SubscriberExt;
16use tracing_subscriber::util::SubscriberInitExt;
17use tracing_subscriber::{EnvFilter, Layer, fmt};
18
19#[cfg(test)]
20mod tests;
21
22pub fn init_logger() {
23 let enable_color = if cfg!(windows) {
26 false
27 } else {
28 supports_color::on(supports_color::Stream::Stderr).is_some()
29 };
30
31 let res = tracing_subscriber::registry()
32 .with(
33 fmt::layer().with_ansi(enable_color).with_filter(
34 EnvFilter::builder()
35 .with_default_directive(LevelFilter::INFO.into())
36 .from_env_lossy(),
37 ),
38 )
39 .try_init();
40
41 if let Err(e) = res {
42 eprintln!(
45 "Failed to initialize logger: {e}. \
46 This is expected when running nexttest test functions under `cargo test`."
47 );
48 }
49}
50
51#[derive(Debug)]
54pub struct AsyncJoinOnDrop<T> {
55 handle: Option<task::JoinHandle<T>>,
56 abort_on_drop: bool,
57}
58
59impl<T> Drop for AsyncJoinOnDrop<T> {
60 #[inline]
61 fn drop(&mut self) {
62 if let Some(handle) = self.handle.take() {
63 if self.abort_on_drop {
64 handle.abort();
65 }
66
67 if !handle.is_finished() {
68 task::block_in_place(move || {
69 let _ = Handle::current().block_on(handle);
70 });
71 }
72 }
73 }
74}
75
76impl<T> AsyncJoinOnDrop<T> {
77 #[inline]
79 pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
80 Self {
81 handle: Some(handle),
82 abort_on_drop,
83 }
84 }
85}
86
87impl<T> FusedFuture for AsyncJoinOnDrop<T> {
88 fn is_terminated(&self) -> bool {
89 self.handle.is_none()
90 }
91}
92
93impl<T> Future for AsyncJoinOnDrop<T> {
94 type Output = Result<T, task::JoinError>;
95
96 #[inline]
97 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
98 if let Some(handle) = self.handle.as_mut() {
99 let result = Pin::new(handle).poll(cx);
100 if result.is_ready() {
101 self.handle.take();
103 }
104 result
105 } else {
106 Poll::Pending
107 }
108 }
109}
110
111pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
113
114impl Drop for JoinOnDrop {
115 #[inline]
116 fn drop(&mut self) {
117 self.0
118 .take()
119 .expect("Always called exactly once; qed")
120 .join()
121 .expect("Panic if background thread panicked");
122 }
123}
124
125impl JoinOnDrop {
126 #[inline]
128 pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
129 Self(Some(handle))
130 }
131}
132
133impl Deref for JoinOnDrop {
134 type Target = thread::JoinHandle<()>;
135
136 #[inline]
137 fn deref(&self) -> &Self::Target {
138 self.0.as_ref().expect("Only dropped in Drop impl; qed")
139 }
140}
141
142pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
147 create_future: CreateFut,
148 thread_name: String,
149) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
150where
151 CreateFut: (FnOnce() -> Fut) + Send + 'static,
152 Fut: Future<Output = T> + 'static,
153 T: Send + 'static,
154{
155 let (drop_tx, drop_rx) = oneshot::channel::<()>();
156 let (result_tx, result_rx) = oneshot::channel();
157 let handle = Handle::current();
158 let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
159 let _tokio_handle_guard = handle.enter();
160
161 let future = pin!(create_future());
162
163 let result = match handle.block_on(futures::future::select(future, drop_rx)) {
164 Either::Left((result, _)) => result,
165 Either::Right(_) => {
166 return;
168 }
169 };
170 if let Err(_error) = result_tx.send(result) {
171 debug!(
172 thread_name = ?thread::current().name(),
173 "Future finished, but receiver was already dropped",
174 );
175 }
176 })?;
177 let join_on_drop = JoinOnDrop::new(join_handle);
179
180 Ok(async move {
181 let result = result_rx.await;
182 drop(drop_tx);
183 drop(join_on_drop);
184 result
185 })
186}
187
188#[cfg(unix)]
190pub async fn shutdown_signal(process_kind: impl Display) {
191 use futures::FutureExt;
192 use std::pin::pin;
193
194 let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt())
195 .expect("Setting signal handlers must never fail");
196 let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())
197 .expect("Setting signal handlers must never fail");
198
199 futures::future::select(
200 pin!(sigint.recv().map(|_| {
201 info!("Received SIGINT, shutting down {process_kind}...");
202 }),),
203 pin!(sigterm.recv().map(|_| {
204 info!("Received SIGTERM, shutting down {process_kind}...");
205 }),),
206 )
207 .await;
208}
209
210#[cfg(not(unix))]
212pub async fn shutdown_signal(process_kind: impl Display) {
213 signal::ctrl_c()
214 .await
215 .expect("Setting signal handlers must never fail");
216
217 info!("Received Ctrl+C, shutting down {process_kind}...");
218}
219
220pub fn raise_fd_limit() {
222 match fdlimit::raise_fd_limit() {
223 Ok(fdlimit::Outcome::LimitRaised { from, to }) => {
224 debug!(
225 "Increased file descriptor limit from previous (most likely soft) limit {} to \
226 new (most likely hard) limit {}",
227 from, to
228 );
229 }
230 Ok(fdlimit::Outcome::Unsupported) => {
231 }
233 Err(error) => {
234 warn!(
235 "Failed to increase file descriptor limit for the process due to an error: {}.",
236 error
237 );
238 }
239 }
240}
241
242pub fn set_exit_on_panic() {
245 let default_panic_hook = panic::take_hook();
246 panic::set_hook(Box::new(move |panic_info| {
247 default_panic_hook(panic_info);
248 exit(1);
249 }));
250}