1use futures::channel::oneshot;
2use futures::channel::oneshot::Canceled;
3use futures::future::{Either, FusedFuture};
4use std::fmt::Display;
5use std::future::Future;
6use std::ops::Deref;
7use std::pin::{Pin, pin};
8use std::process::exit;
9use std::task::{Context, Poll};
10use std::{io, panic, thread};
11use tokio::runtime::Handle;
12use tokio::{signal, task};
13use tracing::level_filters::LevelFilter;
14use tracing::{debug, info, warn};
15use tracing_subscriber::layer::SubscriberExt;
16use tracing_subscriber::util::SubscriberInitExt;
17use tracing_subscriber::{EnvFilter, Layer, fmt};
18
19#[cfg(test)]
20mod tests;
21
22pub fn init_logger() {
23 console::set_colors_enabled(false);
27 console::set_colors_enabled_stderr(false);
28
29 let enable_color = if cfg!(windows) {
32 false
33 } else {
34 supports_color::on(supports_color::Stream::Stderr).is_some()
35 };
36
37 let res = tracing_subscriber::registry()
38 .with(
39 fmt::layer().with_ansi(enable_color).with_filter(
40 EnvFilter::builder()
41 .with_default_directive(LevelFilter::INFO.into())
42 .from_env_lossy(),
43 ),
44 )
45 .try_init();
46
47 if let Err(e) = res {
48 eprintln!(
51 "Failed to initialize logger: {e}. \
52 This is expected when running nexttest test functions under `cargo test`."
53 );
54 }
55}
56
57#[derive(Debug)]
60pub struct AsyncJoinOnDrop<T> {
61 handle: Option<task::JoinHandle<T>>,
62 abort_on_drop: bool,
63}
64
65impl<T> Drop for AsyncJoinOnDrop<T> {
66 #[inline]
67 fn drop(&mut self) {
68 if let Some(handle) = self.handle.take() {
69 if self.abort_on_drop {
70 handle.abort();
71 }
72
73 if !handle.is_finished() {
74 task::block_in_place(move || {
75 let _ = Handle::current().block_on(handle);
76 });
77 }
78 }
79 }
80}
81
82impl<T> AsyncJoinOnDrop<T> {
83 #[inline]
85 pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
86 Self {
87 handle: Some(handle),
88 abort_on_drop,
89 }
90 }
91}
92
93impl<T> FusedFuture for AsyncJoinOnDrop<T> {
94 fn is_terminated(&self) -> bool {
95 self.handle.is_none()
96 }
97}
98
99impl<T> Future for AsyncJoinOnDrop<T> {
100 type Output = Result<T, task::JoinError>;
101
102 #[inline]
103 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104 if let Some(handle) = self.handle.as_mut() {
105 let result = Pin::new(handle).poll(cx);
106 if result.is_ready() {
107 self.handle.take();
109 }
110 result
111 } else {
112 Poll::Pending
113 }
114 }
115}
116
117pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
119
120impl Drop for JoinOnDrop {
121 #[inline]
122 fn drop(&mut self) {
123 self.0
124 .take()
125 .expect("Always called exactly once; qed")
126 .join()
127 .expect("Panic if background thread panicked");
128 }
129}
130
131impl JoinOnDrop {
132 #[inline]
134 pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
135 Self(Some(handle))
136 }
137}
138
139impl Deref for JoinOnDrop {
140 type Target = thread::JoinHandle<()>;
141
142 #[inline]
143 fn deref(&self) -> &Self::Target {
144 self.0.as_ref().expect("Only dropped in Drop impl; qed")
145 }
146}
147
148pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
153 create_future: CreateFut,
154 thread_name: String,
155) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
156where
157 CreateFut: (FnOnce() -> Fut) + Send + 'static,
158 Fut: Future<Output = T> + 'static,
159 T: Send + 'static,
160{
161 let (drop_tx, drop_rx) = oneshot::channel::<()>();
162 let (result_tx, result_rx) = oneshot::channel();
163 let handle = Handle::current();
164 let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
165 let _tokio_handle_guard = handle.enter();
166
167 let future = pin!(create_future());
168
169 let result = match handle.block_on(futures::future::select(future, drop_rx)) {
170 Either::Left((result, _)) => result,
171 Either::Right(_) => {
172 return;
174 }
175 };
176 if let Err(_error) = result_tx.send(result) {
177 debug!(
178 thread_name = ?thread::current().name(),
179 "Future finished, but receiver was already dropped",
180 );
181 }
182 })?;
183 let join_on_drop = JoinOnDrop::new(join_handle);
185
186 Ok(async move {
187 let result = result_rx.await;
188 drop(drop_tx);
189 drop(join_on_drop);
190 result
191 })
192}
193
194#[cfg(unix)]
196pub async fn shutdown_signal(process_kind: impl Display) {
197 use futures::FutureExt;
198 use std::pin::pin;
199
200 let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt())
201 .expect("Setting signal handlers must never fail");
202 let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())
203 .expect("Setting signal handlers must never fail");
204
205 futures::future::select(
206 pin!(sigint.recv().map(|_| {
207 info!("Received SIGINT, shutting down {process_kind}...");
208 }),),
209 pin!(sigterm.recv().map(|_| {
210 info!("Received SIGTERM, shutting down {process_kind}...");
211 }),),
212 )
213 .await;
214}
215
216#[cfg(not(unix))]
218pub async fn shutdown_signal(process_kind: impl Display) {
219 signal::ctrl_c()
220 .await
221 .expect("Setting signal handlers must never fail");
222
223 info!("Received Ctrl+C, shutting down {process_kind}...");
224}
225
226pub fn raise_fd_limit() {
228 match fdlimit::raise_fd_limit() {
229 Ok(fdlimit::Outcome::LimitRaised { from, to }) => {
230 debug!(
231 "Increased file descriptor limit from previous (most likely soft) limit {} to \
232 new (most likely hard) limit {}",
233 from, to
234 );
235 }
236 Ok(fdlimit::Outcome::Unsupported) => {
237 }
239 Err(error) => {
240 warn!(
241 "Failed to increase file descriptor limit for the process due to an error: {}.",
242 error
243 );
244 }
245 }
246}
247
248pub fn set_exit_on_panic() {
251 let default_panic_hook = panic::take_hook();
252 panic::set_hook(Box::new(move |panic_info| {
253 default_panic_hook(panic_info);
254 exit(1);
255 }));
256}