1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#[cfg(not(feature = "std"))]
extern crate alloc;

#[cfg(not(feature = "std"))]
use alloc::collections::BTreeSet;
#[cfg(not(feature = "std"))]
use alloc::fmt;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use frame_support::PalletError;
use hash_db::Hasher;
#[cfg(feature = "std")]
use parity_scale_codec::Codec;
use parity_scale_codec::{Compact, Decode, Encode};
use scale_info::TypeInfo;
use sp_core::storage::StorageKey;
#[cfg(feature = "std")]
use sp_state_machine::prove_read;
#[cfg(feature = "std")]
use sp_state_machine::TrieBackendBuilder;
use sp_std::fmt::Debug;
use sp_std::marker::PhantomData;
use sp_trie::{read_trie_value, LayoutV1, StorageProof};
#[cfg(feature = "std")]
use std::collections::BTreeSet;
#[cfg(feature = "std")]
use std::fmt;
#[cfg(feature = "std")]
use trie_db::{DBValue, TrieDBMutBuilder, TrieLayout, TrieMut};

/// Verification error.
#[derive(Debug, PartialEq, Eq, Encode, Decode, PalletError, TypeInfo)]
pub enum VerificationError {
    /// Emits when the given storage proof is invalid.
    InvalidProof,
    /// Value doesn't exist in the Db for the given key.
    MissingValue,
    /// Failed to decode value.
    FailedToDecode,
    /// Storage proof contains unused nodes after reading the necessary keys.
    UnusedNodesInTheProof,
}

impl fmt::Display for VerificationError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match *self {
            VerificationError::InvalidProof => write!(f, "Given storage proof is invalid"),
            VerificationError::MissingValue => {
                write!(f, "Value doesn't exist in the Db for this key")
            }
            VerificationError::FailedToDecode => write!(f, "Failed to decode value"),
            VerificationError::UnusedNodesInTheProof => write!(
                f,
                "Storage proof contains unused nodes after reading the necessary keys"
            ),
        }
    }
}

/// Type that provides utilities to verify the storage proof.
pub struct StorageProofVerifier<H: Hasher>(PhantomData<H>);

impl<H: Hasher> StorageProofVerifier<H> {
    /// Extracts the value against a given key and returns a decoded value.
    pub fn get_decoded_value<V: Decode>(
        state_root: &H::Out,
        proof: StorageProof,
        key: StorageKey,
    ) -> Result<V, VerificationError> {
        let val = Self::get_bare_value(state_root, proof, key)?;
        let decoded = V::decode(&mut &val[..]).map_err(|_| VerificationError::FailedToDecode)?;

        Ok(decoded)
    }

    /// Returns the value against a given key.
    /// Note: Storage proof should contain nodes that are expected else this function errors out.
    pub fn get_bare_value(
        state_root: &H::Out,
        proof: StorageProof,
        key: StorageKey,
    ) -> Result<Vec<u8>, VerificationError> {
        let expected_nodes_to_be_read = proof.iter_nodes().count();
        let mut recorder = sp_trie::Recorder::<LayoutV1<H>>::new();
        let db = proof.into_memory_db::<H>();
        let val = read_trie_value::<LayoutV1<H>, _>(
            &db,
            state_root,
            key.as_ref(),
            Some(&mut recorder),
            None,
        )
        .map_err(|_| VerificationError::InvalidProof)?
        .ok_or(VerificationError::MissingValue)?;

        // check if the storage proof has any extra nodes that are not read.
        let visited_nodes = recorder
            .drain()
            .into_iter()
            .map(|record| record.data)
            .collect::<BTreeSet<_>>();

        if expected_nodes_to_be_read != visited_nodes.len() {
            return Err(VerificationError::UnusedNodesInTheProof);
        }

        Ok(val)
    }

    /// Constructs the storage key from a given enumerated index.
    pub fn enumerated_storage_key(index: u32) -> StorageKey {
        StorageKey(Compact(index).encode())
    }
}

#[cfg(feature = "std")]
type MemoryDB<T> = memory_db::MemoryDB<
    <T as TrieLayout>::Hash,
    memory_db::HashKey<<T as TrieLayout>::Hash>,
    DBValue,
>;

/// Type that provides utilities to generate the storage proof.
#[cfg(feature = "std")]
pub struct StorageProofProvider<Layout>(PhantomData<Layout>);

#[cfg(feature = "std")]
impl<Layout> StorageProofProvider<Layout>
where
    Layout: TrieLayout,
    <Layout::Hash as Hasher>::Out: Codec,
{
    /// Generate storage proof for given index from the trie constructed from `input`.
    ///
    /// Returns `None` if the given `index` out of range or fail to generate the proof.
    pub fn generate_enumerated_proof_of_inclusion(
        input: &[Vec<u8>],
        index: u32,
    ) -> Option<StorageProof> {
        if input.len() <= index as usize {
            return None;
        }

        let input: Vec<_> = input
            .iter()
            .enumerate()
            .map(|(i, v)| (Compact(i as u32).encode(), v))
            .collect();

        let (db, root) = {
            let mut db = <MemoryDB<Layout>>::default();
            let mut root = Default::default();
            {
                let mut trie = <TrieDBMutBuilder<Layout>>::new(&mut db, &mut root).build();
                for (key, value) in input {
                    trie.insert(&key, value).ok()?;
                }
            }
            (db, root)
        };

        let backend = TrieBackendBuilder::new(db, root).build();
        let key = Compact(index).encode();
        prove_read(backend, &[key]).ok()
    }
}