       1             : #include "fd_wsample.h"
       2             : #include <math.h> /* For sqrt */
       3             : #if FD_HAS_AVX512
       4             : #include "../../util/simd/fd_avx512.h"
       5             : #endif
       6             : 
       7 88933333236 : #define R 9
       8             : /* This sampling problem is an interesting one from a performance
       9             :    perspective.  There are lots of interesting approaches.  The
      10             :    header/implementation split is designed to give lots of flexibility
      11             :    for future optimization.  The current implementation uses radix 9
      12             :    tree with all the leaves on the bottom. */
      13             : 
      14             : /* I'm not sure exactly how to classify the tree that this
      15             :    implementation uses, but it's something like a B-tree with some
      16             :    tricks from binary heaps.  In particular, like a B-tree, each node
      17             :    stores several keys, where the keys are the cumulative sums of
      18             :    subtrees, called left sums below.  Like a binary heap, it is
      19             :    pointer-free, and it is stored implicitly in a flat array.  the
      20             :    typical way that a binary heap is stored.  Specifically, the root is
      21             :    at index 0, and node i's children are at Ri+1, Ri+2, ... Ri+R, where
      22             :    R is the radix.  The leaves are all at the same level, at the bottom,
      23             :    stored implicitly, which means that a search can be done with almost
      24             :    no branch mispredictions.
      25             : 
      26             :    As an example, suppose R=5 and our tree contains weights 100, 80, 50,
      27             :    31, 27, 14, and 6.  Because that is 7 nodes, the height is set to
      28             :    ceil(log_5(7))=2.
      29             : 
      30             :                           Root (0)
      31             :                           /        \
      32             :                    Child (1)      Child (2)     -- -- --
      33             :                /  /   |  \   \     |  \
      34             :             100, 80, 50, 31, 27   14, 6
      35             : 
      36             :    The first child of the root, node 1, has left_sum values
      37             :    |100|180|230|261|.  Note that we only store R-1 values, and so the
      38             :    last child, 27, does not feature in the sums; the search process
      39             :    handles it implicitly, as we'll see.  The second child of the root,
      40             :    node 2, has left sum values |14|20|20|20|.  Then the root node, node
      41             :    0, has left sum values |288|308|308|308|.
      42             : 
      43             :    The total sum is 308, so we start by drawing a random values in the
      44             :    range [0, 308).  We see which child that falls under, adjust our
      45             :    random value, and repeat, until we reach a leaf.
      46             : 
      47             :    In general, if a node has children (left to right) with subtree
      48             :    weights a, b, c, d, e.  Then the left sums are |a|a+b|a+b+c|a+b+c+d|
      49             :    and the full sum S=a+b+c+d+e.  For a query value of x in [0, S), we
      50             :    want to pick:
      51             :       child 0   if             x < a
      52             :       child 1   if   a      <= x < a+b
      53             :       child 2   if   a+b    <= x < a+b+c
      54             :       child 3   if   a+b+c  <= x < a+b+c+d
      55             :       child 4   if   a+b+c+d<= x
      56             : 
      57             :    Which is equivalent to choosing child
      58             :           (a<=x) + (a+b<=x) + (a+b+c<=x) + (a+b+c+d<=x)
      59             :    which can be computed branchlessly.  The value of e only comes into
      60             :    play in the restriction that x is in [0, S), and e is inlcuded in the
      61             :    value of S.
      62             : 
      63             :    There are two details left to discuss in order to search recursively.
      64             :    First, in order to remove an element for sampling without
      65             :    replacement, we need to know the weight of the element we're
      66             :    removing.  As with e above, the weights may not be stored explicitly,
      67             :    but as long as we keep track of the weight of the current subtree as
      68             :    we search recursively, we can obtain the weight without much work.
      69             :    Thus, we need to know the sum S' of the chosen subtree, i in [0, R).
      70             :    For j in [-1, R), define
      71             :                             /  0             if j==-1
      72             :                     l[j] =  | left_sum[ j ]  if j in [0, R-1)
      73             :                             \  S             if j==R-1
      74             :    Essentially, we're computing the natural extension values for
      75             :    left_sum on the left and right side.  Then observe that S' = l[i] -
      76             :    l[i-1].
      77             : 
      78             :    Secondly, in order to search recursively, we need to adjust the query
      79             :    value to put it into [0, S').  Specifically, we need to subtract off
      80             :    the sum of all the children to the left of the child we've chosen.
      81             :    If we've chosen child i, then that's just l[i-1].
      82             : 
      83             :    All the above extends easily to any radix, but 3, 5, and 9 = 1+2^n
      84             :    are the natural ones.  I only tried 5 and 9, and 9 was substantially
      85             :    faster.
      86             : 
      87             :    It's worth noting that storing left sums means that you have to do
      88             :    the most work when you update the left-most child.  Because we
      89             :    typically store the weights largest to smallest, the left-most child
      90             :    is the one we delete the most frequently, so now our most common case
      91             :    (probability-wise) and our worst case (performance-wise) are the
      92             :    same, which seems bad.  I initially implemented this using right sums
      93             :    instead of left to address this problem.  They're less intuitive, but
      94             :    also work.  However, I found that having an unpredictable loop in the
      95             :    deletion method was far worse than just updating each element, which
      96             :    means that the "less work" advantage of right sums went away. */
      97             : 
      98             : struct __attribute__((aligned(8UL*(R-1UL)))) tree_ele {
      99             :   /* left_sum stores the cumulative weight of the subtrees at this
     100             :      node.  See the long note above for more information. */
     101             :   ulong left_sum[ R-1 ];
     102             : };
     103             : typedef struct tree_ele tree_ele_t;
     104             : 
     105             : struct __attribute__((aligned(64UL))) fd_wsample_private {
     106             :   ulong              total_cnt;
     107             :   ulong              total_weight;
     108             :   ulong              unremoved_cnt;
     109             :   ulong              unremoved_weight; /* Initial value for S explained above */
     110             : 
     111             :   /* internal_node_cnt and height are both determined by the number of
     112             :      leaves in the original tree, via the following formulas:
     113             :      height = ceil(log_r(leaf_cnt))
     114             :      internal_node_cnt = sum_{i=0}^{height-1} R^i
     115             :      height and internal_node_cnt both exclude the leaves, which are
     116             :      only implicit.
     117             : 
     118             :      All the math seems to disallow leaf_cnt==0, but for conveniece, we
     119             :      do allow it. height==internal_node_cnt==0 in that case.
     120             : 
     121             :      height actually fits in a uchar.  Storing as ulong is more natural,
     122             :      but we don't want to spill over into another cache line.
     123             : 
     124             :      If poisoned_mode==1, then all sample calls will return poisoned.*/
     125             :   ulong              internal_node_cnt;
     126             :   ulong              poisoned_weight;
     127             :   uint               height;
     128             :   char               restore_enabled;
     129             :   char               poisoned_mode;
     130             :   /* Two bytes of padding here */
     131             : 
     132             :   fd_chacha20rng_t * rng;
     133             : 
     134             :   /* tree: Here's where the actual tree is stored, at indices [0,
     135             :      internal_node_cnt).  The indexing scheme is explained in the long
     136             :      comment above.
     137             : 
     138             :      If restore_enabled==1, then indices [internal_node_cnt+1,
     139             :      2*internal_node_cnt+1) store a copy of the tree after construction
     140             :      but before any deletion so that restoring deleted elements can be
     141             :      implemented as a memcpy.
     142             : 
     143             :      The tree iteself is surrounded by two dummy elements, dummy, and
     144             :      tree[internal_node_cnt], that aren't actually used.  This is
     145             :      because searching the tree branchlessly involves some out of bounds
     146             :      reads, and although the value is immediately discarded, it's better
     147             :      to know where exactly those reads might go. */
     148             :   tree_ele_t        dummy;
     149             :   tree_ele_t        tree[];
     150             : };
     151             : 
     152             : typedef struct fd_wsample_private fd_wsample_t;
     153             : 
     154             : 
     155             : FD_FN_CONST ulong
     156      433521 : fd_wsample_align( void ) {
     157      433521 :   return 64UL;
     158      433521 : }
     159             : 
     160             : /* Returns -1 on failure */
     161             : static inline int
     162             : compute_height( ulong   leaf_cnt,
     163             :                 ulong * out_height,
     164      767103 :                 ulong * out_internal_cnt ) {
     165             :   /* This max is a bit conservative.  The actual max is height <= 25,
     166             :      and leaf_cnt < 5^25 approx 2^58.  A tree that large would take an
     167             :      astronomical amount of memory, so we just retain this max for the
     168             :      moment. */
     169      767103 :   if( FD_UNLIKELY( leaf_cnt >= UINT_MAX-2UL ) ) return -1;
     170             : 
     171      767103 :   ulong height   = 0;
     172      767103 :   ulong internal = 0UL;
     173      767103 :   ulong powRh    = 1UL; /* = R^height */
     174     3429864 :   while( leaf_cnt>powRh ) {
     175     2662761 :     internal += powRh;
     176     2662761 :     powRh    *= R;
     177     2662761 :     height++;
     178     2662761 :   }
     179      767103 :   *out_height       = height;
     180      767103 :   *out_internal_cnt = internal;
     181      767103 :   return 0;
     182      767103 : }
     183             : 
     184             : FD_FN_CONST ulong
     185      446388 : fd_wsample_footprint( ulong ele_cnt, int restore_enabled ) {
     186      446388 :   ulong height;
     187      446388 :   ulong internal_cnt;
     188             :   /* Computing the closed form of the sum in compute_height, we get
     189             :      internal_cnt = 1/8 * (9^ceil(log_9( ele_cnt ) ) - 1)
     190             :                  x <= ceil( x ) < x+1
     191             :      1/8 * ele_cnt - 1/8 <= internal_cnt < 9/8 * ele_cnt - 1/8
     192             :   */
     193      446388 :   if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) return 0UL;
     194      446388 :   return sizeof(fd_wsample_t) + ((restore_enabled?2UL:1UL)*internal_cnt + 1UL)*sizeof(tree_ele_t);
     195      446388 : }
     196             : 
     197             : fd_wsample_t *
     198      320703 : fd_wsample_join( void * shmem  ) {
     199      320703 :   if( FD_UNLIKELY( !shmem ) ) {
     200           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     201           0 :     return NULL;
     202           0 :   }
     203             : 
     204      320703 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     205           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     206           0 :     return NULL;
     207           0 :   }
     208      320703 :   return (fd_wsample_t *)shmem;
     209      320703 : }
     210             : 
     211             : /* Note: The following optimization insights are not used in this
     212             :    high radix implmentation.  Performance in the deletion case is much
     213             :    more important than in the non-deletion case, and it's not clear how
     214             :    to translate this.  I'm leaving the code and comment because it is a
     215             :    useful and non-trivial insight. */
     216             : #if 0
     217             : /* If we assume the probability of querying node i is proportional to
     218             :    1/i, then observe that the midpoint of the probability mass in the
     219             :    continuous approximation is the solution to (in Mathematica syntax):
     220             : 
     221             :         Integrate[ 1/i, {i, lo, hi}] = 2*Integrate[ 1/i, {i, lo, mid} ]
     222             : 
     223             :    which gives mid = sqrt(lo*hi).  This is in contrast to when the
     224             :    integrand is a constant, which gives the normal binary search rule:
     225             :    mid=(lo+hi)/2.
     226             : 
     227             :    Thus, we want the search to follow this modified binary search rule,
     228             :    since that'll approximately split the stake weight/probability mass
     229             :    in half at each step.
     230             : 
     231             :    This is almost as nice to work with mathematically as the normal
     232             :    binary search rule.  The jth entry from the left at level k is the
     233             :    region [ N^((1/2^k)*j), N^((1/2^k)*(j+1)) ).  We're basically doing
     234             :    binary search in the log domain.
     235             : 
     236             :    Rather than trying to compute these transcendental functions, this
     237             :    simple recursive implementation gives the treap the right shape by
     238             :    setting prio very carefully, since higher prio means closer to the
     239             :    root.  This may be a slight abuse of a treap, but it's easier than
     240             :    implementing a whole custom tree for just this purpose.
     241             : 
     242             :    We want to be careful to ensure that the recursion is tightly
     243             :    bounded.  From a continuous perspective, that's not a problem: the
     244             :    widest interval at level k is the last one, and we break when the
     245             :    interval's width is less than 1.  Solving 1=N-N^(1-(1/2^k)) for k
     246             :    yields k=-lg(1-log(N-1)/log(N)).  k is monotonically increasing as a
     247             :    function of N, which means that for all N,
     248             :            k <= -lg(1-log(N_max-1)/log(N_max)) < 37
     249             :    since N_max=2^32-1.
     250             : 
     251             :    The math is more complicated with rounding and finite precision, but
     252             :    sqrt(lo*hi) is very different from lo and hi unless lo and hi are
     253             :    approximately the same.  In that case, lo<mid and mid<hi ensures that
     254             :    both intervals are strictly smaller than the interval they came from,
     255             :    which prevents an infinite loop. */
     256             : static inline void
     257             : seed_recursive( treap_ele_t * pool,
     258             :                 uint lo,
     259             :                 uint hi,
     260             :                 uint prio ) {
     261             :   uint mid = (uint)(sqrtf( (float)lo*(float)hi ) + 0.5f);
     262             :   if( (lo<mid) & (mid<hi) ) {
     263             :     /* since we start with lo=1, shift by 1 */
     264             :     pool[mid-1U].prio = prio;
     265             :     seed_recursive( pool, lo,  mid, prio-1U );
     266             :     seed_recursive( pool, mid, hi,  prio-1U );
     267             :   }
     268             : }
     269             : #endif
     270             : 
     271             : 
     272             : void *
     273             : fd_wsample_new_init( void             * shmem,
     274             :                      fd_chacha20rng_t * rng,
     275             :                      ulong              ele_cnt,
     276             :                      int                restore_enabled,
     277      320715 :                      int                opt_hint ) {
     278      320715 :   if( FD_UNLIKELY( !shmem ) ) {
     279           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     280           0 :     return NULL;
     281           0 :   }
     282             : 
     283      320715 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     284           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     285           0 :     return NULL;
     286           0 :   }
     287             : 
     288      320715 :   ulong height;
     289      320715 :   ulong internal_cnt;
     290      320715 :   if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) {
     291           0 :     FD_LOG_WARNING(( "bad ele_cnt" ));
     292           0 :     return NULL;
     293           0 :   }
     294             : 
     295      320715 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     296             : 
     297      320715 :   sampler->total_weight      = 0UL;
     298      320715 :   sampler->unremoved_cnt     = 0UL;
     299      320715 :   sampler->unremoved_weight  = 0UL;
     300      320715 :   sampler->internal_node_cnt = internal_cnt;
     301      320715 :   sampler->poisoned_weight   = 0UL;
     302      320715 :   sampler->height            = (uint)height;
     303      320715 :   sampler->restore_enabled   = (char)!!restore_enabled;
     304      320715 :   sampler->poisoned_mode     = 0;
     305      320715 :   sampler->rng               = rng;
     306             : 
     307      320715 :   fd_memset( sampler->tree, (char)0, internal_cnt*sizeof(tree_ele_t) );
     308             : 
     309      320715 :   (void)opt_hint; /* Not used at the moment */
     310             : 
     311      320715 :   return shmem;
     312      320715 : }
     313             : 
     314             : void *
     315             : fd_wsample_new_add( void * shmem,
     316   201679101 :                     ulong  weight ) {
     317   201679101 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     318   201679101 :   if( FD_UNLIKELY( !sampler ) ) return NULL;
     319             : 
     320   201679101 :   if( FD_UNLIKELY( weight==0UL ) ) {
     321           0 :     FD_LOG_WARNING(( "zero weight entry found" ));
     322           0 :     return NULL;
     323           0 :   }
     324   201679101 :   if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
     325           0 :     FD_LOG_WARNING(( "total weight too large" ));
     326           0 :     return NULL;
     327           0 :   }
     328             : 
     329   201679101 :   tree_ele_t * tree = sampler->tree;
     330   201679101 :   ulong i = sampler->internal_node_cnt + sampler->unremoved_cnt;
     331             : 
     332  1007521080 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     333   805841979 :     ulong parent = (i-1UL)/R;
     334   805841979 :     ulong child_idx = i-1UL - R*parent; /* in [0, R) */
     335  4940519199 :     for( ulong k=child_idx; k<R-1UL; k++ )  tree[ parent ].left_sum[ k ] += weight;
     336   805841979 :     i = parent;
     337   805841979 :   }
     338             : 
     339   201679101 :   sampler->unremoved_cnt++;
     340   201679101 :   sampler->total_cnt++;
     341   201679101 :   sampler->unremoved_weight += weight;
     342   201679101 :   sampler->total_weight     += weight;
     343             : 
     344   201679101 :   return shmem;
     345   201679101 : }
     346             : 
     347             : void *
     348             : fd_wsample_new_fini( void * shmem,
     349      320715 :                      ulong  poisoned_weight ) {
     350      320715 :   fd_wsample_t *  sampler = (fd_wsample_t *)shmem;
     351      320715 :   if( FD_UNLIKELY( !sampler ) ) return NULL;
     352             : 
     353      320715 :   if( FD_UNLIKELY( sampler->total_weight+poisoned_weight<sampler->total_weight ) ) {
     354           0 :     FD_LOG_WARNING(( "poisoned_weight caused overflow" ));
     355           0 :     return NULL;
     356           0 :   }
     357             : 
     358      320715 :   sampler->poisoned_weight = poisoned_weight;
     359             : 
     360      320715 :   if( sampler->restore_enabled ) {
     361             :     /* Copy the sampler to make restore fast. */
     362      312951 :     fd_memcpy( sampler->tree+sampler->internal_node_cnt+1UL, sampler->tree, sampler->internal_node_cnt*sizeof(tree_ele_t) );
     363      312951 :   }
     364             : 
     365      320715 :   return (void *)sampler;
     366      320715 : }
     367             : 
     368             : void *
     369      320625 : fd_wsample_leave( fd_wsample_t * sampler ) {
     370      320625 :   if( FD_UNLIKELY( !sampler ) ) {
     371           0 :     FD_LOG_WARNING(( "NULL sampler" ));
     372           0 :     return NULL;
     373           0 :   }
     374             : 
     375      320625 :   return (void *)sampler;
     376      320625 : }
     377             : 
     378             : void *
     379      320625 : fd_wsample_delete( void * shmem  ) {
     380      320625 :   if( FD_UNLIKELY( !shmem ) ) {
     381           0 :     FD_LOG_WARNING(( "NULL shmem" ));
     382           0 :     return NULL;
     383           0 :   }
     384             : 
     385      320625 :   if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
     386           0 :     FD_LOG_WARNING(( "misaligned shmem" ));
     387           0 :     return NULL;
     388           0 :   }
     389      320625 :   return shmem;
     390      320625 : }
     391             : 
     392             : 
     393             : 
     394     4506723 : fd_chacha20rng_t * fd_wsample_get_rng( fd_wsample_t * sampler ) { return sampler->rng; }
     395             : 
     396             : 
     397             : /* TODO: Should this function exist at all? */
     398             : void
     399             : fd_wsample_seed_rng( fd_chacha20rng_t * rng,
     400     4506723 :                      uchar seed[static 32] ) {
     401     4506723 :   fd_chacha20rng_init( rng, seed );
     402     4506723 : }
     403             : 
     404             : 
     405             : fd_wsample_t *
     406     5414256 : fd_wsample_restore_all( fd_wsample_t * sampler ) {
     407     5414256 :   if( FD_UNLIKELY( !sampler->restore_enabled ) )  return NULL;
     408             : 
     409     5414253 :   sampler->unremoved_weight = sampler->total_weight;
     410     5414253 :   sampler->unremoved_cnt    = sampler->total_cnt;
     411     5414253 :   sampler->poisoned_mode    = 0;
     412             : 
     413     5414253 :   fd_memcpy( sampler->tree, sampler->tree+sampler->internal_node_cnt+1UL, sampler->internal_node_cnt*sizeof(tree_ele_t) );
     414     5414253 :   return sampler;
     415     5414256 : }
     416             : 
     417             : #define fd_ulong_if_force( c, t, f ) (__extension__({ \
     418             :       ulong result;                                   \
     419             :       __asm__( "testl  %1, %1; \n\t"                  \
     420             :                "movq   %3, %0; \n\t"                  \
     421             :                "cmovne %2, %0; \n\t"                  \
     422             :                : "=&r"(result)                        \
     423             :                : "r"(c), "rm"(t), "rmi"(f)            \
     424             :                : "cc" );                              \
     425             :       result;                                         \
     426             :       }))
     427             : 
     428             : /* Helper methods for sampling functions */
     429             : typedef struct { ulong idx; ulong weight; } idxw_pair_t; /* idx in [0, total_cnt) */
     430             : 
     431             : /* Assumes query in [0, unremoved_weight), which implies
     432             :    unremoved_weight>0, so the tree can't be empty. */
     433             : static inline idxw_pair_t
     434             : fd_wsample_map_sample_i( fd_wsample_t const * sampler,
     435  1567248645 :                          ulong                query ) {
     436  1567248645 :   tree_ele_t const * tree = sampler->tree;
     437             : 
     438  1567248645 :   ulong cursor = 0UL;
     439  1567248645 :   ulong S      = sampler->unremoved_weight;
     440  7245492351 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     441  5678243706 :     tree_ele_t const * e = tree+cursor;
     442  5678243706 :     ulong x = query;
     443  5678243706 :     ulong child_idx = 0UL;
     444             : 
     445  1892747902 : #if FD_HAS_AVX512 && R==9
     446  1892747902 :     __mmask8 mask = _mm512_cmple_epu64_mask( wwv_ld( e->left_sum ), wwv_bcast( x ) );
     447  1892747902 :     child_idx = (ulong)fd_uchar_popcnt( mask );
     448             : #else
     449 34069462236 :     for( ulong i=0UL; i<R-1UL; i++ ) child_idx += (ulong)(e->left_sum[ i ]<=x);
     450  3785495804 : #endif
     451             : 
     452             :     /* See the note at the top of this file for the explanation of l[i]
     453             :        and l[i-1].  Because this is fd_ulong_if and not a ternary, these
     454             :        can read/write out of what you would think the appropriate bounds
     455             :        are.  The dummy elements, as described along with tree makes this
     456             :        safe. */
     457             : #if 0
     458             :     ulong li  = fd_ulong_if( child_idx<R-1UL, e->left_sum[ child_idx     ], S   );
     459             :     ulong lm1 = fd_ulong_if( child_idx>0UL,   e->left_sum[ child_idx-1UL ], 0UL );
     460             : #elif 0
     461             :     ulong li  = fd_ulong_if_force( child_idx<R-1UL, e->left_sum[ child_idx     ], S   );
     462             :     ulong lm1 = fd_ulong_if_force( child_idx>0UL,   e->left_sum[ child_idx-1UL ], 0UL );
     463             : #else
     464  5678243706 :     ulong * temp = (ulong *)e->left_sum;
     465  5678243706 :     ulong orig_m1 = temp[ -1 ];    ulong orig_Rm1 = temp[ R-1UL ];
     466  5678243706 :     temp[ -1 ] = 0UL;              temp[ R-1UL ] = S;
     467  5678243706 :     ulong li  = temp[ child_idx     ];
     468  5678243706 :     ulong lm1 = temp[ child_idx-1UL ];
     469  5678243706 :     temp[ -1 ] = orig_m1;          temp[ R-1UL ] = orig_Rm1;
     470  5678243706 : #endif
     471             : 
     472  5678243706 :     query -= lm1;
     473  5678243706 :     S = li - lm1;
     474  5678243706 :     cursor = R*cursor + child_idx + 1UL;
     475  5678243706 :   }
     476  1567248645 :   idxw_pair_t to_return = { .idx = cursor - sampler->internal_node_cnt, .weight = S };
     477  1567248645 :   return to_return;
     478  1567248645 : }
     479             : 
     480             : ulong
     481             : fd_wsample_map_sample( fd_wsample_t * sampler,
     482   762228741 :                        ulong          query ) {
     483   762228741 :   return fd_wsample_map_sample_i( sampler, query ).idx;
     484   762228741 : }
     485             : 
     486             : 
     487             : 
     488             : /* Also requires the tree to be non-empty */
     489             : static inline void
     490             : fd_wsample_remove( fd_wsample_t * sampler,
     491   808727811 :                    idxw_pair_t    to_remove ) {
     492   808727811 :   ulong cursor = to_remove.idx + sampler->internal_node_cnt;
     493   808727811 :   tree_ele_t * tree = sampler->tree;
     494             : 
     495  4007297094 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     496  3198569283 :     ulong parent = (cursor-1UL)/R;
     497  3198569283 :     ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
     498  1066189761 : #if FD_HAS_AVX512 && R==9
     499  1066189761 :     wwv_t weight = wwv_bcast( to_remove.weight );
     500  1066189761 :     wwv_t left_sum = wwv_ld( tree[ parent ].left_sum );
     501  1066189761 :     __m128i _child_idx = _mm_set1_epi16( (short) child_idx );
     502  1066189761 :     __mmask8 mask = _mm_cmplt_epi16_mask( _child_idx, _mm_setr_epi16( 1, 2, 3, 4, 5, 6, 7, 8 ) );
     503  1066189761 :     left_sum = _mm512_mask_sub_epi64( left_sum, mask, left_sum, weight );
     504  1066189761 :     wwv_st( tree[ parent ].left_sum, left_sum );
     505             : #elif 0
     506             :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if( child_idx<=k, to_remove.weight, 0UL );
     507             : #elif 0
     508             :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if_force( child_idx<=k, to_remove.weight, 0UL );
     509             : #elif 1
     510             :     /* The compiler loves inserting a difficult to predict branch for
     511             :        fd_ulong_if, but this forces it not to do that. */
     512 19191415698 :     for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= (ulong)(((long)(child_idx - k - 1UL))>>63) & to_remove.weight;
     513             : #else
     514             :     /* This version does the least work, but has a hard-to-predict
     515             :        branch.  The branchless versions are normally substantially
     516             :        faster. */
     517             :     for( ulong k=child_idx; k<R-1UL; k++ )  tree[ parent ].left_sum[ k ] -= to_remove.weight;
     518             : #endif
     519  3198569283 :     cursor = parent;
     520  3198569283 :   }
     521   808727811 :   sampler->unremoved_cnt--;
     522   808727811 :   sampler->unremoved_weight -= to_remove.weight;
     523   808727811 : }
     524             : 
     525             : static inline ulong
     526             : fd_wsample_find_weight( fd_wsample_t const * sampler,
     527     3707907 :                         ulong                idx /* in [0, total_cnt) */) {
     528             :   /* The fact we don't store the weights explicitly makes this function
     529             :      more complicated, but this is not used very frequently. */
     530     3707907 :   tree_ele_t const * tree = sampler->tree;
     531     3707907 :   ulong cursor = idx + sampler->internal_node_cnt;
     532             : 
     533             :   /* Initialize to the 0 height case */
     534     3707907 :   ulong lm1 = 0UL;
     535     3707907 :   ulong li  = sampler->unremoved_weight;
     536             : 
     537     3738189 :   for( ulong h=0UL; h<sampler->height; h++ ) {
     538     3737997 :     ulong parent = (cursor-1UL)/R;
     539     3737997 :     ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
     540             : 
     541             :     /* If child_idx < R-1, we can compute the weight easily.  If
     542             :        child_idx==R-1, the computation is S - left_sum[ R-2 ], but we
     543             :        don't know S, so we need to continue up the tree. */
     544     3737997 :     lm1  += fd_ulong_if( child_idx>0UL, tree[ parent ].left_sum[ child_idx-1UL ], 0UL );
     545     3737997 :     if( FD_LIKELY( child_idx<R-1UL ) ) {
     546     3707715 :       li = tree[ parent ].left_sum[ child_idx ];
     547     3707715 :       break;
     548     3707715 :     }
     549             : 
     550       30282 :     cursor = parent;
     551       30282 :   }
     552             : 
     553     3707907 :   return li - lm1;
     554     3707907 : }
     555             : 
     556             : void
     557             : fd_wsample_remove_idx( fd_wsample_t * sampler,
     558     3707907 :                        ulong          idx ) {
     559             : 
     560     3707907 :   ulong weight = fd_wsample_find_weight( sampler, idx );
     561     3707907 :   idxw_pair_t r = { .idx = idx, .weight = weight };
     562     3707907 :   fd_wsample_remove( sampler, r );
     563     3707907 : }
     564             : 
     565             : /* For now, implement the _many functions as loops over the single
     566             :    sample functions.  It is possible to do better though. */
     567             : 
     568             : void
     569             : fd_wsample_sample_many( fd_wsample_t * sampler,
     570             :                         ulong        * idxs,
     571     3694788 :                         ulong          cnt  ) {
     572   373172133 :   for( ulong i=0UL; i<cnt; i++ ) idxs[i] = fd_wsample_sample( sampler );
     573     3694788 : }
     574             : 
     575             : void
     576             : fd_wsample_sample_and_remove_many( fd_wsample_t * sampler,
     577             :                                    ulong        * idxs,
     578     4905405 :                                    ulong          cnt   ) {
     579             :   /* The compiler doesn't seem to like inlining the call to
     580             :      fd_wsample_sample_and_remove, which hurts performance by a few
     581             :      percent because it triggers worse behavior in the CPUs front end.
     582             :      To address this, we manually inline it here. */
     583   807446943 :   for( ulong i=0UL; i<cnt; i++ ) {
     584   802541538 :     if( FD_UNLIKELY( !sampler->unremoved_weight ) ) { idxs[ i ] = FD_WSAMPLE_EMPTY;         continue; }
     585   802317297 :     if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) { idxs[ i ] = FD_WSAMPLE_INDETERMINATE; continue; }
     586   801375105 :     ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     587   801375105 :     if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
     588        3879 :       idxs[ i ] = FD_WSAMPLE_INDETERMINATE;
     589        3879 :       sampler->poisoned_mode = 1;
     590        3879 :       continue;
     591        3879 :     }
     592   801371226 :     idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
     593   801371226 :     fd_wsample_remove( sampler, p );
     594   801371226 :     idxs[ i ] = p.idx;
     595   801371226 :   }
     596     4905405 : }
     597             : 
     598             : 
     599             : 
     600             : ulong
     601   762483447 : fd_wsample_sample( fd_wsample_t * sampler ) {
     602   762483447 :   if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
     603   762483441 :   if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) return FD_WSAMPLE_INDETERMINATE;
     604   762483441 :   ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     605   762483441 :   if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) return FD_WSAMPLE_INDETERMINATE;
     606   762228741 :   return (ulong)fd_wsample_map_sample( sampler, unif );
     607   762483441 : }
     608             : 
     609             : ulong
     610     3651417 : fd_wsample_sample_and_remove( fd_wsample_t * sampler ) {
     611     3651417 :   if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
     612     3651333 :   if( FD_UNLIKELY(  sampler->poisoned_mode    ) ) return FD_WSAMPLE_INDETERMINATE;
     613     3648948 :   ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
     614     3648948 :   if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
     615         270 :     sampler->poisoned_mode = 1;
     616         270 :     return FD_WSAMPLE_INDETERMINATE;
     617         270 :   }
     618     3648678 :   idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
     619     3648678 :   fd_wsample_remove( sampler, p );
     620     3648678 :   return p.idx;
     621     3648948 : }

