subspace_proof_of_time/aes/
x86_64.rs1use core::arch::x86_64::*;
2use core::array;
3use core::simd::{u8x16, u8x32, u8x64};
4use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
5
6const NUM_ROUND_KEYS: usize = 11;
7
8#[target_feature(enable = "aes")]
10#[inline]
11pub(super) fn create(
12 seed: &[u8; 16],
13 key: &[u8; 16],
14 checkpoint_iterations: u32,
15) -> PotCheckpoints {
16 let mut checkpoints = PotCheckpoints::default();
17
18 let keys = expand_key(key);
19 let xor_key = _mm_xor_si128(keys[10], keys[0]);
20 let mut seed = __m128i::from(u8x16::from_array(*seed));
21 seed = _mm_xor_si128(seed, keys[0]);
22 for checkpoint in checkpoints.iter_mut() {
23 for _ in 0..checkpoint_iterations {
24 seed = _mm_aesenc_si128(seed, keys[1]);
25 seed = _mm_aesenc_si128(seed, keys[2]);
26 seed = _mm_aesenc_si128(seed, keys[3]);
27 seed = _mm_aesenc_si128(seed, keys[4]);
28 seed = _mm_aesenc_si128(seed, keys[5]);
29 seed = _mm_aesenc_si128(seed, keys[6]);
30 seed = _mm_aesenc_si128(seed, keys[7]);
31 seed = _mm_aesenc_si128(seed, keys[8]);
32 seed = _mm_aesenc_si128(seed, keys[9]);
33 seed = _mm_aesenclast_si128(seed, xor_key);
34 }
35
36 let checkpoint_reg = _mm_xor_si128(seed, keys[0]);
37 **checkpoint = u8x16::from(checkpoint_reg).to_array();
38 }
39
40 checkpoints
41}
42
43#[target_feature(enable = "aes,sse4.1")]
45#[inline]
46pub(super) fn verify_sequential_aes_sse41(
47 seed: &[u8; 16],
48 key: &[u8; 16],
49 checkpoints: &PotCheckpoints,
50 checkpoint_iterations: u32,
51) -> bool {
52 let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
53
54 let keys = expand_key(key);
55 let xor_key = _mm_xor_si128(keys[10], keys[0]);
56
57 let mut inv_keys = keys;
61 for i in 1..10 {
62 inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
63 }
64
65 let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
66 __m128i::from(u8x16::from(*seed)),
67 __m128i::from(u8x16::from(checkpoints[0])),
68 __m128i::from(u8x16::from(checkpoints[1])),
69 __m128i::from(u8x16::from(checkpoints[2])),
70 __m128i::from(u8x16::from(checkpoints[3])),
71 __m128i::from(u8x16::from(checkpoints[4])),
72 __m128i::from(u8x16::from(checkpoints[5])),
73 __m128i::from(u8x16::from(checkpoints[6])),
74 ];
75
76 let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
77 __m128i::from(u8x16::from(checkpoints[0])),
78 __m128i::from(u8x16::from(checkpoints[1])),
79 __m128i::from(u8x16::from(checkpoints[2])),
80 __m128i::from(u8x16::from(checkpoints[3])),
81 __m128i::from(u8x16::from(checkpoints[4])),
82 __m128i::from(u8x16::from(checkpoints[5])),
83 __m128i::from(u8x16::from(checkpoints[6])),
84 __m128i::from(u8x16::from(checkpoints[7])),
85 ];
86
87 inputs = inputs.map(|input| _mm_xor_si128(input, keys[0]));
88 outputs = outputs.map(|output| _mm_xor_si128(output, keys[10]));
89
90 for _ in 0..checkpoint_iterations / 2 {
91 for i in 1..10 {
92 inputs = inputs.map(|input| _mm_aesenc_si128(input, keys[i]));
93 outputs = outputs.map(|output| _mm_aesdec_si128(output, inv_keys[i]));
94 }
95
96 inputs = inputs.map(|input| _mm_aesenclast_si128(input, xor_key));
97 outputs = outputs.map(|output| _mm_aesdeclast_si128(output, xor_key));
98 }
99
100 let all_ones = _mm_set1_epi8(-1);
102
103 inputs.into_iter().zip(outputs).all(|(input, output)| {
104 let diff = _mm_xor_si128(input, output);
105 let cmp = _mm_xor_si128(diff, xor_key);
106 _mm_test_all_zeros(cmp, all_ones) == 1
107 })
108}
109
110#[target_feature(enable = "avx2,vaes")]
112#[inline]
113pub(super) fn verify_sequential_avx2_vaes(
114 seed: &[u8; 16],
115 key: &[u8; 16],
116 checkpoints: &PotCheckpoints,
117 checkpoint_iterations: u32,
118) -> bool {
119 let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
120
121 let keys = expand_key(key);
122 let xor_key = _mm_xor_si128(keys[10], keys[0]);
123 let xor_key_256 = _mm256_broadcastsi128_si256(xor_key);
124
125 let mut inv_keys = keys;
129 for i in 1..10 {
130 inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
131 }
132
133 let keys_256 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(keys[i]));
134 let inv_keys_256 =
135 array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(inv_keys[i]));
136
137 let mut input_0 = [[0u8; 16]; 2];
138 input_0[0] = *seed;
139 input_0[1] = checkpoints[0];
140 let mut input_0 = __m256i::from(u8x32::from_slice(input_0.as_flattened()));
141
142 let mut input_1 = __m256i::from(u8x32::from_slice(checkpoints[1..3].as_flattened()));
143 let mut input_2 = __m256i::from(u8x32::from_slice(checkpoints[3..5].as_flattened()));
144 let mut input_3 = __m256i::from(u8x32::from_slice(checkpoints[5..7].as_flattened()));
145
146 let mut output_0 = __m256i::from(u8x32::from_slice(checkpoints[0..2].as_flattened()));
147 let mut output_1 = __m256i::from(u8x32::from_slice(checkpoints[2..4].as_flattened()));
148 let mut output_2 = __m256i::from(u8x32::from_slice(checkpoints[4..6].as_flattened()));
149 let mut output_3 = __m256i::from(u8x32::from_slice(checkpoints[6..8].as_flattened()));
150
151 input_0 = _mm256_xor_si256(input_0, keys_256[0]);
152 input_1 = _mm256_xor_si256(input_1, keys_256[0]);
153 input_2 = _mm256_xor_si256(input_2, keys_256[0]);
154 input_3 = _mm256_xor_si256(input_3, keys_256[0]);
155
156 output_0 = _mm256_xor_si256(output_0, keys_256[10]);
157 output_1 = _mm256_xor_si256(output_1, keys_256[10]);
158 output_2 = _mm256_xor_si256(output_2, keys_256[10]);
159 output_3 = _mm256_xor_si256(output_3, keys_256[10]);
160
161 for _ in 0..checkpoint_iterations / 2 {
162 unsafe {
164 for i in 1..10 {
165 input_0 = _mm256_aesenc_epi128(input_0, keys_256[i]);
166 input_1 = _mm256_aesenc_epi128(input_1, keys_256[i]);
167 input_2 = _mm256_aesenc_epi128(input_2, keys_256[i]);
168 input_3 = _mm256_aesenc_epi128(input_3, keys_256[i]);
169
170 output_0 = _mm256_aesdec_epi128(output_0, inv_keys_256[i]);
171 output_1 = _mm256_aesdec_epi128(output_1, inv_keys_256[i]);
172 output_2 = _mm256_aesdec_epi128(output_2, inv_keys_256[i]);
173 output_3 = _mm256_aesdec_epi128(output_3, inv_keys_256[i]);
174 }
175
176 input_0 = _mm256_aesenclast_epi128(input_0, xor_key_256);
177 input_1 = _mm256_aesenclast_epi128(input_1, xor_key_256);
178 input_2 = _mm256_aesenclast_epi128(input_2, xor_key_256);
179 input_3 = _mm256_aesenclast_epi128(input_3, xor_key_256);
180
181 output_0 = _mm256_aesdeclast_epi128(output_0, xor_key_256);
182 output_1 = _mm256_aesdeclast_epi128(output_1, xor_key_256);
183 output_2 = _mm256_aesdeclast_epi128(output_2, xor_key_256);
184 output_3 = _mm256_aesdeclast_epi128(output_3, xor_key_256);
185 }
186 }
187
188 let diff_0 = _mm256_xor_si256(input_0, output_0);
204 let diff_1 = _mm256_xor_si256(input_1, output_1);
205 let diff_2 = _mm256_xor_si256(input_2, output_2);
206 let diff_3 = _mm256_xor_si256(input_3, output_3);
207
208 let mask_0 = _mm256_cmpeq_epi64(diff_0, xor_key_256);
209 let mask_1 = _mm256_cmpeq_epi64(diff_1, xor_key_256);
210 let mask_2 = _mm256_cmpeq_epi64(diff_2, xor_key_256);
211 let mask_3 = _mm256_cmpeq_epi64(diff_3, xor_key_256);
212
213 let all_ones = _mm256_set1_epi64x(-1);
215
216 let match_0 = _mm256_testc_si256(mask_0, all_ones) != 0;
217 let match_1 = _mm256_testc_si256(mask_1, all_ones) != 0;
218 let match_2 = _mm256_testc_si256(mask_2, all_ones) != 0;
219 let match_3 = _mm256_testc_si256(mask_3, all_ones) != 0;
220
221 match_0 && match_1 && match_2 && match_3
222}
223
224#[target_feature(enable = "avx512f,vaes")]
226#[inline]
227pub(super) fn verify_sequential_avx512f_vaes(
228 seed: &[u8; 16],
229 key: &[u8; 16],
230 checkpoints: &PotCheckpoints,
231 checkpoint_iterations: u32,
232) -> bool {
233 let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
234
235 let keys = expand_key(key);
236 let xor_key = _mm_xor_si128(keys[10], keys[0]);
237 let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
238
239 let mut inv_keys = keys;
243 for i in 1..10 {
244 inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
245 }
246
247 let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys[i]));
248 let inv_keys_512 =
249 array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
250
251 let mut input_0 = [[0u8; 16]; 4];
252 input_0[0] = *seed;
253 input_0[1..].copy_from_slice(&checkpoints[..3]);
254 let mut input_0 = __m512i::from(u8x64::from_slice(input_0.as_flattened()));
255 let mut input_1 = __m512i::from(u8x64::from_slice(checkpoints[3..7].as_flattened()));
256
257 let mut output_0 = __m512i::from(u8x64::from_slice(checkpoints[0..4].as_flattened()));
258 let mut output_1 = __m512i::from(u8x64::from_slice(checkpoints[4..8].as_flattened()));
259
260 input_0 = _mm512_xor_si512(input_0, keys_512[0]);
261 input_1 = _mm512_xor_si512(input_1, keys_512[0]);
262
263 output_0 = _mm512_xor_si512(output_0, keys_512[10]);
264 output_1 = _mm512_xor_si512(output_1, keys_512[10]);
265
266 for _ in 0..checkpoint_iterations / 2 {
267 unsafe {
269 for i in 1..10 {
270 input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
271 input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
272
273 output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]);
274 output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]);
275 }
276
277 input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512);
278 input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512);
279
280 output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
281 output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
282 }
283 }
284
285 let diff_0 = _mm512_xor_si512(input_0, output_0);
295 let diff_1 = _mm512_xor_si512(input_1, output_1);
296
297 let mask_0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
298 let mask_1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
299
300 (mask_0 & mask_1) == u8::MAX
302}
303
304#[target_feature(enable = "aes")]
309fn expand_key(key: &[u8; 16]) -> [__m128i; NUM_ROUND_KEYS] {
310 #[target_feature(enable = "aes")]
311 fn expand_round<const RK: i32>(keys: &mut [__m128i; NUM_ROUND_KEYS], pos: usize) {
312 let mut t1 = keys[pos - 1];
313 let mut t2;
314 let mut t3;
315
316 t2 = _mm_aeskeygenassist_si128::<RK>(t1);
317 t2 = _mm_shuffle_epi32::<0xff>(t2);
318 t3 = _mm_slli_si128::<0x4>(t1);
319 t1 = _mm_xor_si128(t1, t3);
320 t3 = _mm_slli_si128::<0x4>(t3);
321 t1 = _mm_xor_si128(t1, t3);
322 t3 = _mm_slli_si128::<0x4>(t3);
323 t1 = _mm_xor_si128(t1, t3);
324 t1 = _mm_xor_si128(t1, t2);
325
326 keys[pos] = t1;
327 }
328
329 let mut keys = [_mm_setzero_si128(); NUM_ROUND_KEYS];
330 keys[0] = __m128i::from(u8x16::from(*key));
331
332 let kr = &mut keys;
333 expand_round::<0x01>(kr, 1);
334 expand_round::<0x02>(kr, 2);
335 expand_round::<0x04>(kr, 3);
336 expand_round::<0x08>(kr, 4);
337 expand_round::<0x10>(kr, 5);
338 expand_round::<0x20>(kr, 6);
339 expand_round::<0x40>(kr, 7);
340 expand_round::<0x80>(kr, 8);
341 expand_round::<0x1B>(kr, 9);
342 expand_round::<0x36>(kr, 10);
343
344 keys
345}