1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//! Module containing wrapper for SCALE encoding/decoding with checksum

#[cfg(test)]
mod tests;

use crate::Blake3Hash;
use core::mem;
use parity_scale_codec::{Decode, Encode, EncodeLike, Error, Input, Output};

/// Output wrapper for SCALE codec that will write Blake3 checksum at the end of the encoding
struct Blake3ChecksumOutput<'a, O>
where
    O: Output + ?Sized,
{
    output: &'a mut O,
    hasher: blake3::Hasher,
}

impl<'a, O> Drop for Blake3ChecksumOutput<'a, O>
where
    O: Output + ?Sized,
{
    #[inline]
    fn drop(&mut self) {
        // Write checksum at the very end of encoding
        let hash = *self.hasher.finalize().as_bytes();
        hash.encode_to(self.output);
    }
}

impl<'a, O> Output for Blake3ChecksumOutput<'a, O>
where
    O: Output + ?Sized,
{
    #[inline]
    fn write(&mut self, bytes: &[u8]) {
        self.hasher.update(bytes);
        self.output.write(bytes);
    }
}

impl<'a, O> Blake3ChecksumOutput<'a, O>
where
    O: Output + ?Sized,
{
    fn new(output: &'a mut O) -> Self {
        Self {
            output,
            hasher: blake3::Hasher::new(),
        }
    }
}

/// Input wrapper for SCALE codec that will write Blake3 checksum at the end of the encoding
struct Blake3ChecksumInput<'a, I>
where
    I: Input,
{
    input: &'a mut I,
    hasher: blake3::Hasher,
}

impl<'a, I> Input for Blake3ChecksumInput<'a, I>
where
    I: Input,
{
    #[inline]
    fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
        self.input.remaining_len()
    }

    #[inline]
    fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
        self.input.read(into)?;
        self.hasher.update(into);
        Ok(())
    }
}

impl<'a, I> Blake3ChecksumInput<'a, I>
where
    I: Input,
{
    fn new(output: &'a mut I) -> Self {
        Self {
            input: output,
            hasher: blake3::Hasher::new(),
        }
    }

    fn finish(self) -> (Blake3Hash, &'a mut I) {
        // Compute checksum at the very end of decoding
        let hash = *self.hasher.finalize().as_bytes();
        (hash, self.input)
    }
}

/// Wrapper data structure that when encoded/decoded will create/check Blake3 checksum
#[derive(Debug, Clone)]
pub struct Blake3Checksummed<T>(pub T);

impl<T> Encode for Blake3Checksummed<T>
where
    T: Encode,
{
    #[inline]
    fn size_hint(&self) -> usize {
        self.0.size_hint() + mem::size_of::<Blake3Hash>()
    }

    #[inline]
    fn encode_to<O>(&self, dest: &mut O)
    where
        O: Output + ?Sized,
    {
        self.0.encode_to(&mut Blake3ChecksumOutput::new(dest));
    }

    #[inline]
    fn encoded_size(&self) -> usize {
        self.0.encoded_size() + mem::size_of::<Blake3Hash>()
    }
}

impl<T> EncodeLike for Blake3Checksummed<T> where T: EncodeLike {}

impl<T> Decode for Blake3Checksummed<T>
where
    T: Decode,
{
    #[inline]
    fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
        let mut input = Blake3ChecksumInput::new(input);
        let data = T::decode(&mut input)?;
        let (actual_hash, input) = input.finish();
        let expected_hash = Blake3Hash::decode(input)?;

        if actual_hash == expected_hash {
            Ok(Self(data))
        } else {
            Err(Error::from("Checksum mismatch"))
        }
    }
}