subspace_runtime_primitives/
extension.rs1#[cfg(feature = "runtime-benchmarks")]
2pub mod benchmarking;
3
4use crate::utility::{MaybeNestedCall, nested_call_iter};
5use crate::weights::balance_transfer_check_extension::WeightInfo as SubstrateWeightInfo;
6use core::marker::PhantomData;
7use frame_support::RuntimeDebugNoBound;
8use frame_support::pallet_prelude::Weight;
9use frame_system::Config;
10use frame_system::pallet_prelude::{OriginFor, RuntimeCallFor};
11use pallet_balances::Call as BalancesCall;
12use parity_scale_codec::{Decode, Encode};
13use scale_info::TypeInfo;
14use scale_info::prelude::fmt;
15use sp_runtime::DispatchResult;
16use sp_runtime::traits::{
17 AsSystemOriginSigner, DispatchInfoOf, DispatchOriginOf, Dispatchable, PostDispatchInfoOf,
18 TransactionExtension, ValidateResult,
19};
20use sp_runtime::transaction_validity::{
21 InvalidTransaction, TransactionSource, TransactionValidityError, ValidTransaction,
22};
23
24const MAXIMUM_NUMBER_OF_CALLS: u32 = 5_000;
26
27pub trait WeightInfo {
29 fn balance_transfer_check_multiple(c: u32) -> Weight;
30 fn balance_transfer_check_utility(c: u32) -> Weight;
31 fn balance_transfer_check_multisig(c: u32) -> Weight;
32}
33
34pub trait MaybeBalancesCall<Runtime>
36where
37 Runtime: pallet_balances::Config,
38{
39 fn maybe_balance_call(&self) -> Option<&BalancesCall<Runtime>>;
40}
41
42pub trait BalanceTransferChecks {
44 fn is_balance_transferable() -> bool;
45}
46
47#[derive(Debug, Encode, Decode, Clone, Eq, PartialEq, TypeInfo)]
49pub struct BalanceTransferCheckExtension<Runtime>(PhantomData<Runtime>);
50
51impl<Runtime> Default for BalanceTransferCheckExtension<Runtime>
52where
53 Runtime: BalanceTransferChecks + pallet_balances::Config,
54 RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
55{
56 fn default() -> Self {
57 Self(PhantomData)
58 }
59}
60
61impl<Runtime> BalanceTransferCheckExtension<Runtime>
62where
63 Runtime: BalanceTransferChecks + pallet_balances::Config,
64 RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
65{
66 fn do_validate_signed(
67 call: &RuntimeCallFor<Runtime>,
68 ) -> Result<(ValidTransaction, u32), TransactionValidityError> {
69 if Runtime::is_balance_transferable() {
70 return Ok((ValidTransaction::default(), 0));
71 }
72
73 let (contains_balance_call, calls) = Self::contains_balance_transfer(call);
75
76 let res = if contains_balance_call {
77 Err(InvalidTransaction::Call.into())
78 } else {
79 Ok((ValidTransaction::default(), calls))
80 };
81
82 if cfg!(feature = "runtime-benchmarks") {
84 Ok((ValidTransaction::default(), calls))
85 } else {
86 res
87 }
88 }
89
90 fn contains_balance_transfer(call: &RuntimeCallFor<Runtime>) -> (bool, u32) {
91 let mut calls = 0;
92 for call in nested_call_iter::<Runtime>(call) {
93 calls += 1;
94 if let Some(balance_call) = call.maybe_balance_call()
97 && matches!(
98 balance_call,
99 BalancesCall::transfer_allow_death { .. }
100 | BalancesCall::transfer_keep_alive { .. }
101 | BalancesCall::transfer_all { .. }
102 )
103 {
104 return (true, calls);
105 }
106 }
107
108 (false, calls)
109 }
110
111 fn get_weights(n: u32) -> Weight {
112 SubstrateWeightInfo::<Runtime>::balance_transfer_check_multisig(n)
113 .max(SubstrateWeightInfo::<Runtime>::balance_transfer_check_multiple(n))
114 .max(SubstrateWeightInfo::<Runtime>::balance_transfer_check_utility(n))
115 }
116}
117
118#[derive(RuntimeDebugNoBound)]
120pub enum Pre {
121 Refund(Weight),
122}
123
124#[derive(RuntimeDebugNoBound)]
126pub enum Val {
127 FullRefund,
128 PartialRefund(Option<u32>),
129}
130
131impl<Runtime> TransactionExtension<RuntimeCallFor<Runtime>>
132 for BalanceTransferCheckExtension<Runtime>
133where
134 Runtime: Config
135 + pallet_balances::Config
136 + scale_info::TypeInfo
137 + fmt::Debug
138 + Send
139 + Sync
140 + BalanceTransferChecks,
141 <RuntimeCallFor<Runtime> as Dispatchable>::RuntimeOrigin:
142 AsSystemOriginSigner<<Runtime as Config>::AccountId> + Clone,
143 RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
144{
145 const IDENTIFIER: &'static str = "BalanceTransferCheckExtension";
146 type Implicit = ();
147 type Val = Val;
148 type Pre = Pre;
149
150 fn weight(&self, _call: &RuntimeCallFor<Runtime>) -> Weight {
151 Self::get_weights(MAXIMUM_NUMBER_OF_CALLS)
152 }
153
154 fn validate(
155 &self,
156 origin: OriginFor<Runtime>,
157 call: &RuntimeCallFor<Runtime>,
158 _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
159 _len: usize,
160 _self_implicit: Self::Implicit,
161 _inherited_implication: &impl Encode,
162 _source: TransactionSource,
163 ) -> ValidateResult<Self::Val, RuntimeCallFor<Runtime>> {
164 let (validity, val) = if origin.as_system_origin_signer().is_some() {
165 let (valid, maybe_calls) =
166 Self::do_validate_signed(call).map(|(valid, calls)| (valid, Some(calls)))?;
167 (valid, Val::PartialRefund(maybe_calls))
168 } else {
169 (ValidTransaction::default(), Val::FullRefund)
170 };
171
172 Ok((validity, val, origin))
173 }
174
175 fn prepare(
176 self,
177 val: Self::Val,
178 _origin: &DispatchOriginOf<RuntimeCallFor<Runtime>>,
179 _call: &RuntimeCallFor<Runtime>,
180 _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
181 _len: usize,
182 ) -> Result<Self::Pre, TransactionValidityError> {
183 let total_weight = Self::get_weights(MAXIMUM_NUMBER_OF_CALLS);
184 match val {
185 Val::FullRefund => Ok(Pre::Refund(total_weight)),
187
188 Val::PartialRefund(maybe_calls) => {
191 let actual_weights = Self::get_weights(maybe_calls.unwrap_or(0));
192 Ok(Pre::Refund(total_weight.saturating_sub(actual_weights)))
193 }
194 }
195 }
196
197 fn post_dispatch_details(
198 pre: Self::Pre,
199 _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
200 _post_info: &PostDispatchInfoOf<RuntimeCallFor<Runtime>>,
201 _len: usize,
202 _result: &DispatchResult,
203 ) -> Result<Weight, TransactionValidityError> {
204 let Pre::Refund(weight) = pre;
205 Ok(weight)
206 }
207}