Line data Source code
1 : #include "fd_wsample.h"
2 : #include <math.h> /* For sqrt */
3 : #if FD_HAS_AVX512
4 : #include "../../util/simd/fd_avx512.h"
5 : #endif
6 :
7 88933333236 : #define R 9
8 : /* This sampling problem is an interesting one from a performance
9 : perspective. There are lots of interesting approaches. The
10 : header/implementation split is designed to give lots of flexibility
11 : for future optimization. The current implementation uses radix 9
12 : tree with all the leaves on the bottom. */
13 :
14 : /* I'm not sure exactly how to classify the tree that this
15 : implementation uses, but it's something like a B-tree with some
16 : tricks from binary heaps. In particular, like a B-tree, each node
17 : stores several keys, where the keys are the cumulative sums of
18 : subtrees, called left sums below. Like a binary heap, it is
19 : pointer-free, and it is stored implicitly in a flat array. the
20 : typical way that a binary heap is stored. Specifically, the root is
21 : at index 0, and node i's children are at Ri+1, Ri+2, ... Ri+R, where
22 : R is the radix. The leaves are all at the same level, at the bottom,
23 : stored implicitly, which means that a search can be done with almost
24 : no branch mispredictions.
25 :
26 : As an example, suppose R=5 and our tree contains weights 100, 80, 50,
27 : 31, 27, 14, and 6. Because that is 7 nodes, the height is set to
28 : ceil(log_5(7))=2.
29 :
30 : Root (0)
31 : / \
32 : Child (1) Child (2) -- -- --
33 : / / | \ \ | \
34 : 100, 80, 50, 31, 27 14, 6
35 :
36 : The first child of the root, node 1, has left_sum values
37 : |100|180|230|261|. Note that we only store R-1 values, and so the
38 : last child, 27, does not feature in the sums; the search process
39 : handles it implicitly, as we'll see. The second child of the root,
40 : node 2, has left sum values |14|20|20|20|. Then the root node, node
41 : 0, has left sum values |288|308|308|308|.
42 :
43 : The total sum is 308, so we start by drawing a random values in the
44 : range [0, 308). We see which child that falls under, adjust our
45 : random value, and repeat, until we reach a leaf.
46 :
47 : In general, if a node has children (left to right) with subtree
48 : weights a, b, c, d, e. Then the left sums are |a|a+b|a+b+c|a+b+c+d|
49 : and the full sum S=a+b+c+d+e. For a query value of x in [0, S), we
50 : want to pick:
51 : child 0 if x < a
52 : child 1 if a <= x < a+b
53 : child 2 if a+b <= x < a+b+c
54 : child 3 if a+b+c <= x < a+b+c+d
55 : child 4 if a+b+c+d<= x
56 :
57 : Which is equivalent to choosing child
58 : (a<=x) + (a+b<=x) + (a+b+c<=x) + (a+b+c+d<=x)
59 : which can be computed branchlessly. The value of e only comes into
60 : play in the restriction that x is in [0, S), and e is inlcuded in the
61 : value of S.
62 :
63 : There are two details left to discuss in order to search recursively.
64 : First, in order to remove an element for sampling without
65 : replacement, we need to know the weight of the element we're
66 : removing. As with e above, the weights may not be stored explicitly,
67 : but as long as we keep track of the weight of the current subtree as
68 : we search recursively, we can obtain the weight without much work.
69 : Thus, we need to know the sum S' of the chosen subtree, i in [0, R).
70 : For j in [-1, R), define
71 : / 0 if j==-1
72 : l[j] = | left_sum[ j ] if j in [0, R-1)
73 : \ S if j==R-1
74 : Essentially, we're computing the natural extension values for
75 : left_sum on the left and right side. Then observe that S' = l[i] -
76 : l[i-1].
77 :
78 : Secondly, in order to search recursively, we need to adjust the query
79 : value to put it into [0, S'). Specifically, we need to subtract off
80 : the sum of all the children to the left of the child we've chosen.
81 : If we've chosen child i, then that's just l[i-1].
82 :
83 : All the above extends easily to any radix, but 3, 5, and 9 = 1+2^n
84 : are the natural ones. I only tried 5 and 9, and 9 was substantially
85 : faster.
86 :
87 : It's worth noting that storing left sums means that you have to do
88 : the most work when you update the left-most child. Because we
89 : typically store the weights largest to smallest, the left-most child
90 : is the one we delete the most frequently, so now our most common case
91 : (probability-wise) and our worst case (performance-wise) are the
92 : same, which seems bad. I initially implemented this using right sums
93 : instead of left to address this problem. They're less intuitive, but
94 : also work. However, I found that having an unpredictable loop in the
95 : deletion method was far worse than just updating each element, which
96 : means that the "less work" advantage of right sums went away. */
97 :
98 : struct __attribute__((aligned(8UL*(R-1UL)))) tree_ele {
99 : /* left_sum stores the cumulative weight of the subtrees at this
100 : node. See the long note above for more information. */
101 : ulong left_sum[ R-1 ];
102 : };
103 : typedef struct tree_ele tree_ele_t;
104 :
105 : struct __attribute__((aligned(64UL))) fd_wsample_private {
106 : ulong total_cnt;
107 : ulong total_weight;
108 : ulong unremoved_cnt;
109 : ulong unremoved_weight; /* Initial value for S explained above */
110 :
111 : /* internal_node_cnt and height are both determined by the number of
112 : leaves in the original tree, via the following formulas:
113 : height = ceil(log_r(leaf_cnt))
114 : internal_node_cnt = sum_{i=0}^{height-1} R^i
115 : height and internal_node_cnt both exclude the leaves, which are
116 : only implicit.
117 :
118 : All the math seems to disallow leaf_cnt==0, but for conveniece, we
119 : do allow it. height==internal_node_cnt==0 in that case.
120 :
121 : height actually fits in a uchar. Storing as ulong is more natural,
122 : but we don't want to spill over into another cache line.
123 :
124 : If poisoned_mode==1, then all sample calls will return poisoned.*/
125 : ulong internal_node_cnt;
126 : ulong poisoned_weight;
127 : uint height;
128 : char restore_enabled;
129 : char poisoned_mode;
130 : /* Two bytes of padding here */
131 :
132 : fd_chacha20rng_t * rng;
133 :
134 : /* tree: Here's where the actual tree is stored, at indices [0,
135 : internal_node_cnt). The indexing scheme is explained in the long
136 : comment above.
137 :
138 : If restore_enabled==1, then indices [internal_node_cnt+1,
139 : 2*internal_node_cnt+1) store a copy of the tree after construction
140 : but before any deletion so that restoring deleted elements can be
141 : implemented as a memcpy.
142 :
143 : The tree iteself is surrounded by two dummy elements, dummy, and
144 : tree[internal_node_cnt], that aren't actually used. This is
145 : because searching the tree branchlessly involves some out of bounds
146 : reads, and although the value is immediately discarded, it's better
147 : to know where exactly those reads might go. */
148 : tree_ele_t dummy;
149 : tree_ele_t tree[];
150 : };
151 :
152 : typedef struct fd_wsample_private fd_wsample_t;
153 :
154 :
155 : FD_FN_CONST ulong
156 433521 : fd_wsample_align( void ) {
157 433521 : return 64UL;
158 433521 : }
159 :
160 : /* Returns -1 on failure */
161 : static inline int
162 : compute_height( ulong leaf_cnt,
163 : ulong * out_height,
164 767103 : ulong * out_internal_cnt ) {
165 : /* This max is a bit conservative. The actual max is height <= 25,
166 : and leaf_cnt < 5^25 approx 2^58. A tree that large would take an
167 : astronomical amount of memory, so we just retain this max for the
168 : moment. */
169 767103 : if( FD_UNLIKELY( leaf_cnt >= UINT_MAX-2UL ) ) return -1;
170 :
171 767103 : ulong height = 0;
172 767103 : ulong internal = 0UL;
173 767103 : ulong powRh = 1UL; /* = R^height */
174 3429864 : while( leaf_cnt>powRh ) {
175 2662761 : internal += powRh;
176 2662761 : powRh *= R;
177 2662761 : height++;
178 2662761 : }
179 767103 : *out_height = height;
180 767103 : *out_internal_cnt = internal;
181 767103 : return 0;
182 767103 : }
183 :
184 : FD_FN_CONST ulong
185 446388 : fd_wsample_footprint( ulong ele_cnt, int restore_enabled ) {
186 446388 : ulong height;
187 446388 : ulong internal_cnt;
188 : /* Computing the closed form of the sum in compute_height, we get
189 : internal_cnt = 1/8 * (9^ceil(log_9( ele_cnt ) ) - 1)
190 : x <= ceil( x ) < x+1
191 : 1/8 * ele_cnt - 1/8 <= internal_cnt < 9/8 * ele_cnt - 1/8
192 : */
193 446388 : if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) return 0UL;
194 446388 : return sizeof(fd_wsample_t) + ((restore_enabled?2UL:1UL)*internal_cnt + 1UL)*sizeof(tree_ele_t);
195 446388 : }
196 :
197 : fd_wsample_t *
198 320703 : fd_wsample_join( void * shmem ) {
199 320703 : if( FD_UNLIKELY( !shmem ) ) {
200 0 : FD_LOG_WARNING(( "NULL shmem" ));
201 0 : return NULL;
202 0 : }
203 :
204 320703 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
205 0 : FD_LOG_WARNING(( "misaligned shmem" ));
206 0 : return NULL;
207 0 : }
208 320703 : return (fd_wsample_t *)shmem;
209 320703 : }
210 :
211 : /* Note: The following optimization insights are not used in this
212 : high radix implmentation. Performance in the deletion case is much
213 : more important than in the non-deletion case, and it's not clear how
214 : to translate this. I'm leaving the code and comment because it is a
215 : useful and non-trivial insight. */
216 : #if 0
217 : /* If we assume the probability of querying node i is proportional to
218 : 1/i, then observe that the midpoint of the probability mass in the
219 : continuous approximation is the solution to (in Mathematica syntax):
220 :
221 : Integrate[ 1/i, {i, lo, hi}] = 2*Integrate[ 1/i, {i, lo, mid} ]
222 :
223 : which gives mid = sqrt(lo*hi). This is in contrast to when the
224 : integrand is a constant, which gives the normal binary search rule:
225 : mid=(lo+hi)/2.
226 :
227 : Thus, we want the search to follow this modified binary search rule,
228 : since that'll approximately split the stake weight/probability mass
229 : in half at each step.
230 :
231 : This is almost as nice to work with mathematically as the normal
232 : binary search rule. The jth entry from the left at level k is the
233 : region [ N^((1/2^k)*j), N^((1/2^k)*(j+1)) ). We're basically doing
234 : binary search in the log domain.
235 :
236 : Rather than trying to compute these transcendental functions, this
237 : simple recursive implementation gives the treap the right shape by
238 : setting prio very carefully, since higher prio means closer to the
239 : root. This may be a slight abuse of a treap, but it's easier than
240 : implementing a whole custom tree for just this purpose.
241 :
242 : We want to be careful to ensure that the recursion is tightly
243 : bounded. From a continuous perspective, that's not a problem: the
244 : widest interval at level k is the last one, and we break when the
245 : interval's width is less than 1. Solving 1=N-N^(1-(1/2^k)) for k
246 : yields k=-lg(1-log(N-1)/log(N)). k is monotonically increasing as a
247 : function of N, which means that for all N,
248 : k <= -lg(1-log(N_max-1)/log(N_max)) < 37
249 : since N_max=2^32-1.
250 :
251 : The math is more complicated with rounding and finite precision, but
252 : sqrt(lo*hi) is very different from lo and hi unless lo and hi are
253 : approximately the same. In that case, lo<mid and mid<hi ensures that
254 : both intervals are strictly smaller than the interval they came from,
255 : which prevents an infinite loop. */
256 : static inline void
257 : seed_recursive( treap_ele_t * pool,
258 : uint lo,
259 : uint hi,
260 : uint prio ) {
261 : uint mid = (uint)(sqrtf( (float)lo*(float)hi ) + 0.5f);
262 : if( (lo<mid) & (mid<hi) ) {
263 : /* since we start with lo=1, shift by 1 */
264 : pool[mid-1U].prio = prio;
265 : seed_recursive( pool, lo, mid, prio-1U );
266 : seed_recursive( pool, mid, hi, prio-1U );
267 : }
268 : }
269 : #endif
270 :
271 :
272 : void *
273 : fd_wsample_new_init( void * shmem,
274 : fd_chacha20rng_t * rng,
275 : ulong ele_cnt,
276 : int restore_enabled,
277 320715 : int opt_hint ) {
278 320715 : if( FD_UNLIKELY( !shmem ) ) {
279 0 : FD_LOG_WARNING(( "NULL shmem" ));
280 0 : return NULL;
281 0 : }
282 :
283 320715 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
284 0 : FD_LOG_WARNING(( "misaligned shmem" ));
285 0 : return NULL;
286 0 : }
287 :
288 320715 : ulong height;
289 320715 : ulong internal_cnt;
290 320715 : if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) {
291 0 : FD_LOG_WARNING(( "bad ele_cnt" ));
292 0 : return NULL;
293 0 : }
294 :
295 320715 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
296 :
297 320715 : sampler->total_weight = 0UL;
298 320715 : sampler->unremoved_cnt = 0UL;
299 320715 : sampler->unremoved_weight = 0UL;
300 320715 : sampler->internal_node_cnt = internal_cnt;
301 320715 : sampler->poisoned_weight = 0UL;
302 320715 : sampler->height = (uint)height;
303 320715 : sampler->restore_enabled = (char)!!restore_enabled;
304 320715 : sampler->poisoned_mode = 0;
305 320715 : sampler->rng = rng;
306 :
307 320715 : fd_memset( sampler->tree, (char)0, internal_cnt*sizeof(tree_ele_t) );
308 :
309 320715 : (void)opt_hint; /* Not used at the moment */
310 :
311 320715 : return shmem;
312 320715 : }
313 :
314 : void *
315 : fd_wsample_new_add( void * shmem,
316 201679101 : ulong weight ) {
317 201679101 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
318 201679101 : if( FD_UNLIKELY( !sampler ) ) return NULL;
319 :
320 201679101 : if( FD_UNLIKELY( weight==0UL ) ) {
321 0 : FD_LOG_WARNING(( "zero weight entry found" ));
322 0 : return NULL;
323 0 : }
324 201679101 : if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
325 0 : FD_LOG_WARNING(( "total weight too large" ));
326 0 : return NULL;
327 0 : }
328 :
329 201679101 : tree_ele_t * tree = sampler->tree;
330 201679101 : ulong i = sampler->internal_node_cnt + sampler->unremoved_cnt;
331 :
332 1007521080 : for( ulong h=0UL; h<sampler->height; h++ ) {
333 805841979 : ulong parent = (i-1UL)/R;
334 805841979 : ulong child_idx = i-1UL - R*parent; /* in [0, R) */
335 4940519199 : for( ulong k=child_idx; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] += weight;
336 805841979 : i = parent;
337 805841979 : }
338 :
339 201679101 : sampler->unremoved_cnt++;
340 201679101 : sampler->total_cnt++;
341 201679101 : sampler->unremoved_weight += weight;
342 201679101 : sampler->total_weight += weight;
343 :
344 201679101 : return shmem;
345 201679101 : }
346 :
347 : void *
348 : fd_wsample_new_fini( void * shmem,
349 320715 : ulong poisoned_weight ) {
350 320715 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
351 320715 : if( FD_UNLIKELY( !sampler ) ) return NULL;
352 :
353 320715 : if( FD_UNLIKELY( sampler->total_weight+poisoned_weight<sampler->total_weight ) ) {
354 0 : FD_LOG_WARNING(( "poisoned_weight caused overflow" ));
355 0 : return NULL;
356 0 : }
357 :
358 320715 : sampler->poisoned_weight = poisoned_weight;
359 :
360 320715 : if( sampler->restore_enabled ) {
361 : /* Copy the sampler to make restore fast. */
362 312951 : fd_memcpy( sampler->tree+sampler->internal_node_cnt+1UL, sampler->tree, sampler->internal_node_cnt*sizeof(tree_ele_t) );
363 312951 : }
364 :
365 320715 : return (void *)sampler;
366 320715 : }
367 :
368 : void *
369 320625 : fd_wsample_leave( fd_wsample_t * sampler ) {
370 320625 : if( FD_UNLIKELY( !sampler ) ) {
371 0 : FD_LOG_WARNING(( "NULL sampler" ));
372 0 : return NULL;
373 0 : }
374 :
375 320625 : return (void *)sampler;
376 320625 : }
377 :
378 : void *
379 320625 : fd_wsample_delete( void * shmem ) {
380 320625 : if( FD_UNLIKELY( !shmem ) ) {
381 0 : FD_LOG_WARNING(( "NULL shmem" ));
382 0 : return NULL;
383 0 : }
384 :
385 320625 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
386 0 : FD_LOG_WARNING(( "misaligned shmem" ));
387 0 : return NULL;
388 0 : }
389 320625 : return shmem;
390 320625 : }
391 :
392 :
393 :
394 4506723 : fd_chacha20rng_t * fd_wsample_get_rng( fd_wsample_t * sampler ) { return sampler->rng; }
395 :
396 :
397 : /* TODO: Should this function exist at all? */
398 : void
399 : fd_wsample_seed_rng( fd_chacha20rng_t * rng,
400 4506723 : uchar seed[static 32] ) {
401 4506723 : fd_chacha20rng_init( rng, seed );
402 4506723 : }
403 :
404 :
405 : fd_wsample_t *
406 5414256 : fd_wsample_restore_all( fd_wsample_t * sampler ) {
407 5414256 : if( FD_UNLIKELY( !sampler->restore_enabled ) ) return NULL;
408 :
409 5414253 : sampler->unremoved_weight = sampler->total_weight;
410 5414253 : sampler->unremoved_cnt = sampler->total_cnt;
411 5414253 : sampler->poisoned_mode = 0;
412 :
413 5414253 : fd_memcpy( sampler->tree, sampler->tree+sampler->internal_node_cnt+1UL, sampler->internal_node_cnt*sizeof(tree_ele_t) );
414 5414253 : return sampler;
415 5414256 : }
416 :
417 : #define fd_ulong_if_force( c, t, f ) (__extension__({ \
418 : ulong result; \
419 : __asm__( "testl %1, %1; \n\t" \
420 : "movq %3, %0; \n\t" \
421 : "cmovne %2, %0; \n\t" \
422 : : "=&r"(result) \
423 : : "r"(c), "rm"(t), "rmi"(f) \
424 : : "cc" ); \
425 : result; \
426 : }))
427 :
428 : /* Helper methods for sampling functions */
429 : typedef struct { ulong idx; ulong weight; } idxw_pair_t; /* idx in [0, total_cnt) */
430 :
431 : /* Assumes query in [0, unremoved_weight), which implies
432 : unremoved_weight>0, so the tree can't be empty. */
433 : static inline idxw_pair_t
434 : fd_wsample_map_sample_i( fd_wsample_t const * sampler,
435 1567248645 : ulong query ) {
436 1567248645 : tree_ele_t const * tree = sampler->tree;
437 :
438 1567248645 : ulong cursor = 0UL;
439 1567248645 : ulong S = sampler->unremoved_weight;
440 7245492351 : for( ulong h=0UL; h<sampler->height; h++ ) {
441 5678243706 : tree_ele_t const * e = tree+cursor;
442 5678243706 : ulong x = query;
443 5678243706 : ulong child_idx = 0UL;
444 :
445 1892747902 : #if FD_HAS_AVX512 && R==9
446 1892747902 : __mmask8 mask = _mm512_cmple_epu64_mask( wwv_ld( e->left_sum ), wwv_bcast( x ) );
447 1892747902 : child_idx = (ulong)fd_uchar_popcnt( mask );
448 : #else
449 34069462236 : for( ulong i=0UL; i<R-1UL; i++ ) child_idx += (ulong)(e->left_sum[ i ]<=x);
450 3785495804 : #endif
451 :
452 : /* See the note at the top of this file for the explanation of l[i]
453 : and l[i-1]. Because this is fd_ulong_if and not a ternary, these
454 : can read/write out of what you would think the appropriate bounds
455 : are. The dummy elements, as described along with tree makes this
456 : safe. */
457 : #if 0
458 : ulong li = fd_ulong_if( child_idx<R-1UL, e->left_sum[ child_idx ], S );
459 : ulong lm1 = fd_ulong_if( child_idx>0UL, e->left_sum[ child_idx-1UL ], 0UL );
460 : #elif 0
461 : ulong li = fd_ulong_if_force( child_idx<R-1UL, e->left_sum[ child_idx ], S );
462 : ulong lm1 = fd_ulong_if_force( child_idx>0UL, e->left_sum[ child_idx-1UL ], 0UL );
463 : #else
464 5678243706 : ulong * temp = (ulong *)e->left_sum;
465 5678243706 : ulong orig_m1 = temp[ -1 ]; ulong orig_Rm1 = temp[ R-1UL ];
466 5678243706 : temp[ -1 ] = 0UL; temp[ R-1UL ] = S;
467 5678243706 : ulong li = temp[ child_idx ];
468 5678243706 : ulong lm1 = temp[ child_idx-1UL ];
469 5678243706 : temp[ -1 ] = orig_m1; temp[ R-1UL ] = orig_Rm1;
470 5678243706 : #endif
471 :
472 5678243706 : query -= lm1;
473 5678243706 : S = li - lm1;
474 5678243706 : cursor = R*cursor + child_idx + 1UL;
475 5678243706 : }
476 1567248645 : idxw_pair_t to_return = { .idx = cursor - sampler->internal_node_cnt, .weight = S };
477 1567248645 : return to_return;
478 1567248645 : }
479 :
480 : ulong
481 : fd_wsample_map_sample( fd_wsample_t * sampler,
482 762228741 : ulong query ) {
483 762228741 : return fd_wsample_map_sample_i( sampler, query ).idx;
484 762228741 : }
485 :
486 :
487 :
488 : /* Also requires the tree to be non-empty */
489 : static inline void
490 : fd_wsample_remove( fd_wsample_t * sampler,
491 808727811 : idxw_pair_t to_remove ) {
492 808727811 : ulong cursor = to_remove.idx + sampler->internal_node_cnt;
493 808727811 : tree_ele_t * tree = sampler->tree;
494 :
495 4007297094 : for( ulong h=0UL; h<sampler->height; h++ ) {
496 3198569283 : ulong parent = (cursor-1UL)/R;
497 3198569283 : ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
498 1066189761 : #if FD_HAS_AVX512 && R==9
499 1066189761 : wwv_t weight = wwv_bcast( to_remove.weight );
500 1066189761 : wwv_t left_sum = wwv_ld( tree[ parent ].left_sum );
501 1066189761 : __m128i _child_idx = _mm_set1_epi16( (short) child_idx );
502 1066189761 : __mmask8 mask = _mm_cmplt_epi16_mask( _child_idx, _mm_setr_epi16( 1, 2, 3, 4, 5, 6, 7, 8 ) );
503 1066189761 : left_sum = _mm512_mask_sub_epi64( left_sum, mask, left_sum, weight );
504 1066189761 : wwv_st( tree[ parent ].left_sum, left_sum );
505 : #elif 0
506 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if( child_idx<=k, to_remove.weight, 0UL );
507 : #elif 0
508 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if_force( child_idx<=k, to_remove.weight, 0UL );
509 : #elif 1
510 : /* The compiler loves inserting a difficult to predict branch for
511 : fd_ulong_if, but this forces it not to do that. */
512 19191415698 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= (ulong)(((long)(child_idx - k - 1UL))>>63) & to_remove.weight;
513 : #else
514 : /* This version does the least work, but has a hard-to-predict
515 : branch. The branchless versions are normally substantially
516 : faster. */
517 : for( ulong k=child_idx; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= to_remove.weight;
518 : #endif
519 3198569283 : cursor = parent;
520 3198569283 : }
521 808727811 : sampler->unremoved_cnt--;
522 808727811 : sampler->unremoved_weight -= to_remove.weight;
523 808727811 : }
524 :
525 : static inline ulong
526 : fd_wsample_find_weight( fd_wsample_t const * sampler,
527 3707907 : ulong idx /* in [0, total_cnt) */) {
528 : /* The fact we don't store the weights explicitly makes this function
529 : more complicated, but this is not used very frequently. */
530 3707907 : tree_ele_t const * tree = sampler->tree;
531 3707907 : ulong cursor = idx + sampler->internal_node_cnt;
532 :
533 : /* Initialize to the 0 height case */
534 3707907 : ulong lm1 = 0UL;
535 3707907 : ulong li = sampler->unremoved_weight;
536 :
537 3738189 : for( ulong h=0UL; h<sampler->height; h++ ) {
538 3737997 : ulong parent = (cursor-1UL)/R;
539 3737997 : ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
540 :
541 : /* If child_idx < R-1, we can compute the weight easily. If
542 : child_idx==R-1, the computation is S - left_sum[ R-2 ], but we
543 : don't know S, so we need to continue up the tree. */
544 3737997 : lm1 += fd_ulong_if( child_idx>0UL, tree[ parent ].left_sum[ child_idx-1UL ], 0UL );
545 3737997 : if( FD_LIKELY( child_idx<R-1UL ) ) {
546 3707715 : li = tree[ parent ].left_sum[ child_idx ];
547 3707715 : break;
548 3707715 : }
549 :
550 30282 : cursor = parent;
551 30282 : }
552 :
553 3707907 : return li - lm1;
554 3707907 : }
555 :
556 : void
557 : fd_wsample_remove_idx( fd_wsample_t * sampler,
558 3707907 : ulong idx ) {
559 :
560 3707907 : ulong weight = fd_wsample_find_weight( sampler, idx );
561 3707907 : idxw_pair_t r = { .idx = idx, .weight = weight };
562 3707907 : fd_wsample_remove( sampler, r );
563 3707907 : }
564 :
565 : /* For now, implement the _many functions as loops over the single
566 : sample functions. It is possible to do better though. */
567 :
568 : void
569 : fd_wsample_sample_many( fd_wsample_t * sampler,
570 : ulong * idxs,
571 3694788 : ulong cnt ) {
572 373172133 : for( ulong i=0UL; i<cnt; i++ ) idxs[i] = fd_wsample_sample( sampler );
573 3694788 : }
574 :
575 : void
576 : fd_wsample_sample_and_remove_many( fd_wsample_t * sampler,
577 : ulong * idxs,
578 4905405 : ulong cnt ) {
579 : /* The compiler doesn't seem to like inlining the call to
580 : fd_wsample_sample_and_remove, which hurts performance by a few
581 : percent because it triggers worse behavior in the CPUs front end.
582 : To address this, we manually inline it here. */
583 807446943 : for( ulong i=0UL; i<cnt; i++ ) {
584 802541538 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) { idxs[ i ] = FD_WSAMPLE_EMPTY; continue; }
585 802317297 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) { idxs[ i ] = FD_WSAMPLE_INDETERMINATE; continue; }
586 801375105 : ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
587 801375105 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
588 3879 : idxs[ i ] = FD_WSAMPLE_INDETERMINATE;
589 3879 : sampler->poisoned_mode = 1;
590 3879 : continue;
591 3879 : }
592 801371226 : idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
593 801371226 : fd_wsample_remove( sampler, p );
594 801371226 : idxs[ i ] = p.idx;
595 801371226 : }
596 4905405 : }
597 :
598 :
599 :
600 : ulong
601 762483447 : fd_wsample_sample( fd_wsample_t * sampler ) {
602 762483447 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
603 762483441 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) return FD_WSAMPLE_INDETERMINATE;
604 762483441 : ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
605 762483441 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) return FD_WSAMPLE_INDETERMINATE;
606 762228741 : return (ulong)fd_wsample_map_sample( sampler, unif );
607 762483441 : }
608 :
609 : ulong
610 3651417 : fd_wsample_sample_and_remove( fd_wsample_t * sampler ) {
611 3651417 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
612 3651333 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) return FD_WSAMPLE_INDETERMINATE;
613 3648948 : ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
614 3648948 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
615 270 : sampler->poisoned_mode = 1;
616 270 : return FD_WSAMPLE_INDETERMINATE;
617 270 : }
618 3648678 : idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
619 3648678 : fd_wsample_remove( sampler, p );
620 3648678 : return p.idx;
621 3648948 : }
|