LCOV - code coverage report
Current view: top level - ballet/blake3 - blake3_avx512.c (source / functions) Hit Total Coverage
Test: cov.lcov Lines: 201 1016 19.8 %
Date: 2025-01-08 12:08:44 Functions: 15 53 28.3 %

          Line data    Source code
       1             : 
       2             : // Source originally from https://github.com/BLAKE3-team/BLAKE3
       3             : // From commit: c0ea395cf91d242f078c23d5f8d87eb9dd5f7b78
       4             : 
       5             : #include "blake3_impl.h"
       6             : 
       7             : #include <immintrin.h>
       8             : 
       9             : #define _mm_shuffle_ps2(a, b, c)                                               \
      10          80 :   (_mm_castps_si128(                                                           \
      11          80 :       _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
      12             : 
      13          40 : INLINE __m128i loadu_128(const uint8_t src[16]) {
      14          40 :   return _mm_loadu_si128((const __m128i *)src);
      15          40 : }
      16             : 
      17           0 : INLINE __m256i loadu_256(const uint8_t src[32]) {
      18           0 :   return _mm256_loadu_si256((const __m256i *)src);
      19           0 : }
      20             : 
      21           0 : INLINE __m512i loadu_512(const uint8_t src[64]) {
      22           0 :   return _mm512_loadu_si512((const __m512i *)src);
      23           0 : }
      24             : 
      25          20 : INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
      26          20 :   _mm_storeu_si128((__m128i *)dest, src);
      27          20 : }
      28             : 
      29           0 : INLINE void storeu_256(__m256i src, uint8_t dest[16]) {
      30           0 :   _mm256_storeu_si256((__m256i *)dest, src);
      31           0 : }
      32             : 
      33         420 : INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
      34             : 
      35           0 : INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
      36             : 
      37           0 : INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
      38             : 
      39         300 : INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
      40             : 
      41           0 : INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
      42             : 
      43           0 : INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
      44             : 
      45           0 : INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
      46             : 
      47           0 : INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
      48             : 
      49           0 : INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
      50             : 
      51          10 : INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
      52          10 :   return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
      53          10 : }
      54             : 
      55          70 : INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
      56             : 
      57           0 : INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
      58             : 
      59           0 : INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
      60             : 
      61          70 : INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
      62             : 
      63           0 : INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
      64             : 
      65           0 : INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
      66             : 
      67          70 : INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
      68             : 
      69           0 : INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
      70             : 
      71           0 : INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
      72             : 
      73          70 : INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
      74             : 
      75           0 : INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
      76             : 
      77           0 : INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
      78             : 
      79             : /*
      80             :  * ----------------------------------------------------------------------------
      81             :  * compress_avx512
      82             :  * ----------------------------------------------------------------------------
      83             :  */
      84             : 
      85             : INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
      86          70 :                __m128i m) {
      87          70 :   *row0 = add_128(add_128(*row0, m), *row1);
      88          70 :   *row3 = xor_128(*row3, *row0);
      89          70 :   *row3 = rot16_128(*row3);
      90          70 :   *row2 = add_128(*row2, *row3);
      91          70 :   *row1 = xor_128(*row1, *row2);
      92          70 :   *row1 = rot12_128(*row1);
      93          70 : }
      94             : 
      95             : INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
      96          70 :                __m128i m) {
      97          70 :   *row0 = add_128(add_128(*row0, m), *row1);
      98          70 :   *row3 = xor_128(*row3, *row0);
      99          70 :   *row3 = rot8_128(*row3);
     100          70 :   *row2 = add_128(*row2, *row3);
     101          70 :   *row1 = xor_128(*row1, *row2);
     102          70 :   *row1 = rot7_128(*row1);
     103          70 : }
     104             : 
     105             : // Note the optimization here of leaving row1 as the unrotated row, rather than
     106             : // row0. All the message loads below are adjusted to compensate for this. See
     107             : // discussion at https://github.com/sneves/blake2-avx2/pull/4
     108          35 : INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
     109          35 :   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
     110          35 :   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
     111          35 :   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
     112          35 : }
     113             : 
     114          35 : INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
     115          35 :   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
     116          35 :   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
     117          35 :   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
     118          35 : }
     119             : 
     120             : INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
     121             :                          const uint8_t block[BLAKE3_BLOCK_LEN],
     122           5 :                          uint8_t block_len, uint64_t counter, uint8_t flags) {
     123           5 :   rows[0] = loadu_128((uint8_t *)&cv[0]);
     124           5 :   rows[1] = loadu_128((uint8_t *)&cv[4]);
     125           5 :   rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
     126           5 :   rows[3] = set4(counter_low(counter), counter_high(counter),
     127           5 :                  (uint32_t)block_len, (uint32_t)flags);
     128             : 
     129           5 :   __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
     130           5 :   __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
     131           5 :   __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
     132           5 :   __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
     133             : 
     134           5 :   __m128i t0, t1, t2, t3, tt;
     135             : 
     136             :   // Round 1. The first round permutes the message words from the original
     137             :   // input order, into the groups that get mixed in parallel.
     138           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); //  6  4  2  0
     139           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     140           5 :   t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); //  7  5  3  1
     141           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     142           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     143           5 :   t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10  8
     144           5 :   t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3));   // 12 10  8 14
     145           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     146           5 :   t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11  9
     147           5 :   t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3));   // 13 11  9 15
     148           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     149           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     150           5 :   m0 = t0;
     151           5 :   m1 = t1;
     152           5 :   m2 = t2;
     153           5 :   m3 = t3;
     154             : 
     155             :   // Round 2. This round and all following rounds apply a fixed permutation
     156             :   // to the message words from the round before.
     157           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
     158           5 :   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
     159           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     160           5 :   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
     161           5 :   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
     162           5 :   t1 = _mm_blend_epi16(tt, t1, 0xCC);
     163           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     164           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     165           5 :   t2 = _mm_unpacklo_epi64(m3, m1);
     166           5 :   tt = _mm_blend_epi16(t2, m2, 0xC0);
     167           5 :   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
     168           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     169           5 :   t3 = _mm_unpackhi_epi32(m1, m3);
     170           5 :   tt = _mm_unpacklo_epi32(m2, t3);
     171           5 :   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
     172           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     173           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     174           5 :   m0 = t0;
     175           5 :   m1 = t1;
     176           5 :   m2 = t2;
     177           5 :   m3 = t3;
     178             : 
     179             :   // Round 3
     180           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
     181           5 :   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
     182           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     183           5 :   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
     184           5 :   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
     185           5 :   t1 = _mm_blend_epi16(tt, t1, 0xCC);
     186           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     187           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     188           5 :   t2 = _mm_unpacklo_epi64(m3, m1);
     189           5 :   tt = _mm_blend_epi16(t2, m2, 0xC0);
     190           5 :   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
     191           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     192           5 :   t3 = _mm_unpackhi_epi32(m1, m3);
     193           5 :   tt = _mm_unpacklo_epi32(m2, t3);
     194           5 :   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
     195           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     196           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     197           5 :   m0 = t0;
     198           5 :   m1 = t1;
     199           5 :   m2 = t2;
     200           5 :   m3 = t3;
     201             : 
     202             :   // Round 4
     203           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
     204           5 :   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
     205           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     206           5 :   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
     207           5 :   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
     208           5 :   t1 = _mm_blend_epi16(tt, t1, 0xCC);
     209           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     210           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     211           5 :   t2 = _mm_unpacklo_epi64(m3, m1);
     212           5 :   tt = _mm_blend_epi16(t2, m2, 0xC0);
     213           5 :   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
     214           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     215           5 :   t3 = _mm_unpackhi_epi32(m1, m3);
     216           5 :   tt = _mm_unpacklo_epi32(m2, t3);
     217           5 :   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
     218           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     219           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     220           5 :   m0 = t0;
     221           5 :   m1 = t1;
     222           5 :   m2 = t2;
     223           5 :   m3 = t3;
     224             : 
     225             :   // Round 5
     226           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
     227           5 :   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
     228           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     229           5 :   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
     230           5 :   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
     231           5 :   t1 = _mm_blend_epi16(tt, t1, 0xCC);
     232           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     233           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     234           5 :   t2 = _mm_unpacklo_epi64(m3, m1);
     235           5 :   tt = _mm_blend_epi16(t2, m2, 0xC0);
     236           5 :   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
     237           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     238           5 :   t3 = _mm_unpackhi_epi32(m1, m3);
     239           5 :   tt = _mm_unpacklo_epi32(m2, t3);
     240           5 :   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
     241           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     242           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     243           5 :   m0 = t0;
     244           5 :   m1 = t1;
     245           5 :   m2 = t2;
     246           5 :   m3 = t3;
     247             : 
     248             :   // Round 6
     249           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
     250           5 :   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
     251           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     252           5 :   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
     253           5 :   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
     254           5 :   t1 = _mm_blend_epi16(tt, t1, 0xCC);
     255           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     256           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     257           5 :   t2 = _mm_unpacklo_epi64(m3, m1);
     258           5 :   tt = _mm_blend_epi16(t2, m2, 0xC0);
     259           5 :   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
     260           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     261           5 :   t3 = _mm_unpackhi_epi32(m1, m3);
     262           5 :   tt = _mm_unpacklo_epi32(m2, t3);
     263           5 :   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
     264           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     265           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     266           5 :   m0 = t0;
     267           5 :   m1 = t1;
     268           5 :   m2 = t2;
     269           5 :   m3 = t3;
     270             : 
     271             :   // Round 7
     272           5 :   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
     273           5 :   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
     274           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
     275           5 :   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
     276           5 :   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
     277           5 :   t1 = _mm_blend_epi16(tt, t1, 0xCC);
     278           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
     279           5 :   diagonalize(&rows[0], &rows[2], &rows[3]);
     280           5 :   t2 = _mm_unpacklo_epi64(m3, m1);
     281           5 :   tt = _mm_blend_epi16(t2, m2, 0xC0);
     282           5 :   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
     283           5 :   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
     284           5 :   t3 = _mm_unpackhi_epi32(m1, m3);
     285           5 :   tt = _mm_unpacklo_epi32(m2, t3);
     286           5 :   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
     287           5 :   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
     288           5 :   undiagonalize(&rows[0], &rows[2], &rows[3]);
     289           5 : }
     290             : 
     291             : void fd_blake3_compress_xof_avx512(const uint32_t cv[8],
     292             :                                    const uint8_t block[BLAKE3_BLOCK_LEN],
     293             :                                    uint8_t block_len, uint64_t counter,
     294           5 :                                    uint8_t flags, uint8_t out[64]) {
     295           5 :   __m128i rows[4];
     296           5 :   compress_pre(rows, cv, block, block_len, counter, flags);
     297           5 :   storeu_128(xor_128(rows[0], rows[2]), &out[0]);
     298           5 :   storeu_128(xor_128(rows[1], rows[3]), &out[16]);
     299           5 :   storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
     300           5 :   storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
     301           5 : }
     302             : 
     303             : void fd_blake3_compress_in_place_avx512(uint32_t cv[8],
     304             :                                         const uint8_t block[BLAKE3_BLOCK_LEN],
     305             :                                         uint8_t block_len, uint64_t counter,
     306           0 :                                         uint8_t flags) {
     307           0 :   __m128i rows[4];
     308           0 :   compress_pre(rows, cv, block, block_len, counter, flags);
     309           0 :   storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
     310           0 :   storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
     311           0 : }
     312             : 
     313             : /*
     314             :  * ----------------------------------------------------------------------------
     315             :  * hash4_avx512
     316             :  * ----------------------------------------------------------------------------
     317             :  */
     318             : 
     319           0 : INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
     320           0 :   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
     321           0 :   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
     322           0 :   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
     323           0 :   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
     324           0 :   v[0] = add_128(v[0], v[4]);
     325           0 :   v[1] = add_128(v[1], v[5]);
     326           0 :   v[2] = add_128(v[2], v[6]);
     327           0 :   v[3] = add_128(v[3], v[7]);
     328           0 :   v[12] = xor_128(v[12], v[0]);
     329           0 :   v[13] = xor_128(v[13], v[1]);
     330           0 :   v[14] = xor_128(v[14], v[2]);
     331           0 :   v[15] = xor_128(v[15], v[3]);
     332           0 :   v[12] = rot16_128(v[12]);
     333           0 :   v[13] = rot16_128(v[13]);
     334           0 :   v[14] = rot16_128(v[14]);
     335           0 :   v[15] = rot16_128(v[15]);
     336           0 :   v[8] = add_128(v[8], v[12]);
     337           0 :   v[9] = add_128(v[9], v[13]);
     338           0 :   v[10] = add_128(v[10], v[14]);
     339           0 :   v[11] = add_128(v[11], v[15]);
     340           0 :   v[4] = xor_128(v[4], v[8]);
     341           0 :   v[5] = xor_128(v[5], v[9]);
     342           0 :   v[6] = xor_128(v[6], v[10]);
     343           0 :   v[7] = xor_128(v[7], v[11]);
     344           0 :   v[4] = rot12_128(v[4]);
     345           0 :   v[5] = rot12_128(v[5]);
     346           0 :   v[6] = rot12_128(v[6]);
     347           0 :   v[7] = rot12_128(v[7]);
     348           0 :   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
     349           0 :   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
     350           0 :   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
     351           0 :   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
     352           0 :   v[0] = add_128(v[0], v[4]);
     353           0 :   v[1] = add_128(v[1], v[5]);
     354           0 :   v[2] = add_128(v[2], v[6]);
     355           0 :   v[3] = add_128(v[3], v[7]);
     356           0 :   v[12] = xor_128(v[12], v[0]);
     357           0 :   v[13] = xor_128(v[13], v[1]);
     358           0 :   v[14] = xor_128(v[14], v[2]);
     359           0 :   v[15] = xor_128(v[15], v[3]);
     360           0 :   v[12] = rot8_128(v[12]);
     361           0 :   v[13] = rot8_128(v[13]);
     362           0 :   v[14] = rot8_128(v[14]);
     363           0 :   v[15] = rot8_128(v[15]);
     364           0 :   v[8] = add_128(v[8], v[12]);
     365           0 :   v[9] = add_128(v[9], v[13]);
     366           0 :   v[10] = add_128(v[10], v[14]);
     367           0 :   v[11] = add_128(v[11], v[15]);
     368           0 :   v[4] = xor_128(v[4], v[8]);
     369           0 :   v[5] = xor_128(v[5], v[9]);
     370           0 :   v[6] = xor_128(v[6], v[10]);
     371           0 :   v[7] = xor_128(v[7], v[11]);
     372           0 :   v[4] = rot7_128(v[4]);
     373           0 :   v[5] = rot7_128(v[5]);
     374           0 :   v[6] = rot7_128(v[6]);
     375           0 :   v[7] = rot7_128(v[7]);
     376             : 
     377           0 :   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
     378           0 :   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
     379           0 :   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
     380           0 :   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
     381           0 :   v[0] = add_128(v[0], v[5]);
     382           0 :   v[1] = add_128(v[1], v[6]);
     383           0 :   v[2] = add_128(v[2], v[7]);
     384           0 :   v[3] = add_128(v[3], v[4]);
     385           0 :   v[15] = xor_128(v[15], v[0]);
     386           0 :   v[12] = xor_128(v[12], v[1]);
     387           0 :   v[13] = xor_128(v[13], v[2]);
     388           0 :   v[14] = xor_128(v[14], v[3]);
     389           0 :   v[15] = rot16_128(v[15]);
     390           0 :   v[12] = rot16_128(v[12]);
     391           0 :   v[13] = rot16_128(v[13]);
     392           0 :   v[14] = rot16_128(v[14]);
     393           0 :   v[10] = add_128(v[10], v[15]);
     394           0 :   v[11] = add_128(v[11], v[12]);
     395           0 :   v[8] = add_128(v[8], v[13]);
     396           0 :   v[9] = add_128(v[9], v[14]);
     397           0 :   v[5] = xor_128(v[5], v[10]);
     398           0 :   v[6] = xor_128(v[6], v[11]);
     399           0 :   v[7] = xor_128(v[7], v[8]);
     400           0 :   v[4] = xor_128(v[4], v[9]);
     401           0 :   v[5] = rot12_128(v[5]);
     402           0 :   v[6] = rot12_128(v[6]);
     403           0 :   v[7] = rot12_128(v[7]);
     404           0 :   v[4] = rot12_128(v[4]);
     405           0 :   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
     406           0 :   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
     407           0 :   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
     408           0 :   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
     409           0 :   v[0] = add_128(v[0], v[5]);
     410           0 :   v[1] = add_128(v[1], v[6]);
     411           0 :   v[2] = add_128(v[2], v[7]);
     412           0 :   v[3] = add_128(v[3], v[4]);
     413           0 :   v[15] = xor_128(v[15], v[0]);
     414           0 :   v[12] = xor_128(v[12], v[1]);
     415           0 :   v[13] = xor_128(v[13], v[2]);
     416           0 :   v[14] = xor_128(v[14], v[3]);
     417           0 :   v[15] = rot8_128(v[15]);
     418           0 :   v[12] = rot8_128(v[12]);
     419           0 :   v[13] = rot8_128(v[13]);
     420           0 :   v[14] = rot8_128(v[14]);
     421           0 :   v[10] = add_128(v[10], v[15]);
     422           0 :   v[11] = add_128(v[11], v[12]);
     423           0 :   v[8] = add_128(v[8], v[13]);
     424           0 :   v[9] = add_128(v[9], v[14]);
     425           0 :   v[5] = xor_128(v[5], v[10]);
     426           0 :   v[6] = xor_128(v[6], v[11]);
     427           0 :   v[7] = xor_128(v[7], v[8]);
     428           0 :   v[4] = xor_128(v[4], v[9]);
     429           0 :   v[5] = rot7_128(v[5]);
     430           0 :   v[6] = rot7_128(v[6]);
     431           0 :   v[7] = rot7_128(v[7]);
     432           0 :   v[4] = rot7_128(v[4]);
     433           0 : }
     434             : 
     435           0 : INLINE void transpose_vecs_128(__m128i vecs[4]) {
     436             :   // Interleave 32-bit lanes. The low unpack is lanes 00/11 and the high is
     437             :   // 22/33. Note that this doesn't split the vector into two lanes, as the
     438             :   // AVX2 counterparts do.
     439           0 :   __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
     440           0 :   __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
     441           0 :   __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
     442           0 :   __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
     443             : 
     444             :   // Interleave 64-bit lanes.
     445           0 :   __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
     446           0 :   __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
     447           0 :   __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
     448           0 :   __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
     449             : 
     450           0 :   vecs[0] = abcd_0;
     451           0 :   vecs[1] = abcd_1;
     452           0 :   vecs[2] = abcd_2;
     453           0 :   vecs[3] = abcd_3;
     454           0 : }
     455             : 
     456             : INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
     457           0 :                                 size_t block_offset, __m128i out[16]) {
     458           0 :   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
     459           0 :   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
     460           0 :   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
     461           0 :   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
     462           0 :   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
     463           0 :   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
     464           0 :   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
     465           0 :   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
     466           0 :   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
     467           0 :   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
     468           0 :   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
     469           0 :   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
     470           0 :   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
     471           0 :   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
     472           0 :   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
     473           0 :   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
     474           0 :   for (size_t i = 0; i < 4; ++i) {
     475           0 :     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
     476           0 :   }
     477           0 :   transpose_vecs_128(&out[0]);
     478           0 :   transpose_vecs_128(&out[4]);
     479           0 :   transpose_vecs_128(&out[8]);
     480           0 :   transpose_vecs_128(&out[12]);
     481           0 : }
     482             : 
     483             : INLINE void load_counters4(uint64_t counter, bool increment_counter,
     484           0 :                            __m128i *out_lo, __m128i *out_hi) {
     485           0 :   int64_t mask = (increment_counter ? ~0 : 0);
     486           0 :   __m256i mask_vec = _mm256_set1_epi64x(mask);
     487           0 :   __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
     488           0 :   deltas = _mm256_and_si256(mask_vec, deltas);
     489           0 :   __m256i counters =
     490           0 :       _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
     491           0 :   *out_lo = _mm256_cvtepi64_epi32(counters);
     492           0 :   *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
     493           0 : }
     494             : 
     495             : static
     496             : void fd_blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
     497             :                             const uint32_t key[8], uint64_t counter,
     498             :                             bool increment_counter, uint8_t flags,
     499           0 :                             uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
     500           0 :   __m128i h_vecs[8] = {
     501           0 :       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
     502           0 :       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
     503           0 :   };
     504           0 :   __m128i counter_low_vec, counter_high_vec;
     505           0 :   load_counters4(counter, increment_counter, &counter_low_vec,
     506           0 :                  &counter_high_vec);
     507           0 :   uint8_t block_flags = flags | flags_start;
     508             : 
     509           0 :   for (size_t block = 0; block < blocks; block++) {
     510           0 :     if (block + 1 == blocks) {
     511           0 :       block_flags |= flags_end;
     512           0 :     }
     513           0 :     __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
     514           0 :     __m128i block_flags_vec = set1_128(block_flags);
     515           0 :     __m128i msg_vecs[16];
     516           0 :     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
     517             : 
     518           0 :     __m128i v[16] = {
     519           0 :         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
     520           0 :         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
     521           0 :         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
     522           0 :         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
     523           0 :     };
     524           0 :     round_fn4(v, msg_vecs, 0);
     525           0 :     round_fn4(v, msg_vecs, 1);
     526           0 :     round_fn4(v, msg_vecs, 2);
     527           0 :     round_fn4(v, msg_vecs, 3);
     528           0 :     round_fn4(v, msg_vecs, 4);
     529           0 :     round_fn4(v, msg_vecs, 5);
     530           0 :     round_fn4(v, msg_vecs, 6);
     531           0 :     h_vecs[0] = xor_128(v[0], v[8]);
     532           0 :     h_vecs[1] = xor_128(v[1], v[9]);
     533           0 :     h_vecs[2] = xor_128(v[2], v[10]);
     534           0 :     h_vecs[3] = xor_128(v[3], v[11]);
     535           0 :     h_vecs[4] = xor_128(v[4], v[12]);
     536           0 :     h_vecs[5] = xor_128(v[5], v[13]);
     537           0 :     h_vecs[6] = xor_128(v[6], v[14]);
     538           0 :     h_vecs[7] = xor_128(v[7], v[15]);
     539             : 
     540           0 :     block_flags = flags;
     541           0 :   }
     542             : 
     543           0 :   transpose_vecs_128(&h_vecs[0]);
     544           0 :   transpose_vecs_128(&h_vecs[4]);
     545             :   // The first four vecs now contain the first half of each output, and the
     546             :   // second four vecs contain the second half of each output.
     547           0 :   storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
     548           0 :   storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
     549           0 :   storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
     550           0 :   storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
     551           0 :   storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
     552           0 :   storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
     553           0 :   storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
     554           0 :   storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
     555           0 : }
     556             : 
     557             : /*
     558             :  * ----------------------------------------------------------------------------
     559             :  * hash8_avx512
     560             :  * ----------------------------------------------------------------------------
     561             :  */
     562             : 
     563           0 : INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
     564           0 :   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
     565           0 :   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
     566           0 :   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
     567           0 :   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
     568           0 :   v[0] = add_256(v[0], v[4]);
     569           0 :   v[1] = add_256(v[1], v[5]);
     570           0 :   v[2] = add_256(v[2], v[6]);
     571           0 :   v[3] = add_256(v[3], v[7]);
     572           0 :   v[12] = xor_256(v[12], v[0]);
     573           0 :   v[13] = xor_256(v[13], v[1]);
     574           0 :   v[14] = xor_256(v[14], v[2]);
     575           0 :   v[15] = xor_256(v[15], v[3]);
     576           0 :   v[12] = rot16_256(v[12]);
     577           0 :   v[13] = rot16_256(v[13]);
     578           0 :   v[14] = rot16_256(v[14]);
     579           0 :   v[15] = rot16_256(v[15]);
     580           0 :   v[8] = add_256(v[8], v[12]);
     581           0 :   v[9] = add_256(v[9], v[13]);
     582           0 :   v[10] = add_256(v[10], v[14]);
     583           0 :   v[11] = add_256(v[11], v[15]);
     584           0 :   v[4] = xor_256(v[4], v[8]);
     585           0 :   v[5] = xor_256(v[5], v[9]);
     586           0 :   v[6] = xor_256(v[6], v[10]);
     587           0 :   v[7] = xor_256(v[7], v[11]);
     588           0 :   v[4] = rot12_256(v[4]);
     589           0 :   v[5] = rot12_256(v[5]);
     590           0 :   v[6] = rot12_256(v[6]);
     591           0 :   v[7] = rot12_256(v[7]);
     592           0 :   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
     593           0 :   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
     594           0 :   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
     595           0 :   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
     596           0 :   v[0] = add_256(v[0], v[4]);
     597           0 :   v[1] = add_256(v[1], v[5]);
     598           0 :   v[2] = add_256(v[2], v[6]);
     599           0 :   v[3] = add_256(v[3], v[7]);
     600           0 :   v[12] = xor_256(v[12], v[0]);
     601           0 :   v[13] = xor_256(v[13], v[1]);
     602           0 :   v[14] = xor_256(v[14], v[2]);
     603           0 :   v[15] = xor_256(v[15], v[3]);
     604           0 :   v[12] = rot8_256(v[12]);
     605           0 :   v[13] = rot8_256(v[13]);
     606           0 :   v[14] = rot8_256(v[14]);
     607           0 :   v[15] = rot8_256(v[15]);
     608           0 :   v[8] = add_256(v[8], v[12]);
     609           0 :   v[9] = add_256(v[9], v[13]);
     610           0 :   v[10] = add_256(v[10], v[14]);
     611           0 :   v[11] = add_256(v[11], v[15]);
     612           0 :   v[4] = xor_256(v[4], v[8]);
     613           0 :   v[5] = xor_256(v[5], v[9]);
     614           0 :   v[6] = xor_256(v[6], v[10]);
     615           0 :   v[7] = xor_256(v[7], v[11]);
     616           0 :   v[4] = rot7_256(v[4]);
     617           0 :   v[5] = rot7_256(v[5]);
     618           0 :   v[6] = rot7_256(v[6]);
     619           0 :   v[7] = rot7_256(v[7]);
     620             : 
     621           0 :   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
     622           0 :   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
     623           0 :   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
     624           0 :   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
     625           0 :   v[0] = add_256(v[0], v[5]);
     626           0 :   v[1] = add_256(v[1], v[6]);
     627           0 :   v[2] = add_256(v[2], v[7]);
     628           0 :   v[3] = add_256(v[3], v[4]);
     629           0 :   v[15] = xor_256(v[15], v[0]);
     630           0 :   v[12] = xor_256(v[12], v[1]);
     631           0 :   v[13] = xor_256(v[13], v[2]);
     632           0 :   v[14] = xor_256(v[14], v[3]);
     633           0 :   v[15] = rot16_256(v[15]);
     634           0 :   v[12] = rot16_256(v[12]);
     635           0 :   v[13] = rot16_256(v[13]);
     636           0 :   v[14] = rot16_256(v[14]);
     637           0 :   v[10] = add_256(v[10], v[15]);
     638           0 :   v[11] = add_256(v[11], v[12]);
     639           0 :   v[8] = add_256(v[8], v[13]);
     640           0 :   v[9] = add_256(v[9], v[14]);
     641           0 :   v[5] = xor_256(v[5], v[10]);
     642           0 :   v[6] = xor_256(v[6], v[11]);
     643           0 :   v[7] = xor_256(v[7], v[8]);
     644           0 :   v[4] = xor_256(v[4], v[9]);
     645           0 :   v[5] = rot12_256(v[5]);
     646           0 :   v[6] = rot12_256(v[6]);
     647           0 :   v[7] = rot12_256(v[7]);
     648           0 :   v[4] = rot12_256(v[4]);
     649           0 :   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
     650           0 :   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
     651           0 :   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
     652           0 :   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
     653           0 :   v[0] = add_256(v[0], v[5]);
     654           0 :   v[1] = add_256(v[1], v[6]);
     655           0 :   v[2] = add_256(v[2], v[7]);
     656           0 :   v[3] = add_256(v[3], v[4]);
     657           0 :   v[15] = xor_256(v[15], v[0]);
     658           0 :   v[12] = xor_256(v[12], v[1]);
     659           0 :   v[13] = xor_256(v[13], v[2]);
     660           0 :   v[14] = xor_256(v[14], v[3]);
     661           0 :   v[15] = rot8_256(v[15]);
     662           0 :   v[12] = rot8_256(v[12]);
     663           0 :   v[13] = rot8_256(v[13]);
     664           0 :   v[14] = rot8_256(v[14]);
     665           0 :   v[10] = add_256(v[10], v[15]);
     666           0 :   v[11] = add_256(v[11], v[12]);
     667           0 :   v[8] = add_256(v[8], v[13]);
     668           0 :   v[9] = add_256(v[9], v[14]);
     669           0 :   v[5] = xor_256(v[5], v[10]);
     670           0 :   v[6] = xor_256(v[6], v[11]);
     671           0 :   v[7] = xor_256(v[7], v[8]);
     672           0 :   v[4] = xor_256(v[4], v[9]);
     673           0 :   v[5] = rot7_256(v[5]);
     674           0 :   v[6] = rot7_256(v[6]);
     675           0 :   v[7] = rot7_256(v[7]);
     676           0 :   v[4] = rot7_256(v[4]);
     677           0 : }
     678             : 
     679           0 : INLINE void transpose_vecs_256(__m256i vecs[8]) {
     680             :   // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
     681             :   // is 22/33/66/77.
     682           0 :   __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
     683           0 :   __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
     684           0 :   __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
     685           0 :   __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
     686           0 :   __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
     687           0 :   __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
     688           0 :   __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
     689           0 :   __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
     690             : 
     691             :   // Interleave 64-bit lanes. The low unpack is lanes 00/22 and the high is
     692             :   // 11/33.
     693           0 :   __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
     694           0 :   __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
     695           0 :   __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
     696           0 :   __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
     697           0 :   __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
     698           0 :   __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
     699           0 :   __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
     700           0 :   __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
     701             : 
     702             :   // Interleave 128-bit lanes.
     703           0 :   vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
     704           0 :   vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
     705           0 :   vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
     706           0 :   vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
     707           0 :   vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
     708           0 :   vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
     709           0 :   vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
     710           0 :   vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
     711           0 : }
     712             : 
     713             : INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
     714           0 :                                 size_t block_offset, __m256i out[16]) {
     715           0 :   out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
     716           0 :   out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
     717           0 :   out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
     718           0 :   out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
     719           0 :   out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
     720           0 :   out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
     721           0 :   out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
     722           0 :   out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
     723           0 :   out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
     724           0 :   out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
     725           0 :   out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
     726           0 :   out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
     727           0 :   out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
     728           0 :   out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
     729           0 :   out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
     730           0 :   out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
     731           0 :   for (size_t i = 0; i < 8; ++i) {
     732           0 :     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
     733           0 :   }
     734           0 :   transpose_vecs_256(&out[0]);
     735           0 :   transpose_vecs_256(&out[8]);
     736           0 : }
     737             : 
     738             : INLINE void load_counters8(uint64_t counter, bool increment_counter,
     739           0 :                            __m256i *out_lo, __m256i *out_hi) {
     740           0 :   int64_t mask = (increment_counter ? ~0 : 0);
     741           0 :   __m512i mask_vec = _mm512_set1_epi64(mask);
     742           0 :   __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
     743           0 :   deltas = _mm512_and_si512(mask_vec, deltas);
     744           0 :   __m512i counters =
     745           0 :       _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
     746           0 :   *out_lo = _mm512_cvtepi64_epi32(counters);
     747           0 :   *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
     748           0 : }
     749             : 
     750             : static
     751             : void fd_blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
     752             :                             const uint32_t key[8], uint64_t counter,
     753             :                             bool increment_counter, uint8_t flags,
     754           0 :                             uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
     755           0 :   __m256i h_vecs[8] = {
     756           0 :       set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
     757           0 :       set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
     758           0 :   };
     759           0 :   __m256i counter_low_vec, counter_high_vec;
     760           0 :   load_counters8(counter, increment_counter, &counter_low_vec,
     761           0 :                  &counter_high_vec);
     762           0 :   uint8_t block_flags = flags | flags_start;
     763             : 
     764           0 :   for (size_t block = 0; block < blocks; block++) {
     765           0 :     if (block + 1 == blocks) {
     766           0 :       block_flags |= flags_end;
     767           0 :     }
     768           0 :     __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
     769           0 :     __m256i block_flags_vec = set1_256(block_flags);
     770           0 :     __m256i msg_vecs[16];
     771           0 :     transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
     772             : 
     773           0 :     __m256i v[16] = {
     774           0 :         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
     775           0 :         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
     776           0 :         set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
     777           0 :         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
     778           0 :     };
     779           0 :     round_fn8(v, msg_vecs, 0);
     780           0 :     round_fn8(v, msg_vecs, 1);
     781           0 :     round_fn8(v, msg_vecs, 2);
     782           0 :     round_fn8(v, msg_vecs, 3);
     783           0 :     round_fn8(v, msg_vecs, 4);
     784           0 :     round_fn8(v, msg_vecs, 5);
     785           0 :     round_fn8(v, msg_vecs, 6);
     786           0 :     h_vecs[0] = xor_256(v[0], v[8]);
     787           0 :     h_vecs[1] = xor_256(v[1], v[9]);
     788           0 :     h_vecs[2] = xor_256(v[2], v[10]);
     789           0 :     h_vecs[3] = xor_256(v[3], v[11]);
     790           0 :     h_vecs[4] = xor_256(v[4], v[12]);
     791           0 :     h_vecs[5] = xor_256(v[5], v[13]);
     792           0 :     h_vecs[6] = xor_256(v[6], v[14]);
     793           0 :     h_vecs[7] = xor_256(v[7], v[15]);
     794             : 
     795           0 :     block_flags = flags;
     796           0 :   }
     797             : 
     798           0 :   transpose_vecs_256(h_vecs);
     799           0 :   storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
     800           0 :   storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
     801           0 :   storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
     802           0 :   storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
     803           0 :   storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
     804           0 :   storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
     805           0 :   storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
     806           0 :   storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
     807           0 : }
     808             : 
     809             : /*
     810             :  * ----------------------------------------------------------------------------
     811             :  * hash16_avx512
     812             :  * ----------------------------------------------------------------------------
     813             :  */
     814             : 
     815           0 : INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
     816           0 :   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
     817           0 :   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
     818           0 :   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
     819           0 :   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
     820           0 :   v[0] = add_512(v[0], v[4]);
     821           0 :   v[1] = add_512(v[1], v[5]);
     822           0 :   v[2] = add_512(v[2], v[6]);
     823           0 :   v[3] = add_512(v[3], v[7]);
     824           0 :   v[12] = xor_512(v[12], v[0]);
     825           0 :   v[13] = xor_512(v[13], v[1]);
     826           0 :   v[14] = xor_512(v[14], v[2]);
     827           0 :   v[15] = xor_512(v[15], v[3]);
     828           0 :   v[12] = rot16_512(v[12]);
     829           0 :   v[13] = rot16_512(v[13]);
     830           0 :   v[14] = rot16_512(v[14]);
     831           0 :   v[15] = rot16_512(v[15]);
     832           0 :   v[8] = add_512(v[8], v[12]);
     833           0 :   v[9] = add_512(v[9], v[13]);
     834           0 :   v[10] = add_512(v[10], v[14]);
     835           0 :   v[11] = add_512(v[11], v[15]);
     836           0 :   v[4] = xor_512(v[4], v[8]);
     837           0 :   v[5] = xor_512(v[5], v[9]);
     838           0 :   v[6] = xor_512(v[6], v[10]);
     839           0 :   v[7] = xor_512(v[7], v[11]);
     840           0 :   v[4] = rot12_512(v[4]);
     841           0 :   v[5] = rot12_512(v[5]);
     842           0 :   v[6] = rot12_512(v[6]);
     843           0 :   v[7] = rot12_512(v[7]);
     844           0 :   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
     845           0 :   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
     846           0 :   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
     847           0 :   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
     848           0 :   v[0] = add_512(v[0], v[4]);
     849           0 :   v[1] = add_512(v[1], v[5]);
     850           0 :   v[2] = add_512(v[2], v[6]);
     851           0 :   v[3] = add_512(v[3], v[7]);
     852           0 :   v[12] = xor_512(v[12], v[0]);
     853           0 :   v[13] = xor_512(v[13], v[1]);
     854           0 :   v[14] = xor_512(v[14], v[2]);
     855           0 :   v[15] = xor_512(v[15], v[3]);
     856           0 :   v[12] = rot8_512(v[12]);
     857           0 :   v[13] = rot8_512(v[13]);
     858           0 :   v[14] = rot8_512(v[14]);
     859           0 :   v[15] = rot8_512(v[15]);
     860           0 :   v[8] = add_512(v[8], v[12]);
     861           0 :   v[9] = add_512(v[9], v[13]);
     862           0 :   v[10] = add_512(v[10], v[14]);
     863           0 :   v[11] = add_512(v[11], v[15]);
     864           0 :   v[4] = xor_512(v[4], v[8]);
     865           0 :   v[5] = xor_512(v[5], v[9]);
     866           0 :   v[6] = xor_512(v[6], v[10]);
     867           0 :   v[7] = xor_512(v[7], v[11]);
     868           0 :   v[4] = rot7_512(v[4]);
     869           0 :   v[5] = rot7_512(v[5]);
     870           0 :   v[6] = rot7_512(v[6]);
     871           0 :   v[7] = rot7_512(v[7]);
     872             : 
     873           0 :   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
     874           0 :   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
     875           0 :   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
     876           0 :   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
     877           0 :   v[0] = add_512(v[0], v[5]);
     878           0 :   v[1] = add_512(v[1], v[6]);
     879           0 :   v[2] = add_512(v[2], v[7]);
     880           0 :   v[3] = add_512(v[3], v[4]);
     881           0 :   v[15] = xor_512(v[15], v[0]);
     882           0 :   v[12] = xor_512(v[12], v[1]);
     883           0 :   v[13] = xor_512(v[13], v[2]);
     884           0 :   v[14] = xor_512(v[14], v[3]);
     885           0 :   v[15] = rot16_512(v[15]);
     886           0 :   v[12] = rot16_512(v[12]);
     887           0 :   v[13] = rot16_512(v[13]);
     888           0 :   v[14] = rot16_512(v[14]);
     889           0 :   v[10] = add_512(v[10], v[15]);
     890           0 :   v[11] = add_512(v[11], v[12]);
     891           0 :   v[8] = add_512(v[8], v[13]);
     892           0 :   v[9] = add_512(v[9], v[14]);
     893           0 :   v[5] = xor_512(v[5], v[10]);
     894           0 :   v[6] = xor_512(v[6], v[11]);
     895           0 :   v[7] = xor_512(v[7], v[8]);
     896           0 :   v[4] = xor_512(v[4], v[9]);
     897           0 :   v[5] = rot12_512(v[5]);
     898           0 :   v[6] = rot12_512(v[6]);
     899           0 :   v[7] = rot12_512(v[7]);
     900           0 :   v[4] = rot12_512(v[4]);
     901           0 :   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
     902           0 :   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
     903           0 :   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
     904           0 :   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
     905           0 :   v[0] = add_512(v[0], v[5]);
     906           0 :   v[1] = add_512(v[1], v[6]);
     907           0 :   v[2] = add_512(v[2], v[7]);
     908           0 :   v[3] = add_512(v[3], v[4]);
     909           0 :   v[15] = xor_512(v[15], v[0]);
     910           0 :   v[12] = xor_512(v[12], v[1]);
     911           0 :   v[13] = xor_512(v[13], v[2]);
     912           0 :   v[14] = xor_512(v[14], v[3]);
     913           0 :   v[15] = rot8_512(v[15]);
     914           0 :   v[12] = rot8_512(v[12]);
     915           0 :   v[13] = rot8_512(v[13]);
     916           0 :   v[14] = rot8_512(v[14]);
     917           0 :   v[10] = add_512(v[10], v[15]);
     918           0 :   v[11] = add_512(v[11], v[12]);
     919           0 :   v[8] = add_512(v[8], v[13]);
     920           0 :   v[9] = add_512(v[9], v[14]);
     921           0 :   v[5] = xor_512(v[5], v[10]);
     922           0 :   v[6] = xor_512(v[6], v[11]);
     923           0 :   v[7] = xor_512(v[7], v[8]);
     924           0 :   v[4] = xor_512(v[4], v[9]);
     925           0 :   v[5] = rot7_512(v[5]);
     926           0 :   v[6] = rot7_512(v[6]);
     927           0 :   v[7] = rot7_512(v[7]);
     928           0 :   v[4] = rot7_512(v[4]);
     929           0 : }
     930             : 
     931             : // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
     932             : #define LO_IMM8 0x88
     933             : 
     934           0 : INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
     935           0 :   return _mm512_shuffle_i32x4(a, b, LO_IMM8);
     936           0 : }
     937             : 
     938             : // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
     939             : #define HI_IMM8 0xdd
     940             : 
     941           0 : INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
     942           0 :   return _mm512_shuffle_i32x4(a, b, HI_IMM8);
     943           0 : }
     944             : 
     945           0 : INLINE void transpose_vecs_512(__m512i vecs[16]) {
     946             :   // Interleave 32-bit lanes. The _0 unpack is lanes
     947             :   // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
     948             :   // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
     949           0 :   __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
     950           0 :   __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
     951           0 :   __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
     952           0 :   __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
     953           0 :   __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
     954           0 :   __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
     955           0 :   __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
     956           0 :   __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
     957           0 :   __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
     958           0 :   __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
     959           0 :   __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
     960           0 :   __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
     961           0 :   __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
     962           0 :   __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
     963           0 :   __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
     964           0 :   __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
     965             : 
     966             :   // Interleave 64-bit lanes. The _0 unpack is lanes
     967             :   // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
     968             :   // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
     969             :   // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
     970             :   // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
     971           0 :   __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
     972           0 :   __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
     973           0 :   __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
     974           0 :   __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
     975           0 :   __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
     976           0 :   __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
     977           0 :   __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
     978           0 :   __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
     979           0 :   __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
     980           0 :   __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
     981           0 :   __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
     982           0 :   __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
     983           0 :   __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
     984           0 :   __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
     985           0 :   __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
     986           0 :   __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
     987             : 
     988             :   // Interleave 128-bit lanes. The _0 unpack is
     989             :   // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
     990             :   // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
     991           0 :   __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
     992           0 :   __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
     993           0 :   __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
     994           0 :   __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
     995           0 :   __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
     996           0 :   __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
     997           0 :   __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
     998           0 :   __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
     999           0 :   __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
    1000           0 :   __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
    1001           0 :   __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
    1002           0 :   __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
    1003           0 :   __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
    1004           0 :   __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
    1005           0 :   __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
    1006           0 :   __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
    1007             : 
    1008             :   // Interleave 128-bit lanes again for the final outputs.
    1009           0 :   vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
    1010           0 :   vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
    1011           0 :   vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
    1012           0 :   vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
    1013           0 :   vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
    1014           0 :   vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
    1015           0 :   vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
    1016           0 :   vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
    1017           0 :   vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
    1018           0 :   vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
    1019           0 :   vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
    1020           0 :   vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
    1021           0 :   vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
    1022           0 :   vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
    1023           0 :   vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
    1024           0 :   vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
    1025           0 : }
    1026             : 
    1027             : INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
    1028           0 :                                  size_t block_offset, __m512i out[16]) {
    1029           0 :   out[0] = loadu_512(&inputs[0][block_offset]);
    1030           0 :   out[1] = loadu_512(&inputs[1][block_offset]);
    1031           0 :   out[2] = loadu_512(&inputs[2][block_offset]);
    1032           0 :   out[3] = loadu_512(&inputs[3][block_offset]);
    1033           0 :   out[4] = loadu_512(&inputs[4][block_offset]);
    1034           0 :   out[5] = loadu_512(&inputs[5][block_offset]);
    1035           0 :   out[6] = loadu_512(&inputs[6][block_offset]);
    1036           0 :   out[7] = loadu_512(&inputs[7][block_offset]);
    1037           0 :   out[8] = loadu_512(&inputs[8][block_offset]);
    1038           0 :   out[9] = loadu_512(&inputs[9][block_offset]);
    1039           0 :   out[10] = loadu_512(&inputs[10][block_offset]);
    1040           0 :   out[11] = loadu_512(&inputs[11][block_offset]);
    1041           0 :   out[12] = loadu_512(&inputs[12][block_offset]);
    1042           0 :   out[13] = loadu_512(&inputs[13][block_offset]);
    1043           0 :   out[14] = loadu_512(&inputs[14][block_offset]);
    1044           0 :   out[15] = loadu_512(&inputs[15][block_offset]);
    1045           0 :   for (size_t i = 0; i < 16; ++i) {
    1046           0 :     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
    1047           0 :   }
    1048           0 :   transpose_vecs_512(out);
    1049           0 : }
    1050             : 
    1051             : INLINE void load_counters16(uint64_t counter, bool increment_counter,
    1052           0 :                             __m512i *out_lo, __m512i *out_hi) {
    1053           0 :   const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
    1054           0 :   const __m512i deltas = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
    1055           0 :   const __m512i masked_deltas = _mm512_and_si512(deltas, mask);
    1056           0 :   const __m512i low_words = _mm512_add_epi32(
    1057           0 :     _mm512_set1_epi32((int32_t)counter),
    1058           0 :     masked_deltas);
    1059             :   // The carry bit is 1 if the high bit of the word was 1 before addition and is
    1060             :   // 0 after.
    1061             :   // NOTE: It would be a bit more natural to use _mm512_cmp_epu32_mask to
    1062             :   // compute the carry bits here, and originally we did, but that intrinsic is
    1063             :   // broken under GCC 5.4. See https://github.com/BLAKE3-team/BLAKE3/issues/271.
    1064           0 :   const __m512i carries = _mm512_srli_epi32(
    1065           0 :     _mm512_andnot_si512(
    1066           0 :         low_words, // 0 after (gets inverted by andnot)
    1067           0 :         _mm512_set1_epi32((int32_t)counter)), // and 1 before
    1068           0 :     31);
    1069           0 :   const __m512i high_words = _mm512_add_epi32(
    1070           0 :     _mm512_set1_epi32((int32_t)(counter >> 32)),
    1071           0 :     carries);
    1072           0 :   *out_lo = low_words;
    1073           0 :   *out_hi = high_words;
    1074           0 : }
    1075             : 
    1076             : static
    1077             : void fd_blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
    1078             :                              const uint32_t key[8], uint64_t counter,
    1079             :                              bool increment_counter, uint8_t flags,
    1080             :                              uint8_t flags_start, uint8_t flags_end,
    1081           0 :                              uint8_t *out) {
    1082           0 :   __m512i h_vecs[8] = {
    1083           0 :       set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
    1084           0 :       set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
    1085           0 :   };
    1086           0 :   __m512i counter_low_vec, counter_high_vec;
    1087           0 :   load_counters16(counter, increment_counter, &counter_low_vec,
    1088           0 :                   &counter_high_vec);
    1089           0 :   uint8_t block_flags = flags | flags_start;
    1090             : 
    1091           0 :   for (size_t block = 0; block < blocks; block++) {
    1092           0 :     if (block + 1 == blocks) {
    1093           0 :       block_flags |= flags_end;
    1094           0 :     }
    1095           0 :     __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
    1096           0 :     __m512i block_flags_vec = set1_512(block_flags);
    1097           0 :     __m512i msg_vecs[16];
    1098           0 :     transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
    1099             : 
    1100           0 :     __m512i v[16] = {
    1101           0 :         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
    1102           0 :         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
    1103           0 :         set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
    1104           0 :         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
    1105           0 :     };
    1106           0 :     round_fn16(v, msg_vecs, 0);
    1107           0 :     round_fn16(v, msg_vecs, 1);
    1108           0 :     round_fn16(v, msg_vecs, 2);
    1109           0 :     round_fn16(v, msg_vecs, 3);
    1110           0 :     round_fn16(v, msg_vecs, 4);
    1111           0 :     round_fn16(v, msg_vecs, 5);
    1112           0 :     round_fn16(v, msg_vecs, 6);
    1113           0 :     h_vecs[0] = xor_512(v[0], v[8]);
    1114           0 :     h_vecs[1] = xor_512(v[1], v[9]);
    1115           0 :     h_vecs[2] = xor_512(v[2], v[10]);
    1116           0 :     h_vecs[3] = xor_512(v[3], v[11]);
    1117           0 :     h_vecs[4] = xor_512(v[4], v[12]);
    1118           0 :     h_vecs[5] = xor_512(v[5], v[13]);
    1119           0 :     h_vecs[6] = xor_512(v[6], v[14]);
    1120           0 :     h_vecs[7] = xor_512(v[7], v[15]);
    1121             : 
    1122           0 :     block_flags = flags;
    1123           0 :   }
    1124             : 
    1125             :   // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
    1126             :   // state vectors. Pad the matrix with zeros. After transposition, store the
    1127             :   // lower half of each vector.
    1128           0 :   __m512i padded[16] = {
    1129           0 :       h_vecs[0],   h_vecs[1],   h_vecs[2],   h_vecs[3],
    1130           0 :       h_vecs[4],   h_vecs[5],   h_vecs[6],   h_vecs[7],
    1131           0 :       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
    1132           0 :       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
    1133           0 :   };
    1134           0 :   transpose_vecs_512(padded);
    1135           0 :   _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
    1136           0 :   _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
    1137           0 :   _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
    1138           0 :   _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
    1139           0 :   _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
    1140           0 :   _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
    1141           0 :   _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
    1142           0 :   _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
    1143           0 :   _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
    1144           0 :   _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
    1145           0 :   _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
    1146           0 :   _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
    1147           0 :   _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
    1148           0 :   _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
    1149           0 :   _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
    1150           0 :   _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
    1151           0 : }
    1152             : 
    1153             : /*
    1154             :  * ----------------------------------------------------------------------------
    1155             :  * hash_many_avx512
    1156             :  * ----------------------------------------------------------------------------
    1157             :  */
    1158             : 
    1159             : INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
    1160             :                             const uint32_t key[8], uint64_t counter,
    1161             :                             uint8_t flags, uint8_t flags_start,
    1162           0 :                             uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
    1163           0 :   uint32_t cv[8];
    1164           0 :   memcpy(cv, key, BLAKE3_KEY_LEN);
    1165           0 :   uint8_t block_flags = flags | flags_start;
    1166           0 :   while (blocks > 0) {
    1167           0 :     if (blocks == 1) {
    1168           0 :       block_flags |= flags_end;
    1169           0 :     }
    1170           0 :     fd_blake3_compress_in_place_avx512(cv, input, BLAKE3_BLOCK_LEN, counter,
    1171           0 :                                        block_flags);
    1172           0 :     input = &input[BLAKE3_BLOCK_LEN];
    1173           0 :     blocks -= 1;
    1174           0 :     block_flags = flags;
    1175           0 :   }
    1176           0 :   memcpy(out, cv, BLAKE3_OUT_LEN);
    1177           0 : }
    1178             : 
    1179             : void fd_blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
    1180             :                                 size_t blocks, const uint32_t key[8],
    1181             :                                 uint64_t counter, bool increment_counter,
    1182             :                                 uint8_t flags, uint8_t flags_start,
    1183           0 :                                 uint8_t flags_end, uint8_t *out) {
    1184           0 :   while (num_inputs >= 16) {
    1185           0 :     fd_blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
    1186           0 :                             flags_start, flags_end, out);
    1187           0 :     if (increment_counter) {
    1188           0 :       counter += 16;
    1189           0 :     }
    1190           0 :     inputs += 16;
    1191           0 :     num_inputs -= 16;
    1192           0 :     out = &out[16 * BLAKE3_OUT_LEN];
    1193           0 :   }
    1194           0 :   while (num_inputs >= 8) {
    1195           0 :     fd_blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
    1196           0 :                            flags_start, flags_end, out);
    1197           0 :     if (increment_counter) {
    1198           0 :       counter += 8;
    1199           0 :     }
    1200           0 :     inputs += 8;
    1201           0 :     num_inputs -= 8;
    1202           0 :     out = &out[8 * BLAKE3_OUT_LEN];
    1203           0 :   }
    1204           0 :   while (num_inputs >= 4) {
    1205           0 :     fd_blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
    1206           0 :                            flags_start, flags_end, out);
    1207           0 :     if (increment_counter) {
    1208           0 :       counter += 4;
    1209           0 :     }
    1210           0 :     inputs += 4;
    1211           0 :     num_inputs -= 4;
    1212           0 :     out = &out[4 * BLAKE3_OUT_LEN];
    1213           0 :   }
    1214           0 :   while (num_inputs > 0) {
    1215           0 :     hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
    1216           0 :                     flags_end, out);
    1217           0 :     if (increment_counter) {
    1218           0 :       counter += 1;
    1219           0 :     }
    1220           0 :     inputs += 1;
    1221           0 :     num_inputs -= 1;
    1222           0 :     out = &out[BLAKE3_OUT_LEN];
    1223           0 :   }
    1224           0 : }

Generated by: LCOV version 1.14