1/*
2 *------------------------------------------------------------------
3 * Copyright (c) 2019 Cisco and/or its affiliates.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at:
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *------------------------------------------------------------------
16 */
17
18#include <vlib/vlib.h>
19#include <vnet/plugin/plugin.h>
20#include <vnet/crypto/crypto.h>
21#include <crypto_native/crypto_native.h>
22#include <crypto_native/aes.h>
23#include <crypto_native/ghash.h>
24
25#if __GNUC__ > 4  && !__clang__ && CLIB_DEBUG == 0
26#pragma GCC optimize ("O3")
27#endif
28
29#ifdef __VAES__
30#define NUM_HI 32
31#else
32#define NUM_HI 8
33#endif
34
35typedef struct
36{
37  /* pre-calculated hash key values */
38  const u8x16 Hi[NUM_HI];
39  /* extracted AES key */
40  const u8x16 Ke[15];
41#ifdef __VAES__
42  const u8x64 Ke4[15];
43#endif
44} aes_gcm_key_data_t;
45
46typedef struct
47{
48  u32 counter;
49  union
50  {
51    u32x4 Y;
52    u32x16 Y4;
53  };
54} aes_gcm_counter_t;
55
56typedef enum
57{
58  AES_GCM_F_WITH_GHASH = (1 << 0),
59  AES_GCM_F_LAST_ROUND = (1 << 1),
60  AES_GCM_F_ENCRYPT = (1 << 2),
61  AES_GCM_F_DECRYPT = (1 << 3),
62} aes_gcm_flags_t;
63
64static const u32x4 ctr_inv_1 = { 0, 0, 0, 1 << 24 };
65
66static_always_inline void
67aes_gcm_enc_first_round (u8x16 * r, aes_gcm_counter_t * ctr, u8x16 k,
68			 int n_blocks)
69{
70  if (PREDICT_TRUE ((u8) ctr->counter < (256 - 2 * n_blocks)))
71    {
72      for (int i = 0; i < n_blocks; i++)
73	{
74	  r[i] = k ^ (u8x16) ctr->Y;
75	  ctr->Y += ctr_inv_1;
76	}
77      ctr->counter += n_blocks;
78    }
79  else
80    {
81      for (int i = 0; i < n_blocks; i++)
82	{
83	  r[i] = k ^ (u8x16) ctr->Y;
84	  ctr->counter++;
85	  ctr->Y[3] = clib_host_to_net_u32 (ctr->counter + 1);
86	}
87    }
88}
89
90static_always_inline void
91aes_gcm_enc_round (u8x16 * r, u8x16 k, int n_blocks)
92{
93  for (int i = 0; i < n_blocks; i++)
94    r[i] = aes_enc_round (r[i], k);
95}
96
97static_always_inline void
98aes_gcm_enc_last_round (u8x16 * r, u8x16 * d, u8x16 const *k,
99			int rounds, int n_blocks)
100{
101
102  /* additional ronuds for AES-192 and AES-256 */
103  for (int i = 10; i < rounds; i++)
104    aes_gcm_enc_round (r, k[i], n_blocks);
105
106  for (int i = 0; i < n_blocks; i++)
107    d[i] ^= aes_enc_last_round (r[i], k[rounds]);
108}
109
110static_always_inline u8x16
111aes_gcm_ghash_blocks (u8x16 T, aes_gcm_key_data_t * kd,
112		      u8x16u * in, int n_blocks)
113{
114  ghash_data_t _gd, *gd = &_gd;
115  u8x16 *Hi = (u8x16 *) kd->Hi + NUM_HI - n_blocks;
116  ghash_mul_first (gd, u8x16_reflect (in[0]) ^ T, Hi[0]);
117  for (int i = 1; i < n_blocks; i++)
118    ghash_mul_next (gd, u8x16_reflect ((in[i])), Hi[i]);
119  ghash_reduce (gd);
120  ghash_reduce2 (gd);
121  return ghash_final (gd);
122}
123
124static_always_inline u8x16
125aes_gcm_ghash (u8x16 T, aes_gcm_key_data_t * kd, u8x16u * in, u32 n_left)
126{
127
128  while (n_left >= 128)
129    {
130      T = aes_gcm_ghash_blocks (T, kd, in, 8);
131      n_left -= 128;
132      in += 8;
133    }
134
135  if (n_left >= 64)
136    {
137      T = aes_gcm_ghash_blocks (T, kd, in, 4);
138      n_left -= 64;
139      in += 4;
140    }
141
142  if (n_left >= 32)
143    {
144      T = aes_gcm_ghash_blocks (T, kd, in, 2);
145      n_left -= 32;
146      in += 2;
147    }
148
149  if (n_left >= 16)
150    {
151      T = aes_gcm_ghash_blocks (T, kd, in, 1);
152      n_left -= 16;
153      in += 1;
154    }
155
156  if (n_left)
157    {
158      u8x16 r = aes_load_partial (in, n_left);
159      T = ghash_mul (u8x16_reflect (r) ^ T, kd->Hi[NUM_HI - 1]);
160    }
161  return T;
162}
163
164static_always_inline u8x16
165aes_gcm_calc (u8x16 T, aes_gcm_key_data_t * kd, u8x16 * d,
166	      aes_gcm_counter_t * ctr, u8x16u * inv, u8x16u * outv,
167	      int rounds, int n, int last_block_bytes, aes_gcm_flags_t f)
168{
169  u8x16 r[n];
170  ghash_data_t _gd = { }, *gd = &_gd;
171  const u8x16 *rk = (u8x16 *) kd->Ke;
172  int ghash_blocks = (f & AES_GCM_F_ENCRYPT) ? 4 : n, gc = 1;
173  u8x16 *Hi = (u8x16 *) kd->Hi + NUM_HI - ghash_blocks;
174
175  clib_prefetch_load (inv + 4);
176
177  /* AES rounds 0 and 1 */
178  aes_gcm_enc_first_round (r, ctr, rk[0], n);
179  aes_gcm_enc_round (r, rk[1], n);
180
181  /* load data - decrypt round */
182  if (f & AES_GCM_F_DECRYPT)
183    {
184      for (int i = 0; i < n - ((f & AES_GCM_F_LAST_ROUND) != 0); i++)
185	d[i] = inv[i];
186
187      if (f & AES_GCM_F_LAST_ROUND)
188	d[n - 1] = aes_load_partial (inv + n - 1, last_block_bytes);
189    }
190
191  /* GHASH multiply block 1 */
192  if (f & AES_GCM_F_WITH_GHASH)
193    ghash_mul_first (gd, u8x16_reflect (d[0]) ^ T, Hi[0]);
194
195  /* AES rounds 2 and 3 */
196  aes_gcm_enc_round (r, rk[2], n);
197  aes_gcm_enc_round (r, rk[3], n);
198
199  /* GHASH multiply block 2 */
200  if ((f & AES_GCM_F_WITH_GHASH) && gc++ < ghash_blocks)
201    ghash_mul_next (gd, u8x16_reflect (d[1]), Hi[1]);
202
203  /* AES rounds 4 and 5 */
204  aes_gcm_enc_round (r, rk[4], n);
205  aes_gcm_enc_round (r, rk[5], n);
206
207  /* GHASH multiply block 3 */
208  if ((f & AES_GCM_F_WITH_GHASH) && gc++ < ghash_blocks)
209    ghash_mul_next (gd, u8x16_reflect (d[2]), Hi[2]);
210
211  /* AES rounds 6 and 7 */
212  aes_gcm_enc_round (r, rk[6], n);
213  aes_gcm_enc_round (r, rk[7], n);
214
215  /* GHASH multiply block 4 */
216  if ((f & AES_GCM_F_WITH_GHASH) && gc++ < ghash_blocks)
217    ghash_mul_next (gd, u8x16_reflect (d[3]), Hi[3]);
218
219  /* AES rounds 8 and 9 */
220  aes_gcm_enc_round (r, rk[8], n);
221  aes_gcm_enc_round (r, rk[9], n);
222
223  /* GHASH reduce 1st step */
224  if (f & AES_GCM_F_WITH_GHASH)
225    ghash_reduce (gd);
226
227  /* load data - encrypt round */
228  if (f & AES_GCM_F_ENCRYPT)
229    {
230      for (int i = 0; i < n - ((f & AES_GCM_F_LAST_ROUND) != 0); i++)
231	d[i] = inv[i];
232
233      if (f & AES_GCM_F_LAST_ROUND)
234	d[n - 1] = aes_load_partial (inv + n - 1, last_block_bytes);
235    }
236
237  /* GHASH reduce 2nd step */
238  if (f & AES_GCM_F_WITH_GHASH)
239    ghash_reduce2 (gd);
240
241  /* AES last round(s) */
242  aes_gcm_enc_last_round (r, d, rk, rounds, n);
243
244  /* store data */
245  for (int i = 0; i < n - ((f & AES_GCM_F_LAST_ROUND) != 0); i++)
246    outv[i] = d[i];
247
248  if (f & AES_GCM_F_LAST_ROUND)
249    aes_store_partial (outv + n - 1, d[n - 1], last_block_bytes);
250
251  /* GHASH final step */
252  if (f & AES_GCM_F_WITH_GHASH)
253    T = ghash_final (gd);
254
255  return T;
256}
257
258static_always_inline u8x16
259aes_gcm_calc_double (u8x16 T, aes_gcm_key_data_t * kd, u8x16 * d,
260		     aes_gcm_counter_t * ctr, u8x16u * inv, u8x16u * outv,
261		     int rounds, aes_gcm_flags_t f)
262{
263  u8x16 r[4];
264  ghash_data_t _gd, *gd = &_gd;
265  const u8x16 *rk = (u8x16 *) kd->Ke;
266  u8x16 *Hi = (u8x16 *) kd->Hi + NUM_HI - 8;
267
268  /* AES rounds 0 and 1 */
269  aes_gcm_enc_first_round (r, ctr, rk[0], 4);
270  aes_gcm_enc_round (r, rk[1], 4);
271
272  /* load 4 blocks of data - decrypt round */
273  if (f & AES_GCM_F_DECRYPT)
274    {
275      d[0] = inv[0];
276      d[1] = inv[1];
277      d[2] = inv[2];
278      d[3] = inv[3];
279    }
280
281  /* GHASH multiply block 0 */
282  ghash_mul_first (gd, u8x16_reflect (d[0]) ^ T, Hi[0]);
283
284  /* AES rounds 2 and 3 */
285  aes_gcm_enc_round (r, rk[2], 4);
286  aes_gcm_enc_round (r, rk[3], 4);
287
288  /* GHASH multiply block 1 */
289  ghash_mul_next (gd, u8x16_reflect (d[1]), Hi[1]);
290
291  /* AES rounds 4 and 5 */
292  aes_gcm_enc_round (r, rk[4], 4);
293  aes_gcm_enc_round (r, rk[5], 4);
294
295  /* GHASH multiply block 2 */
296  ghash_mul_next (gd, u8x16_reflect (d[2]), Hi[2]);
297
298  /* AES rounds 6 and 7 */
299  aes_gcm_enc_round (r, rk[6], 4);
300  aes_gcm_enc_round (r, rk[7], 4);
301
302  /* GHASH multiply block 3 */
303  ghash_mul_next (gd, u8x16_reflect (d[3]), Hi[3]);
304
305  /* AES rounds 8 and 9 */
306  aes_gcm_enc_round (r, rk[8], 4);
307  aes_gcm_enc_round (r, rk[9], 4);
308
309  /* load 4 blocks of data - encrypt round */
310  if (f & AES_GCM_F_ENCRYPT)
311    {
312      d[0] = inv[0];
313      d[1] = inv[1];
314      d[2] = inv[2];
315      d[3] = inv[3];
316    }
317
318  /* AES last round(s) */
319  aes_gcm_enc_last_round (r, d, rk, rounds, 4);
320
321  /* store 4 blocks of data */
322  outv[0] = d[0];
323  outv[1] = d[1];
324  outv[2] = d[2];
325  outv[3] = d[3];
326
327  /* load next 4 blocks of data data - decrypt round */
328  if (f & AES_GCM_F_DECRYPT)
329    {
330      d[0] = inv[4];
331      d[1] = inv[5];
332      d[2] = inv[6];
333      d[3] = inv[7];
334    }
335
336  /* GHASH multiply block 4 */
337  ghash_mul_next (gd, u8x16_reflect (d[0]), Hi[4]);
338
339  /* AES rounds 0, 1 and 2 */
340  aes_gcm_enc_first_round (r, ctr, rk[0], 4);
341  aes_gcm_enc_round (r, rk[1], 4);
342  aes_gcm_enc_round (r, rk[2], 4);
343
344  /* GHASH multiply block 5 */
345  ghash_mul_next (gd, u8x16_reflect (d[1]), Hi[5]);
346
347  /* AES rounds 3 and 4 */
348  aes_gcm_enc_round (r, rk[3], 4);
349  aes_gcm_enc_round (r, rk[4], 4);
350
351  /* GHASH multiply block 6 */
352  ghash_mul_next (gd, u8x16_reflect (d[2]), Hi[6]);
353
354  /* AES rounds 5 and 6 */
355  aes_gcm_enc_round (r, rk[5], 4);
356  aes_gcm_enc_round (r, rk[6], 4);
357
358  /* GHASH multiply block 7 */
359  ghash_mul_next (gd, u8x16_reflect (d[3]), Hi[7]);
360
361  /* AES rounds 7 and 8 */
362  aes_gcm_enc_round (r, rk[7], 4);
363  aes_gcm_enc_round (r, rk[8], 4);
364
365  /* GHASH reduce 1st step */
366  ghash_reduce (gd);
367
368  /* AES round 9 */
369  aes_gcm_enc_round (r, rk[9], 4);
370
371  /* load data - encrypt round */
372  if (f & AES_GCM_F_ENCRYPT)
373    {
374      d[0] = inv[4];
375      d[1] = inv[5];
376      d[2] = inv[6];
377      d[3] = inv[7];
378    }
379
380  /* GHASH reduce 2nd step */
381  ghash_reduce2 (gd);
382
383  /* AES last round(s) */
384  aes_gcm_enc_last_round (r, d, rk, rounds, 4);
385
386  /* store data */
387  outv[4] = d[0];
388  outv[5] = d[1];
389  outv[6] = d[2];
390  outv[7] = d[3];
391
392  /* GHASH final step */
393  return ghash_final (gd);
394}
395
396static_always_inline u8x16
397aes_gcm_ghash_last (u8x16 T, aes_gcm_key_data_t * kd, u8x16 * d,
398		    int n_blocks, int n_bytes)
399{
400  ghash_data_t _gd, *gd = &_gd;
401  u8x16 *Hi = (u8x16 *) kd->Hi + NUM_HI - n_blocks;
402
403  if (n_bytes)
404    d[n_blocks - 1] = aes_byte_mask (d[n_blocks - 1], n_bytes);
405
406  ghash_mul_first (gd, u8x16_reflect (d[0]) ^ T, Hi[0]);
407  if (n_blocks > 1)
408    ghash_mul_next (gd, u8x16_reflect (d[1]), Hi[1]);
409  if (n_blocks > 2)
410    ghash_mul_next (gd, u8x16_reflect (d[2]), Hi[2]);
411  if (n_blocks > 3)
412    ghash_mul_next (gd, u8x16_reflect (d[3]), Hi[3]);
413  ghash_reduce (gd);
414  ghash_reduce2 (gd);
415  return ghash_final (gd);
416}
417
418#ifdef __VAES__
419static const u32x16 ctr_inv_1234 = {
420  0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
421};
422
423static const u32x16 ctr_inv_4444 = {
424  0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24
425};
426
427static const u32x16 ctr_1234 = {
428  1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0,
429};
430
431static_always_inline void
432aes4_gcm_enc_first_round (u8x64 * r, aes_gcm_counter_t * ctr, u8x64 k, int n)
433{
434  u8 last_byte = (u8) ctr->counter;
435  int i = 0;
436
437  /* As counter is stored in network byte order for performance reasons we
438     are incrementing least significant byte only except in case where we
439     overlow. As we are processing four 512-blocks in parallel except the
440     last round, overflow can happen only when n == 4 */
441
442  if (n == 4)
443    for (; i < 2; i++)
444      {
445	r[i] = k ^ (u8x64) ctr->Y4;
446	ctr->Y4 += ctr_inv_4444;
447      }
448
449  if (n == 4 && PREDICT_TRUE (last_byte == 241))
450    {
451      u32x16 Yc, Yr = (u32x16) u8x64_reflect_u8x16 ((u8x64) ctr->Y4);
452
453      for (; i < n; i++)
454	{
455	  r[i] = k ^ (u8x64) ctr->Y4;
456	  Yc = u32x16_splat (ctr->counter + 4 * (i + 1)) + ctr_1234;
457	  Yr = (u32x16) u32x16_mask_blend (Yr, Yc, 0x1111);
458	  ctr->Y4 = (u32x16) u8x64_reflect_u8x16 ((u8x64) Yr);
459	}
460    }
461  else
462    {
463      for (; i < n; i++)
464	{
465	  r[i] = k ^ (u8x64) ctr->Y4;
466	  ctr->Y4 += ctr_inv_4444;
467	}
468    }
469  ctr->counter += n * 4;
470}
471
472static_always_inline void
473aes4_gcm_enc_round (u8x64 * r, u8x64 k, int n_blocks)
474{
475  for (int i = 0; i < n_blocks; i++)
476    r[i] = aes_enc_round_x4 (r[i], k);
477}
478
479static_always_inline void
480aes4_gcm_enc_last_round (u8x64 * r, u8x64 * d, u8x64 const *k,
481			 int rounds, int n_blocks)
482{
483
484  /* additional ronuds for AES-192 and AES-256 */
485  for (int i = 10; i < rounds; i++)
486    aes4_gcm_enc_round (r, k[i], n_blocks);
487
488  for (int i = 0; i < n_blocks; i++)
489    d[i] ^= aes_enc_last_round_x4 (r[i], k[rounds]);
490}
491
492static_always_inline u8x16
493aes4_gcm_calc (u8x16 T, aes_gcm_key_data_t * kd, u8x64 * d,
494	       aes_gcm_counter_t * ctr, u8x16u * in, u8x16u * out,
495	       int rounds, int n, int last_4block_bytes, aes_gcm_flags_t f)
496{
497  ghash4_data_t _gd, *gd = &_gd;
498  const u8x64 *rk = (u8x64 *) kd->Ke4;
499  int i, ghash_blocks, gc = 1;
500  u8x64u *Hi4, *inv = (u8x64u *) in, *outv = (u8x64u *) out;
501  u8x64 r[4];
502  u64 byte_mask = _bextr_u64 (-1LL, 0, last_4block_bytes);
503
504  if (f & AES_GCM_F_ENCRYPT)
505    {
506      /* during encryption we either hash four 512-bit blocks from previous
507         round or we don't hash at all */
508      ghash_blocks = 4;
509      Hi4 = (u8x64u *) (kd->Hi + NUM_HI - ghash_blocks * 4);
510    }
511  else
512    {
513      /* during deccryption we hash 1..4 512-bit blocks from current round */
514      ghash_blocks = n;
515      int n_128bit_blocks = n * 4;
516      /* if this is last round of decryption, we may have less than 4
517         128-bit blocks in the last 512-bit data block, so we need to adjust
518         Hi4 pointer accordingly */
519      if (f & AES_GCM_F_LAST_ROUND)
520	n_128bit_blocks += ((last_4block_bytes + 15) >> 4) - 4;
521      Hi4 = (u8x64u *) (kd->Hi + NUM_HI - n_128bit_blocks);
522    }
523
524  /* AES rounds 0 and 1 */
525  aes4_gcm_enc_first_round (r, ctr, rk[0], n);
526  aes4_gcm_enc_round (r, rk[1], n);
527
528  /* load 4 blocks of data - decrypt round */
529  if (f & AES_GCM_F_DECRYPT)
530    {
531      for (i = 0; i < n - ((f & AES_GCM_F_LAST_ROUND) != 0); i++)
532	d[i] = inv[i];
533
534      if (f & AES_GCM_F_LAST_ROUND)
535	d[i] = u8x64_mask_load (u8x64_splat (0), inv + i, byte_mask);
536    }
537
538  /* GHASH multiply block 0 */
539  if (f & AES_GCM_F_WITH_GHASH)
540    ghash4_mul_first (gd, u8x64_reflect_u8x16 (d[0]) ^
541		      u8x64_insert_u8x16 (u8x64_splat (0), T, 0), Hi4[0]);
542
543  /* AES rounds 2 and 3 */
544  aes4_gcm_enc_round (r, rk[2], n);
545  aes4_gcm_enc_round (r, rk[3], n);
546
547  /* GHASH multiply block 1 */
548  if ((f & AES_GCM_F_WITH_GHASH) && gc++ < ghash_blocks)
549    ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[1]), Hi4[1]);
550
551  /* AES rounds 4 and 5 */
552  aes4_gcm_enc_round (r, rk[4], n);
553  aes4_gcm_enc_round (r, rk[5], n);
554
555  /* GHASH multiply block 2 */
556  if ((f & AES_GCM_F_WITH_GHASH) && gc++ < ghash_blocks)
557    ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[2]), Hi4[2]);
558
559  /* AES rounds 6 and 7 */
560  aes4_gcm_enc_round (r, rk[6], n);
561  aes4_gcm_enc_round (r, rk[7], n);
562
563  /* GHASH multiply block 3 */
564  if ((f & AES_GCM_F_WITH_GHASH) && gc++ < ghash_blocks)
565    ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[3]), Hi4[3]);
566
567  /* load 4 blocks of data - decrypt round */
568  if (f & AES_GCM_F_ENCRYPT)
569    {
570      for (i = 0; i < n - ((f & AES_GCM_F_LAST_ROUND) != 0); i++)
571	d[i] = inv[i];
572
573      if (f & AES_GCM_F_LAST_ROUND)
574	d[i] = u8x64_mask_load (u8x64_splat (0), inv + i, byte_mask);
575    }
576
577  /* AES rounds 8 and 9 */
578  aes4_gcm_enc_round (r, rk[8], n);
579  aes4_gcm_enc_round (r, rk[9], n);
580
581  /* AES last round(s) */
582  aes4_gcm_enc_last_round (r, d, rk, rounds, n);
583
584  /* store 4 blocks of data */
585  for (i = 0; i < n - ((f & AES_GCM_F_LAST_ROUND) != 0); i++)
586    outv[i] = d[i];
587
588  if (f & AES_GCM_F_LAST_ROUND)
589    u8x64_mask_store (d[i], outv + i, byte_mask);
590
591  /* GHASH reduce 1st step */
592  ghash4_reduce (gd);
593
594  /* GHASH reduce 2nd step */
595  ghash4_reduce2 (gd);
596
597  /* GHASH final step */
598  return ghash4_final (gd);
599}
600
601static_always_inline u8x16
602aes4_gcm_calc_double (u8x16 T, aes_gcm_key_data_t * kd, u8x64 * d,
603		      aes_gcm_counter_t * ctr, u8x16u * in, u8x16u * out,
604		      int rounds, aes_gcm_flags_t f)
605{
606  u8x64 r[4];
607  ghash4_data_t _gd, *gd = &_gd;
608  const u8x64 *rk = (u8x64 *) kd->Ke4;
609  u8x64 *Hi4 = (u8x64 *) (kd->Hi + NUM_HI - 32);
610  u8x64u *inv = (u8x64u *) in, *outv = (u8x64u *) out;
611
612  /* AES rounds 0 and 1 */
613  aes4_gcm_enc_first_round (r, ctr, rk[0], 4);
614  aes4_gcm_enc_round (r, rk[1], 4);
615
616  /* load 4 blocks of data - decrypt round */
617  if (f & AES_GCM_F_DECRYPT)
618    for (int i = 0; i < 4; i++)
619      d[i] = inv[i];
620
621  /* GHASH multiply block 0 */
622  ghash4_mul_first (gd, u8x64_reflect_u8x16 (d[0]) ^
623		    u8x64_insert_u8x16 (u8x64_splat (0), T, 0), Hi4[0]);
624
625  /* AES rounds 2 and 3 */
626  aes4_gcm_enc_round (r, rk[2], 4);
627  aes4_gcm_enc_round (r, rk[3], 4);
628
629  /* GHASH multiply block 1 */
630  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[1]), Hi4[1]);
631
632  /* AES rounds 4 and 5 */
633  aes4_gcm_enc_round (r, rk[4], 4);
634  aes4_gcm_enc_round (r, rk[5], 4);
635
636  /* GHASH multiply block 2 */
637  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[2]), Hi4[2]);
638
639  /* AES rounds 6 and 7 */
640  aes4_gcm_enc_round (r, rk[6], 4);
641  aes4_gcm_enc_round (r, rk[7], 4);
642
643  /* GHASH multiply block 3 */
644  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[3]), Hi4[3]);
645
646  /* AES rounds 8 and 9 */
647  aes4_gcm_enc_round (r, rk[8], 4);
648  aes4_gcm_enc_round (r, rk[9], 4);
649
650  /* load 4 blocks of data - encrypt round */
651  if (f & AES_GCM_F_ENCRYPT)
652    for (int i = 0; i < 4; i++)
653      d[i] = inv[i];
654
655  /* AES last round(s) */
656  aes4_gcm_enc_last_round (r, d, rk, rounds, 4);
657
658  /* store 4 blocks of data */
659  for (int i = 0; i < 4; i++)
660    outv[i] = d[i];
661
662  /* load 4 blocks of data - decrypt round */
663  if (f & AES_GCM_F_DECRYPT)
664    for (int i = 0; i < 4; i++)
665      d[i] = inv[i + 4];
666
667  /* GHASH multiply block 3 */
668  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[0]), Hi4[4]);
669
670  /* AES rounds 0 and 1 */
671  aes4_gcm_enc_first_round (r, ctr, rk[0], 4);
672  aes4_gcm_enc_round (r, rk[1], 4);
673
674  /* GHASH multiply block 5 */
675  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[1]), Hi4[5]);
676
677  /* AES rounds 2 and 3 */
678  aes4_gcm_enc_round (r, rk[2], 4);
679  aes4_gcm_enc_round (r, rk[3], 4);
680
681  /* GHASH multiply block 6 */
682  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[2]), Hi4[6]);
683
684  /* AES rounds 4 and 5 */
685  aes4_gcm_enc_round (r, rk[4], 4);
686  aes4_gcm_enc_round (r, rk[5], 4);
687
688  /* GHASH multiply block 7 */
689  ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[3]), Hi4[7]);
690
691  /* AES rounds 6 and 7 */
692  aes4_gcm_enc_round (r, rk[6], 4);
693  aes4_gcm_enc_round (r, rk[7], 4);
694
695  /* GHASH reduce 1st step */
696  ghash4_reduce (gd);
697
698  /* AES rounds 8 and 9 */
699  aes4_gcm_enc_round (r, rk[8], 4);
700  aes4_gcm_enc_round (r, rk[9], 4);
701
702  /* GHASH reduce 2nd step */
703  ghash4_reduce2 (gd);
704
705  /* load 4 blocks of data - encrypt round */
706  if (f & AES_GCM_F_ENCRYPT)
707    for (int i = 0; i < 4; i++)
708      d[i] = inv[i + 4];
709
710  /* AES last round(s) */
711  aes4_gcm_enc_last_round (r, d, rk, rounds, 4);
712
713  /* store 4 blocks of data */
714  for (int i = 0; i < 4; i++)
715    outv[i + 4] = d[i];
716
717  /* GHASH final step */
718  return ghash4_final (gd);
719}
720
721static_always_inline u8x16
722aes4_gcm_ghash_last (u8x16 T, aes_gcm_key_data_t * kd, u8x64 * d,
723		     int n, int last_4block_bytes)
724{
725  ghash4_data_t _gd, *gd = &_gd;
726  u8x64u *Hi4;
727  int n_128bit_blocks;
728  u64 byte_mask = _bextr_u64 (-1LL, 0, last_4block_bytes);
729  n_128bit_blocks = (n - 1) * 4 + ((last_4block_bytes + 15) >> 4);
730  Hi4 = (u8x64u *) (kd->Hi + NUM_HI - n_128bit_blocks);
731
732  d[n - 1] = u8x64_mask_blend (u8x64_splat (0), d[n - 1], byte_mask);
733  ghash4_mul_first (gd, u8x64_reflect_u8x16 (d[0]) ^
734		    u8x64_insert_u8x16 (u8x64_splat (0), T, 0), Hi4[0]);
735  if (n > 1)
736    ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[1]), Hi4[1]);
737  if (n > 2)
738    ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[2]), Hi4[2]);
739  if (n > 3)
740    ghash4_mul_next (gd, u8x64_reflect_u8x16 (d[3]), Hi4[3]);
741  ghash4_reduce (gd);
742  ghash4_reduce2 (gd);
743  return ghash4_final (gd);
744}
745#endif
746
747static_always_inline u8x16
748aes_gcm_enc (u8x16 T, aes_gcm_key_data_t * kd, aes_gcm_counter_t * ctr,
749	     u8x16u * inv, u8x16u * outv, u32 n_left, int rounds)
750{
751  u8x16 d[4];
752  aes_gcm_flags_t f = AES_GCM_F_ENCRYPT;
753
754  if (n_left == 0)
755    return T;
756
757#if __VAES__
758  u8x64 d4[4];
759  if (n_left < 256)
760    {
761      f |= AES_GCM_F_LAST_ROUND;
762      if (n_left > 192)
763	{
764	  n_left -= 192;
765	  aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 4, n_left, f);
766	  return aes4_gcm_ghash_last (T, kd, d4, 4, n_left);
767	}
768      else if (n_left > 128)
769	{
770	  n_left -= 128;
771	  aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 3, n_left, f);
772	  return aes4_gcm_ghash_last (T, kd, d4, 3, n_left);
773	}
774      else if (n_left > 64)
775	{
776	  n_left -= 64;
777	  aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 2, n_left, f);
778	  return aes4_gcm_ghash_last (T, kd, d4, 2, n_left);
779	}
780      else
781	{
782	  aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 1, n_left, f);
783	  return aes4_gcm_ghash_last (T, kd, d4, 1, n_left);
784	}
785    }
786
787  aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 4, 0, f);
788
789  /* next */
790  n_left -= 256;
791  outv += 16;
792  inv += 16;
793
794  f |= AES_GCM_F_WITH_GHASH;
795
796  while (n_left >= 512)
797    {
798      T = aes4_gcm_calc_double (T, kd, d4, ctr, inv, outv, rounds, f);
799
800      /* next */
801      n_left -= 512;
802      outv += 32;
803      inv += 32;
804    }
805
806  while (n_left >= 256)
807    {
808      T = aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 4, 0, f);
809
810      /* next */
811      n_left -= 256;
812      outv += 16;
813      inv += 16;
814    }
815
816  if (n_left == 0)
817    return aes4_gcm_ghash_last (T, kd, d4, 4, 64);
818
819  f |= AES_GCM_F_LAST_ROUND;
820
821  if (n_left > 192)
822    {
823      n_left -= 192;
824      T = aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 4, n_left, f);
825      return aes4_gcm_ghash_last (T, kd, d4, 4, n_left);
826    }
827
828  if (n_left > 128)
829    {
830      n_left -= 128;
831      T = aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 3, n_left, f);
832      return aes4_gcm_ghash_last (T, kd, d4, 3, n_left);
833    }
834
835  if (n_left > 64)
836    {
837      n_left -= 64;
838      T = aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 2, n_left, f);
839      return aes4_gcm_ghash_last (T, kd, d4, 2, n_left);
840    }
841
842  T = aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 1, n_left, f);
843  return aes4_gcm_ghash_last (T, kd, d4, 1, n_left);
844#endif
845
846  if (n_left < 64)
847    {
848      f |= AES_GCM_F_LAST_ROUND;
849      if (n_left > 48)
850	{
851	  n_left -= 48;
852	  aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 4, n_left, f);
853	  return aes_gcm_ghash_last (T, kd, d, 4, n_left);
854	}
855      else if (n_left > 32)
856	{
857	  n_left -= 32;
858	  aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 3, n_left, f);
859	  return aes_gcm_ghash_last (T, kd, d, 3, n_left);
860	}
861      else if (n_left > 16)
862	{
863	  n_left -= 16;
864	  aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 2, n_left, f);
865	  return aes_gcm_ghash_last (T, kd, d, 2, n_left);
866	}
867      else
868	{
869	  aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 1, n_left, f);
870	  return aes_gcm_ghash_last (T, kd, d, 1, n_left);
871	}
872    }
873
874  aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 4, 0, f);
875
876  /* next */
877  n_left -= 64;
878  outv += 4;
879  inv += 4;
880
881  f |= AES_GCM_F_WITH_GHASH;
882
883  while (n_left >= 128)
884    {
885      T = aes_gcm_calc_double (T, kd, d, ctr, inv, outv, rounds, f);
886
887      /* next */
888      n_left -= 128;
889      outv += 8;
890      inv += 8;
891    }
892
893  if (n_left >= 64)
894    {
895      T = aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 4, 0, f);
896
897      /* next */
898      n_left -= 64;
899      outv += 4;
900      inv += 4;
901    }
902
903  if (n_left == 0)
904    return aes_gcm_ghash_last (T, kd, d, 4, 0);
905
906  f |= AES_GCM_F_LAST_ROUND;
907
908  if (n_left > 48)
909    {
910      n_left -= 48;
911      T = aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 4, n_left, f);
912      return aes_gcm_ghash_last (T, kd, d, 4, n_left);
913    }
914
915  if (n_left > 32)
916    {
917      n_left -= 32;
918      T = aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 3, n_left, f);
919      return aes_gcm_ghash_last (T, kd, d, 3, n_left);
920    }
921
922  if (n_left > 16)
923    {
924      n_left -= 16;
925      T = aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 2, n_left, f);
926      return aes_gcm_ghash_last (T, kd, d, 2, n_left);
927    }
928
929  T = aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 1, n_left, f);
930  return aes_gcm_ghash_last (T, kd, d, 1, n_left);
931}
932
933static_always_inline u8x16
934aes_gcm_dec (u8x16 T, aes_gcm_key_data_t * kd, aes_gcm_counter_t * ctr,
935	     u8x16u * inv, u8x16u * outv, u32 n_left, int rounds)
936{
937  aes_gcm_flags_t f = AES_GCM_F_WITH_GHASH | AES_GCM_F_DECRYPT;
938#ifdef __VAES__
939  u8x64 d4[4] = { };
940
941  while (n_left >= 512)
942    {
943      T = aes4_gcm_calc_double (T, kd, d4, ctr, inv, outv, rounds, f);
944
945      /* next */
946      n_left -= 512;
947      outv += 32;
948      inv += 32;
949    }
950
951  while (n_left >= 256)
952    {
953      T = aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 4, 0, f);
954
955      /* next */
956      n_left -= 256;
957      outv += 16;
958      inv += 16;
959    }
960
961  if (n_left == 0)
962    return T;
963
964  f |= AES_GCM_F_LAST_ROUND;
965
966  if (n_left > 192)
967    return aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 4,
968			  n_left - 192, f);
969  if (n_left > 128)
970    return aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 3,
971			  n_left - 128, f);
972  if (n_left > 64)
973    return aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 2,
974			  n_left - 64, f);
975  return aes4_gcm_calc (T, kd, d4, ctr, inv, outv, rounds, 1, n_left, f);
976#else
977  u8x16 d[4];
978  while (n_left >= 128)
979    {
980      T = aes_gcm_calc_double (T, kd, d, ctr, inv, outv, rounds, f);
981
982      /* next */
983      n_left -= 128;
984      outv += 8;
985      inv += 8;
986    }
987
988  if (n_left >= 64)
989    {
990      T = aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 4, 0, f);
991
992      /* next */
993      n_left -= 64;
994      outv += 4;
995      inv += 4;
996    }
997
998  if (n_left == 0)
999    return T;
1000
1001  f |= AES_GCM_F_LAST_ROUND;
1002
1003  if (n_left > 48)
1004    return aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 4, n_left - 48, f);
1005
1006  if (n_left > 32)
1007    return aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 3, n_left - 32, f);
1008
1009  if (n_left > 16)
1010    return aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 2, n_left - 16, f);
1011
1012  return aes_gcm_calc (T, kd, d, ctr, inv, outv, rounds, 1, n_left, f);
1013#endif
1014}
1015
1016static_always_inline int
1017aes_gcm (u8x16u * in, u8x16u * out, u8x16u * addt, u8x16u * iv, u8x16u * tag,
1018	 u32 data_bytes, u32 aad_bytes, u8 tag_len, aes_gcm_key_data_t * kd,
1019	 int aes_rounds, int is_encrypt)
1020{
1021  int i;
1022  u8x16 r, T = { };
1023  u32x4 Y0;
1024  ghash_data_t _gd, *gd = &_gd;
1025  aes_gcm_counter_t _ctr, *ctr = &_ctr;
1026
1027  clib_prefetch_load (iv);
1028  clib_prefetch_load (in);
1029  clib_prefetch_load (in + 4);
1030
1031  /* calculate ghash for AAD - optimized for ipsec common cases */
1032  if (aad_bytes == 8)
1033    T = aes_gcm_ghash (T, kd, addt, 8);
1034  else if (aad_bytes == 12)
1035    T = aes_gcm_ghash (T, kd, addt, 12);
1036  else
1037    T = aes_gcm_ghash (T, kd, addt, aad_bytes);
1038
1039  /* initalize counter */
1040  ctr->counter = 1;
1041  Y0 = (u32x4) aes_load_partial (iv, 12) + ctr_inv_1;
1042#ifdef __VAES__
1043  ctr->Y4 = u32x16_splat_u32x4 (Y0) + ctr_inv_1234;
1044#else
1045  ctr->Y = Y0 + ctr_inv_1;
1046#endif
1047
1048  /* ghash and encrypt/edcrypt  */
1049  if (is_encrypt)
1050    T = aes_gcm_enc (T, kd, ctr, in, out, data_bytes, aes_rounds);
1051  else
1052    T = aes_gcm_dec (T, kd, ctr, in, out, data_bytes, aes_rounds);
1053
1054  clib_prefetch_load (tag);
1055
1056  /* Finalize ghash  - data bytes and aad bytes converted to bits */
1057  /* *INDENT-OFF* */
1058  r = (u8x16) ((u64x2) {data_bytes, aad_bytes} << 3);
1059  /* *INDENT-ON* */
1060
1061  /* interleaved computation of final ghash and E(Y0, k) */
1062  ghash_mul_first (gd, r ^ T, kd->Hi[NUM_HI - 1]);
1063  r = kd->Ke[0] ^ (u8x16) Y0;
1064  for (i = 1; i < 5; i += 1)
1065    r = aes_enc_round (r, kd->Ke[i]);
1066  ghash_reduce (gd);
1067  ghash_reduce2 (gd);
1068  for (; i < 9; i += 1)
1069    r = aes_enc_round (r, kd->Ke[i]);
1070  T = ghash_final (gd);
1071  for (; i < aes_rounds; i += 1)
1072    r = aes_enc_round (r, kd->Ke[i]);
1073  r = aes_enc_last_round (r, kd->Ke[aes_rounds]);
1074  T = u8x16_reflect (T) ^ r;
1075
1076  /* tag_len 16 -> 0 */
1077  tag_len &= 0xf;
1078
1079  if (is_encrypt)
1080    {
1081      /* store tag */
1082      if (tag_len)
1083	aes_store_partial (tag, T, tag_len);
1084      else
1085	tag[0] = T;
1086    }
1087  else
1088    {
1089      /* check tag */
1090      u16 tag_mask = tag_len ? (1 << tag_len) - 1 : 0xffff;
1091      if ((u8x16_msb_mask (tag[0] == T) & tag_mask) != tag_mask)
1092	return 0;
1093    }
1094  return 1;
1095}
1096
1097static_always_inline u32
1098aes_ops_enc_aes_gcm (vlib_main_t * vm, vnet_crypto_op_t * ops[],
1099		     u32 n_ops, aes_key_size_t ks)
1100{
1101  crypto_native_main_t *cm = &crypto_native_main;
1102  vnet_crypto_op_t *op = ops[0];
1103  aes_gcm_key_data_t *kd;
1104  u32 n_left = n_ops;
1105
1106
1107next:
1108  kd = (aes_gcm_key_data_t *) cm->key_data[op->key_index];
1109  aes_gcm ((u8x16u *) op->src, (u8x16u *) op->dst, (u8x16u *) op->aad,
1110	   (u8x16u *) op->iv, (u8x16u *) op->tag, op->len, op->aad_len,
1111	   op->tag_len, kd, AES_KEY_ROUNDS (ks), /* is_encrypt */ 1);
1112  op->status = VNET_CRYPTO_OP_STATUS_COMPLETED;
1113
1114  if (--n_left)
1115    {
1116      op += 1;
1117      goto next;
1118    }
1119
1120  return n_ops;
1121}
1122
1123static_always_inline u32
1124aes_ops_dec_aes_gcm (vlib_main_t * vm, vnet_crypto_op_t * ops[], u32 n_ops,
1125		     aes_key_size_t ks)
1126{
1127  crypto_native_main_t *cm = &crypto_native_main;
1128  vnet_crypto_op_t *op = ops[0];
1129  aes_gcm_key_data_t *kd;
1130  u32 n_left = n_ops;
1131  int rv;
1132
1133next:
1134  kd = (aes_gcm_key_data_t *) cm->key_data[op->key_index];
1135  rv = aes_gcm ((u8x16u *) op->src, (u8x16u *) op->dst, (u8x16u *) op->aad,
1136		(u8x16u *) op->iv, (u8x16u *) op->tag, op->len,
1137		op->aad_len, op->tag_len, kd, AES_KEY_ROUNDS (ks),
1138		/* is_encrypt */ 0);
1139
1140  if (rv)
1141    {
1142      op->status = VNET_CRYPTO_OP_STATUS_COMPLETED;
1143    }
1144  else
1145    {
1146      op->status = VNET_CRYPTO_OP_STATUS_FAIL_BAD_HMAC;
1147      n_ops--;
1148    }
1149
1150  if (--n_left)
1151    {
1152      op += 1;
1153      goto next;
1154    }
1155
1156  return n_ops;
1157}
1158
1159static_always_inline void *
1160aes_gcm_key_exp (vnet_crypto_key_t * key, aes_key_size_t ks)
1161{
1162  aes_gcm_key_data_t *kd;
1163  u8x16 H;
1164
1165  kd = clib_mem_alloc_aligned (sizeof (*kd), CLIB_CACHE_LINE_BYTES);
1166
1167  /* expand AES key */
1168  aes_key_expand ((u8x16 *) kd->Ke, key->data, ks);
1169
1170  /* pre-calculate H */
1171  H = aes_encrypt_block (u8x16_splat (0), kd->Ke, ks);
1172  H = u8x16_reflect (H);
1173  ghash_precompute (H, (u8x16 *) kd->Hi, NUM_HI);
1174#ifdef __VAES__
1175  u8x64 *Ke4 = (u8x64 *) kd->Ke4;
1176  for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
1177    Ke4[i] = u8x64_splat_u8x16 (kd->Ke[i]);
1178#endif
1179  return kd;
1180}
1181
1182#define foreach_aes_gcm_handler_type _(128) _(192) _(256)
1183
1184#define _(x) \
1185static u32 aes_ops_dec_aes_gcm_##x                                         \
1186(vlib_main_t * vm, vnet_crypto_op_t * ops[], u32 n_ops)                      \
1187{ return aes_ops_dec_aes_gcm (vm, ops, n_ops, AES_KEY_##x); }              \
1188static u32 aes_ops_enc_aes_gcm_##x                                         \
1189(vlib_main_t * vm, vnet_crypto_op_t * ops[], u32 n_ops)                      \
1190{ return aes_ops_enc_aes_gcm (vm, ops, n_ops, AES_KEY_##x); }              \
1191static void * aes_gcm_key_exp_##x (vnet_crypto_key_t *key)                 \
1192{ return aes_gcm_key_exp (key, AES_KEY_##x); }
1193
1194foreach_aes_gcm_handler_type;
1195#undef _
1196
1197clib_error_t *
1198#ifdef __VAES__
1199crypto_native_aes_gcm_init_icl (vlib_main_t * vm)
1200#elif __AVX512F__
1201crypto_native_aes_gcm_init_skx (vlib_main_t * vm)
1202#elif __AVX2__
1203crypto_native_aes_gcm_init_hsw (vlib_main_t * vm)
1204#elif __aarch64__
1205crypto_native_aes_gcm_init_neon (vlib_main_t * vm)
1206#else
1207crypto_native_aes_gcm_init_slm (vlib_main_t * vm)
1208#endif
1209{
1210  crypto_native_main_t *cm = &crypto_native_main;
1211
1212#define _(x) \
1213  vnet_crypto_register_ops_handler (vm, cm->crypto_engine_index, \
1214				    VNET_CRYPTO_OP_AES_##x##_GCM_ENC, \
1215				    aes_ops_enc_aes_gcm_##x); \
1216  vnet_crypto_register_ops_handler (vm, cm->crypto_engine_index, \
1217				    VNET_CRYPTO_OP_AES_##x##_GCM_DEC, \
1218				    aes_ops_dec_aes_gcm_##x); \
1219  cm->key_fn[VNET_CRYPTO_ALG_AES_##x##_GCM] = aes_gcm_key_exp_##x;
1220  foreach_aes_gcm_handler_type;
1221#undef _
1222  return 0;
1223}
1224
1225/*
1226 * fd.io coding-style-patch-verification: ON
1227 *
1228 * Local Variables:
1229 * eval: (c-set-style "gnu")
1230 * End:
1231 */
1232