subspace_runtime_primitives/
extension.rs

1#[cfg(feature = "runtime-benchmarks")]
2pub mod benchmarking;
3pub mod weights;
4
5use crate::extension::weights::WeightInfo as SubstrateWeightInfo;
6use crate::utility::{MaybeNestedCall, nested_call_iter};
7use core::marker::PhantomData;
8use frame_support::RuntimeDebugNoBound;
9use frame_support::pallet_prelude::Weight;
10use frame_system::Config;
11use frame_system::pallet_prelude::{OriginFor, RuntimeCallFor};
12use pallet_balances::Call as BalancesCall;
13use parity_scale_codec::{Decode, Encode};
14use scale_info::TypeInfo;
15use scale_info::prelude::fmt;
16use sp_runtime::DispatchResult;
17use sp_runtime::traits::{
18    AsSystemOriginSigner, DispatchInfoOf, DispatchOriginOf, Dispatchable, PostDispatchInfoOf,
19    TransactionExtension, ValidateResult,
20};
21use sp_runtime::transaction_validity::{
22    InvalidTransaction, TransactionSource, TransactionValidityError, ValidTransaction,
23};
24
25/// Maximum number of calls we benchmarked for.
26const MAXIMUM_NUMBER_OF_CALLS: u32 = 5_000;
27
28/// Weights for the balance transfer check extension.
29pub trait WeightInfo {
30    fn balance_transfer_check_multiple(c: u32) -> Weight;
31    fn balance_transfer_check_utility(c: u32) -> Weight;
32    fn balance_transfer_check_multisig(c: u32) -> Weight;
33}
34
35/// Trait to convert Runtime call to possible Balance call.
36pub trait MaybeBalancesCall<Runtime>
37where
38    Runtime: pallet_balances::Config,
39{
40    fn maybe_balance_call(&self) -> Option<&BalancesCall<Runtime>>;
41}
42
43/// Trait to check if the Balance transfers are enabled.
44pub trait BalanceTransferChecks {
45    fn is_balance_transferable() -> bool;
46}
47
48/// Disable balance transfers, if configured in the runtime.
49#[derive(Debug, Encode, Decode, Clone, Eq, PartialEq, TypeInfo)]
50pub struct BalanceTransferCheckExtension<Runtime>(PhantomData<Runtime>);
51
52impl<Runtime> Default for BalanceTransferCheckExtension<Runtime>
53where
54    Runtime: BalanceTransferChecks + pallet_balances::Config,
55    RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
56{
57    fn default() -> Self {
58        Self(PhantomData)
59    }
60}
61
62impl<Runtime> BalanceTransferCheckExtension<Runtime>
63where
64    Runtime: BalanceTransferChecks + pallet_balances::Config,
65    RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
66{
67    fn do_validate_signed(
68        call: &RuntimeCallFor<Runtime>,
69    ) -> Result<(ValidTransaction, u32), TransactionValidityError> {
70        if Runtime::is_balance_transferable() {
71            return Ok((ValidTransaction::default(), 0));
72        }
73
74        // Disable normal balance transfers.
75        let (contains_balance_call, calls) = Self::contains_balance_transfer(call);
76        if contains_balance_call {
77            Err(InvalidTransaction::Call.into())
78        } else {
79            Ok((ValidTransaction::default(), calls))
80        }
81    }
82
83    fn contains_balance_transfer(call: &RuntimeCallFor<Runtime>) -> (bool, u32) {
84        let mut calls = 0;
85        for call in nested_call_iter::<Runtime>(call) {
86            calls += 1;
87            // Any other calls might contain nested calls, so we can only return early if we find a
88            // balance transfer call.
89            if let Some(balance_call) = call.maybe_balance_call()
90                && matches!(
91                    balance_call,
92                    BalancesCall::transfer_allow_death { .. }
93                        | BalancesCall::transfer_keep_alive { .. }
94                        | BalancesCall::transfer_all { .. }
95                )
96            {
97                return (true, calls);
98            }
99        }
100
101        (false, calls)
102    }
103
104    fn get_weights(n: u32) -> Weight {
105        SubstrateWeightInfo::<Runtime>::balance_transfer_check_multisig(n)
106            .max(SubstrateWeightInfo::<Runtime>::balance_transfer_check_multiple(n))
107            .max(SubstrateWeightInfo::<Runtime>::balance_transfer_check_utility(n))
108    }
109}
110
111/// Data passed from prepare to post_dispatch.
112#[derive(RuntimeDebugNoBound)]
113pub enum Pre {
114    Refund(Weight),
115}
116
117/// Data passed from validate to prepare.
118#[derive(RuntimeDebugNoBound)]
119pub enum Val {
120    FullRefund,
121    PartialRefund(Option<u32>),
122}
123
124impl<Runtime> TransactionExtension<RuntimeCallFor<Runtime>>
125    for BalanceTransferCheckExtension<Runtime>
126where
127    Runtime: Config
128        + pallet_balances::Config
129        + scale_info::TypeInfo
130        + fmt::Debug
131        + Send
132        + Sync
133        + BalanceTransferChecks,
134    <RuntimeCallFor<Runtime> as Dispatchable>::RuntimeOrigin:
135        AsSystemOriginSigner<<Runtime as Config>::AccountId> + Clone,
136    RuntimeCallFor<Runtime>: MaybeBalancesCall<Runtime> + MaybeNestedCall<Runtime>,
137{
138    const IDENTIFIER: &'static str = "BalanceTransferCheckExtension";
139    type Implicit = ();
140    type Val = Val;
141    type Pre = Pre;
142
143    fn weight(&self, _call: &RuntimeCallFor<Runtime>) -> Weight {
144        Self::get_weights(MAXIMUM_NUMBER_OF_CALLS)
145    }
146
147    fn validate(
148        &self,
149        origin: OriginFor<Runtime>,
150        call: &RuntimeCallFor<Runtime>,
151        _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
152        _len: usize,
153        _self_implicit: Self::Implicit,
154        _inherited_implication: &impl Encode,
155        _source: TransactionSource,
156    ) -> ValidateResult<Self::Val, RuntimeCallFor<Runtime>> {
157        let (validity, val) = if origin.as_system_origin_signer().is_some() {
158            let (valid, maybe_calls) =
159                Self::do_validate_signed(call).map(|(valid, calls)| (valid, Some(calls)))?;
160            (valid, Val::PartialRefund(maybe_calls))
161        } else {
162            (ValidTransaction::default(), Val::FullRefund)
163        };
164
165        Ok((validity, val, origin))
166    }
167
168    fn prepare(
169        self,
170        val: Self::Val,
171        _origin: &DispatchOriginOf<RuntimeCallFor<Runtime>>,
172        _call: &RuntimeCallFor<Runtime>,
173        _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
174        _len: usize,
175    ) -> Result<Self::Pre, TransactionValidityError> {
176        let total_weight = Self::get_weights(MAXIMUM_NUMBER_OF_CALLS);
177        match val {
178            // not a signed transaction, so return full refund.
179            Val::FullRefund => Ok(Pre::Refund(total_weight)),
180
181            // signed transaction with a minimum of one read weight,
182            // so refund any extra call weight
183            Val::PartialRefund(maybe_calls) => {
184                let actual_weights = Self::get_weights(maybe_calls.unwrap_or(0));
185                Ok(Pre::Refund(total_weight.saturating_sub(actual_weights)))
186            }
187        }
188    }
189
190    fn post_dispatch_details(
191        pre: Self::Pre,
192        _info: &DispatchInfoOf<RuntimeCallFor<Runtime>>,
193        _post_info: &PostDispatchInfoOf<RuntimeCallFor<Runtime>>,
194        _len: usize,
195        _result: &DispatchResult,
196    ) -> Result<Weight, TransactionValidityError> {
197        let Pre::Refund(weight) = pre;
198        Ok(weight)
199    }
200}