LCOV - code coverage report
Current view: top level - ballet/wsample - fd_wsample.c (source / functions) Hit Total Coverage
Test: cov.lcov Lines: 262 295 88.8 %
Date: 2026-06-29 05:51:35 Functions: 20 20 100.0 %

          Line data    Source code
       1             : #include "fd_wsample.h"
       2             : #include <math.h> /* For sqrt */
       3             : #if FD_HAS_AVX512
       4             : #include "../../util/simd/fd_avx512.h"
       5             : #endif
       6             : 
       7 48495815617 : #define R 9
       8             : /* This sampling problem is an interesting one from a performance
       9             :    perspective.  There are lots of interesting approaches.  The
      10             :    header/implementation split is designed to give lots of flexibility
      11             :    for future optimization.  The current implementation uses radix 9
      12             :    tree with all the leaves on the bottom. */
      13             : 
      14             : /* I'm not sure exactly how to classify the tree that this
      15             :    implementation uses, but it's something like a B-tree with some
      16             :    tricks from binary heaps.  In particular, like a B-tree, each node
      17             :    stores several keys, where the keys are the cumulative sums of
      18             :    subtrees, called left sums below.  Like a binary heap, it is
      19             :    pointer-free, and it is stored implicitly in a flat array.  the
      20             :    typical way that a binary heap is stored.  Specifically, the root is
      21             :    at index 0, and node i's children are at Ri+1, Ri+2, ... Ri+R, where
      22             :    R is the radix.  The leaves are all at the same level, at the bottom,
      23             :    stored implicitly, which means that a search can be done with almost
      24             :    no branch mispredictions.
      25             : 
      26             :    As an example, suppose R=5 and our tree contains weights 100, 80, 50,
      27             :    31, 27, 14, and 6.  Because that is 7 nodes, the height is set to
      28             :    ceil(log_5(7))=2.
      29             : 
      30             :                           Root (0)
      31             :                           /        \
      32             :                    Child (1)      Child (2)     -- -- --
      33             :                /  /   |  \   \     |  \
      34             :             100, 80, 50, 31, 27   14, 6
      35             : 
      36             :    The first child of the root, node 1, has left_sum values
      37             :    |100|180|230|261|.  Note that we only store R-1 values, and so the
      38             :    last child, 27, does not feature in the sums; the search process
      39             :    handles it implicitly, as we'll see.  The second child of the root,
      40             :    node 2, has left sum values |14|20|20|20|.  Then the root node, node
      41             :    0, has left sum values |288|308|308|308|.
      42             : 
      43             :    The total sum is 308, so we start by drawing a random values in the
      44             :    range [0, 308).  We see which child that falls under, adjust our
      45             :    random value, and repeat, until we reach a leaf.
      46             : 
      47             :    In general, if a node has children (left to right) with subtree
      48             :    weights a, b, c, d, e.  Then the left sums are |a|a+b|a+b+c|a+b+c+d|
      49             :    and the full sum S=a+b+c+d+e.  For a query value of x in [0, S), we
      50             :    want to pick:
      51             :       child 0   if             x < a
      52             :       child 1   if   a      <= x < a+b
      53             :       child 2   if   a+b    <= x < a+b+c
      54             :       child 3   if   a+b+c  <= x < a+b+c+d
      55             :       child 4   if   a+b+c+d<= x
      56             : 
      57             :    Which is equivalent to choosing child
      58             :           (a<=x) + (a+b<=x) + (a+b+c<=x) + (a+b+c+d<=x)
      59             :    which can be computed branchlessly.  The value of e only comes into
      60             :    play in the restriction that x is in [0, S), and e is included in the
      61             :    value of S.
      62             : 
      63             :    There are two details left to discuss in order to search recursively.
      64             :    First, in order to remove an element for sampling without
      65             :    replacement, we need to know the weight of the element we're
      66             :    removing.  As with e above, the weights may not be stored explicitly,
      67             :    but as long as we keep track of the weight of the current subtree as
      68             :    we search recursively, we can obtain the weight without much work.
      69             :    Thus, we need to know the sum S' of the chosen subtree, i in [0, R).
      70             :    For j in [-1, R), define
      71             :                             /  0             if j==-1
      72             :                     l[j] =  | left_sum[ j ]  if j in [0, R-1)
      73             :                             \  S             if j==R-1
      74             :    Essentially, we're computing the natural extension values for
      75             :    left_sum on the left and right side.  Then observe that S' = l[i] -
      76             :    l[i-1].
      77             : 
      78             :    Secondly, in order to search recursively, we need to adjust the query
      79             :    value to put it into [0, S').  Specifically, we need to subtract off
      80             :    the sum of all the children to the left of the child we've chosen.
      81             :    If we've chosen child i, then that's just l[i-1].
      82             : 
      83             :    All the above extends easily to any radix, but 3, 5, and 9 = 1+2^n
      84             :    are the natural ones.  I only tried 5 and 9, and 9 was substantially
      85             :    faster.
      86             : 
      87             :    It's worth noting that storing left sums means that you have to do
      88             :    the most work when you update the left-most child.  Because we
      89             :    typically store the weights largest to smallest, the left-most child
      90             :    is the one we delete the most frequently, so now our most common case
      91             :    (probability-wise) and our worst case (performance-wise) are the
      92             :    same, which seems bad.  I initially implemented this using right sums
      93             :    instead of left to address this problem.  They're less intuitive, but
      94             :    also work.  However, I found that having an unpredictable loop in the
      95             :    deletion method was far worse than just updating each element, which
      96             :    means that the "less work" advantage of right sums went away. */
      97             : 
      98             : struct __attribute__((aligned(8UL*(R-1UL)))) tree_ele {
      99             :   /* left_sum stores the cumulative weight of the subtrees at this
     100             :      node.  See the long note above for more information. */
     101             :   ulong left_sum[ R-1 ];
     102             : };
     103             : typedef struct tree_ele tree_ele_t;
     104             : 
     105             : struct __attribute__((aligned(64UL))) fd_wsample_private {
     106             :   ulong              total_cnt;
     107             :   ulong              total_weight;
     108             :   ulong              unremoved_cnt;
     109             :   ulong              unremoved_weight; /* Initial value for S explained above */
     110             : 
     111             :   /* internal_node_cnt and height are both determined by the number of
     112             :      leaves in the original tree, via the following formulas:
     113             :      height = ceil(log_r(leaf_cnt))
     114             :      internal_node_cnt = sum_{i=0}^{height-1} R^i
     115             :      height and internal_node_cnt both exclude the leaves, which are
     116             :      only implicit.
     117             : 
     118             :      All the math seems to disallow leaf_cnt==0, but for convenience, we
     119             :      do allow it. height==internal_node_cnt==0 in that case.
     120             : 
     121             :      height actually fits in a uchar.  Storing as ulong is more natural,
     122             :      but we don't want to spill over into another cache line.
     123             : 
     124             :      If poisoned_mode==1, then all sample calls will return poisoned.*/
     125             :   ulong              internal_node_cnt;
     126             :   ulong              poisoned_weight;
     127             :   uint               height;
     128             :   char               restore_enabled;
     129             :   char               poisoned_mode;
     130             :   /* Two bytes of padding here */
     131             : 
     132             :   fd_chacha_rng_t * rng;
     133             : 
     134             :   /* tree: Here's where the actual tree is stored, at indices [0,
     135             :      internal_node_cnt).  The indexing scheme is explained in the long
     136             :      comment above.
     137             : 
     138             :      If restore_enabled==1, then indices [internal_node_cnt+1,
     139             :      2*internal_node_cnt+1) store a copy of the tree after construction
     140             :      but before any deletion so that restoring deleted elements can be
     141             :      implemented as a memcpy.
     142             : 
     143             :      The tree itself is surrounded by two dummy elements, dummy, and
     144             :      tree[internal_node_cnt], that aren't actually used.  This is
     145             :      because searching the tree branchlessly involves some out of bounds
     146             :      reads, and although the value is immediately discarded, it's better
     147             :      to know where exactly those reads might go. */
     148             :   tree_ele_t        dummy;
     149             :   tree_ele_t        tree[];
     150             : };
     151             : 
     152             : typedef struct fd_wsample_private fd_wsample_t;
     153             : 
     154             : 
     155             : FD_FN_CONST ulong
     156     1530864 : fd_wsample_align( void ) {
     157     1530864 :   return 64UL;
     158     1530864 : }
     159             : 
     160             : /* Returns -1 on failure */
     161             : static inline int
     162             : compute_height( ulong   leaf_cnt,
     163             :                 ulong * out_height,
     164      776628 :                 ulong * out_internal_cnt ) {
     165             :   /* This max is a bit conservative.  The actual max is height <= 25,
     166             :      and leaf_cnt < 5^25 approx 2^58.  A tree that large would take an
     167             :      astronomical amount of memory, so we just retain this max for the
     168             :      moment. */
     169      776628 :   if( FD_UNLIKELY( leaf_cnt >= UINT_MAX-2UL ) ) return -1;
     170             : 
     171      776628 :   ulong height   = 0;
     172      776628 :   ulong internal = 0UL;
     173      776628 :   ulong powRh    = 1UL; /* = R^height */
     174     3441153 :   while( leaf_cnt>powRh ) {
     175     2664525 :     internal += powRh;
     176     2664525 :     powRh    *= R;
     177     2664525 :     height++;
     178     2664525 :   }
     179      776628 :   *out_height       = height;
     180      776628 :   *out_internal_cnt = internal;
     181      776628 :   return 0;
     182      776628 : }
     183             : 
     184             : FD_FN_CONST ulong
     185      451152 : fd_wsample_footprint( ulong ele_cnt, int restore_enabled ) {
     186      451152 :   ulong height;
     187      451152 :   ulong internal_cnt;
     188             :   /* Computing the closed form of the sum in compute_height, we get
     189             :      internal_cnt = 1/8 * (9^ceil(log_9( ele_cnt ) ) - 1)
     190             :                  x <= ceil( x ) < x+1
     191             :      1/8 * ele_cnt - 1/8 <= internal_cnt < 9/8 * ele_cnt - 1/8
     192             :   */
     193      451152 :   if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) return 0UL;
     194      451152 :   return sizeof(fd_wsample_t) + ((restore_enabled?2UL:1UL)*internal_cnt + 1UL)*sizeof(tree_ele_t);
     195      451152 : }
     196             : 
     197             : fd_wsample_t *
     198      325464 : fd_wsample_join( void * shmem  ) {
     199      325464 :   if( FD_UNLIKELY( !shmem ) ) {
     200           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     201           0 :     return NULL;
     202           0 :   }
     203             : 
     204      325464 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     205           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     206           0 :     return NULL;
     207           0 :   }
     208      325464 :   return (fd_wsample_t *)shmem;
     209      325464 : }
     210             : 
     211             : /* Note: The following optimization insights are not used in this
     212             :    high radix implementation.  Performance in the deletion case is much
     213             :    more important than in the non-deletion case, and it's not clear how
     214             :    to translate this.  I'm leaving the code and comment because it is a
     215             :    useful and non-trivial insight. */
     216             : #if 0
     217             : /* If we assume the probability of querying node i is proportional to
     218             :    1/i, then observe that the midpoint of the probability mass in the
     219             :    continuous approximation is the solution to (in Mathematica syntax):
     220             : 
     221             :         Integrate[ 1/i, {i, lo, hi}] = 2*Integrate[ 1/i, {i, lo, mid} ]
     222             : 
     223             :    which gives mid = sqrt(lo*hi).  This is in contrast to when the
     224             :    integrand is a constant, which gives the normal binary search rule:
     225             :    mid=(lo+hi)/2.
     226             : 
     227             :    Thus, we want the search to follow this modified binary search rule,
     228             :    since that'll approximately split the stake weight/probability mass
     229             :    in half at each step.
     230             : 
     231             :    This is almost as nice to work with mathematically as the normal
     232             :    binary search rule.  The jth entry from the left at level k is the
     233             :    region [ N^((1/2^k)*j), N^((1/2^k)*(j+1)) ).  We're basically doing
     234             :    binary search in the log domain.
     235             : 
     236             :    Rather than trying to compute these transcendental functions, this
     237             :    simple recursive implementation gives the treap the right shape by
     238             :    setting prio very carefully, since higher prio means closer to the
     239             :    root.  This may be a slight abuse of a treap, but it's easier than
     240             :    implementing a whole custom tree for just this purpose.
     241             : 
     242             :    We want to be careful to ensure that the recursion is tightly
     243             :    bounded.  From a continuous perspective, that's not a problem: the
     244             :    widest interval at level k is the last one, and we break when the
     245             :    interval's width is less than 1.  Solving 1=N-N^(1-(1/2^k)) for k
     246             :    yields k=-lg(1-log(N-1)/log(N)).  k is monotonically increasing as a
     247             :    function of N, which means that for all N,
     248             :            k <= -lg(1-log(N_max-1)/log(N_max)) < 37
     249             :    since N_max=2^32-1.
     250             : 
     251             :    The math is more complicated with rounding and finite precision, but
     252             :    sqrt(lo*hi) is very different from lo and hi unless lo and hi are
     253             :    approximately the same.  In that case, lo<mid and mid<hi ensures that
     254             :    both intervals are strictly smaller than the interval they came from,
     255             :    which prevents an infinite loop. */
     256             : static inline void
     257             : seed_recursive( treap_ele_t * pool,
     258             :                 uint lo,
     259             :                 uint hi,
     260             :                 uint prio ) {
     261             :   uint mid = (uint)(sqrtf( (float)lo*(float)hi ) + 0.5f);
     262             :   if( (lo<mid) & (mid<hi) ) {
     263             :     /* since we start with lo=1, shift by 1 */
     264             :     pool[mid-1U].prio = prio;
     265             :     seed_recursive( pool, lo,  mid, prio-1U );
     266             :     seed_recursive( pool, mid, hi,  prio-1U );
     267             :   }
     268             : }
     269             : #endif
     270             : 
     271             : 
     272             : void *
     273             : fd_wsample_new_init( void             * shmem,
     274             :                      fd_chacha_rng_t * rng,
     275             :                      ulong              ele_cnt,
     276             :                      int                restore_enabled,
     277      325476 :                      int                opt_hint ) {
     278      325476 :   if( FD_UNLIKELY( !shmem ) ) {
     279           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     280           0 :     return NULL;
     281           0 :   }
     282             : 
     283      325476 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     284           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     285           0 :     return NULL;
     286           0 :   }
     287             : 
     288      325476 :   ulong height;
     289      325476 :   ulong internal_cnt;
     290      325476 :   if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) {
     291           0 :     FD_LOG_WARNING(( "bad ele_cnt" ));
     292           0 :     return NULL;
     293           0 :   }
     294             : 
     295      325476 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     296             : 
     297      325476 :   sampler->total_weight      = 0UL;
     298      325476 :   sampler->unremoved_cnt     = 0UL;
     299      325476 :   sampler->unremoved_weight  = 0UL;
     300      325476 :   sampler->internal_node_cnt = internal_cnt;
     301      325476 :   sampler->poisoned_weight   = 0UL;
     302      325476 :   sampler->height            = (uint)height;
     303      325476 :   sampler->restore_enabled   = (char)!!restore_enabled;
     304      325476 :   sampler->poisoned_mode     = 0;
     305      325476 :   sampler->rng               = rng;
     306             : 
     307      325476 :   fd_memset( sampler->tree, (char)0, internal_cnt*sizeof(tree_ele_t) );
     308             : 
     309      325476 :   (void)opt_hint; /* Not used at the moment */
     310             : 
     311      325476 :   return shmem;
     312      325476 : }
     313             : 
     314             : void *
     315             : fd_wsample_new_add( void * shmem,
     316   208844229 :                     ulong  weight ) {
     317   208844229 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     318   208844229 :   if( FD_UNLIKELY( !sampler ) ) return NULL;
     319             : 
     320   208844229 :   if( FD_UNLIKELY( weight==0UL ) ) {
     321           0 :     FD_LOG_WARNING(( "zero weight entry found" ));
     322           0 :     return NULL;
     323           0 :   }
     324   208844229 :   if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
     325           0 :     FD_LOG_WARNING(( "total weight too large" ));
     326           0 :     return NULL;
     327           0 :   }
     328             : 
     329   208844229 :   tree_ele_t * tree = sampler->tree;
     330   208844229 :   ulong i = sampler->internal_node_cnt + sampler->unremoved_cnt;
     331             : 
     332  1056315201 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     333   847470972 :     ulong parent = (i-1UL)/R;
     334   847470972 :     ulong child_idx = i-1UL - R*parent; /* in [0, R) */
     335  5167771182 :     for( ulong k=child_idx; k<R-1UL; k++ )  tree[ parent ].left_sum[ k ] += weight;
     336   847470972 :     i = parent;
     337   847470972 :   }
     338             : 
     339   208844229 :   sampler->unremoved_cnt++;
     340   208844229 :   sampler->total_cnt++;
     341   208844229 :   sampler->unremoved_weight += weight;
     342   208844229 :   sampler->total_weight     += weight;
     343             : 
     344   208844229 :   return shmem;
     345   208844229 : }
     346             : 
     347             : void *
     348             : fd_wsample_new_fini( void * shmem,
     349      325476 :                      ulong  poisoned_weight ) {
     350      325476 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     351      325476 :   if( FD_UNLIKELY( !sampler ) ) return NULL;
     352             : 
     353      325476 :   if( FD_UNLIKELY( sampler->total_weight+poisoned_weight<sampler->total_weight ) ) {
     354           0 :     FD_LOG_WARNING(( "poisoned_weight caused overflow" ));
     355           0 :     return NULL;
     356           0 :   }
     357             : 
     358      325476 :   sampler->poisoned_weight = poisoned_weight;
     359             : 
     360      325476 :   if( sampler->restore_enabled ) {
     361             :     /* Copy the sampler to make restore fast. */
     362      313506 :     fd_memcpy( sampler->tree+sampler->internal_node_cnt+1UL, sampler->tree, sampler->internal_node_cnt*sizeof(tree_ele_t) );
     363      313506 :   }
     364             : 
     365      325476 :   return (void *)sampler;
     366      325476 : }
     367             : 
     368             : void *
     369      325236 : fd_wsample_leave( fd_wsample_t * sampler ) {
     370      325236 :   if( FD_UNLIKELY( !sampler ) ) {
     371           0 :     FD_LOG_WARNING(( "NULL sampler" ));
     372           0 :     return NULL;
     373           0 :   }
     374             : 
     375      325236 :   return (void *)sampler;
     376      325236 : }
     377             : 
     378             : void *
     379      325236 : fd_wsample_delete( void * shmem  ) {
     380      325236 :   if( FD_UNLIKELY( !shmem ) ) {
     381           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     382           0 :     return NULL;
     383           0 :   }
     384             : 
     385      325236 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     386           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     387           0 :     return NULL;
     388           0 :   }
     389      325236 :   return shmem;
     390      325236 : }
     391             : 
     392             : void
     393             : fd_wsample_seed_rng( fd_wsample_t * sampler,
     394     1806753 :                      uchar          seed[ 32 ] ) {
     395     1806753 :   fd_chacha_rng_init( sampler->rng, seed, FD_CHACHA_RNG_ALGO_CHACHA8 );
     396     1806753 : }
     397             : 
     398             : fd_wsample_t *
     399     2714286 : fd_wsample_restore_all( fd_wsample_t * sampler ) {
     400     2714286 :   if( FD_UNLIKELY( !sampler->restore_enabled ) )  return NULL;
     401             : 
     402     2714283 :   sampler->unremoved_weight = sampler->total_weight;
     403     2714283 :   sampler->unremoved_cnt    = sampler->total_cnt;
     404     2714283 :   sampler->poisoned_mode    = 0;
     405             : 
     406     2714283 :   fd_memcpy( sampler->tree, sampler->tree+sampler->internal_node_cnt+1UL, sampler->internal_node_cnt*sizeof(tree_ele_t) );
     407     2714283 :   return sampler;
     408     2714286 : }
     409             : 
     410             : #define fd_ulong_if_force( c, t, f ) (__extension__({ \
     411             :       ulong result;                                   \
     412             :       __asm__( "testl  %1, %1; \n\t"                  \
     413             :                "movq   %3, %0; \n\t"                  \
     414             :                "cmovne %2, %0; \n\t"                  \
     415             :                : "=&r"(result)                        \
     416             :                : "r"(c), "rm"(t), "rmi"(f)            \
     417             :                : "cc" );                              \
     418             :       result;                                         \
     419             :       }))
     420             : 
     421             : /* Helper methods for sampling functions */
     422             : typedef struct { ulong idx; ulong weight; } idxw_pair_t; /* idx in [0, total_cnt) */
     423             : 
     424             : /* Assumes query in [0, unremoved_weight), which implies
     425             :    unremoved_weight>0, so the tree can't be empty. */
     426             : static inline idxw_pair_t
     427             : fd_wsample_map_sample_i( fd_wsample_t const * sampler,
     428   945767753 :                          ulong                query ) {
     429   945767753 :   tree_ele_t const * tree = sampler->tree;
     430             : 
     431   945767753 :   ulong cursor = 0UL;
     432   945767753 :   ulong S      = sampler->unremoved_weight;
     433  4131400765 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     434  3185633012 :     tree_ele_t const * e = tree+cursor;
     435  3185633012 :     ulong x = query;
     436  3185633012 :     ulong child_idx = 0UL;
     437             : 
     438   843255244 : #if FD_HAS_AVX512 && R==9
     439   843255244 :     __mmask8 mask = _mm512_cmple_epu64_mask( wwv_ld( e->left_sum ), wwv_bcast( x ) );
     440   843255244 :     child_idx = (ulong)fd_uchar_popcnt( mask );
     441             : #else
     442 21081399912 :     for( ulong i=0UL; i<R-1UL; i++ ) child_idx += (ulong)(e->left_sum[ i ]<=x);
     443  2342377768 : #endif
     444             : 
     445             :     /* See the note at the top of this file for the explanation of l[i]
     446             :        and l[i-1].  Because this is fd_ulong_if and not a ternary, these
     447             :        can read/write out of what you would think the appropriate bounds
     448             :        are.  The dummy elements, as described along with tree makes this
     449             :        safe. */
     450             : #if 0
     451             :     ulong li  = fd_ulong_if( child_idx<R-1UL, e->left_sum[ child_idx     ], S   );
     452             :     ulong lm1 = fd_ulong_if( child_idx>0UL,   e->left_sum[ child_idx-1UL ], 0UL );
     453             : #elif 0
     454             :     ulong li  = fd_ulong_if_force( child_idx<R-1UL, e->left_sum[ child_idx     ], S   );
     455             :     ulong lm1 = fd_ulong_if_force( child_idx>0UL,   e->left_sum[ child_idx-1UL ], 0UL );
     456             : #else
     457  3185633012 :     ulong * temp = (ulong *)e->left_sum;
     458  3185633012 :     ulong orig_m1 = temp[ -1 ];    ulong orig_Rm1 = temp[ R-1UL ];
     459  3185633012 :     temp[ -1 ] = 0UL;              temp[ R-1UL ] = S;
     460  3185633012 :     ulong li  = temp[ child_idx     ];
     461  3185633012 :     ulong lm1 = temp[ child_idx-1UL ];
     462  3185633012 :     temp[ -1 ] = orig_m1;          temp[ R-1UL ] = orig_Rm1;
     463  3185633012 : #endif
     464             : 
     465  3185633012 :     query -= lm1;
     466  3185633012 :     S = li - lm1;
     467  3185633012 :     cursor = R*cursor + child_idx + 1UL;
     468  3185633012 :   }
     469   945767753 :   idxw_pair_t to_return = { .idx = cursor - sampler->internal_node_cnt, .weight = S };
     470   945767753 :   return to_return;
     471   945767753 : }
     472             : 
     473             : ulong
     474             : fd_wsample_map_sample( fd_wsample_t * sampler,
     475   765532941 :                        ulong          query ) {
     476   765532941 :   return fd_wsample_map_sample_i( sampler, query ).idx;
     477   765532941 : }
     478             : 
     479             : 
     480             : 
     481             : /* Also requires the tree to be non-empty */
     482             : static inline void
     483             : fd_wsample_remove( fd_wsample_t * sampler,
     484   181242749 :                    idxw_pair_t    to_remove ) {
     485   181242749 :   ulong cursor = to_remove.idx + sampler->internal_node_cnt;
     486   181242749 :   tree_ele_t * tree = sampler->tree;
     487             : 
     488   869854558 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     489   688611809 :     ulong parent = (cursor-1UL)/R;
     490   688611809 :     ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
     491    10914843 : #if FD_HAS_AVX512 && R==9
     492    10914843 :     wwv_t weight = wwv_bcast( to_remove.weight );
     493    10914843 :     wwv_t left_sum = wwv_ld( tree[ parent ].left_sum );
     494    10914843 :     __m128i _child_idx = _mm_set1_epi16( (short) child_idx );
     495    10914843 :     __mmask8 mask = _mm_cmplt_epi16_mask( _child_idx, _mm_setr_epi16( 1, 2, 3, 4, 5, 6, 7, 8 ) );
     496    10914843 :     left_sum = _mm512_mask_sub_epi64( left_sum, mask, left_sum, weight );
     497    10914843 :     wwv_st( tree[ parent ].left_sum, left_sum );
     498             : #elif 0
     499             :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if( child_idx<=k, to_remove.weight, 0UL );
     500             : #elif 0
     501             :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if_force( child_idx<=k, to_remove.weight, 0UL );
     502             : #elif 1
     503             :     /* The compiler loves inserting a difficult to predict branch for
     504             :        fd_ulong_if, but this forces it not to do that. */
     505  6099272694 :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= (ulong)(((long)(child_idx - k - 1UL))>>63) & to_remove.weight;
     506             : #else
     507             :     /* This version does the least work, but has a hard-to-predict
     508             :        branch.  The branchless versions are normally substantially
     509             :        faster. */
     510             :     for( ulong k=child_idx; k<R-1UL; k++ )  tree[ parent ].left_sum[ k ] -= to_remove.weight;
     511             : #endif
     512   688611809 :     cursor = parent;
     513   688611809 :   }
     514   181242749 :   sampler->unremoved_cnt--;
     515   181242749 :   sampler->unremoved_weight -= to_remove.weight;
     516   181242749 : }
     517             : 
     518             : static inline ulong
     519             : fd_wsample_find_weight( fd_wsample_t const * sampler,
     520     1007937 :                         ulong                idx /* in [0, total_cnt) */) {
     521             :   /* The fact we don't store the weights explicitly makes this function
     522             :      more complicated, but this is not used very frequently. */
     523     1007937 :   tree_ele_t const * tree = sampler->tree;
     524     1007937 :   ulong cursor = idx + sampler->internal_node_cnt;
     525             : 
     526             :   /* Initialize to the 0 height case */
     527     1007937 :   ulong lm1 = 0UL;
     528     1007937 :   ulong li  = sampler->unremoved_weight;
     529             : 
     530     1038219 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     531     1038027 :     ulong parent = (cursor-1UL)/R;
     532     1038027 :     ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
     533             : 
     534             :     /* If child_idx < R-1, we can compute the weight easily.  If
     535             :        child_idx==R-1, the computation is S - left_sum[ R-2 ], but we
     536             :        don't know S, so we need to continue up the tree. */
     537     1038027 :     lm1  += fd_ulong_if( child_idx>0UL, tree[ parent ].left_sum[ child_idx-1UL ], 0UL );
     538     1038027 :     if( FD_LIKELY( child_idx<R-1UL ) ) {
     539     1007745 :       li = tree[ parent ].left_sum[ child_idx ];
     540     1007745 :       break;
     541     1007745 :     }
     542             : 
     543       30282 :     cursor = parent;
     544       30282 :   }
     545             : 
     546     1007937 :   return li - lm1;
     547     1007937 : }
     548             : 
     549             : void
     550             : fd_wsample_remove_idx( fd_wsample_t * sampler,
     551     1007937 :                        ulong          idx ) {
     552             : 
     553     1007937 :   ulong weight = fd_wsample_find_weight( sampler, idx );
     554     1007937 :   idxw_pair_t r = { .idx = idx, .weight = weight };
     555     1007937 :   fd_wsample_remove( sampler, r );
     556     1007937 : }
     557             : 
     558             : 
     559             : #if FD_HAS_AVX512
     560             : 
     561             : /* TRAVERSE_LEVEL Takes in and updates s (a vector with all elements set
     562             :    to the unremoved weight in the current subtree), x or xprime (a
     563             :    vector with all elements set to the current query value, or the
     564             :    value+1, respectively), and cursor (a ulong of the current position
     565             :    in the traversal).  See fd_wsample_map_sample_i for more about these
     566             :    values.  Calling this macro height times is the same as calling
     567             :    fd_wsample_map_sample_i, except for that it declares mask{i} and
     568             :    cursor{i}, to save work in PROPAGATE_LEVEL, which is like
     569             :    fd_wsample_remove.
     570             : 
     571             :    This only works in the R=9 case, but we'll explain it as if R=5 like
     572             :    before.  Then, v, a vector of tree elements has values
     573             :            ----------------------------------------
     574             :            |    a    |   a+b  |  a+b+c  | a+b+c+d |
     575             :            ----------------------------------------
     576             :    Then, like before, the desired state is
     577             :          If...                pick,   and  l_{i-1}    and l_i
     578             :               x < a         child 0        0             a
     579             :     a      <= x < a+b       child 1        a            a+b
     580             :     a+b    <= x < a+b+c     child 2       a+b          a+b+c
     581             :     a+b+c  <= x < a+b+c+d   child 3      a+b+c        a+b+c+d
     582             :     a+b+c+d<= x             child 4     a+b+c+d      a+b+c+d+e
     583             : 
     584             :    Below, there are two pretty different implementations, and this
     585             :    explains both:
     586             : 
     587             :    Implementation 1:
     588             :    We want to get information about v0<=x in as much as the vector as
     589             :    fast as possible.  We could do some kind of compress and then bcast,
     590             :    but operations with mask registers are frustratingly slow.  Instead,
     591             :    we first define x'=x+1, so then v0<=x is equivalent to v0-x'<0, or
     592             :    whether v0-x' has the high bit set.  Because of
     593             :    fd_chacha_rng_ulong_roll's contract, we know x<ULONG_MAX, so forming
     594             :    x' is safe. We then use a _mm512_permutexvar_epi8 to essentially
     595             :    broadcast the high byte from each of the ulongs (really we just need
     596             :    the high bit) to the whole vector.  A popcnt gives us our base
     597             :    values.
     598             : 
     599             :                   Case             compressed_signs   popcnt  popcnt-1
     600             :                    x < a                      0         0       -1
     601             :          a      <= x < a+b                 0x80         1        0
     602             :          a+b    <= x < a+b+c             0x8080         2        1
     603             :          a+b+c  <= x < a+b+c+d         0x808080         3        2
     604             :          a+b+c+d<= x                 0x80808080         4        3
     605             : 
     606             :    Then, keeping in mind that _mm512_permutex2var_epi64 will select an
     607             :    element from the second vector if the next highest bit is set, using
     608             :    popcnt-1 as the selector and (vec, 0) as (a,b) gives l_{i-1}; and
     609             :    using popcnt as the selector with (vec, s) gives l_i.
     610             : 
     611             :    Finally, we just need to produce the mask mask{i}, which we can do
     612             :    with wwv_lt, off the critical path.
     613             : 
     614             : 
     615             :    Implementation 2
     616             :    We first use wwv_le to compare v0<=x to form lmask.  Then We extract
     617             :    the most significant bit by lmask^(lmask>>1).  Note that below the
     618             :    bits are written least to most significant, which is backwards of the
     619             :    normal way.
     620             : 
     621             :                Case            lmask       single_bit     ~lmask
     622             :                 x < a         0 0 0 0       0 0 0 0       1 1 1 1
     623             :       a      <= x < a+b       1 0 0 0       1 0 0 0       0 1 1 1
     624             :       a+b    <= x < a+b+c     1 1 0 0       0 1 0 0       0 0 1 1
     625             :       a+b+c  <= x < a+b+c+d   1 1 1 0       0 0 1 0       0 0 0 1
     626             :       a+b+c+d<= x             1 1 1 1       0 0 0 1       0 0 0 0
     627             : 
     628             :    Then we use _mm512_mask_compress_epi64 to move the element
     629             :    corresponding to the first set bit to the 0th position, or the src
     630             :    vector if there aren't any set bits.  From there, we can use
     631             :    _mm512_broadcastq_epi64 to fill a vector with that value.  Applying
     632             :   this trick to single_bit gives us l_{i-1} and to ~lmask give l_i. */
     633             : #define FD_WSAMPLE_IMPLEMENTATION 2
     634             : 
     635             : #if FD_WSAMPLE_IMPLEMENTATION==1
     636             : #define PREPARE()                                                                            \
     637             :   ulong cursor           = 0UL;                                                              \
     638             :   wwv_t xprime           = wwv_bcast( unif+1UL );                                            \
     639             :   wwv_t s                = wwv_bcast( sampler->unremoved_weight );                           \
     640             :   wwv_t high_bit_mask    = wwv_bcast( 0x8000000000000000UL );                                \
     641             :   wwv_t gather_signs_idx = wwv_bcast( 0x070F171F272F373FUL )
     642             : 
     643             : #define TRAVERSE_LEVEL(i)                                                                    \
     644             :   __mmask8 mask##i;                                                                          \
     645             :   ulong    cursor##i = cursor;                                                               \
     646             :   wwv_t    vec##i;                                                                           \
     647             :   do {                                                                                       \
     648             :     wwv_t vec  = wwv_ld( tree[cursor].left_sum );                                            \
     649             :     wwv_t sign = wwv_and( high_bit_mask, wwv_sub( vec, xprime ) );                           \
     650             :     wwv_t compressed_signs = _mm512_permutexvar_epi8( gather_signs_idx, sign );              \
     651             :     wwv_t popcnt = _mm512_popcnt_epi64( compressed_signs );                                  \
     652             :     wwv_t li = _mm512_permutex2var_epi64( vec, popcnt, s );                                  \
     653             :     wwv_t lim1 = _mm512_permutex2var_epi64( vec, wwv_sub( popcnt, wwv_one() ), wwv_zero() ); \
     654             :     mask##i = _mm512_cmpge_epu64_mask( vec, xprime );                                        \
     655             :     xprime = wwv_sub( xprime, lim1 );                                                        \
     656             :     s = wwv_sub( li, lim1 );                                                                 \
     657             :     ulong child_idx = (ulong)_mm_extract_epi64( _mm512_castsi512_si128( popcnt ), 0 );       \
     658             :     cursor = R*cursor + child_idx + 1UL;                                                     \
     659             :     vec##i = vec;                                                                            \
     660             :   } while( 0 )
     661             : 
     662             : #define PROPAGATE_LEVEL(i)                                                                   \
     663             :   wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
     664             : 
     665             : #define FINALIZE()                                                                              \
     666             :   do {                                                                                          \
     667             :     sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
     668             :     sampler->unremoved_cnt--;                                                                   \
     669             :   } while( 0 )
     670             : 
     671             : #elif FD_WSAMPLE_IMPLEMENTATION==2
     672             : 
     673             : #define PREPARE()                                                                            \
     674    81983410 :   ulong cursor           = 0UL;                                                              \
     675    81983410 :   wwv_t x                = wwv_bcast( unif );                                                \
     676    81983410 :   wwv_t s                = wwv_bcast( sampler->unremoved_weight );                           \
     677             : 
     678             : #define TRAVERSE_LEVEL(i)                                                                                             \
     679   327933640 :   __mmask8 mask##i;                                                                                                   \
     680   327933640 :   ulong    cursor##i = cursor;                                                                                        \
     681   327933640 :   wwv_t    vec##i;                                                                                                    \
     682   327933640 :   do {                                                                                                                \
     683   327933640 :     wwv_t vec  = wwv_ld( tree[cursor].left_sum );                                                                     \
     684   327933640 :     __mmask8 lmask = _mm512_cmple_epu64_mask( vec, x );                                                               \
     685   327933640 :     cursor = R*cursor + (ulong)fd_uchar_popcnt( lmask ) + 1UL;                                                        \
     686   327933640 :     __mmask8 single_bit = _kxor_mask8( lmask, _kshiftri_mask8( lmask, 1U ) );                                         \
     687   327933640 :     wwv_t lim1 = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_maskz_compress_epi64( single_bit, vec ) ) ); \
     688   327933640 :     x = wwv_sub( x, lim1 );                                                                                           \
     689   327933640 :     mask##i = _knot_mask8( lmask );                                                                                   \
     690   327933640 :     wwv_t li = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_mask_compress_epi64( s, mask##i, vec ) ) );    \
     691   327933640 :     s = wwv_sub( li, lim1 );                                                                                          \
     692   327933640 :     vec##i = vec;                                                                                                     \
     693   327933640 :   } while( 0 )
     694             : 
     695             : #define PROPAGATE_LEVEL(i)                                                                                          \
     696   327933640 :   wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
     697             : 
     698             : #define FINALIZE()                                                                              \
     699    81983410 :   do {                                                                                          \
     700    81983410 :     sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
     701    81983410 :     sampler->unremoved_cnt--;                                                                   \
     702    81983410 :   } while( 0 )
     703             : 
     704             : #endif
     705             : #else
     706             : #define FD_WSAMPLE_IMPLEMENTATION 0
     707             : #endif
     708             : 
     709             : /* For now, implement the _many functions as loops over the single
     710             :    sample functions.  It is possible to do better though. */
     711             : 
     712             : void
     713             : fd_wsample_sample_many( fd_wsample_t * sampler,
     714             :                         ulong        * idxs,
     715     3694788 :                         ulong          cnt  ) {
     716   373172133 :   for( ulong i=0UL; i<cnt; i++ ) idxs[i] = fd_wsample_sample( sampler );
     717     3694788 : }
     718             : 
     719             : void
     720             : fd_wsample_sample_and_remove_many( fd_wsample_t * sampler,
     721             :                                    ulong        * idxs,
     722     2205876 :                                    ulong          cnt   ) {
     723             :   /* The compiler doesn't seem to like inlining the call to
     724             :      fd_wsample_sample_and_remove, which hurts performance by a few
     725             :      percent because it triggers worse behavior in the CPUs front end.
     726             :      To address this, we manually inline it here. */
     727   262004217 :   for( ulong i=0UL; i<cnt; i++ ) {
     728   259798341 :     if( FD_UNLIKELY( !sampler->unremoved_weight ) ) { idxs[ i ] = FD_WSAMPLE_EMPTY;         continue; }
     729   259574130 :     if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) { idxs[ i ] = FD_WSAMPLE_INDETERMINATE; continue; }
     730   258573453 :     ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     731   258573453 :     if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
     732        3909 :       idxs[ i ] = FD_WSAMPLE_INDETERMINATE;
     733        3909 :       sampler->poisoned_mode = 1;
     734        3909 :       continue;
     735        3909 :     }
     736    86189848 : #if FD_WSAMPLE_IMPLEMENTATION > 0
     737    86189848 :   if( FD_LIKELY( sampler->height==4UL ) ) {
     738    81980695 :     tree_ele_t * tree = sampler->tree;
     739    81980695 :     PREPARE();
     740             : 
     741    81980695 :     TRAVERSE_LEVEL(0);
     742    81980695 :     TRAVERSE_LEVEL(1);
     743    81980695 :     TRAVERSE_LEVEL(2);
     744    81980695 :     TRAVERSE_LEVEL(3);
     745    81980695 :     PROPAGATE_LEVEL(0);
     746    81980695 :     PROPAGATE_LEVEL(1);
     747    81980695 :     PROPAGATE_LEVEL(2);
     748    81980695 :     PROPAGATE_LEVEL(3);
     749             : 
     750    81980695 :     FINALIZE();
     751    81980695 :     idxs[ i ] = cursor - sampler->internal_node_cnt;
     752    81980695 :     continue;
     753    81980695 :   }
     754     4209153 : #endif
     755   176588849 :     idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
     756   176588849 :     fd_wsample_remove( sampler, p );
     757   176588849 :     idxs[ i ] = p.idx;
     758   176588849 :   }
     759     2205876 : }
     760             : 
     761             : 
     762             : 
     763             : ulong
     764   765787647 : fd_wsample_sample( fd_wsample_t * sampler ) {
     765   765787647 :   if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
     766   765787641 :   if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) return FD_WSAMPLE_INDETERMINATE;
     767   765787641 :   ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     768   765787641 :   if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) return FD_WSAMPLE_INDETERMINATE;
     769   765532941 :   return (ulong)fd_wsample_map_sample( sampler, unif );
     770   765787641 : }
     771             : 
     772             : ulong
     773     3651417 : fd_wsample_sample_and_remove( fd_wsample_t * sampler ) {
     774     3651417 :   if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
     775     3651333 :   if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) return FD_WSAMPLE_INDETERMINATE;
     776     3648948 :   ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     777     3648948 :   if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
     778         270 :     sampler->poisoned_mode = 1;
     779         270 :     return FD_WSAMPLE_INDETERMINATE;
     780         270 :   }
     781             : 
     782     1216226 : #if FD_WSAMPLE_IMPLEMENTATION > 0
     783     1216226 :   if( FD_LIKELY( sampler->height==4UL ) ) {
     784        2715 :     tree_ele_t * tree = sampler->tree;
     785        2715 :     PREPARE();
     786             : 
     787        2715 :     TRAVERSE_LEVEL(0);
     788        2715 :     TRAVERSE_LEVEL(1);
     789        2715 :     TRAVERSE_LEVEL(2);
     790        2715 :     TRAVERSE_LEVEL(3);
     791        2715 :     PROPAGATE_LEVEL(0);
     792        2715 :     PROPAGATE_LEVEL(1);
     793        2715 :     PROPAGATE_LEVEL(2);
     794        2715 :     PROPAGATE_LEVEL(3);
     795        2715 :     FINALIZE();
     796        2715 :     return cursor - sampler->internal_node_cnt;
     797        2715 :   }
     798     1213511 : #endif
     799     3645963 :   idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
     800     3645963 :   fd_wsample_remove( sampler, p );
     801     3645963 :   return p.idx;
     802     3648858 : }

Generated by: LCOV version 1.14