subspace_proof_of_time/aes/
x86_64.rs

1use core::arch::x86_64::*;
2use core::array;
3use core::simd::{u8x16, u8x32, u8x64};
4use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
5
6const NUM_ROUND_KEYS: usize = 11;
7
8/// Create PoT proof with checkpoints
9#[target_feature(enable = "aes")]
10#[inline]
11pub(super) fn create(
12    seed: &[u8; 16],
13    key: &[u8; 16],
14    checkpoint_iterations: u32,
15) -> PotCheckpoints {
16    let mut checkpoints = PotCheckpoints::default();
17
18    let keys = expand_key(key);
19    let xor_key = _mm_xor_si128(keys[10], keys[0]);
20    let mut seed = __m128i::from(u8x16::from_array(*seed));
21    seed = _mm_xor_si128(seed, keys[0]);
22    for checkpoint in checkpoints.iter_mut() {
23        for _ in 0..checkpoint_iterations {
24            seed = _mm_aesenc_si128(seed, keys[1]);
25            seed = _mm_aesenc_si128(seed, keys[2]);
26            seed = _mm_aesenc_si128(seed, keys[3]);
27            seed = _mm_aesenc_si128(seed, keys[4]);
28            seed = _mm_aesenc_si128(seed, keys[5]);
29            seed = _mm_aesenc_si128(seed, keys[6]);
30            seed = _mm_aesenc_si128(seed, keys[7]);
31            seed = _mm_aesenc_si128(seed, keys[8]);
32            seed = _mm_aesenc_si128(seed, keys[9]);
33            seed = _mm_aesenclast_si128(seed, xor_key);
34        }
35
36        let checkpoint_reg = _mm_xor_si128(seed, keys[0]);
37        **checkpoint = u8x16::from(checkpoint_reg).to_array();
38    }
39
40    checkpoints
41}
42
43/// Verification mimics `create` function, but also has decryption half for better performance
44#[target_feature(enable = "aes,sse4.1")]
45#[inline]
46pub(super) fn verify_sequential_aes_sse41(
47    seed: &[u8; 16],
48    key: &[u8; 16],
49    checkpoints: &PotCheckpoints,
50    checkpoint_iterations: u32,
51) -> bool {
52    let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
53
54    let keys = expand_key(key);
55    let xor_key = _mm_xor_si128(keys[10], keys[0]);
56
57    // Invert keys for decryption, the first and last element is not used below, hence they are
58    // copied as is from encryption keys (otherwise the first and last element would need to be
59    // swapped)
60    let mut inv_keys = keys;
61    for i in 1..10 {
62        inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
63    }
64
65    let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
66        __m128i::from(u8x16::from(*seed)),
67        __m128i::from(u8x16::from(checkpoints[0])),
68        __m128i::from(u8x16::from(checkpoints[1])),
69        __m128i::from(u8x16::from(checkpoints[2])),
70        __m128i::from(u8x16::from(checkpoints[3])),
71        __m128i::from(u8x16::from(checkpoints[4])),
72        __m128i::from(u8x16::from(checkpoints[5])),
73        __m128i::from(u8x16::from(checkpoints[6])),
74    ];
75
76    let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
77        __m128i::from(u8x16::from(checkpoints[0])),
78        __m128i::from(u8x16::from(checkpoints[1])),
79        __m128i::from(u8x16::from(checkpoints[2])),
80        __m128i::from(u8x16::from(checkpoints[3])),
81        __m128i::from(u8x16::from(checkpoints[4])),
82        __m128i::from(u8x16::from(checkpoints[5])),
83        __m128i::from(u8x16::from(checkpoints[6])),
84        __m128i::from(u8x16::from(checkpoints[7])),
85    ];
86
87    inputs = inputs.map(|input| _mm_xor_si128(input, keys[0]));
88    outputs = outputs.map(|output| _mm_xor_si128(output, keys[10]));
89
90    for _ in 0..checkpoint_iterations / 2 {
91        for i in 1..10 {
92            inputs = inputs.map(|input| _mm_aesenc_si128(input, keys[i]));
93            outputs = outputs.map(|output| _mm_aesdec_si128(output, inv_keys[i]));
94        }
95
96        inputs = inputs.map(|input| _mm_aesenclast_si128(input, xor_key));
97        outputs = outputs.map(|output| _mm_aesdeclast_si128(output, xor_key));
98    }
99
100    // All bits set
101    let all_ones = _mm_set1_epi8(-1);
102
103    inputs.into_iter().zip(outputs).all(|(input, output)| {
104        let diff = _mm_xor_si128(input, output);
105        let cmp = _mm_xor_si128(diff, xor_key);
106        _mm_test_all_zeros(cmp, all_ones) == 1
107    })
108}
109
110/// Verification mimics `create` function, but also has decryption half for better performance
111#[target_feature(enable = "avx2,vaes")]
112#[inline]
113pub(super) fn verify_sequential_avx2_vaes(
114    seed: &[u8; 16],
115    key: &[u8; 16],
116    checkpoints: &PotCheckpoints,
117    checkpoint_iterations: u32,
118) -> bool {
119    let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
120
121    let keys = expand_key(key);
122    let xor_key = _mm_xor_si128(keys[10], keys[0]);
123    let xor_key_256 = _mm256_broadcastsi128_si256(xor_key);
124
125    // Invert keys for decryption, the first and last element is not used below, hence they are
126    // copied as is from encryption keys (otherwise the first and last element would need to be
127    // swapped)
128    let mut inv_keys = keys;
129    for i in 1..10 {
130        inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
131    }
132
133    let keys_256 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(keys[i]));
134    let inv_keys_256 =
135        array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(inv_keys[i]));
136
137    let mut input_0 = [[0u8; 16]; 2];
138    input_0[0] = *seed;
139    input_0[1] = checkpoints[0];
140    let mut input_0 = __m256i::from(u8x32::from_slice(input_0.as_flattened()));
141
142    let mut input_1 = __m256i::from(u8x32::from_slice(checkpoints[1..3].as_flattened()));
143    let mut input_2 = __m256i::from(u8x32::from_slice(checkpoints[3..5].as_flattened()));
144    let mut input_3 = __m256i::from(u8x32::from_slice(checkpoints[5..7].as_flattened()));
145
146    let mut output_0 = __m256i::from(u8x32::from_slice(checkpoints[0..2].as_flattened()));
147    let mut output_1 = __m256i::from(u8x32::from_slice(checkpoints[2..4].as_flattened()));
148    let mut output_2 = __m256i::from(u8x32::from_slice(checkpoints[4..6].as_flattened()));
149    let mut output_3 = __m256i::from(u8x32::from_slice(checkpoints[6..8].as_flattened()));
150
151    input_0 = _mm256_xor_si256(input_0, keys_256[0]);
152    input_1 = _mm256_xor_si256(input_1, keys_256[0]);
153    input_2 = _mm256_xor_si256(input_2, keys_256[0]);
154    input_3 = _mm256_xor_si256(input_3, keys_256[0]);
155
156    output_0 = _mm256_xor_si256(output_0, keys_256[10]);
157    output_1 = _mm256_xor_si256(output_1, keys_256[10]);
158    output_2 = _mm256_xor_si256(output_2, keys_256[10]);
159    output_3 = _mm256_xor_si256(output_3, keys_256[10]);
160
161    for _ in 0..checkpoint_iterations / 2 {
162        // TODO: Shouldn't be unsafe: https://github.com/rust-lang/rust/issues/141718
163        unsafe {
164            for i in 1..10 {
165                input_0 = _mm256_aesenc_epi128(input_0, keys_256[i]);
166                input_1 = _mm256_aesenc_epi128(input_1, keys_256[i]);
167                input_2 = _mm256_aesenc_epi128(input_2, keys_256[i]);
168                input_3 = _mm256_aesenc_epi128(input_3, keys_256[i]);
169
170                output_0 = _mm256_aesdec_epi128(output_0, inv_keys_256[i]);
171                output_1 = _mm256_aesdec_epi128(output_1, inv_keys_256[i]);
172                output_2 = _mm256_aesdec_epi128(output_2, inv_keys_256[i]);
173                output_3 = _mm256_aesdec_epi128(output_3, inv_keys_256[i]);
174            }
175
176            input_0 = _mm256_aesenclast_epi128(input_0, xor_key_256);
177            input_1 = _mm256_aesenclast_epi128(input_1, xor_key_256);
178            input_2 = _mm256_aesenclast_epi128(input_2, xor_key_256);
179            input_3 = _mm256_aesenclast_epi128(input_3, xor_key_256);
180
181            output_0 = _mm256_aesdeclast_epi128(output_0, xor_key_256);
182            output_1 = _mm256_aesdeclast_epi128(output_1, xor_key_256);
183            output_2 = _mm256_aesdeclast_epi128(output_2, xor_key_256);
184            output_3 = _mm256_aesdeclast_epi128(output_3, xor_key_256);
185        }
186    }
187
188    // Code below is a more efficient version of this:
189    // input_0 = _mm256_xor_si256(input_0, keys_256[0]);
190    // input_1 = _mm256_xor_si256(input_1, keys_256[0]);
191    // input_2 = _mm256_xor_si256(input_2, keys_256[0]);
192    // input_3 = _mm256_xor_si256(input_3, keys_256[0]);
193    // output_0 = _mm256_xor_si256(output_0, keys_256[10]);
194    // output_1 = _mm256_xor_si256(output_1, keys_256[10]);
195    // output_2 = _mm256_xor_si256(output_2, keys_256[10]);
196    // output_3 = _mm256_xor_si256(output_3, keys_256[10]);
197    //
198    // let mask_0 = _mm256_cmpeq_epi64(input_0, output_0);
199    // let mask_1 = _mm256_cmpeq_epi64(input_1, output_1);
200    // let mask_2 = _mm256_cmpeq_epi64(input_2, output_1);
201    // let mask_3 = _mm256_cmpeq_epi64(input_3, output_1);
202
203    let diff_0 = _mm256_xor_si256(input_0, output_0);
204    let diff_1 = _mm256_xor_si256(input_1, output_1);
205    let diff_2 = _mm256_xor_si256(input_2, output_2);
206    let diff_3 = _mm256_xor_si256(input_3, output_3);
207
208    let mask_0 = _mm256_cmpeq_epi64(diff_0, xor_key_256);
209    let mask_1 = _mm256_cmpeq_epi64(diff_1, xor_key_256);
210    let mask_2 = _mm256_cmpeq_epi64(diff_2, xor_key_256);
211    let mask_3 = _mm256_cmpeq_epi64(diff_3, xor_key_256);
212
213    // All bits set
214    let all_ones = _mm256_set1_epi64x(-1);
215
216    let match_0 = _mm256_testc_si256(mask_0, all_ones) != 0;
217    let match_1 = _mm256_testc_si256(mask_1, all_ones) != 0;
218    let match_2 = _mm256_testc_si256(mask_2, all_ones) != 0;
219    let match_3 = _mm256_testc_si256(mask_3, all_ones) != 0;
220
221    match_0 && match_1 && match_2 && match_3
222}
223
224/// Verification mimics `create` function, but also has decryption half for better performance
225#[target_feature(enable = "avx512f,vaes")]
226#[inline]
227pub(super) fn verify_sequential_avx512f_vaes(
228    seed: &[u8; 16],
229    key: &[u8; 16],
230    checkpoints: &PotCheckpoints,
231    checkpoint_iterations: u32,
232) -> bool {
233    let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
234
235    let keys = expand_key(key);
236    let xor_key = _mm_xor_si128(keys[10], keys[0]);
237    let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
238
239    // Invert keys for decryption, the first and last element is not used below, hence they are
240    // copied as is from encryption keys (otherwise the first and last element would need to be
241    // swapped)
242    let mut inv_keys = keys;
243    for i in 1..10 {
244        inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
245    }
246
247    let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys[i]));
248    let inv_keys_512 =
249        array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
250
251    let mut input_0 = [[0u8; 16]; 4];
252    input_0[0] = *seed;
253    input_0[1..].copy_from_slice(&checkpoints[..3]);
254    let mut input_0 = __m512i::from(u8x64::from_slice(input_0.as_flattened()));
255    let mut input_1 = __m512i::from(u8x64::from_slice(checkpoints[3..7].as_flattened()));
256
257    let mut output_0 = __m512i::from(u8x64::from_slice(checkpoints[0..4].as_flattened()));
258    let mut output_1 = __m512i::from(u8x64::from_slice(checkpoints[4..8].as_flattened()));
259
260    input_0 = _mm512_xor_si512(input_0, keys_512[0]);
261    input_1 = _mm512_xor_si512(input_1, keys_512[0]);
262
263    output_0 = _mm512_xor_si512(output_0, keys_512[10]);
264    output_1 = _mm512_xor_si512(output_1, keys_512[10]);
265
266    for _ in 0..checkpoint_iterations / 2 {
267        // TODO: Shouldn't be unsafe: https://github.com/rust-lang/rust/issues/141718
268        unsafe {
269            for i in 1..10 {
270                input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
271                input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
272
273                output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]);
274                output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]);
275            }
276
277            input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512);
278            input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512);
279
280            output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
281            output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
282        }
283    }
284
285    // Code below is a more efficient version of this:
286    // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
287    // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
288    // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
289    // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
290    //
291    // let mask_0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
292    // let mask_1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
293
294    let diff_0 = _mm512_xor_si512(input_0, output_0);
295    let diff_1 = _mm512_xor_si512(input_1, output_1);
296
297    let mask_0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
298    let mask_1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
299
300    // All inputs match outputs
301    (mask_0 & mask_1) == u8::MAX
302}
303
304// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by
305// Artyom Pavlov:
306// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/ni/expand.rs
307
308#[target_feature(enable = "aes")]
309fn expand_key(key: &[u8; 16]) -> [__m128i; NUM_ROUND_KEYS] {
310    #[target_feature(enable = "aes")]
311    fn expand_round<const RK: i32>(keys: &mut [__m128i; NUM_ROUND_KEYS], pos: usize) {
312        let mut t1 = keys[pos - 1];
313        let mut t2;
314        let mut t3;
315
316        t2 = _mm_aeskeygenassist_si128::<RK>(t1);
317        t2 = _mm_shuffle_epi32::<0xff>(t2);
318        t3 = _mm_slli_si128::<0x4>(t1);
319        t1 = _mm_xor_si128(t1, t3);
320        t3 = _mm_slli_si128::<0x4>(t3);
321        t1 = _mm_xor_si128(t1, t3);
322        t3 = _mm_slli_si128::<0x4>(t3);
323        t1 = _mm_xor_si128(t1, t3);
324        t1 = _mm_xor_si128(t1, t2);
325
326        keys[pos] = t1;
327    }
328
329    let mut keys = [_mm_setzero_si128(); NUM_ROUND_KEYS];
330    keys[0] = __m128i::from(u8x16::from(*key));
331
332    let kr = &mut keys;
333    expand_round::<0x01>(kr, 1);
334    expand_round::<0x02>(kr, 2);
335    expand_round::<0x04>(kr, 3);
336    expand_round::<0x08>(kr, 4);
337    expand_round::<0x10>(kr, 5);
338    expand_round::<0x20>(kr, 6);
339    expand_round::<0x40>(kr, 7);
340    expand_round::<0x80>(kr, 8);
341    expand_round::<0x1B>(kr, 9);
342    expand_round::<0x36>(kr, 10);
343
344    keys
345}