subspace_core_primitives/
checksum.rs

1//! Module containing wrapper for SCALE encoding/decoding with checksum
2
3#[cfg(test)]
4mod tests;
5
6use crate::hashes::Blake3Hash;
7use parity_scale_codec::{Decode, Encode, EncodeLike, Error, Input, Output};
8
9/// Output wrapper for SCALE codec that will write Blake3 checksum at the end of the encoding
10struct Blake3ChecksumOutput<'a, O>
11where
12    O: Output + ?Sized,
13{
14    output: &'a mut O,
15    hasher: blake3::Hasher,
16}
17
18impl<O> Drop for Blake3ChecksumOutput<'_, O>
19where
20    O: Output + ?Sized,
21{
22    #[inline]
23    fn drop(&mut self) {
24        // Write checksum at the very end of encoding
25        let hash = *self.hasher.finalize().as_bytes();
26        hash.encode_to(self.output);
27    }
28}
29
30impl<O> Output for Blake3ChecksumOutput<'_, O>
31where
32    O: Output + ?Sized,
33{
34    #[inline]
35    fn write(&mut self, bytes: &[u8]) {
36        self.hasher.update(bytes);
37        self.output.write(bytes);
38    }
39}
40
41impl<'a, O> Blake3ChecksumOutput<'a, O>
42where
43    O: Output + ?Sized,
44{
45    fn new(output: &'a mut O) -> Self {
46        Self {
47            output,
48            hasher: blake3::Hasher::new(),
49        }
50    }
51}
52
53/// Input wrapper for SCALE codec that will write Blake3 checksum at the end of the encoding
54struct Blake3ChecksumInput<'a, I>
55where
56    I: Input,
57{
58    input: &'a mut I,
59    hasher: blake3::Hasher,
60}
61
62impl<I> Input for Blake3ChecksumInput<'_, I>
63where
64    I: Input,
65{
66    #[inline]
67    fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
68        self.input.remaining_len()
69    }
70
71    #[inline]
72    fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
73        self.input.read(into)?;
74        self.hasher.update(into);
75        Ok(())
76    }
77}
78
79impl<'a, I> Blake3ChecksumInput<'a, I>
80where
81    I: Input,
82{
83    fn new(output: &'a mut I) -> Self {
84        Self {
85            input: output,
86            hasher: blake3::Hasher::new(),
87        }
88    }
89
90    fn finish(self) -> (Blake3Hash, &'a mut I) {
91        // Compute checksum at the very end of decoding
92        let hash = *self.hasher.finalize().as_bytes();
93        (hash.into(), self.input)
94    }
95}
96
97/// Wrapper data structure that when encoded/decoded will create/check Blake3 checksum
98#[derive(Debug, Clone)]
99pub struct Blake3Checksummed<T>(pub T);
100
101impl<T> Encode for Blake3Checksummed<T>
102where
103    T: Encode,
104{
105    #[inline]
106    fn size_hint(&self) -> usize {
107        self.0.size_hint() + Blake3Hash::SIZE
108    }
109
110    #[inline]
111    fn encode_to<O>(&self, dest: &mut O)
112    where
113        O: Output + ?Sized,
114    {
115        self.0.encode_to(&mut Blake3ChecksumOutput::new(dest));
116    }
117
118    #[inline]
119    fn encoded_size(&self) -> usize {
120        self.0.encoded_size() + Blake3Hash::SIZE
121    }
122}
123
124impl<T> EncodeLike for Blake3Checksummed<T> where T: EncodeLike {}
125
126impl<T> Decode for Blake3Checksummed<T>
127where
128    T: Decode,
129{
130    #[inline]
131    fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
132        let mut input = Blake3ChecksumInput::new(input);
133        let data = T::decode(&mut input)?;
134        let (actual_hash, input) = input.finish();
135        let expected_hash = Blake3Hash::decode(input)?;
136
137        if actual_hash == expected_hash {
138            Ok(Self(data))
139        } else {
140            Err(Error::from("Checksum mismatch"))
141        }
142    }
143}