subspace_proof_of_time/aes/
x86_64.rs

1use core::arch::x86_64::*;
2use core::array;
3use core::simd::{u8x16, u8x32, u8x64};
4use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
5
6const NUM_ROUND_KEYS: usize = 11;
7
8/// Create PoT proof with checkpoints
9#[target_feature(enable = "aes")]
10#[inline]
11pub(super) fn create(
12    seed: &[u8; 16],
13    key: &[u8; 16],
14    checkpoint_iterations: u32,
15) -> PotCheckpoints {
16    let mut checkpoints = PotCheckpoints::default();
17
18    let keys = expand_key(key);
19    let xor_key = _mm_xor_si128(keys[10], keys[0]);
20    let mut seed = __m128i::from(u8x16::from_array(*seed));
21    seed = _mm_xor_si128(seed, keys[0]);
22    for checkpoint in checkpoints.iter_mut() {
23        for _ in 0..checkpoint_iterations {
24            seed = _mm_aesenc_si128(seed, keys[1]);
25            seed = _mm_aesenc_si128(seed, keys[2]);
26            seed = _mm_aesenc_si128(seed, keys[3]);
27            seed = _mm_aesenc_si128(seed, keys[4]);
28            seed = _mm_aesenc_si128(seed, keys[5]);
29            seed = _mm_aesenc_si128(seed, keys[6]);
30            seed = _mm_aesenc_si128(seed, keys[7]);
31            seed = _mm_aesenc_si128(seed, keys[8]);
32            seed = _mm_aesenc_si128(seed, keys[9]);
33            seed = _mm_aesenclast_si128(seed, xor_key);
34        }
35
36        let checkpoint_reg = _mm_xor_si128(seed, keys[0]);
37        **checkpoint = u8x16::from(checkpoint_reg).to_array();
38    }
39
40    checkpoints
41}
42
43/// Verification mimics `create` function, but also has decryption half for better performance
44#[target_feature(enable = "aes,sse4.1")]
45#[inline]
46pub(super) fn verify_sequential_aes_sse41(
47    seed: &[u8; 16],
48    key: &[u8; 16],
49    checkpoints: &PotCheckpoints,
50    checkpoint_iterations: u32,
51) -> bool {
52    let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
53
54    let keys = expand_key(key);
55    let xor_key = _mm_xor_si128(keys[10], keys[0]);
56
57    // Invert keys for decryption, the first and last element is not used below, hence they are
58    // copied as is from encryption keys (otherwise the first and last element would need to be
59    // swapped)
60    let mut inv_keys = keys;
61    for i in 1..10 {
62        inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
63    }
64
65    let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
66        __m128i::from(u8x16::from(*seed)),
67        __m128i::from(u8x16::from(checkpoints[0])),
68        __m128i::from(u8x16::from(checkpoints[1])),
69        __m128i::from(u8x16::from(checkpoints[2])),
70        __m128i::from(u8x16::from(checkpoints[3])),
71        __m128i::from(u8x16::from(checkpoints[4])),
72        __m128i::from(u8x16::from(checkpoints[5])),
73        __m128i::from(u8x16::from(checkpoints[6])),
74    ];
75
76    let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
77        __m128i::from(u8x16::from(checkpoints[0])),
78        __m128i::from(u8x16::from(checkpoints[1])),
79        __m128i::from(u8x16::from(checkpoints[2])),
80        __m128i::from(u8x16::from(checkpoints[3])),
81        __m128i::from(u8x16::from(checkpoints[4])),
82        __m128i::from(u8x16::from(checkpoints[5])),
83        __m128i::from(u8x16::from(checkpoints[6])),
84        __m128i::from(u8x16::from(checkpoints[7])),
85    ];
86
87    inputs = inputs.map(|input| _mm_xor_si128(input, keys[0]));
88    outputs = outputs.map(|output| _mm_xor_si128(output, keys[10]));
89
90    for _ in 0..checkpoint_iterations / 2 {
91        for i in 1..10 {
92            inputs = inputs.map(|input| _mm_aesenc_si128(input, keys[i]));
93            outputs = outputs.map(|output| _mm_aesdec_si128(output, inv_keys[i]));
94        }
95
96        inputs = inputs.map(|input| _mm_aesenclast_si128(input, xor_key));
97        outputs = outputs.map(|output| _mm_aesdeclast_si128(output, xor_key));
98    }
99
100    // All bits set
101    let all_ones = _mm_set1_epi8(-1);
102
103    inputs.into_iter().zip(outputs).all(|(input, output)| {
104        let diff = _mm_xor_si128(input, output);
105        let cmp = _mm_xor_si128(diff, xor_key);
106        _mm_test_all_zeros(cmp, all_ones) == 1
107    })
108}
109
110/// Verification mimics `create` function, but also has decryption half for better performance
111#[target_feature(enable = "avx2,vaes")]
112#[inline]
113pub(super) fn verify_sequential_avx2_vaes(
114    seed: &[u8; 16],
115    key: &[u8; 16],
116    checkpoints: &PotCheckpoints,
117    checkpoint_iterations: u32,
118) -> bool {
119    let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
120
121    let keys = expand_key(key);
122    let xor_key = _mm_xor_si128(keys[10], keys[0]);
123    let xor_key_256 = _mm256_broadcastsi128_si256(xor_key);
124
125    // Invert keys for decryption, the first and last element is not used below, hence they are
126    // copied as is from encryption keys (otherwise the first and last element would need to be
127    // swapped)
128    let mut inv_keys = keys;
129    for i in 1..10 {
130        inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
131    }
132
133    let keys_256 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(keys[i]));
134    let inv_keys_256 =
135        array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(inv_keys[i]));
136
137    let mut input_0 = [[0u8; 16]; 2];
138    input_0[0] = *seed;
139    input_0[1] = checkpoints[0];
140    let mut input_0 = __m256i::from(u8x32::from_slice(input_0.as_flattened()));
141
142    let mut input_1 = __m256i::from(u8x32::from_slice(checkpoints[1..3].as_flattened()));
143    let mut input_2 = __m256i::from(u8x32::from_slice(checkpoints[3..5].as_flattened()));
144    let mut input_3 = __m256i::from(u8x32::from_slice(checkpoints[5..7].as_flattened()));
145
146    let mut output_0 = __m256i::from(u8x32::from_slice(checkpoints[0..2].as_flattened()));
147    let mut output_1 = __m256i::from(u8x32::from_slice(checkpoints[2..4].as_flattened()));
148    let mut output_2 = __m256i::from(u8x32::from_slice(checkpoints[4..6].as_flattened()));
149    let mut output_3 = __m256i::from(u8x32::from_slice(checkpoints[6..8].as_flattened()));
150
151    input_0 = _mm256_xor_si256(input_0, keys_256[0]);
152    input_1 = _mm256_xor_si256(input_1, keys_256[0]);
153    input_2 = _mm256_xor_si256(input_2, keys_256[0]);
154    input_3 = _mm256_xor_si256(input_3, keys_256[0]);
155
156    output_0 = _mm256_xor_si256(output_0, keys_256[10]);
157    output_1 = _mm256_xor_si256(output_1, keys_256[10]);
158    output_2 = _mm256_xor_si256(output_2, keys_256[10]);
159    output_3 = _mm256_xor_si256(output_3, keys_256[10]);
160
161    for _ in 0..checkpoint_iterations / 2 {
162        for i in 1..10 {
163            input_0 = _mm256_aesenc_epi128(input_0, keys_256[i]);
164            input_1 = _mm256_aesenc_epi128(input_1, keys_256[i]);
165            input_2 = _mm256_aesenc_epi128(input_2, keys_256[i]);
166            input_3 = _mm256_aesenc_epi128(input_3, keys_256[i]);
167
168            output_0 = _mm256_aesdec_epi128(output_0, inv_keys_256[i]);
169            output_1 = _mm256_aesdec_epi128(output_1, inv_keys_256[i]);
170            output_2 = _mm256_aesdec_epi128(output_2, inv_keys_256[i]);
171            output_3 = _mm256_aesdec_epi128(output_3, inv_keys_256[i]);
172        }
173
174        input_0 = _mm256_aesenclast_epi128(input_0, xor_key_256);
175        input_1 = _mm256_aesenclast_epi128(input_1, xor_key_256);
176        input_2 = _mm256_aesenclast_epi128(input_2, xor_key_256);
177        input_3 = _mm256_aesenclast_epi128(input_3, xor_key_256);
178
179        output_0 = _mm256_aesdeclast_epi128(output_0, xor_key_256);
180        output_1 = _mm256_aesdeclast_epi128(output_1, xor_key_256);
181        output_2 = _mm256_aesdeclast_epi128(output_2, xor_key_256);
182        output_3 = _mm256_aesdeclast_epi128(output_3, xor_key_256);
183    }
184
185    // Code below is a more efficient version of this:
186    // input_0 = _mm256_xor_si256(input_0, keys_256[0]);
187    // input_1 = _mm256_xor_si256(input_1, keys_256[0]);
188    // input_2 = _mm256_xor_si256(input_2, keys_256[0]);
189    // input_3 = _mm256_xor_si256(input_3, keys_256[0]);
190    // output_0 = _mm256_xor_si256(output_0, keys_256[10]);
191    // output_1 = _mm256_xor_si256(output_1, keys_256[10]);
192    // output_2 = _mm256_xor_si256(output_2, keys_256[10]);
193    // output_3 = _mm256_xor_si256(output_3, keys_256[10]);
194    //
195    // let mask_0 = _mm256_cmpeq_epi64(input_0, output_0);
196    // let mask_1 = _mm256_cmpeq_epi64(input_1, output_1);
197    // let mask_2 = _mm256_cmpeq_epi64(input_2, output_1);
198    // let mask_3 = _mm256_cmpeq_epi64(input_3, output_1);
199
200    let diff_0 = _mm256_xor_si256(input_0, output_0);
201    let diff_1 = _mm256_xor_si256(input_1, output_1);
202    let diff_2 = _mm256_xor_si256(input_2, output_2);
203    let diff_3 = _mm256_xor_si256(input_3, output_3);
204
205    let mask_0 = _mm256_cmpeq_epi64(diff_0, xor_key_256);
206    let mask_1 = _mm256_cmpeq_epi64(diff_1, xor_key_256);
207    let mask_2 = _mm256_cmpeq_epi64(diff_2, xor_key_256);
208    let mask_3 = _mm256_cmpeq_epi64(diff_3, xor_key_256);
209
210    // All bits set
211    let all_ones = _mm256_set1_epi64x(-1);
212
213    let match_0 = _mm256_testc_si256(mask_0, all_ones) != 0;
214    let match_1 = _mm256_testc_si256(mask_1, all_ones) != 0;
215    let match_2 = _mm256_testc_si256(mask_2, all_ones) != 0;
216    let match_3 = _mm256_testc_si256(mask_3, all_ones) != 0;
217
218    match_0 && match_1 && match_2 && match_3
219}
220
221/// Verification mimics `create` function, but also has decryption half for better performance
222#[target_feature(enable = "avx512f,vaes")]
223#[inline]
224pub(super) fn verify_sequential_avx512f_vaes(
225    seed: &[u8; 16],
226    key: &[u8; 16],
227    checkpoints: &PotCheckpoints,
228    checkpoint_iterations: u32,
229) -> bool {
230    let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
231
232    let keys = expand_key(key);
233    let xor_key = _mm_xor_si128(keys[10], keys[0]);
234    let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
235
236    // Invert keys for decryption, the first and last element is not used below, hence they are
237    // copied as is from encryption keys (otherwise the first and last element would need to be
238    // swapped)
239    let mut inv_keys = keys;
240    for i in 1..10 {
241        inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
242    }
243
244    let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys[i]));
245    let inv_keys_512 =
246        array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
247
248    let mut input_0 = [[0u8; 16]; 4];
249    input_0[0] = *seed;
250    input_0[1..].copy_from_slice(&checkpoints[..3]);
251    let mut input_0 = __m512i::from(u8x64::from_slice(input_0.as_flattened()));
252    let mut input_1 = __m512i::from(u8x64::from_slice(checkpoints[3..7].as_flattened()));
253
254    let mut output_0 = __m512i::from(u8x64::from_slice(checkpoints[0..4].as_flattened()));
255    let mut output_1 = __m512i::from(u8x64::from_slice(checkpoints[4..8].as_flattened()));
256
257    input_0 = _mm512_xor_si512(input_0, keys_512[0]);
258    input_1 = _mm512_xor_si512(input_1, keys_512[0]);
259
260    output_0 = _mm512_xor_si512(output_0, keys_512[10]);
261    output_1 = _mm512_xor_si512(output_1, keys_512[10]);
262
263    for _ in 0..checkpoint_iterations / 2 {
264        for i in 1..10 {
265            input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
266            input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
267
268            output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]);
269            output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]);
270        }
271
272        input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512);
273        input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512);
274
275        output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
276        output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
277    }
278
279    // Code below is a more efficient version of this:
280    // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
281    // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
282    // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
283    // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
284    //
285    // let mask_0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
286    // let mask_1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
287
288    let diff_0 = _mm512_xor_si512(input_0, output_0);
289    let diff_1 = _mm512_xor_si512(input_1, output_1);
290
291    let mask_0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
292    let mask_1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
293
294    // All inputs match outputs
295    (mask_0 & mask_1) == u8::MAX
296}
297
298// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by
299// Artyom Pavlov:
300// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/ni/expand.rs
301
302#[target_feature(enable = "aes")]
303fn expand_key(key: &[u8; 16]) -> [__m128i; NUM_ROUND_KEYS] {
304    #[target_feature(enable = "aes")]
305    fn expand_round<const RK: i32>(keys: &mut [__m128i; NUM_ROUND_KEYS], pos: usize) {
306        let mut t1 = keys[pos - 1];
307        let mut t2;
308        let mut t3;
309
310        t2 = _mm_aeskeygenassist_si128::<RK>(t1);
311        t2 = _mm_shuffle_epi32::<0xff>(t2);
312        t3 = _mm_slli_si128::<0x4>(t1);
313        t1 = _mm_xor_si128(t1, t3);
314        t3 = _mm_slli_si128::<0x4>(t3);
315        t1 = _mm_xor_si128(t1, t3);
316        t3 = _mm_slli_si128::<0x4>(t3);
317        t1 = _mm_xor_si128(t1, t3);
318        t1 = _mm_xor_si128(t1, t2);
319
320        keys[pos] = t1;
321    }
322
323    let mut keys = [_mm_setzero_si128(); NUM_ROUND_KEYS];
324    keys[0] = __m128i::from(u8x16::from(*key));
325
326    let kr = &mut keys;
327    expand_round::<0x01>(kr, 1);
328    expand_round::<0x02>(kr, 2);
329    expand_round::<0x04>(kr, 3);
330    expand_round::<0x08>(kr, 4);
331    expand_round::<0x10>(kr, 5);
332    expand_round::<0x20>(kr, 6);
333    expand_round::<0x40>(kr, 7);
334    expand_round::<0x80>(kr, 8);
335    expand_round::<0x1B>(kr, 9);
336    expand_round::<0x36>(kr, 10);
337
338    keys
339}