subspace_proof_of_time/aes/
x86_64.rs1use core::arch::x86_64::*;
2use core::array;
3use core::simd::{u8x16, u8x32, u8x64};
4use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
5
6const NUM_ROUND_KEYS: usize = 11;
7
8#[target_feature(enable = "aes")]
10#[inline]
11pub(super) fn create(
12 seed: &[u8; 16],
13 key: &[u8; 16],
14 checkpoint_iterations: u32,
15) -> PotCheckpoints {
16 let mut checkpoints = PotCheckpoints::default();
17
18 let keys = expand_key(key);
19 let xor_key = _mm_xor_si128(keys[10], keys[0]);
20 let mut seed = __m128i::from(u8x16::from_array(*seed));
21 seed = _mm_xor_si128(seed, keys[0]);
22 for checkpoint in checkpoints.iter_mut() {
23 for _ in 0..checkpoint_iterations {
24 seed = _mm_aesenc_si128(seed, keys[1]);
25 seed = _mm_aesenc_si128(seed, keys[2]);
26 seed = _mm_aesenc_si128(seed, keys[3]);
27 seed = _mm_aesenc_si128(seed, keys[4]);
28 seed = _mm_aesenc_si128(seed, keys[5]);
29 seed = _mm_aesenc_si128(seed, keys[6]);
30 seed = _mm_aesenc_si128(seed, keys[7]);
31 seed = _mm_aesenc_si128(seed, keys[8]);
32 seed = _mm_aesenc_si128(seed, keys[9]);
33 seed = _mm_aesenclast_si128(seed, xor_key);
34 }
35
36 let checkpoint_reg = _mm_xor_si128(seed, keys[0]);
37 **checkpoint = u8x16::from(checkpoint_reg).to_array();
38 }
39
40 checkpoints
41}
42
43#[target_feature(enable = "aes,sse4.1")]
45#[inline]
46pub(super) fn verify_sequential_aes_sse41(
47 seed: &[u8; 16],
48 key: &[u8; 16],
49 checkpoints: &PotCheckpoints,
50 checkpoint_iterations: u32,
51) -> bool {
52 let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
53
54 let keys = expand_key(key);
55 let xor_key = _mm_xor_si128(keys[10], keys[0]);
56
57 let mut inv_keys = keys;
61 for i in 1..10 {
62 inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
63 }
64
65 let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
66 __m128i::from(u8x16::from(*seed)),
67 __m128i::from(u8x16::from(checkpoints[0])),
68 __m128i::from(u8x16::from(checkpoints[1])),
69 __m128i::from(u8x16::from(checkpoints[2])),
70 __m128i::from(u8x16::from(checkpoints[3])),
71 __m128i::from(u8x16::from(checkpoints[4])),
72 __m128i::from(u8x16::from(checkpoints[5])),
73 __m128i::from(u8x16::from(checkpoints[6])),
74 ];
75
76 let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
77 __m128i::from(u8x16::from(checkpoints[0])),
78 __m128i::from(u8x16::from(checkpoints[1])),
79 __m128i::from(u8x16::from(checkpoints[2])),
80 __m128i::from(u8x16::from(checkpoints[3])),
81 __m128i::from(u8x16::from(checkpoints[4])),
82 __m128i::from(u8x16::from(checkpoints[5])),
83 __m128i::from(u8x16::from(checkpoints[6])),
84 __m128i::from(u8x16::from(checkpoints[7])),
85 ];
86
87 inputs = inputs.map(|input| _mm_xor_si128(input, keys[0]));
88 outputs = outputs.map(|output| _mm_xor_si128(output, keys[10]));
89
90 for _ in 0..checkpoint_iterations / 2 {
91 for i in 1..10 {
92 inputs = inputs.map(|input| _mm_aesenc_si128(input, keys[i]));
93 outputs = outputs.map(|output| _mm_aesdec_si128(output, inv_keys[i]));
94 }
95
96 inputs = inputs.map(|input| _mm_aesenclast_si128(input, xor_key));
97 outputs = outputs.map(|output| _mm_aesdeclast_si128(output, xor_key));
98 }
99
100 let all_ones = _mm_set1_epi8(-1);
102
103 inputs.into_iter().zip(outputs).all(|(input, output)| {
104 let diff = _mm_xor_si128(input, output);
105 let cmp = _mm_xor_si128(diff, xor_key);
106 _mm_test_all_zeros(cmp, all_ones) == 1
107 })
108}
109
110#[target_feature(enable = "avx2,vaes")]
112#[inline]
113pub(super) fn verify_sequential_avx2_vaes(
114 seed: &[u8; 16],
115 key: &[u8; 16],
116 checkpoints: &PotCheckpoints,
117 checkpoint_iterations: u32,
118) -> bool {
119 let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
120
121 let keys = expand_key(key);
122 let xor_key = _mm_xor_si128(keys[10], keys[0]);
123 let xor_key_256 = _mm256_broadcastsi128_si256(xor_key);
124
125 let mut inv_keys = keys;
129 for i in 1..10 {
130 inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
131 }
132
133 let keys_256 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(keys[i]));
134 let inv_keys_256 =
135 array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(inv_keys[i]));
136
137 let mut input_0 = [[0u8; 16]; 2];
138 input_0[0] = *seed;
139 input_0[1] = checkpoints[0];
140 let mut input_0 = __m256i::from(u8x32::from_slice(input_0.as_flattened()));
141
142 let mut input_1 = __m256i::from(u8x32::from_slice(checkpoints[1..3].as_flattened()));
143 let mut input_2 = __m256i::from(u8x32::from_slice(checkpoints[3..5].as_flattened()));
144 let mut input_3 = __m256i::from(u8x32::from_slice(checkpoints[5..7].as_flattened()));
145
146 let mut output_0 = __m256i::from(u8x32::from_slice(checkpoints[0..2].as_flattened()));
147 let mut output_1 = __m256i::from(u8x32::from_slice(checkpoints[2..4].as_flattened()));
148 let mut output_2 = __m256i::from(u8x32::from_slice(checkpoints[4..6].as_flattened()));
149 let mut output_3 = __m256i::from(u8x32::from_slice(checkpoints[6..8].as_flattened()));
150
151 input_0 = _mm256_xor_si256(input_0, keys_256[0]);
152 input_1 = _mm256_xor_si256(input_1, keys_256[0]);
153 input_2 = _mm256_xor_si256(input_2, keys_256[0]);
154 input_3 = _mm256_xor_si256(input_3, keys_256[0]);
155
156 output_0 = _mm256_xor_si256(output_0, keys_256[10]);
157 output_1 = _mm256_xor_si256(output_1, keys_256[10]);
158 output_2 = _mm256_xor_si256(output_2, keys_256[10]);
159 output_3 = _mm256_xor_si256(output_3, keys_256[10]);
160
161 for _ in 0..checkpoint_iterations / 2 {
162 for i in 1..10 {
163 input_0 = _mm256_aesenc_epi128(input_0, keys_256[i]);
164 input_1 = _mm256_aesenc_epi128(input_1, keys_256[i]);
165 input_2 = _mm256_aesenc_epi128(input_2, keys_256[i]);
166 input_3 = _mm256_aesenc_epi128(input_3, keys_256[i]);
167
168 output_0 = _mm256_aesdec_epi128(output_0, inv_keys_256[i]);
169 output_1 = _mm256_aesdec_epi128(output_1, inv_keys_256[i]);
170 output_2 = _mm256_aesdec_epi128(output_2, inv_keys_256[i]);
171 output_3 = _mm256_aesdec_epi128(output_3, inv_keys_256[i]);
172 }
173
174 input_0 = _mm256_aesenclast_epi128(input_0, xor_key_256);
175 input_1 = _mm256_aesenclast_epi128(input_1, xor_key_256);
176 input_2 = _mm256_aesenclast_epi128(input_2, xor_key_256);
177 input_3 = _mm256_aesenclast_epi128(input_3, xor_key_256);
178
179 output_0 = _mm256_aesdeclast_epi128(output_0, xor_key_256);
180 output_1 = _mm256_aesdeclast_epi128(output_1, xor_key_256);
181 output_2 = _mm256_aesdeclast_epi128(output_2, xor_key_256);
182 output_3 = _mm256_aesdeclast_epi128(output_3, xor_key_256);
183 }
184
185 let diff_0 = _mm256_xor_si256(input_0, output_0);
201 let diff_1 = _mm256_xor_si256(input_1, output_1);
202 let diff_2 = _mm256_xor_si256(input_2, output_2);
203 let diff_3 = _mm256_xor_si256(input_3, output_3);
204
205 let mask_0 = _mm256_cmpeq_epi64(diff_0, xor_key_256);
206 let mask_1 = _mm256_cmpeq_epi64(diff_1, xor_key_256);
207 let mask_2 = _mm256_cmpeq_epi64(diff_2, xor_key_256);
208 let mask_3 = _mm256_cmpeq_epi64(diff_3, xor_key_256);
209
210 let all_ones = _mm256_set1_epi64x(-1);
212
213 let match_0 = _mm256_testc_si256(mask_0, all_ones) != 0;
214 let match_1 = _mm256_testc_si256(mask_1, all_ones) != 0;
215 let match_2 = _mm256_testc_si256(mask_2, all_ones) != 0;
216 let match_3 = _mm256_testc_si256(mask_3, all_ones) != 0;
217
218 match_0 && match_1 && match_2 && match_3
219}
220
221#[target_feature(enable = "avx512f,vaes")]
223#[inline]
224pub(super) fn verify_sequential_avx512f_vaes(
225 seed: &[u8; 16],
226 key: &[u8; 16],
227 checkpoints: &PotCheckpoints,
228 checkpoint_iterations: u32,
229) -> bool {
230 let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
231
232 let keys = expand_key(key);
233 let xor_key = _mm_xor_si128(keys[10], keys[0]);
234 let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
235
236 let mut inv_keys = keys;
240 for i in 1..10 {
241 inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
242 }
243
244 let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys[i]));
245 let inv_keys_512 =
246 array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
247
248 let mut input_0 = [[0u8; 16]; 4];
249 input_0[0] = *seed;
250 input_0[1..].copy_from_slice(&checkpoints[..3]);
251 let mut input_0 = __m512i::from(u8x64::from_slice(input_0.as_flattened()));
252 let mut input_1 = __m512i::from(u8x64::from_slice(checkpoints[3..7].as_flattened()));
253
254 let mut output_0 = __m512i::from(u8x64::from_slice(checkpoints[0..4].as_flattened()));
255 let mut output_1 = __m512i::from(u8x64::from_slice(checkpoints[4..8].as_flattened()));
256
257 input_0 = _mm512_xor_si512(input_0, keys_512[0]);
258 input_1 = _mm512_xor_si512(input_1, keys_512[0]);
259
260 output_0 = _mm512_xor_si512(output_0, keys_512[10]);
261 output_1 = _mm512_xor_si512(output_1, keys_512[10]);
262
263 for _ in 0..checkpoint_iterations / 2 {
264 for i in 1..10 {
265 input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
266 input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
267
268 output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]);
269 output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]);
270 }
271
272 input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512);
273 input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512);
274
275 output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
276 output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
277 }
278
279 let diff_0 = _mm512_xor_si512(input_0, output_0);
289 let diff_1 = _mm512_xor_si512(input_1, output_1);
290
291 let mask_0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
292 let mask_1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
293
294 (mask_0 & mask_1) == u8::MAX
296}
297
298#[target_feature(enable = "aes")]
303fn expand_key(key: &[u8; 16]) -> [__m128i; NUM_ROUND_KEYS] {
304 #[target_feature(enable = "aes")]
305 fn expand_round<const RK: i32>(keys: &mut [__m128i; NUM_ROUND_KEYS], pos: usize) {
306 let mut t1 = keys[pos - 1];
307 let mut t2;
308 let mut t3;
309
310 t2 = _mm_aeskeygenassist_si128::<RK>(t1);
311 t2 = _mm_shuffle_epi32::<0xff>(t2);
312 t3 = _mm_slli_si128::<0x4>(t1);
313 t1 = _mm_xor_si128(t1, t3);
314 t3 = _mm_slli_si128::<0x4>(t3);
315 t1 = _mm_xor_si128(t1, t3);
316 t3 = _mm_slli_si128::<0x4>(t3);
317 t1 = _mm_xor_si128(t1, t3);
318 t1 = _mm_xor_si128(t1, t2);
319
320 keys[pos] = t1;
321 }
322
323 let mut keys = [_mm_setzero_si128(); NUM_ROUND_KEYS];
324 keys[0] = __m128i::from(u8x16::from(*key));
325
326 let kr = &mut keys;
327 expand_round::<0x01>(kr, 1);
328 expand_round::<0x02>(kr, 2);
329 expand_round::<0x04>(kr, 3);
330 expand_round::<0x08>(kr, 4);
331 expand_round::<0x10>(kr, 5);
332 expand_round::<0x20>(kr, 6);
333 expand_round::<0x40>(kr, 7);
334 expand_round::<0x80>(kr, 8);
335 expand_round::<0x1B>(kr, 9);
336 expand_round::<0x36>(kr, 10);
337
338 keys
339}