Line data Source code
1 : #include "fd_wsample.h"
2 : #include <math.h> /* For sqrt */
3 : #if FD_HAS_AVX512
4 : #include "../../util/simd/fd_avx512.h"
5 : #endif
6 :
7 83697291363 : #define R 9
8 : /* This sampling problem is an interesting one from a performance
9 : perspective. There are lots of interesting approaches. The
10 : header/implementation split is designed to give lots of flexibility
11 : for future optimization. The current implementation uses radix 9
12 : tree with all the leaves on the bottom. */
13 :
14 : /* I'm not sure exactly how to classify the tree that this
15 : implementation uses, but it's something like a B-tree with some
16 : tricks from binary heaps. In particular, like a B-tree, each node
17 : stores several keys, where the keys are the cumulative sums of
18 : subtrees, called left sums below. Like a binary heap, it is
19 : pointer-free, and it is stored implicitly in a flat array. the
20 : typical way that a binary heap is stored. Specifically, the root is
21 : at index 0, and node i's children are at Ri+1, Ri+2, ... Ri+R, where
22 : R is the radix. The leaves are all at the same level, at the bottom,
23 : stored implicitly, which means that a search can be done with almost
24 : no branch mispredictions.
25 :
26 : As an example, suppose R=5 and our tree contains weights 100, 80, 50,
27 : 31, 27, 14, and 6. Because that is 7 nodes, the height is set to
28 : ceil(log_5(7))=2.
29 :
30 : Root (0)
31 : / \
32 : Child (1) Child (2) -- -- --
33 : / / | \ \ | \
34 : 100, 80, 50, 31, 27 14, 6
35 :
36 : The first child of the root, node 1, has left_sum values
37 : |100|180|230|261|. Note that we only store R-1 values, and so the
38 : last child, 27, does not feature in the sums; the search process
39 : handles it implicitly, as we'll see. The second child of the root,
40 : node 2, has left sum values |14|20|20|20|. Then the root node, node
41 : 0, has left sum values |288|308|308|308|.
42 :
43 : The total sum is 308, so we start by drawing a random values in the
44 : range [0, 308). We see which child that falls under, adjust our
45 : random value, and repeat, until we reach a leaf.
46 :
47 : In general, if a node has children (left to right) with subtree
48 : weights a, b, c, d, e. Then the left sums are |a|a+b|a+b+c|a+b+c+d|
49 : and the full sum S=a+b+c+d+e. For a query value of x in [0, S), we
50 : want to pick:
51 : child 0 if x < a
52 : child 1 if a <= x < a+b
53 : child 2 if a+b <= x < a+b+c
54 : child 3 if a+b+c <= x < a+b+c+d
55 : child 4 if a+b+c+d<= x
56 :
57 : Which is equivalent to choosing child
58 : (a<=x) + (a+b<=x) + (a+b+c<=x) + (a+b+c+d<=x)
59 : which can be computed branchlessly. The value of e only comes into
60 : play in the restriction that x is in [0, S), and e is inlcuded in the
61 : value of S.
62 :
63 : There are two details left to discuss in order to search recursively.
64 : First, in order to remove an element for sampling without
65 : replacement, we need to know the weight of the element we're
66 : removing. As with e above, the weights may not be stored explicitly,
67 : but as long as we keep track of the weight of the current subtree as
68 : we search recursively, we can obtain the weight without much work.
69 : Thus, we need to know the sum S' of the chosen subtree, i in [0, R).
70 : For j in [-1, R), define
71 : / 0 if j==-1
72 : l[j] = | left_sum[ j ] if j in [0, R-1)
73 : \ S if j==R-1
74 : Essentially, we're computing the natural extension values for
75 : left_sum on the left and right side. Then observe that S' = l[i] -
76 : l[i-1].
77 :
78 : Secondly, in order to search recursively, we need to adjust the query
79 : value to put it into [0, S'). Specifically, we need to subtract off
80 : the sum of all the children to the left of the child we've chosen.
81 : If we've chosen child i, then that's just l[i-1].
82 :
83 : All the above extends easily to any radix, but 3, 5, and 9 = 1+2^n
84 : are the natural ones. I only tried 5 and 9, and 9 was substantially
85 : faster.
86 :
87 : It's worth noting that storing left sums means that you have to do
88 : the most work when you update the left-most child. Because we
89 : typically store the weights largest to smallest, the left-most child
90 : is the one we delete the most frequently, so now our most common case
91 : (probability-wise) and our worst case (performance-wise) are the
92 : same, which seems bad. I initially implemented this using right sums
93 : instead of left to address this problem. They're less intuitive, but
94 : also work. However, I found that having an unpredictable loop in the
95 : deletion method was far worse than just updating each element, which
96 : means that the "less work" advantage of right sums went away. */
97 :
98 : struct __attribute__((aligned(8UL*(R-1UL)))) tree_ele {
99 : /* left_sum stores the cumulative weight of the subtrees at this
100 : node. See the long note above for more information. */
101 : ulong left_sum[ R-1 ];
102 : };
103 : typedef struct tree_ele tree_ele_t;
104 :
105 : struct __attribute__((aligned(64UL))) fd_wsample_private {
106 : ulong total_cnt;
107 : ulong total_weight;
108 : ulong unremoved_cnt;
109 : ulong unremoved_weight; /* Initial value for S explained above */
110 :
111 : /* internal_node_cnt and height are both determined by the number of
112 : leaves in the original tree, via the following formulas:
113 : height = ceil(log_r(leaf_cnt))
114 : internal_node_cnt = sum_{i=0}^{height-1} R^i
115 : height and internal_node_cnt both exclude the leaves, which are
116 : only implicit.
117 :
118 : All the math seems to disallow leaf_cnt==0, but for conveniece, we
119 : do allow it. height==internal_node_cnt==0 in that case.
120 :
121 : height actually fits in a uchar. Storing as ulong is more natural,
122 : but we don't want to spill over into another cache line.
123 :
124 : If poisoned_mode==1, then all sample calls will return poisoned.*/
125 : ulong internal_node_cnt;
126 : ulong poisoned_weight;
127 : uint height;
128 : char restore_enabled;
129 : char poisoned_mode;
130 : /* Two bytes of padding here */
131 :
132 : fd_chacha20rng_t * rng;
133 :
134 : /* tree: Here's where the actual tree is stored, at indices [0,
135 : internal_node_cnt). The indexing scheme is explained in the long
136 : comment above.
137 :
138 : If restore_enabled==1, then indices [internal_node_cnt+1,
139 : 2*internal_node_cnt+1) store a copy of the tree after construction
140 : but before any deletion so that restoring deleted elements can be
141 : implemented as a memcpy.
142 :
143 : The tree iteself is surrounded by two dummy elements, dummy, and
144 : tree[internal_node_cnt], that aren't actually used. This is
145 : because searching the tree branchlessly involves some out of bounds
146 : reads, and although the value is immediately discarded, it's better
147 : to know where exactly those reads might go. */
148 : tree_ele_t dummy;
149 : tree_ele_t tree[];
150 : };
151 :
152 : typedef struct fd_wsample_private fd_wsample_t;
153 :
154 :
155 : FD_FN_CONST ulong
156 1829604 : fd_wsample_align( void ) {
157 1829604 : return 64UL;
158 1829604 : }
159 :
160 : /* Returns -1 on failure */
161 : static inline int
162 : compute_height( ulong leaf_cnt,
163 : ulong * out_height,
164 767424 : ulong * out_internal_cnt ) {
165 : /* This max is a bit conservative. The actual max is height <= 25,
166 : and leaf_cnt < 5^25 approx 2^58. A tree that large would take an
167 : astronomical amount of memory, so we just retain this max for the
168 : moment. */
169 767424 : if( FD_UNLIKELY( leaf_cnt >= UINT_MAX-2UL ) ) return -1;
170 :
171 767424 : ulong height = 0;
172 767424 : ulong internal = 0UL;
173 767424 : ulong powRh = 1UL; /* = R^height */
174 3430491 : while( leaf_cnt>powRh ) {
175 2663067 : internal += powRh;
176 2663067 : powRh *= R;
177 2663067 : height++;
178 2663067 : }
179 767424 : *out_height = height;
180 767424 : *out_internal_cnt = internal;
181 767424 : return 0;
182 767424 : }
183 :
184 : FD_FN_CONST ulong
185 446550 : fd_wsample_footprint( ulong ele_cnt, int restore_enabled ) {
186 446550 : ulong height;
187 446550 : ulong internal_cnt;
188 : /* Computing the closed form of the sum in compute_height, we get
189 : internal_cnt = 1/8 * (9^ceil(log_9( ele_cnt ) ) - 1)
190 : x <= ceil( x ) < x+1
191 : 1/8 * ele_cnt - 1/8 <= internal_cnt < 9/8 * ele_cnt - 1/8
192 : */
193 446550 : if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) return 0UL;
194 446550 : return sizeof(fd_wsample_t) + ((restore_enabled?2UL:1UL)*internal_cnt + 1UL)*sizeof(tree_ele_t);
195 446550 : }
196 :
197 : fd_wsample_t *
198 320862 : fd_wsample_join( void * shmem ) {
199 320862 : if( FD_UNLIKELY( !shmem ) ) {
200 0 : FD_LOG_WARNING(( "NULL shmem" ));
201 0 : return NULL;
202 0 : }
203 :
204 320862 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
205 0 : FD_LOG_WARNING(( "misaligned shmem" ));
206 0 : return NULL;
207 0 : }
208 320862 : return (fd_wsample_t *)shmem;
209 320862 : }
210 :
211 : /* Note: The following optimization insights are not used in this
212 : high radix implmentation. Performance in the deletion case is much
213 : more important than in the non-deletion case, and it's not clear how
214 : to translate this. I'm leaving the code and comment because it is a
215 : useful and non-trivial insight. */
216 : #if 0
217 : /* If we assume the probability of querying node i is proportional to
218 : 1/i, then observe that the midpoint of the probability mass in the
219 : continuous approximation is the solution to (in Mathematica syntax):
220 :
221 : Integrate[ 1/i, {i, lo, hi}] = 2*Integrate[ 1/i, {i, lo, mid} ]
222 :
223 : which gives mid = sqrt(lo*hi). This is in contrast to when the
224 : integrand is a constant, which gives the normal binary search rule:
225 : mid=(lo+hi)/2.
226 :
227 : Thus, we want the search to follow this modified binary search rule,
228 : since that'll approximately split the stake weight/probability mass
229 : in half at each step.
230 :
231 : This is almost as nice to work with mathematically as the normal
232 : binary search rule. The jth entry from the left at level k is the
233 : region [ N^((1/2^k)*j), N^((1/2^k)*(j+1)) ). We're basically doing
234 : binary search in the log domain.
235 :
236 : Rather than trying to compute these transcendental functions, this
237 : simple recursive implementation gives the treap the right shape by
238 : setting prio very carefully, since higher prio means closer to the
239 : root. This may be a slight abuse of a treap, but it's easier than
240 : implementing a whole custom tree for just this purpose.
241 :
242 : We want to be careful to ensure that the recursion is tightly
243 : bounded. From a continuous perspective, that's not a problem: the
244 : widest interval at level k is the last one, and we break when the
245 : interval's width is less than 1. Solving 1=N-N^(1-(1/2^k)) for k
246 : yields k=-lg(1-log(N-1)/log(N)). k is monotonically increasing as a
247 : function of N, which means that for all N,
248 : k <= -lg(1-log(N_max-1)/log(N_max)) < 37
249 : since N_max=2^32-1.
250 :
251 : The math is more complicated with rounding and finite precision, but
252 : sqrt(lo*hi) is very different from lo and hi unless lo and hi are
253 : approximately the same. In that case, lo<mid and mid<hi ensures that
254 : both intervals are strictly smaller than the interval they came from,
255 : which prevents an infinite loop. */
256 : static inline void
257 : seed_recursive( treap_ele_t * pool,
258 : uint lo,
259 : uint hi,
260 : uint prio ) {
261 : uint mid = (uint)(sqrtf( (float)lo*(float)hi ) + 0.5f);
262 : if( (lo<mid) & (mid<hi) ) {
263 : /* since we start with lo=1, shift by 1 */
264 : pool[mid-1U].prio = prio;
265 : seed_recursive( pool, lo, mid, prio-1U );
266 : seed_recursive( pool, mid, hi, prio-1U );
267 : }
268 : }
269 : #endif
270 :
271 :
272 : void *
273 : fd_wsample_new_init( void * shmem,
274 : fd_chacha20rng_t * rng,
275 : ulong ele_cnt,
276 : int restore_enabled,
277 320874 : int opt_hint ) {
278 320874 : if( FD_UNLIKELY( !shmem ) ) {
279 0 : FD_LOG_WARNING(( "NULL shmem" ));
280 0 : return NULL;
281 0 : }
282 :
283 320874 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
284 0 : FD_LOG_WARNING(( "misaligned shmem" ));
285 0 : return NULL;
286 0 : }
287 :
288 320874 : ulong height;
289 320874 : ulong internal_cnt;
290 320874 : if( FD_UNLIKELY( compute_height( ele_cnt, &height, &internal_cnt ) ) ) {
291 0 : FD_LOG_WARNING(( "bad ele_cnt" ));
292 0 : return NULL;
293 0 : }
294 :
295 320874 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
296 :
297 320874 : sampler->total_weight = 0UL;
298 320874 : sampler->unremoved_cnt = 0UL;
299 320874 : sampler->unremoved_weight = 0UL;
300 320874 : sampler->internal_node_cnt = internal_cnt;
301 320874 : sampler->poisoned_weight = 0UL;
302 320874 : sampler->height = (uint)height;
303 320874 : sampler->restore_enabled = (char)!!restore_enabled;
304 320874 : sampler->poisoned_mode = 0;
305 320874 : sampler->rng = rng;
306 :
307 320874 : fd_memset( sampler->tree, (char)0, internal_cnt*sizeof(tree_ele_t) );
308 :
309 320874 : (void)opt_hint; /* Not used at the moment */
310 :
311 320874 : return shmem;
312 320874 : }
313 :
314 : void *
315 : fd_wsample_new_add( void * shmem,
316 202282431 : ulong weight ) {
317 202282431 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
318 202282431 : if( FD_UNLIKELY( !sampler ) ) return NULL;
319 :
320 202282431 : if( FD_UNLIKELY( weight==0UL ) ) {
321 0 : FD_LOG_WARNING(( "zero weight entry found" ));
322 0 : return NULL;
323 0 : }
324 202282431 : if( FD_UNLIKELY( sampler->total_weight+weight<weight ) ) {
325 0 : FD_LOG_WARNING(( "total weight too large" ));
326 0 : return NULL;
327 0 : }
328 :
329 202282431 : tree_ele_t * tree = sampler->tree;
330 202282431 : ulong i = sampler->internal_node_cnt + sampler->unremoved_cnt;
331 :
332 1011139644 : for( ulong h=0UL; h<sampler->height; h++ ) {
333 808857213 : ulong parent = (i-1UL)/R;
334 808857213 : ulong child_idx = i-1UL - R*parent; /* in [0, R) */
335 4956512352 : for( ulong k=child_idx; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] += weight;
336 808857213 : i = parent;
337 808857213 : }
338 :
339 202282431 : sampler->unremoved_cnt++;
340 202282431 : sampler->total_cnt++;
341 202282431 : sampler->unremoved_weight += weight;
342 202282431 : sampler->total_weight += weight;
343 :
344 202282431 : return shmem;
345 202282431 : }
346 :
347 : void *
348 : fd_wsample_new_fini( void * shmem,
349 320874 : ulong poisoned_weight ) {
350 320874 : fd_wsample_t * sampler = (fd_wsample_t *)shmem;
351 320874 : if( FD_UNLIKELY( !sampler ) ) return NULL;
352 :
353 320874 : if( FD_UNLIKELY( sampler->total_weight+poisoned_weight<sampler->total_weight ) ) {
354 0 : FD_LOG_WARNING(( "poisoned_weight caused overflow" ));
355 0 : return NULL;
356 0 : }
357 :
358 320874 : sampler->poisoned_weight = poisoned_weight;
359 :
360 320874 : if( sampler->restore_enabled ) {
361 : /* Copy the sampler to make restore fast. */
362 312975 : fd_memcpy( sampler->tree+sampler->internal_node_cnt+1UL, sampler->tree, sampler->internal_node_cnt*sizeof(tree_ele_t) );
363 312975 : }
364 :
365 320874 : return (void *)sampler;
366 320874 : }
367 :
368 : void *
369 320778 : fd_wsample_leave( fd_wsample_t * sampler ) {
370 320778 : if( FD_UNLIKELY( !sampler ) ) {
371 0 : FD_LOG_WARNING(( "NULL sampler" ));
372 0 : return NULL;
373 0 : }
374 :
375 320778 : return (void *)sampler;
376 320778 : }
377 :
378 : void *
379 320778 : fd_wsample_delete( void * shmem ) {
380 320778 : if( FD_UNLIKELY( !shmem ) ) {
381 0 : FD_LOG_WARNING(( "NULL shmem" ));
382 0 : return NULL;
383 0 : }
384 :
385 320778 : if( FD_UNLIKELY( !fd_ulong_is_aligned( (ulong)shmem, fd_wsample_align() ) ) ) {
386 0 : FD_LOG_WARNING(( "misaligned shmem" ));
387 0 : return NULL;
388 0 : }
389 320778 : return shmem;
390 320778 : }
391 :
392 :
393 :
394 4506723 : fd_chacha20rng_t * fd_wsample_get_rng( fd_wsample_t * sampler ) { return sampler->rng; }
395 :
396 :
397 : /* TODO: Should this function exist at all? */
398 : void
399 : fd_wsample_seed_rng( fd_chacha20rng_t * rng,
400 4506723 : uchar seed[static 32] ) {
401 4506723 : fd_chacha20rng_init( rng, seed );
402 4506723 : }
403 :
404 :
405 : fd_wsample_t *
406 5414256 : fd_wsample_restore_all( fd_wsample_t * sampler ) {
407 5414256 : if( FD_UNLIKELY( !sampler->restore_enabled ) ) return NULL;
408 :
409 5414253 : sampler->unremoved_weight = sampler->total_weight;
410 5414253 : sampler->unremoved_cnt = sampler->total_cnt;
411 5414253 : sampler->poisoned_mode = 0;
412 :
413 5414253 : fd_memcpy( sampler->tree, sampler->tree+sampler->internal_node_cnt+1UL, sampler->internal_node_cnt*sizeof(tree_ele_t) );
414 5414253 : return sampler;
415 5414256 : }
416 :
417 : #define fd_ulong_if_force( c, t, f ) (__extension__({ \
418 : ulong result; \
419 : __asm__( "testl %1, %1; \n\t" \
420 : "movq %3, %0; \n\t" \
421 : "cmovne %2, %0; \n\t" \
422 : : "=&r"(result) \
423 : : "r"(c), "rm"(t), "rmi"(f) \
424 : : "cc" ); \
425 : result; \
426 : }))
427 :
428 : /* Helper methods for sampling functions */
429 : typedef struct { ulong idx; ulong weight; } idxw_pair_t; /* idx in [0, total_cnt) */
430 :
431 : /* Assumes query in [0, unremoved_weight), which implies
432 : unremoved_weight>0, so the tree can't be empty. */
433 : static inline idxw_pair_t
434 : fd_wsample_map_sample_i( fd_wsample_t const * sampler,
435 1304348784 : ulong query ) {
436 1304348784 : tree_ele_t const * tree = sampler->tree;
437 :
438 1304348784 : ulong cursor = 0UL;
439 1304348784 : ulong S = sampler->unremoved_weight;
440 5930945580 : for( ulong h=0UL; h<sampler->height; h++ ) {
441 4626596796 : tree_ele_t const * e = tree+cursor;
442 4626596796 : ulong x = query;
443 4626596796 : ulong child_idx = 0UL;
444 :
445 841078492 : #if FD_HAS_AVX512 && R==9
446 841078492 : __mmask8 mask = _mm512_cmple_epu64_mask( wwv_ld( e->left_sum ), wwv_bcast( x ) );
447 841078492 : child_idx = (ulong)fd_uchar_popcnt( mask );
448 : #else
449 34069664736 : for( ulong i=0UL; i<R-1UL; i++ ) child_idx += (ulong)(e->left_sum[ i ]<=x);
450 3785518304 : #endif
451 :
452 : /* See the note at the top of this file for the explanation of l[i]
453 : and l[i-1]. Because this is fd_ulong_if and not a ternary, these
454 : can read/write out of what you would think the appropriate bounds
455 : are. The dummy elements, as described along with tree makes this
456 : safe. */
457 : #if 0
458 : ulong li = fd_ulong_if( child_idx<R-1UL, e->left_sum[ child_idx ], S );
459 : ulong lm1 = fd_ulong_if( child_idx>0UL, e->left_sum[ child_idx-1UL ], 0UL );
460 : #elif 0
461 : ulong li = fd_ulong_if_force( child_idx<R-1UL, e->left_sum[ child_idx ], S );
462 : ulong lm1 = fd_ulong_if_force( child_idx>0UL, e->left_sum[ child_idx-1UL ], 0UL );
463 : #else
464 4626596796 : ulong * temp = (ulong *)e->left_sum;
465 4626596796 : ulong orig_m1 = temp[ -1 ]; ulong orig_Rm1 = temp[ R-1UL ];
466 4626596796 : temp[ -1 ] = 0UL; temp[ R-1UL ] = S;
467 4626596796 : ulong li = temp[ child_idx ];
468 4626596796 : ulong lm1 = temp[ child_idx-1UL ];
469 4626596796 : temp[ -1 ] = orig_m1; temp[ R-1UL ] = orig_Rm1;
470 4626596796 : #endif
471 :
472 4626596796 : query -= lm1;
473 4626596796 : S = li - lm1;
474 4626596796 : cursor = R*cursor + child_idx + 1UL;
475 4626596796 : }
476 1304348784 : idxw_pair_t to_return = { .idx = cursor - sampler->internal_node_cnt, .weight = S };
477 1304348784 : return to_return;
478 1304348784 : }
479 :
480 : ulong
481 : fd_wsample_map_sample( fd_wsample_t * sampler,
482 762249045 : ulong query ) {
483 762249045 : return fd_wsample_map_sample_i( sampler, query ).idx;
484 762249045 : }
485 :
486 :
487 :
488 : /* Also requires the tree to be non-empty */
489 : static inline void
490 : fd_wsample_remove( fd_wsample_t * sampler,
491 545807646 : idxw_pair_t to_remove ) {
492 545807646 : ulong cursor = to_remove.idx + sampler->internal_node_cnt;
493 545807646 : tree_ele_t * tree = sampler->tree;
494 :
495 2692696269 : for( ulong h=0UL; h<sampler->height; h++ ) {
496 2146888623 : ulong parent = (cursor-1UL)/R;
497 2146888623 : ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
498 14509101 : #if FD_HAS_AVX512 && R==9
499 14509101 : wwv_t weight = wwv_bcast( to_remove.weight );
500 14509101 : wwv_t left_sum = wwv_ld( tree[ parent ].left_sum );
501 14509101 : __m128i _child_idx = _mm_set1_epi16( (short) child_idx );
502 14509101 : __mmask8 mask = _mm_cmplt_epi16_mask( _child_idx, _mm_setr_epi16( 1, 2, 3, 4, 5, 6, 7, 8 ) );
503 14509101 : left_sum = _mm512_mask_sub_epi64( left_sum, mask, left_sum, weight );
504 14509101 : wwv_st( tree[ parent ].left_sum, left_sum );
505 : #elif 0
506 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if( child_idx<=k, to_remove.weight, 0UL );
507 : #elif 0
508 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= fd_ulong_if_force( child_idx<=k, to_remove.weight, 0UL );
509 : #elif 1
510 : /* The compiler loves inserting a difficult to predict branch for
511 : fd_ulong_if, but this forces it not to do that. */
512 19191415698 : for( ulong k=0UL; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= (ulong)(((long)(child_idx - k - 1UL))>>63) & to_remove.weight;
513 : #else
514 : /* This version does the least work, but has a hard-to-predict
515 : branch. The branchless versions are normally substantially
516 : faster. */
517 : for( ulong k=child_idx; k<R-1UL; k++ ) tree[ parent ].left_sum[ k ] -= to_remove.weight;
518 : #endif
519 2146888623 : cursor = parent;
520 2146888623 : }
521 545807646 : sampler->unremoved_cnt--;
522 545807646 : sampler->unremoved_weight -= to_remove.weight;
523 545807646 : }
524 :
525 : static inline ulong
526 : fd_wsample_find_weight( fd_wsample_t const * sampler,
527 3707907 : ulong idx /* in [0, total_cnt) */) {
528 : /* The fact we don't store the weights explicitly makes this function
529 : more complicated, but this is not used very frequently. */
530 3707907 : tree_ele_t const * tree = sampler->tree;
531 3707907 : ulong cursor = idx + sampler->internal_node_cnt;
532 :
533 : /* Initialize to the 0 height case */
534 3707907 : ulong lm1 = 0UL;
535 3707907 : ulong li = sampler->unremoved_weight;
536 :
537 3738189 : for( ulong h=0UL; h<sampler->height; h++ ) {
538 3737997 : ulong parent = (cursor-1UL)/R;
539 3737997 : ulong child_idx = cursor-1UL - R*parent; /* in [0, R) */
540 :
541 : /* If child_idx < R-1, we can compute the weight easily. If
542 : child_idx==R-1, the computation is S - left_sum[ R-2 ], but we
543 : don't know S, so we need to continue up the tree. */
544 3737997 : lm1 += fd_ulong_if( child_idx>0UL, tree[ parent ].left_sum[ child_idx-1UL ], 0UL );
545 3737997 : if( FD_LIKELY( child_idx<R-1UL ) ) {
546 3707715 : li = tree[ parent ].left_sum[ child_idx ];
547 3707715 : break;
548 3707715 : }
549 :
550 30282 : cursor = parent;
551 30282 : }
552 :
553 3707907 : return li - lm1;
554 3707907 : }
555 :
556 : void
557 : fd_wsample_remove_idx( fd_wsample_t * sampler,
558 3707907 : ulong idx ) {
559 :
560 3707907 : ulong weight = fd_wsample_find_weight( sampler, idx );
561 3707907 : idxw_pair_t r = { .idx = idx, .weight = weight };
562 3707907 : fd_wsample_remove( sampler, r );
563 3707907 : }
564 :
565 :
566 : #if FD_HAS_AVX512
567 :
568 : /* TRAVERSE_LEVEL Takes in and updates s (a vector with all elements set
569 : to the unremoved weight in the current subtree), x or xprime (a
570 : vector with all elements set to the current query value, or the
571 : value+1, respectively), and cursor (a ulong of the current position
572 : in the traversal). See fd_wsample_map_sample_i for more about these
573 : values. Calling this macro height times is the same as calling
574 : fd_wsample_map_sample_i, except for that it declares mask{i} and
575 : cursor{i}, to save work in PROPAGATE_LEVEL, which is like
576 : fd_wsample_remove.
577 :
578 : This only works in the R=9 case, but we'll explain it as if R=5 like
579 : before. Then, v, a vector of tree elements has values
580 : ----------------------------------------
581 : | a | a+b | a+b+c | a+b+c+d |
582 : ----------------------------------------
583 : Then, like before, the desired state is
584 : If... pick, and l_{i-1} and l_i
585 : x < a child 0 0 a
586 : a <= x < a+b child 1 a a+b
587 : a+b <= x < a+b+c child 2 a+b a+b+c
588 : a+b+c <= x < a+b+c+d child 3 a+b+c a+b+c+d
589 : a+b+c+d<= x child 4 a+b+c+d a+b+c+d+e
590 :
591 : Below, there are two pretty different implementations, and this
592 : explains both:
593 :
594 : Implementation 1:
595 : We want to get information about v0<=x in as much as the vector as
596 : fast as possible. We could do some kind of compress and then bcast,
597 : but operations with mask registers are frustratingly slow. Instead,
598 : we first define x'=x+1, so then v0<=x is equivalent to v0-x'<0, or
599 : whether v0-x' has the high bit set. Because of
600 : fd_chacha20rng_ulong_roll's contract, we know x<ULONG_MAX, so forming
601 : x' is safe. We then use a _mm512_permutexvar_epi8 to essentially
602 : broadcast the high byte from each of the ulongs (really we just need
603 : the high bit) to the whole vector. A popcnt gives us our base
604 : values.
605 :
606 : Case compressed_signs popcnt popcnt-1
607 : x < a 0 0 -1
608 : a <= x < a+b 0x80 1 0
609 : a+b <= x < a+b+c 0x8080 2 1
610 : a+b+c <= x < a+b+c+d 0x808080 3 2
611 : a+b+c+d<= x 0x80808080 4 3
612 :
613 : Then, keeping in mind that _mm512_permutex2var_epi64 will select an
614 : element from the second vector if the next highest bit is set, using
615 : popcnt-1 as the selector and (vec, 0) as (a,b) gives l_{i-1}; and
616 : using popcnt as the selector with (vec, s) gives l_i.
617 :
618 : Finally, we just need to produce the mask mask{i}, which we can do
619 : with wwv_lt, off the critical path.
620 :
621 :
622 : Implementation 2
623 : We first use wwv_le to compare v0<=x to form lmask. Then We extract
624 : the most significant bit by lmask^(lmask>>1). Note that below the
625 : bits are written least to most significant, which is backwards of the
626 : normal way.
627 :
628 : Case lmask single_bit ~lmask
629 : x < a 0 0 0 0 0 0 0 0 1 1 1 1
630 : a <= x < a+b 1 0 0 0 1 0 0 0 0 1 1 1
631 : a+b <= x < a+b+c 1 1 0 0 0 1 0 0 0 0 1 1
632 : a+b+c <= x < a+b+c+d 1 1 1 0 0 0 1 0 0 0 0 1
633 : a+b+c+d<= x 1 1 1 1 0 0 0 1 0 0 0 0
634 :
635 : Then we use _mm512_mask_compress_epi64 to move the element
636 : corresponding to the first set bit to the 0th position, or the src
637 : vector if there aren't any set bits. From there, we can use
638 : _mm512_broadcastq_epi64 to fill a vector with that value. Applying
639 : this trick to single_bit gives us l_{i-1} and to ~lmask give l_i. */
640 : #define FD_WSAMPLE_IMPLEMENTATION 2
641 :
642 : #if FD_WSAMPLE_IMPLEMENTATION==1
643 : #define PREPARE() \
644 : ulong cursor = 0UL; \
645 : wwv_t xprime = wwv_bcast( unif+1UL ); \
646 : wwv_t s = wwv_bcast( sampler->unremoved_weight ); \
647 : wwv_t high_bit_mask = wwv_bcast( 0x8000000000000000UL ); \
648 : wwv_t gather_signs_idx = wwv_bcast( 0x070F171F272F373FUL )
649 :
650 : #define TRAVERSE_LEVEL(i) \
651 : __mmask8 mask##i; \
652 : ulong cursor##i = cursor; \
653 : wwv_t vec##i; \
654 : do { \
655 : wwv_t vec = wwv_ld( tree[cursor].left_sum ); \
656 : wwv_t sign = wwv_and( high_bit_mask, wwv_sub( vec, xprime ) ); \
657 : wwv_t compressed_signs = _mm512_permutexvar_epi8( gather_signs_idx, sign ); \
658 : wwv_t popcnt = _mm512_popcnt_epi64( compressed_signs ); \
659 : wwv_t li = _mm512_permutex2var_epi64( vec, popcnt, s ); \
660 : wwv_t lim1 = _mm512_permutex2var_epi64( vec, wwv_sub( popcnt, wwv_one() ), wwv_zero() ); \
661 : mask##i = _mm512_cmpge_epu64_mask( vec, xprime ); \
662 : xprime = wwv_sub( xprime, lim1 ); \
663 : s = wwv_sub( li, lim1 ); \
664 : ulong child_idx = (ulong)_mm_extract_epi64( _mm512_castsi512_si128( popcnt ), 0 ); \
665 : cursor = R*cursor + child_idx + 1UL; \
666 : vec##i = vec; \
667 : } while( 0 )
668 :
669 : #define PROPAGATE_LEVEL(i) \
670 : wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
671 :
672 : #define FINALIZE() \
673 : do { \
674 : sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
675 : sampler->unremoved_cnt--; \
676 : } while( 0 )
677 :
678 : #elif FD_WSAMPLE_IMPLEMENTATION==2
679 :
680 : #define PREPARE() \
681 262920165 : ulong cursor = 0UL; \
682 262920165 : wwv_t x = wwv_bcast( unif ); \
683 262920165 : wwv_t s = wwv_bcast( sampler->unremoved_weight ); \
684 :
685 : #define TRAVERSE_LEVEL(i) \
686 1051680660 : __mmask8 mask##i; \
687 1051680660 : ulong cursor##i = cursor; \
688 1051680660 : wwv_t vec##i; \
689 1051680660 : do { \
690 1051680660 : wwv_t vec = wwv_ld( tree[cursor].left_sum ); \
691 1051680660 : __mmask8 lmask = _mm512_cmple_epu64_mask( vec, x ); \
692 1051680660 : cursor = R*cursor + (ulong)fd_uchar_popcnt( lmask ) + 1UL; \
693 1051680660 : __mmask8 single_bit = _kxor_mask8( lmask, _kshiftri_mask8( lmask, 1U ) ); \
694 1051680660 : wwv_t lim1 = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_maskz_compress_epi64( single_bit, vec ) ) ); \
695 1051680660 : x = wwv_sub( x, lim1 ); \
696 1051680660 : mask##i = _knot_mask8( lmask ); \
697 1051680660 : wwv_t li = _mm512_broadcastq_epi64( _mm512_castsi512_si128( _mm512_mask_compress_epi64( s, mask##i, vec ) ) ); \
698 1051680660 : s = wwv_sub( li, lim1 ); \
699 1051680660 : vec##i = vec; \
700 1051680660 : } while( 0 )
701 :
702 : #define PROPAGATE_LEVEL(i) \
703 1051680660 : wwv_st( tree[cursor##i].left_sum, wwv_sub_if( mask##i, vec##i, s, vec##i ) );
704 :
705 : #define FINALIZE() \
706 262920165 : do { \
707 262920165 : sampler->unremoved_weight -= (ulong)_mm256_extract_epi64( _mm512_castsi512_si256( s ), 0 ); \
708 262920165 : sampler->unremoved_cnt--; \
709 262920165 : } while( 0 )
710 :
711 : #endif
712 : #else
713 : #define FD_WSAMPLE_IMPLEMENTATION 0
714 : #endif
715 :
716 : /* For now, implement the _many functions as loops over the single
717 : sample functions. It is possible to do better though. */
718 :
719 : void
720 : fd_wsample_sample_many( fd_wsample_t * sampler,
721 : ulong * idxs,
722 3694788 : ulong cnt ) {
723 373172133 : for( ulong i=0UL; i<cnt; i++ ) idxs[i] = fd_wsample_sample( sampler );
724 3694788 : }
725 :
726 : void
727 : fd_wsample_sample_and_remove_many( fd_wsample_t * sampler,
728 : ulong * idxs,
729 4905405 : ulong cnt ) {
730 : /* The compiler doesn't seem to like inlining the call to
731 : fd_wsample_sample_and_remove, which hurts performance by a few
732 : percent because it triggers worse behavior in the CPUs front end.
733 : To address this, we manually inline it here. */
734 807446943 : for( ulong i=0UL; i<cnt; i++ ) {
735 802541538 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) { idxs[ i ] = FD_WSAMPLE_EMPTY; continue; }
736 802317297 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) { idxs[ i ] = FD_WSAMPLE_INDETERMINATE; continue; }
737 801375105 : ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
738 801375105 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
739 3879 : idxs[ i ] = FD_WSAMPLE_INDETERMINATE;
740 3879 : sampler->poisoned_mode = 1;
741 3879 : continue;
742 3879 : }
743 267123742 : #if FD_WSAMPLE_IMPLEMENTATION > 0
744 267123742 : if( FD_LIKELY( sampler->height==4UL ) ) {
745 262917450 : tree_ele_t * tree = sampler->tree;
746 262917450 : PREPARE();
747 :
748 262917450 : TRAVERSE_LEVEL(0);
749 262917450 : TRAVERSE_LEVEL(1);
750 262917450 : TRAVERSE_LEVEL(2);
751 262917450 : TRAVERSE_LEVEL(3);
752 262917450 : PROPAGATE_LEVEL(0);
753 262917450 : PROPAGATE_LEVEL(1);
754 262917450 : PROPAGATE_LEVEL(2);
755 262917450 : PROPAGATE_LEVEL(3);
756 :
757 262917450 : FINALIZE();
758 262917450 : idxs[ i ] = cursor - sampler->internal_node_cnt;
759 262917450 : continue;
760 262917450 : }
761 4206292 : #endif
762 538453776 : idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
763 538453776 : fd_wsample_remove( sampler, p );
764 538453776 : idxs[ i ] = p.idx;
765 538453776 : }
766 4905405 : }
767 :
768 :
769 :
770 : ulong
771 762503751 : fd_wsample_sample( fd_wsample_t * sampler ) {
772 762503751 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
773 762503745 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) return FD_WSAMPLE_INDETERMINATE;
774 762503745 : ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
775 762503745 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) return FD_WSAMPLE_INDETERMINATE;
776 762249045 : return (ulong)fd_wsample_map_sample( sampler, unif );
777 762503745 : }
778 :
779 : ulong
780 3651417 : fd_wsample_sample_and_remove( fd_wsample_t * sampler ) {
781 3651417 : if( FD_UNLIKELY( !sampler->unremoved_weight ) ) return FD_WSAMPLE_EMPTY;
782 3651333 : if( FD_UNLIKELY( sampler->poisoned_mode ) ) return FD_WSAMPLE_INDETERMINATE;
783 3648948 : ulong unif = fd_chacha20rng_ulong_roll( sampler->rng, sampler->unremoved_weight+sampler->poisoned_weight );
784 3648948 : if( FD_UNLIKELY( unif>=sampler->unremoved_weight ) ) {
785 270 : sampler->poisoned_mode = 1;
786 270 : return FD_WSAMPLE_INDETERMINATE;
787 270 : }
788 :
789 1216226 : #if FD_WSAMPLE_IMPLEMENTATION > 0
790 1216226 : if( FD_LIKELY( sampler->height==4UL ) ) {
791 2715 : tree_ele_t * tree = sampler->tree;
792 2715 : PREPARE();
793 :
794 2715 : TRAVERSE_LEVEL(0);
795 2715 : TRAVERSE_LEVEL(1);
796 2715 : TRAVERSE_LEVEL(2);
797 2715 : TRAVERSE_LEVEL(3);
798 2715 : PROPAGATE_LEVEL(0);
799 2715 : PROPAGATE_LEVEL(1);
800 2715 : PROPAGATE_LEVEL(2);
801 2715 : PROPAGATE_LEVEL(3);
802 2715 : FINALIZE();
803 2715 : return cursor - sampler->internal_node_cnt;
804 2715 : }
805 1213511 : #endif
806 3645963 : idxw_pair_t p = fd_wsample_map_sample_i( sampler, unif );
807 3645963 : fd_wsample_remove( sampler, p );
808 3645963 : return p.idx;
809 3648858 : }
|