Line data Source code
1 : #include "fd_wsample.h"
2 : #include <math.h> /* For sqrt */
3 : #if FD_HAS_AVX512
4 : #include "../../util/simd/fd_avx512.h"
5 : #endif
6 :
7 58538995809 : #define R 9
8 : /* This sampling problem is an interesting one from a performance
9 : perspective. There are lots of interesting approaches. The
10 : header/implementation split is designed to give lots of flexibility
11 : for future optimization. The current implementation uses radix 9
12 : tree with all the leaves on the bottom. */
13 :
14 : /* I'm not sure exactly how to classify the tree that this
15 : implementation uses, but it's something like a B-tree with some
16 : tricks from binary heaps. In particular, like a B-tree, each node
17 : stores several keys, where the keys are the cumulative sums of
18 : subtrees, called left sums below. Like a binary heap, it is
19 : pointer-free, and it is stored implicitly in a flat array. the
20 : typical way that a binary heap is stored. Specifically, the root is
21 : at index 0, and node i's children are at Ri+1, Ri+2, ... Ri+R, where
22 : R is the radix. The leaves are all at the same level, at the bottom,
23 : stored implicitly, which means that a search can be done with almost
24 : no branch mispredictions.
25 :
26 : As an example, suppose R=5 and our tree contains weights 100, 80, 50,
27 : 31, 27, 14, and 6. Because that is 7 nodes, the height is set to
28 : ceil(log_5(7))=2.
29 :
30 : Root (0)
31 : / \
32 : Child (1) Child (2) -- -- --
33 : / / | \ \ | \
34 : 100, 80, 50, 31, 27 14, 6
35 :
36 : The first child of the root, node 1, has left_sum values
37 : |100|180|230|261|. Note that we only store R-1 values, and so the
38 : last child, 27, does not feature in the sums; the search process
39 : handles it implicitly, as we'll see. The second child of the root,
40 : node 2, has left sum values |14|20|20|20|. Then the root node, node
41 : 0, has left sum values |288|308|308|308|.
42 :
43 : The total sum is 308, so we start by drawing a random values in the
44 : range [0, 308). We see which child that falls under, adjust our
45 : random value, and repeat, until we reach a leaf.
46 :
47 : In general, if a node has children (left to right) with subtree
48 : weights a, b, c, d, e. Then the left sums are |a|a+b|a+b+c|a+b+c+d|
49 : and the full sum S=a+b+c+d+e. For a query value of x in [0, S), we
50 : want to pick:
51 : child 0 if x < a
52 : child 1 if a <= x < a+b
53 : child 2 if a+b <= x < a+b+c
54 : child 3 if a+b+c <= x < a+b+c+d
55 : child 4 if a+b+c+d<= x
56 :
57 : Which is equivalent to choosing child
58 : (a<=x) + (a+b<=x) + (a+b+c<=x) + (a+b+c+d<=x)
59 : which can be computed branchlessly. The value of e only comes into
60 : play in the restriction that x is in [0, S), and e is included in the
61 : value of S.
62 :
63 : There are two details left to discuss in order to search recursively.
64 : First, in order to remove an element for sampling without
65 : replacement, we need to know the weight of the element we're
66 : removing. As with e above, the weights may not be stored explicitly,
67 : but as long as we keep track of the weight of the current subtree as
68 : we search recursively, we can obtain the weight without much work.
69 : Thus, we need to know the sum S' of the chosen subtree, i in [0, R).
70 : For j in [-1, R), define
71 : / 0 if j==-1
72 : l[j] = | left_sum[ j ] if j in [0, R-1)
73 : \ S if j==R-1
74 : Essentially, we're computing the natural extension values for
75 : left_sum on the left and right side. Then observe that S' = l[i] -
76 : l[i-1].
77 :
78 : Secondly, in order to search recursively, we need to adjust the query
79 : value to put it into [0, S'). Specifically, we need to subtract off
80 : the sum of all the children to the left of the child we've chosen.
81 : If we've chosen child i, then that's just l[i-1].
82 :
83 : All the above extends easily to any radix, but 3, 5, and 9 = 1+2^n
84 : are the natural ones. I only tried 5 and 9, and 9 was substantially
85 : faster.
86 :
87 : It's worth noting that storing left sums means that you have to do
88 : the most work when you update the left-most child. Because we
89 : typically store the weights largest to smallest, the left-most child
90 : is the one we delete the most frequently, so now our most common case
91 : (probability-wise) and our worst case (performance-wise) are the
92 : same, which seems bad. I initially implemented this using right sums
93 : instead of left to address this problem. They're less intuitive, but
94 : also work. However, I found that having an unpredictable loop in the
95 : deletion method was far worse than just updating each element, which
96 : means that the "less work" advantage of right sums went away. */
97 :
98 : struct __attribute__((aligned(8UL*(R-1UL)))) tree_ele {
99 : /* left_sum stores the cumulative weight of the subtrees at this
100 : node. See the long note above for more information. */
101 : ulong left_sum[ R-1 ];
102 : };
103 : typedef struct tree_ele tree_ele_t;
104 :
105 : struct __attribute__((aligned(64UL))) fd_wsample_private {
106 : ulong total_cnt;
107 : ulong total_weight;
108 : ulong unremoved_cnt;
109 : ulong unremoved_weight; /* Initial value for S explained above */
110 :
111 : /* internal_node_cnt and height are both determined by the number of
112 : leaves in the original tree, via the following formulas:
113 : height = ceil(log_r(leaf_cnt))
114 : internal_node_cnt = sum_{i=0}^{height-1} R^i
115 : height and internal_node_cnt both exclude the leaves, which are
116 : only implicit.
117 :
118 : All the math seems to disallow leaf_cnt==0, but for convenience, we
119 : do allow it. height==internal_node_cnt==0 in that case.
120 :
121 : height actually fits in a uchar. Storing as ulong is more natural,
122 : but we don't want to spill over into another cache line.
123 :
124 : If poisoned_mode==1, then all sample calls will return poisoned.*/
125 : ulong internal_node_cnt;
126 : ulong poisoned_weight;
127 : uint height;
128 : char restore_enabled;
129 : char poisoned_mode;
130 : /* Two bytes of padding here */
131 :
132 : fd_chacha_rng_t * rng;
133 :
134 : /* tree: Here's where the actual tree is stored, at indices [0,
135 : internal_node_cnt). The indexing scheme is explained in the long
136 : comment above.
137 :
138 : If restore_enabled==1, then indices [internal_node_cnt+1,
139 : 2*internal_node_cnt+1) store a copy of the tree after construction
140 : but before any deletion so that restoring deleted elements can be
141 : implemented as a memcpy.
142 :
143 : The tree itself is surrounded by two dummy elements, dummy, and
144 : tree[internal_node_cnt], that aren't actually used. This is
145 : because searching the tree branchlessly involves some out of bounds
146 : reads, and although the value is immediately discarded, it's better
147 : to know where exactly those reads might go. */
148 : tree_ele_t dummy;
149 : tree_ele_t tree[];
150 : };
151 :
152 : typedef struct fd_wsample_private fd_wsample_t;
153 :
154 :
155 : FD_FN_CONST ulong
156 1520106 : fd_wsample_align( void ) {
157 1520106 : return 64UL;
158 1520106 : }
159 :
160 : /* Returns -1 on failure */
161 : static inline int
162 : compute_height( ulong leaf_cnt,
163 : ulong * out_height,
164 769422 : ulong * out_internal_cnt ) {
165 : /* This max is a bit conservative. The actual max is height <= 25,
166 : and leaf_cnt < 5^25 approx 2^58. A tree that large would take an
167 : astronomical amount of memory, so we just retain this max for the
168 : moment. */
169 769422 : if( FD_UNLIKELY( leaf_cnt >= UINT_MAX-2UL ) ) return -1;
170 :
171 769422 : ulong height = 0;
172 769422 : ulong internal = 0UL;
173 769422 : ulong powRh = 1UL; /* = R^height */
174 3433704 : while( leaf_cnt>powRh ) {
175 2664282 : internal += powRh;
176 2664282 : powRh *= R;
177 2664282 : height++;
178 2664282 : }
179 769422 : *out_height = height;
180 769422 : *out_internal_cnt = internal;
181 769422 : return 0;
182 769422 : }
183 :
184 : FD_FN_CONST ulong
185 447546 : fd_wsample_footprint( ulong ele_cnt, int restore_enabled ) {
186 447546 : ulong height;
187 447546 : ulong internal_cnt;
188 : /* Computing the closed form of the sum in compute_height, we get
189 : internal_cnt = 1/8 * (9^ceil(log_9( ele_cnt ) ) - 1)
190 : x <= ceil( x ) < x+1
191 : 1/8 * ele_cnt - 1/8 <= internal_cnt < 9/8 * ele_cnt - 1/8
192 : */
193 447546 : if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) return 0UL;
194 447546 : return sizeof(fd_wsample_t) + ((restore_enabled?2UL:1UL)*internal_cnt + 1UL)*sizeof(tree_ele_t);
195 447546 : }
196 :
197 : fd_wsample_t *
198 321864 : fd_wsample_join( void * shmem ) {
199 321864 : if( FD_UNLIKELY( !shmem ) ) {
200 0 : FD_LOG_WARNING(( "NULL shmem" ));
201 0 : return NULL;
202 0 : }
203 :
204 321864 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
205 0 : FD_LOG_WARNING(( "misaligned shmem" ));
206 0 : return NULL;
207 0 : }
208 321864 : return (fd_wsample_t *)shmem;
209 321864 : }
210 :
211 : /* Note: The following optimization insights are not used in this
212 : high radix implementation. Performance in the deletion case is much
213 : more important than in the non-deletion case, and it's not clear how
214 : to translate this. I'm leaving the code and comment because it is a
215 : useful and non-trivial insight. */
216 : #if 0
217 : /* If we assume the probability of querying node i is proportional to
218 : 1/i, then observe that the midpoint of the probability mass in the
219 : continuous approximation is the solution to (in Mathematica syntax):
220 :
221 : Integrate[ 1/i, {i, lo, hi}] = 2*Integrate[ 1/i, {i, lo, mid} ]
222 :
223 : which gives mid = sqrt(lo*hi). This is in contrast to when the
224 : integrand is a constant, which gives the normal binary search rule:
225 : mid=(lo+hi)/2.
226 :
227 : Thus, we want the search to follow this modified binary search rule,
228 : since that'll approximately split the stake weight/probability mass
229 : in half at each step.
230 :
231 : This is almost as nice to work with mathematically as the normal
232 : binary search rule. The jth entry from the left at level k is the
233 : region [ N^((1/2^k)*j), N^((1/2^k)*(j+1)) ). We're basically doing
234 : binary search in the log domain.
235 :
236 : Rather than trying to compute these transcendental functions, this
237 : simple recursive implementation gives the treap the right shape by
238 : setting prio very carefully, since higher prio means closer to the
239 : root. This may be a slight abuse of a treap, but it's easier than
240 : implementing a whole custom tree for just this purpose.
241 :
242 : We want to be careful to ensure that the recursion is tightly
243 : bounded. From a continuous perspective, that's not a problem: the
244 : widest interval at level k is the last one, and we break when the
245 : interval's width is less than 1. Solving 1=N-N^(1-(1/2^k)) for k
246 : yields k=-lg(1-log(N-1)/log(N)). k is monotonically increasing as a
247 : function of N, which means that for all N,
248 : k <= -lg(1-log(N_max-1)/log(N_max)) < 37
249 : since N_max=2^32-1.
250 :
251 : The math is more complicated with rounding and finite precision, but
252 : sqrt(lo*hi) is very different from lo and hi unless lo and hi are
253 : approximately the same. In that case, lo<mid and mid<hi ensures that
254 : both intervals are strictly smaller than the interval they came from,
255 : which prevents an infinite loop. */
256 : static inline void
257 : seed_recursive( treap_ele_t * pool,
258 : uint lo,
259 : uint hi,
260 : uint prio ) {
261 : uint mid = (uint)(sqrtf( (float)lo*(float)hi ) + 0.5f);
262 : if( (lo<mid) & (mid<hi) ) {
263 : /* since we start with lo=1, shift by 1 */
264 : pool[mid-1U].prio = prio;
265 : seed_recursive( pool, lo, mid, prio-1U );
266 : seed_recursive( pool, mid, hi, prio-1U );
267 : }
268 : }
269 : #endif
270 :
271 :
272 : void *
273 : fd_wsample_new_init( void * shmem,
274 : fd_chacha_rng_t * rng,
275 : ulong ele_cnt,
276 : int restore_enabled,
277 321876 : int opt_hint ) {
278 321876 : if( FD_UNLIKELY( !shmem ) ) {
279 0 : FD_LOG_WARNING(( "NULL shmem" ));
280 0 : return NULL;
281 0 : }
282 :
283 321876 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
284 0 : FD_LOG_WARNING(( "misaligned shmem" ));
285 0 : return NULL;
286 0 : }
287 :
288 321876 : ulong height;
289 321876 : ulong internal_cnt;
290 321876 : if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) {
291 0 : FD_LOG_WARNING(( "bad ele_cnt" ));
292 0 : return NULL;
293 0 : }
294 :
295 321876 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
296 :
297 321876 : sampler->total_weight = 0UL;
298 321876 : sampler->unremoved_cnt = 0UL;
299 321876 : sampler->unremoved_weight = 0UL;
300 321876 : sampler->internal_node_cnt = internal_cnt;
301 321876 : sampler->poisoned_weight = 0UL;
302 321876 : sampler->height = (uint)height;
303 321876 : sampler->restore_enabled = (char)!!restore_enabled;
304 321876 : sampler->poisoned_mode = 0;
305 321876 : sampler->rng = rng;
306 :
307 321876 : fd_memset( sampler->tree, (char)0, internal_cnt*sizeof(tree_ele_t) );
308 :
309 321876 : (void)opt_hint; /* Not used at the moment */
310 :
311 321876 : return shmem;
312 321876 : }
313 :
314 : void *
315 : fd_wsample_new_add( void * shmem,
316 202286043 : ulong weight ) {
317 202286043 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
318 202286043 : if( FD_UNLIKELY( !sampler ) ) return NULL;
319 :
320 202286043 : if( FD_UNLIKELY( weight==0UL ) ) {
321 0 : FD_LOG_WARNING(( "zero weight entry found" ));
322 0 : return NULL;
323 0 : }
324 202286043 : if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
325 0 : FD_LOG_WARNING(( "total weight too large" ));
326 0 : return NULL;
327 0 : }
328 :
329 202286043 : tree_ele_t * tree = sampler->tree;
330 202286043 : ulong i = sampler->internal_node_cnt + sampler->unremoved_cnt;
331 :
332 1011149181 : for( ulong h=0UL; h<sampler->height; h++ ) {
333 808863138 : ulong parent = (i-1UL)/R;
334 808863138 : ulong child_idx = i-1UL - R*parent; /* in [0, R) */
335 4956555249 : for( ulong k=child_idx; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] += weight;
336 808863138 : i = parent;
337 808863138 : }
338 :
339 202286043 : sampler->unremoved_cnt++;
340 202286043 : sampler->total_cnt++;
341 202286043 : sampler->unremoved_weight += weight;
342 202286043 : sampler->total_weight += weight;
343 :
344 202286043 : return shmem;
345 202286043 : }
346 :
347 : void *
348 : fd_wsample_new_fini( void * shmem,
349 321876 : ulong poisoned_weight ) {
350 321876 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
351 321876 : if( FD_UNLIKELY( !sampler ) ) return NULL;
352 :
353 321876 : if( FD_UNLIKELY( sampler->total_weight+poisoned_weight<sampler->total_weight ) ) {
354 0 : FD_LOG_WARNING(( "poisoned_weight caused overflow" ));
355 0 : return NULL;
356 0 : }
357 :
358 321876 : sampler->poisoned_weight = poisoned_weight;
359 :
360 321876 : if( sampler->restore_enabled ) {
361 : /* Copy the sampler to make restore fast. */
362 313602 : fd_memcpy( sampler->tree+sampler->internal_node_cnt+1UL, sampler->tree, sampler->internal_node_cnt*sizeof(tree_ele_t) );
363 313602 : }
364 :
365 321876 : return (void *)sampler;
366 321876 : }
367 :
368 : void *
369 321582 : fd_wsample_leave( fd_wsample_t * sampler ) {
370 321582 : if( FD_UNLIKELY( !sampler ) ) {
371 0 : FD_LOG_WARNING(( "NULL sampler" ));
372 0 : return NULL;
373 0 : }
374 :
375 321582 : return (void *)sampler;
376 321582 : }
377 :
378 : void *
379 321582 : fd_wsample_delete( void * shmem ) {
380 321582 : if( FD_UNLIKELY( !shmem ) ) {
381 0 : FD_LOG_WARNING(( "NULL shmem" ));
382 0 : return NULL;
383 0 : }
384 :
385 321582 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
386 0 : FD_LOG_WARNING(( "misaligned shmem" ));
387 0 : return NULL;
388 0 : }
389 321582 : return shmem;
390 321582 : }
391 :
392 : void
393 : fd_wsample_seed_rng( fd_wsample_t * sampler,
394 : uchar seed[ 32 ],
395 2586789 : int use_chacha8 ) {
396 2586789 : fd_chacha_rng_init( sampler->rng, seed, use_chacha8 ? FD_CHACHA_RNG_ALGO_CHACHA8 : FD_CHACHA_RNG_ALGO_CHACHA20 );
397 2586789 : }
398 :
399 : fd_wsample_t *
400 3494316 : fd_wsample_restore_all( fd_wsample_t * sampler ) {
401 3494316 : if( FD_UNLIKELY( !sampler->restore_enabled ) ) return NULL;
402 :
403 3494313 : sampler->unremoved_weight = sampler->total_weight;
404 3494313 : sampler->unremoved_cnt = sampler->total_cnt;
405 3494313 : sampler->poisoned_mode = 0;
406 :
407 3494313 : fd_memcpy( sampler->tree, sampler->tree+sampler->internal_node_cnt+1UL, sampler->internal_node_cnt*sizeof(tree_ele_t) );
408 3494313 : return sampler;
409 3494316 : }
410 :
411 : #define fd_ulong_if_force( c, t, f ) (__extension__({ \
412 : ulong result; \
413 : __asm__( "testl %1, %1; \n\t" \
414 : "movq %3, %0; \n\t" \
415 : "cmovne %2, %0; \n\t" \
416 : : "=&r"(result) \
417 : : "r"(c), "rm"(t), "rmi"(f) \
418 : : "cc" ); \
419 : result; \
420 : }))
421 :
422 : /* Helper methods for sampling functions */
423 : typedef struct { ulong idx; ulong weight; } idxw_pair_t; /* idx in [0, total_cnt) */
424 :
425 : /* Assumes query in [0, unremoved_weight), which implies
426 : unremoved_weight>0, so the tree can't be empty. */
427 : static inline idxw_pair_t
428 : fd_wsample_map_sample_i( fd_wsample_t const * sampler,
429 1053480876 : ulong query ) {
430 1053480876 : tree_ele_t const * tree = sampler->tree;
431 :
432 1053480876 : ulong cursor = 0UL;
433 1053480876 : ulong S = sampler->unremoved_weight;
434 4663553676 : for( ulong h=0UL; h<sampler->height; h++ ) {
435 3610072800 : tree_ele_t const * e = tree+cursor;
436 3610072800 : ulong x = query;
437 3610072800 : ulong child_idx = 0UL;
438 :
439 845406000 : #if FD_HAS_AVX512 && R==9
440 845406000 : __mmask8 mask = _mm512_cmple_epu64_mask( wwv_ld( e->left_sum ), wwv_bcast( x ) );
441 845406000 : child_idx = (ulong)fd_uchar_popcnt( mask );
442 : #else
443 24882001200 : for( ulong i=0UL; i<R-1UL; i++ ) child_idx += (ulong)(e->left_sum[ i ]<=x);
444 2764666800 : #endif
445 :
446 : /* See the note at the top of this file for the explanation of l[i]
447 : and l[i-1]. Because this is fd_ulong_if and not a ternary, these
448 : can read/write out of what you would think the appropriate bounds
449 : are. The dummy elements, as described along with tree makes this
450 : safe. */
451 : #if 0
452 : ulong li = fd_ulong_if( child_idx<R-1UL, e->left_sum[ child_idx ], S );
453 : ulong lm1 = fd_ulong_if( child_idx>0UL, e->left_sum[ child_idx-1UL ], 0UL );
454 : #elif 0
455 : ulong li = fd_ulong_if_force( child_idx<R-1UL, e->left_sum[ child_idx ], S );
456 : ulong lm1 = fd_ulong_if_force( child_idx>0UL, e->left_sum[ child_idx-1UL ], 0UL );
457 : #else
458 3610072800 : ulong * temp = (ulong *)e->left_sum;
459 3610072800 : ulong orig_m1 = temp[ -1 ]; ulong orig_Rm1 = temp[ R-1UL ];
460 3610072800 : temp[ -1 ] = 0UL; temp[ R-1UL ] = S;
461 3610072800 : ulong li = temp[ child_idx ];
462 3610072800 : ulong lm1 = temp[ child_idx-1UL ];
463 3610072800 : temp[ -1 ] = orig_m1; temp[ R-1UL ] = orig_Rm1;
464 3610072800 : #endif
465 :
466 3610072800 : query -= lm1;
467 3610072800 : S = li - lm1;
468 3610072800 : cursor = R*cursor + child_idx + 1UL;
469 3610072800 : }
470 1053480876 : idxw_pair_t to_return = { .idx = cursor - sampler->internal_node_cnt, .weight = S };
471 1053480876 : return to_return;
472 1053480876 : }
473 :
474 : ulong
475 : fd_wsample_map_sample( fd_wsample_t * sampler,
476 768757023 : ulong query ) {
477 768757023 : return fd_wsample_map_sample_i( sampler, query ).idx;
478 768757023 : }
479 :
480 :
481 :
482 : /* Also requires the tree to be non-empty */
483 : static inline void
484 : fd_wsample_remove( fd_wsample_t * sampler,
485 286511820 : idxw_pair_t to_remove ) {
486 286511820 : ulong cursor = to_remove.idx + sampler->internal_node_cnt;
487 286511820 : tree_ele_t * tree = sampler->tree;
488 :
489 1396215531 : for( ulong h=0UL; h<sampler->height; h++ ) {
490 1109703711 : ulong parent = (cursor-1UL)/R;
491 1109703711 : ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
492 11949637 : #if FD_HAS_AVX512 && R==9
493 11949637 : wwv_t weight = wwv_bcast( to_remove.weight );
494 11949637 : wwv_t left_sum = wwv_ld( tree[ parent ].left_sum );
495 11949637 : __m128i _child_idx = _mm_set1_epi16( (short) child_idx );
496 11949637 : __mmask8 mask = _mm_cmplt_epi16_mask( _child_idx, _mm_setr_epi16( 1, 2, 3, 4, 5, 6, 7, 8 ) );
497 11949637 : left_sum = _mm512_mask_sub_epi64( left_sum, mask, left_sum, weight );
498 11949637 : wwv_st( tree[ parent ].left_sum, left_sum );
499 : #elif 0
500 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if( child_idx<=k, to_remove.weight, 0UL );
501 : #elif 0
502 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if_force( child_idx<=k, to_remove.weight, 0UL );
503 : #elif 1
504 : /* The compiler loves inserting a difficult to predict branch for
505 : fd_ulong_if, but this forces it not to do that. */
506 9879786666 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= (ulong)(((long)(child_idx - k - 1UL))>>63) & to_remove.weight;
507 : #else
508 : /* This version does the least work, but has a hard-to-predict
509 : branch. The branchless versions are normally substantially
510 : faster. */
511 : for( ulong k=child_idx; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= to_remove.weight;
512 : #endif
513 1109703711 : cursor = parent;
514 1109703711 : }
515 286511820 : sampler->unremoved_cnt--;
516 286511820 : sampler->unremoved_weight -= to_remove.weight;
517 286511820 : }
518 :
519 : static inline ulong
520 : fd_wsample_find_weight( fd_wsample_t const * sampler,
521 1787967 : ulong idx /* in [0, total_cnt) */) {
522 : /* The fact we don't store the weights explicitly makes this function
523 : more complicated, but this is not used very frequently. */
524 1787967 : tree_ele_t const * tree = sampler->tree;
525 1787967 : ulong cursor = idx + sampler->internal_node_cnt;
526 :
527 : /* Initialize to the 0 height case */
528 1787967 : ulong lm1 = 0UL;
529 1787967 : ulong li = sampler->unremoved_weight;
530 :
531 1818249 : for( ulong h=0UL; h<sampler->height; h++ ) {
532 1818057 : ulong parent = (cursor-1UL)/R;
533 1818057 : ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
534 :
535 : /* If child_idx < R-1, we can compute the weight easily. If
536 : child_idx==R-1, the computation is S - left_sum[ R-2 ], but we
537 : don't know S, so we need to continue up the tree. */
538 1818057 : lm1 += fd_ulong_if( child_idx>0UL, tree[ parent ].left_sum[ child_idx-1UL ], 0UL );
539 1818057 : if( FD_LIKELY( child_idx<R-1UL ) ) {
540 1787775 : li = tree[ parent ].left_sum[ child_idx ];
541 1787775 : break;
542 1787775 : }
543 :
544 30282 : cursor = parent;
545 30282 : }
546 :
547 1787967 : return li - lm1;
548 1787967 : }
549 :
550 : void
551 : fd_wsample_remove_idx( fd_wsample_t * sampler,
552 1787967 : ulong idx ) {
553 :
554 1787967 : ulong weight = fd_wsample_find_weight( sampler, idx );
555 1787967 : idxw_pair_t r = { .idx = idx, .weight = weight };
556 1787967 : fd_wsample_remove( sampler, r );
557 1787967 : }
558 :
559 :
560 : #if FD_HAS_AVX512
561 :
562 : /* TRAVERSE_LEVEL Takes in and updates s (a vector with all elements set
563 : to the unremoved weight in the current subtree), x or xprime (a
564 : vector with all elements set to the current query value, or the
565 : value+1, respectively), and cursor (a ulong of the current position
566 : in the traversal). See fd_wsample_map_sample_i for more about these
567 : values. Calling this macro height times is the same as calling
568 : fd_wsample_map_sample_i, except for that it declares mask{i} and
569 : cursor{i}, to save work in PROPAGATE_LEVEL, which is like
570 : fd_wsample_remove.
571 :
572 : This only works in the R=9 case, but we'll explain it as if R=5 like
573 : before. Then, v, a vector of tree elements has values
574 : ----------------------------------------
575 : | a | a+b | a+b+c | a+b+c+d |
576 : ----------------------------------------
577 : Then, like before, the desired state is
578 : If... pick, and l_{i-1} and l_i
579 : x < a child 0 0 a
580 : a <= x < a+b child 1 a a+b
581 : a+b <= x < a+b+c child 2 a+b a+b+c
582 : a+b+c <= x < a+b+c+d child 3 a+b+c a+b+c+d
583 : a+b+c+d<= x child 4 a+b+c+d a+b+c+d+e
584 :
585 : Below, there are two pretty different implementations, and this
586 : explains both:
587 :
588 : Implementation 1:
589 : We want to get information about v0<=x in as much as the vector as
590 : fast as possible. We could do some kind of compress and then bcast,
591 : but operations with mask registers are frustratingly slow. Instead,
592 : we first define x'=x+1, so then v0<=x is equivalent to v0-x'<0, or
593 : whether v0-x' has the high bit set. Because of
594 : fd_chacha_rng_ulong_roll's contract, we know x<ULONG_MAX, so forming
595 : x' is safe. We then use a _mm512_permutexvar_epi8 to essentially
596 : broadcast the high byte from each of the ulongs (really we just need
597 : the high bit) to the whole vector. A popcnt gives us our base
598 : values.
599 :
600 : Case compressed_signs popcnt popcnt-1
601 : x < a 0 0 -1
602 : a <= x < a+b 0x80 1 0
603 : a+b <= x < a+b+c 0x8080 2 1
604 : a+b+c <= x < a+b+c+d 0x808080 3 2
605 : a+b+c+d<= x 0x80808080 4 3
606 :
607 : Then, keeping in mind that _mm512_permutex2var_epi64 will select an
608 : element from the second vector if the next highest bit is set, using
609 : popcnt-1 as the selector and (vec, 0) as (a,b) gives l_{i-1}; and
610 : using popcnt as the selector with (vec, s) gives l_i.
611 :
612 : Finally, we just need to produce the mask mask{i}, which we can do
613 : with wwv_lt, off the critical path.
614 :
615 :
616 : Implementation 2
617 : We first use wwv_le to compare v0<=x to form lmask. Then We extract
618 : the most significant bit by lmask^(lmask>>1). Note that below the
619 : bits are written least to most significant, which is backwards of the
620 : normal way.
621 :
622 : Case lmask single_bit ~lmask
623 : x < a 0 0 0 0 0 0 0 0 1 1 1 1
624 : a <= x < a+b 1 0 0 0 1 0 0 0 0 1 1 1
625 : a+b <= x < a+b+c 1 1 0 0 0 1 0 0 0 0 1 1
626 : a+b+c <= x < a+b+c+d 1 1 1 0 0 0 1 0 0 0 0 1
627 : a+b+c+d<= x 1 1 1 1 0 0 0 1 0 0 0 0
628 :
629 : Then we use _mm512_mask_compress_epi64 to move the element
630 : corresponding to the first set bit to the 0th position, or the src
631 : vector if there aren't any set bits. From there, we can use
632 : _mm512_broadcastq_epi64 to fill a vector with that value. Applying
633 : this trick to single_bit gives us l_{i-1} and to ~lmask give l_i. */
634 : #define FD_WSAMPLE_IMPLEMENTATION 2
635 :
636 : #if FD_WSAMPLE_IMPLEMENTATION==1
637 : #define PREPARE() \
638 : ulong cursor = 0UL; \
639 : wwv_t xprime = wwv_bcast( unif+1UL ); \
640 : wwv_t s = wwv_bcast( sampler->unremoved_weight ); \
641 : wwv_t high_bit_mask = wwv_bcast( 0x8000000000000000UL ); \
642 : wwv_t gather_signs_idx = wwv_bcast( 0x070F171F272F373FUL )
643 :
644 : #define TRAVERSE_LEVEL(i) \
645 : __mmask8 mask##i; \
646 : ulong cursor##i = cursor; \
647 : wwv_t vec##i; \
648 : do { \
649 : wwv_t vec = wwv_ld( tree[cursor].left_sum ); \
650 : wwv_t sign = wwv_and( high_bit_mask, wwv_sub( vec, xprime ) ); \
651 : wwv_t compressed_signs = _mm512_permutexvar_epi8( gather_signs_idx, sign ); \
652 : wwv_t popcnt = _mm512_popcnt_epi64( compressed_signs ); \
653 : wwv_t li = _mm512_permutex2var_epi64( vec, popcnt, s ); \
654 : wwv_t lim1 = _mm512_permutex2var_epi64( vec, wwv_sub( popcnt, wwv_one() ), wwv_zero() ); \
655 : mask##i = _mm512_cmpge_epu64_mask( vec, xprime ); \
656 : xprime = wwv_sub( xprime, lim1 ); \
657 : s = wwv_sub( li, lim1 ); \
658 : ulong child_idx = (ulong)_mm_extract_epi64( _mm512_castsi512_si128( popcnt ), 0 ); \
659 : cursor = R*cursor + child_idx + 1UL; \
660 : vec##i = vec; \
661 : } while( 0 )
662 :
663 : #define PROPAGATE_LEVEL(i) \
664 : wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
665 :
666 : #define FINALIZE() \
667 : do { \
668 : sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
669 : sampler->unremoved_cnt--; \
670 : } while( 0 )
671 :
672 : #elif FD_WSAMPLE_IMPLEMENTATION==2
673 :
674 : #define PREPARE() \
675 134231850 : ulong cursor = 0UL; \
676 134231850 : wwv_t x = wwv_bcast( unif ); \
677 134231850 : wwv_t s = wwv_bcast( sampler->unremoved_weight ); \
678 :
679 : #define TRAVERSE_LEVEL(i) \
680 536927400 : __mmask8 mask##i; \
681 536927400 : ulong cursor##i = cursor; \
682 536927400 : wwv_t vec##i; \
683 536927400 : do { \
684 536927400 : wwv_t vec = wwv_ld( tree[cursor].left_sum ); \
685 536927400 : __mmask8 lmask = _mm512_cmple_epu64_mask( vec, x ); \
686 536927400 : cursor = R*cursor + (ulong)fd_uchar_popcnt( lmask ) + 1UL; \
687 536927400 : __mmask8 single_bit = _kxor_mask8( lmask, _kshiftri_mask8( lmask, 1U ) ); \
688 536927400 : wwv_t lim1 = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_maskz_compress_epi64( single_bit, vec ) ) ); \
689 536927400 : x = wwv_sub( x, lim1 ); \
690 536927400 : mask##i = _knot_mask8( lmask ); \
691 536927400 : wwv_t li = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_mask_compress_epi64( s, mask##i, vec ) ) ); \
692 536927400 : s = wwv_sub( li, lim1 ); \
693 536927400 : vec##i = vec; \
694 536927400 : } while( 0 )
695 :
696 : #define PROPAGATE_LEVEL(i) \
697 536927400 : wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
698 :
699 : #define FINALIZE() \
700 134231850 : do { \
701 134231850 : sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
702 134231850 : sampler->unremoved_cnt--; \
703 134231850 : } while( 0 )
704 :
705 : #endif
706 : #else
707 : #define FD_WSAMPLE_IMPLEMENTATION 0
708 : #endif
709 :
710 : /* For now, implement the _many functions as loops over the single
711 : sample functions. It is possible to do better though. */
712 :
713 : void
714 : fd_wsample_sample_many( fd_wsample_t * sampler,
715 : ulong * idxs,
716 3694788 : ulong cnt ) {
717 373172133 : for( ulong i=0UL; i<cnt; i++ ) idxs[i] = fd_wsample_sample( sampler );
718 3694788 : }
719 :
720 : void
721 : fd_wsample_sample_and_remove_many( fd_wsample_t * sampler,
722 : ulong * idxs,
723 2985462 : ulong cnt ) {
724 : /* The compiler doesn't seem to like inlining the call to
725 : fd_wsample_sample_and_remove, which hurts performance by a few
726 : percent because it triggers worse behavior in the CPUs front end.
727 : To address this, we manually inline it here. */
728 419462751 : for( ulong i=0UL; i<cnt; i++ ) {
729 416477289 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) { idxs[ i ] = FD_WSAMPLE_EMPTY; continue; }
730 416253042 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) { idxs[ i ] = FD_WSAMPLE_INDETERMINATE; continue; }
731 415310850 : ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
732 415310850 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
733 3879 : idxs[ i ] = FD_WSAMPLE_INDETERMINATE;
734 3879 : sampler->poisoned_mode = 1;
735 3879 : continue;
736 3879 : }
737 138435657 : #if FD_WSAMPLE_IMPLEMENTATION > 0
738 138435657 : if( FD_LIKELY( sampler->height==4UL ) ) {
739 134229135 : tree_ele_t * tree = sampler->tree;
740 134229135 : PREPARE();
741 :
742 134229135 : TRAVERSE_LEVEL(0);
743 134229135 : TRAVERSE_LEVEL(1);
744 134229135 : TRAVERSE_LEVEL(2);
745 134229135 : TRAVERSE_LEVEL(3);
746 134229135 : PROPAGATE_LEVEL(0);
747 134229135 : PROPAGATE_LEVEL(1);
748 134229135 : PROPAGATE_LEVEL(2);
749 134229135 : PROPAGATE_LEVEL(3);
750 :
751 134229135 : FINALIZE();
752 134229135 : idxs[ i ] = cursor - sampler->internal_node_cnt;
753 134229135 : continue;
754 134229135 : }
755 4206522 : #endif
756 281077836 : idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
757 281077836 : fd_wsample_remove( sampler, p );
758 281077836 : idxs[ i ] = p.idx;
759 281077836 : }
760 2985462 : }
761 :
762 :
763 :
764 : ulong
765 769011729 : fd_wsample_sample( fd_wsample_t * sampler ) {
766 769011729 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
767 769011723 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) return FD_WSAMPLE_INDETERMINATE;
768 769011723 : ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
769 769011723 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) return FD_WSAMPLE_INDETERMINATE;
770 768757023 : return (ulong)fd_wsample_map_sample( sampler, unif );
771 769011723 : }
772 :
773 : ulong
774 3651471 : fd_wsample_sample_and_remove( fd_wsample_t * sampler ) {
775 3651471 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
776 3651387 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) return FD_WSAMPLE_INDETERMINATE;
777 3649002 : ulong unif = fd_chacha_rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
778 3649002 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
779 270 : sampler->poisoned_mode = 1;
780 270 : return FD_WSAMPLE_INDETERMINATE;
781 270 : }
782 :
783 1216244 : #if FD_WSAMPLE_IMPLEMENTATION > 0
784 1216244 : if( FD_LIKELY( sampler->height==4UL ) ) {
785 2715 : tree_ele_t * tree = sampler->tree;
786 2715 : PREPARE();
787 :
788 2715 : TRAVERSE_LEVEL(0);
789 2715 : TRAVERSE_LEVEL(1);
790 2715 : TRAVERSE_LEVEL(2);
791 2715 : TRAVERSE_LEVEL(3);
792 2715 : PROPAGATE_LEVEL(0);
793 2715 : PROPAGATE_LEVEL(1);
794 2715 : PROPAGATE_LEVEL(2);
795 2715 : PROPAGATE_LEVEL(3);
796 2715 : FINALIZE();
797 2715 : return cursor - sampler->internal_node_cnt;
798 2715 : }
799 1213529 : #endif
800 3646017 : idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
801 3646017 : fd_wsample_remove( sampler, p );
802 3646017 : return p.idx;
803 3648912 : }
|