LCOV - code coverage report
Current view: top level - ballet/wsample - fd_wsample.c (source / functions) Hit Total Coverage
Test: cov.lcov Lines: 262 295 88.8 %
Date: 2026-01-23 05:02:40 Functions: 20 20 100.0 %

          Line data    Source code
       1             : #include "fd_wsample.h"
       2             : #include <math.h> /* For sqrt */
       3             : #if FD_HAS_AVX512
       4             : #include "../../util/simd/fd_avx512.h"
       5             : #endif
       6             : 
       7 58538995809 : #define R 9
       8             : /* This sampling problem is an interesting one from a performance
       9             :    perspective.  There are lots of interesting approaches.  The
      10             :    header/implementation split is designed to give lots of flexibility
      11             :    for future optimization.  The current implementation uses radix 9
      12             :    tree with all the leaves on the bottom. */
      13             : 
      14             : /* I'm not sure exactly how to classify the tree that this
      15             :    implementation uses, but it's something like a B-tree with some
      16             :    tricks from binary heaps.  In particular, like a B-tree, each node
      17             :    stores several keys, where the keys are the cumulative sums of
      18             :    subtrees, called left sums below.  Like a binary heap, it is
      19             :    pointer-free, and it is stored implicitly in a flat array.  the
      20             :    typical way that a binary heap is stored.  Specifically, the root is
      21             :    at index 0, and node i's children are at Ri+1, Ri+2, ... Ri+R, where
      22             :    R is the radix.  The leaves are all at the same level, at the bottom,
      23             :    stored implicitly, which means that a search can be done with almost
      24             :    no branch mispredictions.
      25             : 
      26             :    As an example, suppose R=5 and our tree contains weights 100, 80, 50,
      27             :    31, 27, 14, and 6.  Because that is 7 nodes, the height is set to
      28             :    ceil(log_5(7))=2.
      29             : 
      30             :                           Root (0)
      31             :                           /        \
      32             :                    Child (1)      Child (2)     -- -- --
      33             :                /  /   |  \   \     |  \
      34             :             100, 80, 50, 31, 27   14, 6
      35             : 
      36             :    The first child of the root, node 1, has left_sum values
      37             :    |100|180|230|261|.  Note that we only store R-1 values, and so the
      38             :    last child, 27, does not feature in the sums; the search process
      39             :    handles it implicitly, as we'll see.  The second child of the root,
      40             :    node 2, has left sum values |14|20|20|20|.  Then the root node, node
      41             :    0, has left sum values |288|308|308|308|.
      42             : 
      43             :    The total sum is 308, so we start by drawing a random values in the
      44             :    range [0, 308).  We see which child that falls under, adjust our
      45             :    random value, and repeat, until we reach a leaf.
      46             : 
      47             :    In general, if a node has children (left to right) with subtree
      48             :    weights a, b, c, d, e.  Then the left sums are |a|a+b|a+b+c|a+b+c+d|
      49             :    and the full sum S=a+b+c+d+e.  For a query value of x in [0, S), we
      50             :    want to pick:
      51             :       child 0   if             x < a
      52             :       child 1   if   a      <= x < a+b
      53             :       child 2   if   a+b    <= x < a+b+c
      54             :       child 3   if   a+b+c  <= x < a+b+c+d
      55             :       child 4   if   a+b+c+d<= x
      56             : 
      57             :    Which is equivalent to choosing child
      58             :           (a<=x) + (a+b<=x) + (a+b+c<=x) + (a+b+c+d<=x)
      59             :    which can be computed branchlessly.  The value of e only comes into
      60             :    play in the restriction that x is in [0, S), and e is included in the
      61             :    value of S.
      62             : 
      63             :    There are two details left to discuss in order to search recursively.
      64             :    First, in order to remove an element for sampling without
      65             :    replacement, we need to know the weight of the element we're
      66             :    removing.  As with e above, the weights may not be stored explicitly,
      67             :    but as long as we keep track of the weight of the current subtree as
      68             :    we search recursively, we can obtain the weight without much work.
      69             :    Thus, we need to know the sum S' of the chosen subtree, i in [0, R).
      70             :    For j in [-1, R), define
      71             :                             /  0             if j==-1
      72             :                     l[j] =  | left_sum[ j ]  if j in [0, R-1)
      73             :                             \  S             if j==R-1
      74             :    Essentially, we're computing the natural extension values for
      75             :    left_sum on the left and right side.  Then observe that S' = l[i] -
      76             :    l[i-1].
      77             : 
      78             :    Secondly, in order to search recursively, we need to adjust the query
      79             :    value to put it into [0, S').  Specifically, we need to subtract off
      80             :    the sum of all the children to the left of the child we've chosen.
      81             :    If we've chosen child i, then that's just l[i-1].
      82             : 
      83             :    All the above extends easily to any radix, but 3, 5, and 9 = 1+2^n
      84             :    are the natural ones.  I only tried 5 and 9, and 9 was substantially
      85             :    faster.
      86             : 
      87             :    It's worth noting that storing left sums means that you have to do
      88             :    the most work when you update the left-most child.  Because we
      89             :    typically store the weights largest to smallest, the left-most child
      90             :    is the one we delete the most frequently, so now our most common case
      91             :    (probability-wise) and our worst case (performance-wise) are the
      92             :    same, which seems bad.  I initially implemented this using right sums
      93             :    instead of left to address this problem.  They're less intuitive, but
      94             :    also work.  However, I found that having an unpredictable loop in the
      95             :    deletion method was far worse than just updating each element, which
      96             :    means that the "less work" advantage of right sums went away. */
      97             : 
      98             : struct __attribute__((aligned(8UL*(R-1UL)))) tree_ele {
      99             :   /* left_sum stores the cumulative weight of the subtrees at this
     100             :      node.  See the long note above for more information. */
     101             :   ulong left_sum[ R-1 ];
     102             : };
     103             : typedef struct tree_ele tree_ele_t;
     104             : 
     105             : struct __attribute__((aligned(64UL))) fd_wsample_private {
     106             :   ulong              total_cnt;
     107             :   ulong              total_weight;
     108             :   ulong              unremoved_cnt;
     109             :   ulong              unremoved_weight; /* Initial value for S explained above */
     110             : 
     111             :   /* internal_node_cnt and height are both determined by the number of
     112             :      leaves in the original tree, via the following formulas:
     113             :      height = ceil(log_r(leaf_cnt))
     114             :      internal_node_cnt = sum_{i=0}^{height-1} R^i
     115             :      height and internal_node_cnt both exclude the leaves, which are
     116             :      only implicit.
     117             : 
     118             :      All the math seems to disallow leaf_cnt==0, but for convenience, we
     119             :      do allow it. height==internal_node_cnt==0 in that case.
     120             : 
     121             :      height actually fits in a uchar.  Storing as ulong is more natural,
     122             :      but we don't want to spill over into another cache line.
     123             : 
     124             :      If poisoned_mode==1, then all sample calls will return poisoned.*/
     125             :   ulong              internal_node_cnt;
     126             :   ulong              poisoned_weight;
     127             :   uint               height;
     128             :   char               restore_enabled;
     129             :   char               poisoned_mode;
     130             :   /* Two bytes of padding here */
     131             : 
     132             :   fd_chacha_rng_t * rng;
     133             : 
     134             :   /* tree: Here's where the actual tree is stored, at indices [0,
     135             :      internal_node_cnt).  The indexing scheme is explained in the long
     136             :      comment above.
     137             : 
     138             :      If restore_enabled==1, then indices [internal_node_cnt+1,
     139             :      2*internal_node_cnt+1) store a copy of the tree after construction
     140             :      but before any deletion so that restoring deleted elements can be
     141             :      implemented as a memcpy.
     142             : 
     143             :      The tree itself is surrounded by two dummy elements, dummy, and
     144             :      tree[internal_node_cnt], that aren't actually used.  This is
     145             :      because searching the tree branchlessly involves some out of bounds
     146             :      reads, and although the value is immediately discarded, it's better
     147             :      to know where exactly those reads might go. */
     148             :   tree_ele_t        dummy;
     149             :   tree_ele_t        tree[];
     150             : };
     151             : 
     152             : typedef struct fd_wsample_private fd_wsample_t;
     153             : 
     154             : 
     155             : FD_FN_CONST ulong
     156     1520106 : fd_wsample_align( void ) {
     157     1520106 :   return 64UL;
     158     1520106 : }
     159             : 
     160             : /* Returns -1 on failure */
     161             : static inline int
     162             : compute_height( ulong   leaf_cnt,
     163             :                 ulong * out_height,
     164      769422 :                 ulong * out_internal_cnt ) {
     165             :   /* This max is a bit conservative.  The actual max is height <= 25,
     166             :      and leaf_cnt < 5^25 approx 2^58.  A tree that large would take an
     167             :      astronomical amount of memory, so we just retain this max for the
     168             :      moment. */
     169      769422 :   if( FD_UNLIKELY( leaf_cnt >= UINT_MAX-2UL ) ) return -1;
     170             : 
     171      769422 :   ulong height   = 0;
     172      769422 :   ulong internal = 0UL;
     173      769422 :   ulong powRh    = 1UL; /* = R^height */
     174     3433704 :   while( leaf_cnt>powRh ) {
     175     2664282 :     internal += powRh;
     176     2664282 :     powRh    *= R;
     177     2664282 :     height++;
     178     2664282 :   }
     179      769422 :   *out_height       = height;
     180      769422 :   *out_internal_cnt = internal;
     181      769422 :   return 0;
     182      769422 : }
     183             : 
     184             : FD_FN_CONST ulong
     185      447546 : fd_wsample_footprint( ulong ele_cnt, int restore_enabled ) {
     186      447546 :   ulong height;
     187      447546 :   ulong internal_cnt;
     188             :   /* Computing the closed form of the sum in compute_height, we get
     189             :      internal_cnt = 1/8 * (9^ceil(log_9( ele_cnt ) ) - 1)
     190             :                  x <= ceil( x ) < x+1
     191             :      1/8 * ele_cnt - 1/8 <= internal_cnt < 9/8 * ele_cnt - 1/8
     192             :   */
     193      447546 :   if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) return 0UL;
     194      447546 :   return sizeof(fd_wsample_t) + ((restore_enabled?2UL:1UL)*internal_cnt + 1UL)*sizeof(tree_ele_t);
     195      447546 : }
     196             : 
     197             : fd_wsample_t *
     198      321864 : fd_wsample_join( void * shmem  ) {
     199      321864 :   if( FD_UNLIKELY( !shmem ) ) {
     200           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     201           0 :     return NULL;
     202           0 :   }
     203             : 
     204      321864 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     205           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     206           0 :     return NULL;
     207           0 :   }
     208      321864 :   return (fd_wsample_t *)shmem;
     209      321864 : }
     210             : 
     211             : /* Note: The following optimization insights are not used in this
     212             :    high radix implementation.  Performance in the deletion case is much
     213             :    more important than in the non-deletion case, and it's not clear how
     214             :    to translate this.  I'm leaving the code and comment because it is a
     215             :    useful and non-trivial insight. */
     216             : #if 0
     217             : /* If we assume the probability of querying node i is proportional to
     218             :    1/i, then observe that the midpoint of the probability mass in the
     219             :    continuous approximation is the solution to (in Mathematica syntax):
     220             : 
     221             :         Integrate[ 1/i, {i, lo, hi}] = 2*Integrate[ 1/i, {i, lo, mid} ]
     222             : 
     223             :    which gives mid = sqrt(lo*hi).  This is in contrast to when the
     224             :    integrand is a constant, which gives the normal binary search rule:
     225             :    mid=(lo+hi)/2.
     226             : 
     227             :    Thus, we want the search to follow this modified binary search rule,
     228             :    since that'll approximately split the stake weight/probability mass
     229             :    in half at each step.
     230             : 
     231             :    This is almost as nice to work with mathematically as the normal
     232             :    binary search rule.  The jth entry from the left at level k is the
     233             :    region [ N^((1/2^k)*j), N^((1/2^k)*(j+1)) ).  We're basically doing
     234             :    binary search in the log domain.
     235             : 
     236             :    Rather than trying to compute these transcendental functions, this
     237             :    simple recursive implementation gives the treap the right shape by
     238             :    setting prio very carefully, since higher prio means closer to the
     239             :    root.  This may be a slight abuse of a treap, but it's easier than
     240             :    implementing a whole custom tree for just this purpose.
     241             : 
     242             :    We want to be careful to ensure that the recursion is tightly
     243             :    bounded.  From a continuous perspective, that's not a problem: the
     244             :    widest interval at level k is the last one, and we break when the
     245             :    interval's width is less than 1.  Solving 1=N-N^(1-(1/2^k)) for k
     246             :    yields k=-lg(1-log(N-1)/log(N)).  k is monotonically increasing as a
     247             :    function of N, which means that for all N,
     248             :            k <= -lg(1-log(N_max-1)/log(N_max)) < 37
     249             :    since N_max=2^32-1.
     250             : 
     251             :    The math is more complicated with rounding and finite precision, but
     252             :    sqrt(lo*hi) is very different from lo and hi unless lo and hi are
     253             :    approximately the same.  In that case, lo<mid and mid<hi ensures that
     254             :    both intervals are strictly smaller than the interval they came from,
     255             :    which prevents an infinite loop. */
     256             : static inline void
     257             : seed_recursive( treap_ele_t * pool,
     258             :                 uint lo,
     259             :                 uint hi,
     260             :                 uint prio ) {
     261             :   uint mid = (uint)(sqrtf( (float)lo*(float)hi ) + 0.5f);
     262             :   if( (lo<mid) & (mid<hi) ) {
     263             :     /* since we start with lo=1, shift by 1 */
     264             :     pool[mid-1U].prio = prio;
     265             :     seed_recursive( pool, lo,  mid, prio-1U );
     266             :     seed_recursive( pool, mid, hi,  prio-1U );
     267             :   }
     268             : }
     269             : #endif
     270             : 
     271             : 
     272             : void *
     273             : fd_wsample_new_init( void             * shmem,
     274             :                      fd_chacha_rng_t * rng,
     275             :                      ulong              ele_cnt,
     276             :                      int                restore_enabled,
     277      321876 :                      int                opt_hint ) {
     278      321876 :   if( FD_UNLIKELY( !shmem ) ) {
     279           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     280           0 :     return NULL;
     281           0 :   }
     282             : 
     283      321876 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     284           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     285           0 :     return NULL;
     286           0 :   }
     287             : 
     288      321876 :   ulong height;
     289      321876 :   ulong internal_cnt;
     290      321876 :   if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) {
     291           0 :     FD_LOG_WARNING(( "bad ele_cnt" ));
     292           0 :     return NULL;
     293           0 :   }
     294             : 
     295      321876 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     296             : 
     297      321876 :   sampler->total_weight      = 0UL;
     298      321876 :   sampler->unremoved_cnt     = 0UL;
     299      321876 :   sampler->unremoved_weight  = 0UL;
     300      321876 :   sampler->internal_node_cnt = internal_cnt;
     301      321876 :   sampler->poisoned_weight   = 0UL;
     302      321876 :   sampler->height            = (uint)height;
     303      321876 :   sampler->restore_enabled   = (char)!!restore_enabled;
     304      321876 :   sampler->poisoned_mode     = 0;
     305      321876 :   sampler->rng               = rng;
     306             : 
     307      321876 :   fd_memset( sampler->tree, (char)0, internal_cnt*sizeof(tree_ele_t) );
     308             : 
     309      321876 :   (void)opt_hint; /* Not used at the moment */
     310             : 
     311      321876 :   return shmem;
     312      321876 : }
     313             : 
     314             : void *
     315             : fd_wsample_new_add( void * shmem,
     316   202286043 :                     ulong  weight ) {
     317   202286043 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     318   202286043 :   if( FD_UNLIKELY( !sampler ) ) return NULL;
     319             : 
     320   202286043 :   if( FD_UNLIKELY( weight==0UL ) ) {
     321           0 :     FD_LOG_WARNING(( "zero weight entry found" ));
     322           0 :     return NULL;
     323           0 :   }
     324   202286043 :   if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
     325           0 :     FD_LOG_WARNING(( "total weight too large" ));
     326           0 :     return NULL;
     327           0 :   }
     328             : 
     329   202286043 :   tree_ele_t * tree = sampler->tree;
     330   202286043 :   ulong i = sampler->internal_node_cnt + sampler->unremoved_cnt;
     331             : 
     332  1011149181 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     333   808863138 :     ulong parent = (i-1UL)/R;
     334   808863138 :     ulong child_idx = i-1UL - R*parent; /* in [0, R) */
     335  4956555249 :     for( ulong k=child_idx; k<R-1UL; k++ )  tree[ parent ].left_sum[ k ] += weight;
     336   808863138 :     i = parent;
     337   808863138 :   }
     338             : 
     339   202286043 :   sampler->unremoved_cnt++;
     340   202286043 :   sampler->total_cnt++;
     341   202286043 :   sampler->unremoved_weight += weight;
     342   202286043 :   sampler->total_weight     += weight;
     343             : 
     344   202286043 :   return shmem;
     345   202286043 : }
     346             : 
     347             : void *
     348             : fd_wsample_new_fini( void * shmem,
     349      321876 :                      ulong  poisoned_weight ) {
     350      321876 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     351      321876 :   if( FD_UNLIKELY( !sampler ) ) return NULL;
     352             : 
     353      321876 :   if( FD_UNLIKELY( sampler->total_weight+poisoned_weight<sampler->total_weight ) ) {
     354           0 :     FD_LOG_WARNING(( "poisoned_weight caused overflow" ));
     355           0 :     return NULL;
     356           0 :   }
     357             : 
     358      321876 :   sampler->poisoned_weight = poisoned_weight;
     359             : 
     360      321876 :   if( sampler->restore_enabled ) {
     361             :     /* Copy the sampler to make restore fast. */
     362      313602 :     fd_memcpy( sampler->tree+sampler->internal_node_cnt+1UL, sampler->tree, sampler->internal_node_cnt*sizeof(tree_ele_t) );
     363      313602 :   }
     364             : 
     365      321876 :   return (void *)sampler;
     366      321876 : }
     367             : 
     368             : void *
     369      321582 : fd_wsample_leave( fd_wsample_t * sampler ) {
     370      321582 :   if( FD_UNLIKELY( !sampler ) ) {
     371           0 :     FD_LOG_WARNING(( "NULL sampler" ));
     372           0 :     return NULL;
     373           0 :   }
     374             : 
     375      321582 :   return (void *)sampler;
     376      321582 : }
     377             : 
     378             : void *
     379      321582 : fd_wsample_delete( void * shmem  ) {
     380      321582 :   if( FD_UNLIKELY( !shmem ) ) {
     381           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     382           0 :     return NULL;
     383           0 :   }
     384             : 
     385      321582 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     386           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     387           0 :     return NULL;
     388           0 :   }
     389      321582 :   return shmem;
     390      321582 : }
     391             : 
     392             : void
     393             : fd_wsample_seed_rng( fd_wsample_t * sampler,
     394             :                      uchar          seed[ 32 ],
     395     2586789 :                      int            use_chacha8 ) {
     396     2586789 :   fd_chacha_rng_init( sampler->rng, seed, use_chacha8 ? FD_CHACHA_RNG_ALGO_CHACHA8 : FD_CHACHA_RNG_ALGO_CHACHA20 );
     397     2586789 : }
     398             : 
     399             : fd_wsample_t *
     400     3494316 : fd_wsample_restore_all( fd_wsample_t * sampler ) {
     401     3494316 :   if( FD_UNLIKELY( !sampler->restore_enabled ) )  return NULL;
     402             : 
     403     3494313 :   sampler->unremoved_weight = sampler->total_weight;
     404     3494313 :   sampler->unremoved_cnt    = sampler->total_cnt;
     405     3494313 :   sampler->poisoned_mode    = 0;
     406             : 
     407     3494313 :   fd_memcpy( sampler->tree, sampler->tree+sampler->internal_node_cnt+1UL, sampler->internal_node_cnt*sizeof(tree_ele_t) );
     408     3494313 :   return sampler;
     409     3494316 : }
     410             : 
     411             : #define fd_ulong_if_force( c, t, f ) (__extension__({ \
     412             :       ulong result;                                   \
     413             :       __asm__( "testl  %1, %1; \n\t"                  \
     414             :                "movq   %3, %0; \n\t"                  \
     415             :                "cmovne %2, %0; \n\t"                  \
     416             :                : "=&r"(result)                        \
     417             :                : "r"(c), "rm"(t), "rmi"(f)            \
     418             :                : "cc" );                              \
     419             :       result;                                         \
     420             :       }))
     421             : 
     422             : /* Helper methods for sampling functions */
     423             : typedef struct { ulong idx; ulong weight; } idxw_pair_t; /* idx in [0, total_cnt) */
     424             : 
     425             : /* Assumes query in [0, unremoved_weight), which implies
     426             :    unremoved_weight>0, so the tree can't be empty. */
     427             : static inline idxw_pair_t
     428             : fd_wsample_map_sample_i( fd_wsample_t const * sampler,
     429  1053480876 :                          ulong                query ) {
     430  1053480876 :   tree_ele_t const * tree = sampler->tree;
     431             : 
     432  1053480876 :   ulong cursor = 0UL;
     433  1053480876 :   ulong S      = sampler->unremoved_weight;
     434  4663553676 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     435  3610072800 :     tree_ele_t const * e = tree+cursor;
     436  3610072800 :     ulong x = query;
     437  3610072800 :     ulong child_idx = 0UL;
     438             : 
     439   845406000 : #if FD_HAS_AVX512 && R==9
     440   845406000 :     __mmask8 mask = _mm512_cmple_epu64_mask( wwv_ld( e->left_sum ), wwv_bcast( x ) );
     441   845406000 :     child_idx = (ulong)fd_uchar_popcnt( mask );
     442             : #else
     443 24882001200 :     for( ulong i=0UL; i<R-1UL; i++ ) child_idx += (ulong)(e->left_sum[ i ]<=x);
     444  2764666800 : #endif
     445             : 
     446             :     /* See the note at the top of this file for the explanation of l[i]
     447             :        and l[i-1].  Because this is fd_ulong_if and not a ternary, these
     448             :        can read/write out of what you would think the appropriate bounds
     449             :        are.  The dummy elements, as described along with tree makes this
     450             :        safe. */
     451             : #if 0
     452             :     ulong li  = fd_ulong_if( child_idx<R-1UL, e->left_sum[ child_idx     ], S   );
     453             :     ulong lm1 = fd_ulong_if( child_idx>0UL,   e->left_sum[ child_idx-1UL ], 0UL );
     454             : #elif 0
     455             :     ulong li  = fd_ulong_if_force( child_idx<R-1UL, e->left_sum[ child_idx     ], S   );
     456             :     ulong lm1 = fd_ulong_if_force( child_idx>0UL,   e->left_sum[ child_idx-1UL ], 0UL );
     457             : #else
     458  3610072800 :     ulong * temp = (ulong *)e->left_sum;
     459  3610072800 :     ulong orig_m1 = temp[ -1 ];    ulong orig_Rm1 = temp[ R-1UL ];
     460  3610072800 :     temp[ -1 ] = 0UL;              temp[ R-1UL ] = S;
     461  3610072800 :     ulong li  = temp[ child_idx     ];
     462  3610072800 :     ulong lm1 = temp[ child_idx-1UL ];
     463  3610072800 :     temp[ -1 ] = orig_m1;          temp[ R-1UL ] = orig_Rm1;
     464  3610072800 : #endif
     465             : 
     466  3610072800 :     query -= lm1;
     467  3610072800 :     S = li - lm1;
     468  3610072800 :     cursor = R*cursor + child_idx + 1UL;
     469  3610072800 :   }
     470  1053480876 :   idxw_pair_t to_return = { .idx = cursor - sampler->internal_node_cnt, .weight = S };
     471  1053480876 :   return to_return;
     472  1053480876 : }
     473             : 
     474             : ulong
     475             : fd_wsample_map_sample( fd_wsample_t * sampler,
     476   768757023 :                        ulong          query ) {
     477   768757023 :   return fd_wsample_map_sample_i( sampler, query ).idx;
     478   768757023 : }
     479             : 
     480             : 
     481             : 
     482             : /* Also requires the tree to be non-empty */
     483             : static inline void
     484             : fd_wsample_remove( fd_wsample_t * sampler,
     485   286511820 :                    idxw_pair_t    to_remove ) {
     486   286511820 :   ulong cursor = to_remove.idx + sampler->internal_node_cnt;
     487   286511820 :   tree_ele_t * tree = sampler->tree;
     488             : 
     489  1396215531 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     490  1109703711 :     ulong parent = (cursor-1UL)/R;
     491  1109703711 :     ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
     492    11949637 : #if FD_HAS_AVX512 && R==9
     493    11949637 :     wwv_t weight = wwv_bcast( to_remove.weight );
     494    11949637 :     wwv_t left_sum = wwv_ld( tree[ parent ].left_sum );
     495    11949637 :     __m128i _child_idx = _mm_set1_epi16( (short) child_idx );
     496    11949637 :     __mmask8 mask = _mm_cmplt_epi16_mask( _child_idx, _mm_setr_epi16( 1, 2, 3, 4, 5, 6, 7, 8 ) );
     497    11949637 :     left_sum = _mm512_mask_sub_epi64( left_sum, mask, left_sum, weight );
     498    11949637 :     wwv_st( tree[ parent ].left_sum, left_sum );
     499             : #elif 0
     500             :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if( child_idx<=k, to_remove.weight, 0UL );
     501             : #elif 0
     502             :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if_force( child_idx<=k, to_remove.weight, 0UL );
     503             : #elif 1
     504             :     /* The compiler loves inserting a difficult to predict branch for
     505             :        fd_ulong_if, but this forces it not to do that. */
     506  9879786666 :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= (ulong)(((long)(child_idx - k - 1UL))>>63) & to_remove.weight;
     507             : #else
     508             :     /* This version does the least work, but has a hard-to-predict
     509             :        branch.  The branchless versions are normally substantially
     510             :        faster. */
     511             :     for( ulong k=child_idx; k<R-1UL; k++ )  tree[ parent ].left_sum[ k ] -= to_remove.weight;
     512             : #endif
     513  1109703711 :     cursor = parent;
     514  1109703711 :   }
     515   286511820 :   sampler->unremoved_cnt--;
     516   286511820 :   sampler->unremoved_weight -= to_remove.weight;
     517   286511820 : }
     518             : 
     519             : static inline ulong
     520             : fd_wsample_find_weight( fd_wsample_t const * sampler,
     521     1787967 :                         ulong                idx /* in [0, total_cnt) */) {
     522             :   /* The fact we don't store the weights explicitly makes this function
     523             :      more complicated, but this is not used very frequently. */
     524     1787967 :   tree_ele_t const * tree = sampler->tree;
     525     1787967 :   ulong cursor = idx + sampler->internal_node_cnt;
     526             : 
     527             :   /* Initialize to the 0 height case */
     528     1787967 :   ulong lm1 = 0UL;
     529     1787967 :   ulong li  = sampler->unremoved_weight;
     530             : 
     531     1818249 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     532     1818057 :     ulong parent = (cursor-1UL)/R;
     533     1818057 :     ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
     534             : 
     535             :     /* If child_idx < R-1, we can compute the weight easily.  If
     536             :        child_idx==R-1, the computation is S - left_sum[ R-2 ], but we
     537             :        don't know S, so we need to continue up the tree. */
     538     1818057 :     lm1  += fd_ulong_if( child_idx>0UL, tree[ parent ].left_sum[ child_idx-1UL ], 0UL );
     539     1818057 :     if( FD_LIKELY( child_idx<R-1UL ) ) {
     540     1787775 :       li = tree[ parent ].left_sum[ child_idx ];
     541     1787775 :       break;
     542     1787775 :     }
     543             : 
     544       30282 :     cursor = parent;
     545       30282 :   }
     546             : 
     547     1787967 :   return li - lm1;
     548     1787967 : }
     549             : 
     550             : void
     551             : fd_wsample_remove_idx( fd_wsample_t * sampler,
     552     1787967 :                        ulong          idx ) {
     553             : 
     554     1787967 :   ulong weight = fd_wsample_find_weight( sampler, idx );
     555     1787967 :   idxw_pair_t r = { .idx = idx, .weight = weight };
     556     1787967 :   fd_wsample_remove( sampler, r );
     557     1787967 : }
     558             : 
     559             : 
     560             : #if FD_HAS_AVX512
     561             : 
     562             : /* TRAVERSE_LEVEL Takes in and updates s (a vector with all elements set
     563             :    to the unremoved weight in the current subtree), x or xprime (a
     564             :    vector with all elements set to the current query value, or the
     565             :    value+1, respectively), and cursor (a ulong of the current position
     566             :    in the traversal).  See fd_wsample_map_sample_i for more about these
     567             :    values.  Calling this macro height times is the same as calling
     568             :    fd_wsample_map_sample_i, except for that it declares mask{i} and
     569             :    cursor{i}, to save work in PROPAGATE_LEVEL, which is like
     570             :    fd_wsample_remove.
     571             : 
     572             :    This only works in the R=9 case, but we'll explain it as if R=5 like
     573             :    before.  Then, v, a vector of tree elements has values
     574             :            ----------------------------------------
     575             :            |    a    |   a+b  |  a+b+c  | a+b+c+d |
     576             :            ----------------------------------------
     577             :    Then, like before, the desired state is
     578             :          If...                pick,   and  l_{i-1}    and l_i
     579             :               x < a         child 0        0             a
     580             :     a      <= x < a+b       child 1        a            a+b
     581             :     a+b    <= x < a+b+c     child 2       a+b          a+b+c
     582             :     a+b+c  <= x < a+b+c+d   child 3      a+b+c        a+b+c+d
     583             :     a+b+c+d<= x             child 4     a+b+c+d      a+b+c+d+e
     584             : 
     585             :    Below, there are two pretty different implementations, and this
     586             :    explains both:
     587             : 
     588             :    Implementation 1:
     589             :    We want to get information about v0<=x in as much as the vector as
     590             :    fast as possible.  We could do some kind of compress and then bcast,
     591             :    but operations with mask registers are frustratingly slow.  Instead,
     592             :    we first define x'=x+1, so then v0<=x is equivalent to v0-x'<0, or
     593             :    whether v0-x' has the high bit set.  Because of
     594             :    fd_chacha_rng_ulong_roll's contract, we know x<ULONG_MAX, so forming
     595             :    x' is safe. We then use a _mm512_permutexvar_epi8 to essentially
     596             :    broadcast the high byte from each of the ulongs (really we just need
     597             :    the high bit) to the whole vector.  A popcnt gives us our base
     598             :    values.
     599             : 
     600             :                   Case             compressed_signs   popcnt  popcnt-1
     601             :                    x < a                      0         0       -1
     602             :          a      <= x < a+b                 0x80         1        0
     603             :          a+b    <= x < a+b+c             0x8080         2        1
     604             :          a+b+c  <= x < a+b+c+d         0x808080         3        2
     605             :          a+b+c+d<= x                 0x80808080         4        3
     606             : 
     607             :    Then, keeping in mind that _mm512_permutex2var_epi64 will select an
     608             :    element from the second vector if the next highest bit is set, using
     609             :    popcnt-1 as the selector and (vec, 0) as (a,b) gives l_{i-1}; and
     610             :    using popcnt as the selector with (vec, s) gives l_i.
     611             : 
     612             :    Finally, we just need to produce the mask mask{i}, which we can do
     613             :    with wwv_lt, off the critical path.
     614             : 
     615             : 
     616             :    Implementation 2
     617             :    We first use wwv_le to compare v0<=x to form lmask.  Then We extract
     618             :    the most significant bit by lmask^(lmask>>1).  Note that below the
     619             :    bits are written least to most significant, which is backwards of the
     620             :    normal way.
     621             : 
     622             :                Case            lmask       single_bit     ~lmask
     623             :                 x < a         0 0 0 0       0 0 0 0       1 1 1 1
     624             :       a      <= x < a+b       1 0 0 0       1 0 0 0       0 1 1 1
     625             :       a+b    <= x < a+b+c     1 1 0 0       0 1 0 0       0 0 1 1
     626             :       a+b+c  <= x < a+b+c+d   1 1 1 0       0 0 1 0       0 0 0 1
     627             :       a+b+c+d<= x             1 1 1 1       0 0 0 1       0 0 0 0
     628             : 
     629             :    Then we use _mm512_mask_compress_epi64 to move the element
     630             :    corresponding to the first set bit to the 0th position, or the src
     631             :    vector if there aren't any set bits.  From there, we can use
     632             :    _mm512_broadcastq_epi64 to fill a vector with that value.  Applying
     633             :   this trick to single_bit gives us l_{i-1} and to ~lmask give l_i. */
     634             : #define FD_WSAMPLE_IMPLEMENTATION 2
     635             : 
     636             : #if FD_WSAMPLE_IMPLEMENTATION==1
     637             : #define PREPARE()                                                                            \
     638             :   ulong cursor           = 0UL;                                                              \
     639             :   wwv_t xprime           = wwv_bcast( unif+1UL );                                            \
     640             :   wwv_t s                = wwv_bcast( sampler->unremoved_weight );                           \
     641             :   wwv_t high_bit_mask    = wwv_bcast( 0x8000000000000000UL );                                \
     642             :   wwv_t gather_signs_idx = wwv_bcast( 0x070F171F272F373FUL )
     643             : 
     644             : #define TRAVERSE_LEVEL(i)                                                                    \
     645             :   __mmask8 mask##i;                                                                          \
     646             :   ulong    cursor##i = cursor;                                                               \
     647             :   wwv_t    vec##i;                                                                           \
     648             :   do {                                                                                       \
     649             :     wwv_t vec  = wwv_ld( tree[cursor].left_sum );                                            \
     650             :     wwv_t sign = wwv_and( high_bit_mask, wwv_sub( vec, xprime ) );                           \
     651             :     wwv_t compressed_signs = _mm512_permutexvar_epi8( gather_signs_idx, sign );              \
     652             :     wwv_t popcnt = _mm512_popcnt_epi64( compressed_signs );                                  \
     653             :     wwv_t li = _mm512_permutex2var_epi64( vec, popcnt, s );                                  \
     654             :     wwv_t lim1 = _mm512_permutex2var_epi64( vec, wwv_sub( popcnt, wwv_one() ), wwv_zero() ); \
     655             :     mask##i = _mm512_cmpge_epu64_mask( vec, xprime );                                        \
     656             :     xprime = wwv_sub( xprime, lim1 );                                                        \
     657             :     s = wwv_sub( li, lim1 );                                                                 \
     658             :     ulong child_idx = (ulong)_mm_extract_epi64( _mm512_castsi512_si128( popcnt ), 0 );       \
     659             :     cursor = R*cursor + child_idx + 1UL;                                                     \
     660             :     vec##i = vec;                                                                            \
     661             :   } while( 0 )
     662             : 
     663             : #define PROPAGATE_LEVEL(i)                                                                   \
     664             :   wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
     665             : 
     666             : #define FINALIZE()                                                                              \
     667             :   do {                                                                                          \
     668             :     sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
     669             :     sampler->unremoved_cnt--;                                                                   \
     670             :   } while( 0 )
     671             : 
     672             : #elif FD_WSAMPLE_IMPLEMENTATION==2
     673             : 
     674             : #define PREPARE()                                                                            \
     675   134231850 :   ulong cursor           = 0UL;                                                              \
     676   134231850 :   wwv_t x                = wwv_bcast( unif );                                                \
     677   134231850 :   wwv_t s                = wwv_bcast( sampler->unremoved_weight );                           \
     678             : 
     679             : #define TRAVERSE_LEVEL(i)                                                                                             \
     680   536927400 :   __mmask8 mask##i;                                                                                                   \
     681   536927400 :   ulong    cursor##i = cursor;                                                                                        \
     682   536927400 :   wwv_t    vec##i;                                                                                                    \
     683   536927400 :   do {                                                                                                                \
     684   536927400 :     wwv_t vec  = wwv_ld( tree[cursor].left_sum );                                                                     \
     685   536927400 :     __mmask8 lmask = _mm512_cmple_epu64_mask( vec, x );                                                               \
     686   536927400 :     cursor = R*cursor + (ulong)fd_uchar_popcnt( lmask ) + 1UL;                                                        \
     687   536927400 :     __mmask8 single_bit = _kxor_mask8( lmask, _kshiftri_mask8( lmask, 1U ) );                                         \
     688   536927400 :     wwv_t lim1 = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_maskz_compress_epi64( single_bit, vec ) ) ); \
     689   536927400 :     x = wwv_sub( x, lim1 );                                                                                           \
     690   536927400 :     mask##i = _knot_mask8( lmask );                                                                                   \
     691   536927400 :     wwv_t li = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_mask_compress_epi64( s, mask##i, vec ) ) );    \
     692   536927400 :     s = wwv_sub( li, lim1 );                                                                                          \
     693   536927400 :     vec##i = vec;                                                                                                     \
     694   536927400 :   } while( 0 )
     695             : 
     696             : #define PROPAGATE_LEVEL(i)                                                                                          \
     697   536927400 :   wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
     698             : 
     699             : #define FINALIZE()                                                                              \
     700   134231850 :   do {                                                                                          \
     701   134231850 :     sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
     702   134231850 :     sampler->unremoved_cnt--;                                                                   \
     703   134231850 :   } while( 0 )
     704             : 
     705             : #endif
     706             : #else
     707             : #define FD_WSAMPLE_IMPLEMENTATION 0
     708             : #endif
     709             : 
     710             : /* For now, implement the _many functions as loops over the single
     711             :    sample functions.  It is possible to do better though. */
     712             : 
     713             : void
     714             : fd_wsample_sample_many( fd_wsample_t * sampler,
     715             :                         ulong        * idxs,
     716     3694788 :                         ulong          cnt  ) {
     717   373172133 :   for( ulong i=0UL; i<cnt; i++ ) idxs[i] = fd_wsample_sample( sampler );
     718     3694788 : }
     719             : 
     720             : void
     721             : fd_wsample_sample_and_remove_many( fd_wsample_t * sampler,
     722             :                                    ulong        * idxs,
     723     2985462 :                                    ulong          cnt   ) {
     724             :   /* The compiler doesn't seem to like inlining the call to
     725             :      fd_wsample_sample_and_remove, which hurts performance by a few
     726             :      percent because it triggers worse behavior in the CPUs front end.
     727             :      To address this, we manually inline it here. */
     728   419462751 :   for( ulong i=0UL; i<cnt; i++ ) {
     729   416477289 :     if( FD_UNLIKELY( !sampler->unremoved_weight ) ) { idxs[ i ] = FD_WSAMPLE_EMPTY;         continue; }
     730   416253042 :     if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) { idxs[ i ] = FD_WSAMPLE_INDETERMINATE; continue; }
     731   415310850 :     ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     732   415310850 :     if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
     733        3879 :       idxs[ i ] = FD_WSAMPLE_INDETERMINATE;
     734        3879 :       sampler->poisoned_mode = 1;
     735        3879 :       continue;
     736        3879 :     }
     737   138435657 : #if FD_WSAMPLE_IMPLEMENTATION > 0
     738   138435657 :   if( FD_LIKELY( sampler->height==4UL ) ) {
     739   134229135 :     tree_ele_t * tree = sampler->tree;
     740   134229135 :     PREPARE();
     741             : 
     742   134229135 :     TRAVERSE_LEVEL(0);
     743   134229135 :     TRAVERSE_LEVEL(1);
     744   134229135 :     TRAVERSE_LEVEL(2);
     745   134229135 :     TRAVERSE_LEVEL(3);
     746   134229135 :     PROPAGATE_LEVEL(0);
     747   134229135 :     PROPAGATE_LEVEL(1);
     748   134229135 :     PROPAGATE_LEVEL(2);
     749   134229135 :     PROPAGATE_LEVEL(3);
     750             : 
     751   134229135 :     FINALIZE();
     752   134229135 :     idxs[ i ] = cursor - sampler->internal_node_cnt;
     753   134229135 :     continue;
     754   134229135 :   }
     755     4206522 : #endif
     756   281077836 :     idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
     757   281077836 :     fd_wsample_remove( sampler, p );
     758   281077836 :     idxs[ i ] = p.idx;
     759   281077836 :   }
     760     2985462 : }
     761             : 
     762             : 
     763             : 
     764             : ulong
     765   769011729 : fd_wsample_sample( fd_wsample_t * sampler ) {
     766   769011729 :   if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
     767   769011723 :   if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) return FD_WSAMPLE_INDETERMINATE;
     768   769011723 :   ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     769   769011723 :   if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) return FD_WSAMPLE_INDETERMINATE;
     770   768757023 :   return (ulong)fd_wsample_map_sample( sampler, unif );
     771   769011723 : }
     772             : 
     773             : ulong
     774     3651471 : fd_wsample_sample_and_remove( fd_wsample_t * sampler ) {
     775     3651471 :   if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
     776     3651387 :   if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) return FD_WSAMPLE_INDETERMINATE;
     777     3649002 :   ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     778     3649002 :   if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
     779         270 :     sampler->poisoned_mode = 1;
     780         270 :     return FD_WSAMPLE_INDETERMINATE;
     781         270 :   }
     782             : 
     783     1216244 : #if FD_WSAMPLE_IMPLEMENTATION > 0
     784     1216244 :   if( FD_LIKELY( sampler->height==4UL ) ) {
     785        2715 :     tree_ele_t * tree = sampler->tree;
     786        2715 :     PREPARE();
     787             : 
     788        2715 :     TRAVERSE_LEVEL(0);
     789        2715 :     TRAVERSE_LEVEL(1);
     790        2715 :     TRAVERSE_LEVEL(2);
     791        2715 :     TRAVERSE_LEVEL(3);
     792        2715 :     PROPAGATE_LEVEL(0);
     793        2715 :     PROPAGATE_LEVEL(1);
     794        2715 :     PROPAGATE_LEVEL(2);
     795        2715 :     PROPAGATE_LEVEL(3);
     796        2715 :     FINALIZE();
     797        2715 :     return cursor - sampler->internal_node_cnt;
     798        2715 :   }
     799     1213529 : #endif
     800     3646017 :   idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
     801     3646017 :   fd_wsample_remove( sampler, p );
     802     3646017 :   return p.idx;
     803     3648912 : }

Generated by: LCOV version 1.14