subspace_runtime_primitives/
extension.rs

1#[cfg(feature = "runtime-benchmarks")]
2pub mod benchmarking;
3
4use crate::utility::{MaybeNestedCall, nested_call_iter};
5use crate::weights::balance_transfer_check_extension::WeightInfo as SubstrateWeightInfo;
6use core::marker::PhantomData;
7use frame_support::RuntimeDebugNoBound;
8use frame_support::pallet_prelude::Weight;
9use frame_system::Config;
10use frame_system::pallet_prelude::{OriginFor, RuntimeCallFor};
11use pallet_balances::Call as BalancesCall;
12use parity_scale_codec::{Decode, Encode};
13use scale_info::TypeInfo;
14use scale_info::prelude::fmt;
15use sp_runtime::DispatchResult;
16use sp_runtime::traits::{
17    AsSystemOriginSigner, DispatchInfoOf, DispatchOriginOf, Dispatchable, PostDispatchInfoOf,
18    TransactionExtension, ValidateResult,
19};
20use sp_runtime::transaction_validity::{
21    InvalidTransaction, TransactionSource, TransactionValidityError, ValidTransaction,
22};
23
24/// Maximum number of calls we benchmarked for.
25const MAXIMUM_NUMBER_OF_CALLS: u32 = 5_000;
26
27/// Weights for the balance transfer check extension.
28pub trait WeightInfo {
29    fn balance_transfer_check_multiple(c: u32) -> Weight;
30    fn balance_transfer_check_utility(c: u32) -> Weight;
31    fn balance_transfer_check_multisig(c: u32) -> Weight;
32}
33
34/// Trait to convert Runtime call to possible Balance call.
35pub trait MaybeBalancesCall<Runtime>
36where
37    Runtime: pallet_balances::Config,
38{
39    fn maybe_balance_call(&self) -> Option<&BalancesCall<Runtime>>;
40}
41
42/// Trait to check if the Balance transfers are enabled.
43pub trait BalanceTransferChecks {
44    fn is_balance_transferable() -> bool;
45}
46
47/// Disable balance transfers, if configured in the runtime.
48#[derive(Debug, Encode, Decode, Clone, Eq, PartialEq, TypeInfo)]
49pub struct BalanceTransferCheckExtension<Runtime>(PhantomData<Runtime>);
50
51impl<Runtime> Default for BalanceTransferCheckExtension<Runtime>
52where
53    Runtime: BalanceTransferChecks + pallet_balances::Config,
54    RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
55{
56    fn default() -> Self {
57        Self(PhantomData)
58    }
59}
60
61impl<Runtime> BalanceTransferCheckExtension<Runtime>
62where
63    Runtime: BalanceTransferChecks + pallet_balances::Config,
64    RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
65{
66    fn do_validate_signed(
67        call: &RuntimeCallFor<Runtime>,
68    ) -> Result<(ValidTransaction, u32), TransactionValidityError> {
69        if Runtime::is_balance_transferable() {
70            return Ok((ValidTransaction::default(), 0));
71        }
72
73        // Disable normal balance transfers.
74        let (contains_balance_call, calls) = Self::contains_balance_transfer(call);
75
76        let res = if contains_balance_call {
77            Err(InvalidTransaction::Call.into())
78        } else {
79            Ok((ValidTransaction::default(), calls))
80        };
81
82        // For benchmarks, always return success, so benchmarks run transfers.
83        if cfg!(feature = "runtime-benchmarks") {
84            Ok((ValidTransaction::default(), calls))
85        } else {
86            res
87        }
88    }
89
90    fn contains_balance_transfer(call: &RuntimeCallFor<Runtime>) -> (bool, u32) {
91        let mut calls = 0;
92        for call in nested_call_iter::<Runtime>(call) {
93            calls += 1;
94            // Any other calls might contain nested calls, so we can only return early if we find a
95            // balance transfer call.
96            if let Some(balance_call) = call.maybe_balance_call()
97                && matches!(
98                    balance_call,
99                    BalancesCall::transfer_allow_death { .. }
100                        | BalancesCall::transfer_keep_alive { .. }
101                        | BalancesCall::transfer_all { .. }
102                )
103            {
104                return (true, calls);
105            }
106        }
107
108        (false, calls)
109    }
110
111    fn get_weights(n: u32) -> Weight {
112        SubstrateWeightInfo::<Runtime>::balance_transfer_check_multisig(n)
113            .max(SubstrateWeightInfo::<Runtime>::balance_transfer_check_multiple(n))
114            .max(SubstrateWeightInfo::<Runtime>::balance_transfer_check_utility(n))
115    }
116}
117
118/// Data passed from prepare to post_dispatch.
119#[derive(RuntimeDebugNoBound)]
120pub enum Pre {
121    Refund(Weight),
122}
123
124/// Data passed from validate to prepare.
125#[derive(RuntimeDebugNoBound)]
126pub enum Val {
127    FullRefund,
128    PartialRefund(Option<u32>),
129}
130
131impl<Runtime> TransactionExtension<RuntimeCallFor<Runtime>>
132    for BalanceTransferCheckExtension<Runtime>
133where
134    Runtime: Config
135        + pallet_balances::Config
136        + scale_info::TypeInfo
137        + fmt::Debug
138        + Send
139        + Sync
140        + BalanceTransferChecks,
141    <RuntimeCallFor<Runtime> as Dispatchable>::RuntimeOrigin:
142        AsSystemOriginSigner<<Runtime as Config>::AccountId> + Clone,
143    RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
144{
145    const IDENTIFIER: &'static str = "BalanceTransferCheckExtension";
146    type Implicit = ();
147    type Val = Val;
148    type Pre = Pre;
149
150    fn weight(&self, _call: &RuntimeCallFor<Runtime>) -> Weight {
151        Self::get_weights(MAXIMUM_NUMBER_OF_CALLS)
152    }
153
154    fn validate(
155        &self,
156        origin: OriginFor<Runtime>,
157        call: &RuntimeCallFor<Runtime>,
158        _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
159        _len: usize,
160        _self_implicit: Self::Implicit,
161        _inherited_implication: &impl Encode,
162        _source: TransactionSource,
163    ) -> ValidateResult<Self::Val, RuntimeCallFor<Runtime>> {
164        let (validity, val) = if origin.as_system_origin_signer().is_some() {
165            let (valid, maybe_calls) =
166                Self::do_validate_signed(call).map(|(valid, calls)| (valid, Some(calls)))?;
167            (valid, Val::PartialRefund(maybe_calls))
168        } else {
169            (ValidTransaction::default(), Val::FullRefund)
170        };
171
172        Ok((validity, val, origin))
173    }
174
175    fn prepare(
176        self,
177        val: Self::Val,
178        _origin: &DispatchOriginOf<RuntimeCallFor<Runtime>>,
179        _call: &RuntimeCallFor<Runtime>,
180        _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
181        _len: usize,
182    ) -> Result<Self::Pre, TransactionValidityError> {
183        let total_weight = Self::get_weights(MAXIMUM_NUMBER_OF_CALLS);
184        match val {
185            // not a signed transaction, so return full refund.
186            Val::FullRefund => Ok(Pre::Refund(total_weight)),
187
188            // signed transaction with a minimum of one read weight,
189            // so refund any extra call weight
190            Val::PartialRefund(maybe_calls) => {
191                let actual_weights = Self::get_weights(maybe_calls.unwrap_or(0));
192                Ok(Pre::Refund(total_weight.saturating_sub(actual_weights)))
193            }
194        }
195    }
196
197    fn post_dispatch_details(
198        pre: Self::Pre,
199        _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
200        _post_info: &PostDispatchInfoOf<RuntimeCallFor<Runtime>>,
201        _len: usize,
202        _result: &DispatchResult,
203    ) -> Result<Weight, TransactionValidityError> {
204        let Pre::Refund(weight) = pre;
205        Ok(weight)
206    }
207}