LCOV - code coverage report
Current view: top level - util/math - fd_sqrt.h (source / functions) Hit Total Coverage
Test: cov.lcov Lines: 40 40 100.0 %
Date: 2025-01-08 12:08:44 Functions: 18 48 37.5 %

          Line data    Source code
       1             : #ifndef HEADER_fd_src_util_math_fd_sqrt_h
       2             : #define HEADER_fd_src_util_math_fd_sqrt_h
       3             : 
       4             : /* Portable robust integer sqrt.  Adapted from Pyth Oracle. */
       5             : 
       6             : #include "../bits/fd_bits.h"
       7             : 
       8             : FD_PROTOTYPES_BEGIN
       9             : 
      10             : /* Compute y = floor( sqrt( x ) ) for unsigned integers exactly.  This
      11             :    is based on the fixed point iteration:
      12             : 
      13             :      y' = (y + x/y) / 2
      14             : 
      15             :    In continuum math, this converges quadratically to sqrt(x).  This is
      16             :    a useful starting point for a method because we have a relatively low
      17             :    cost unsigned integer division in the machine model and the
      18             :    operations and intermediates in this calculation all have magnitudes
      19             :    smaller than x (so limited concern about overflow issues).
      20             : 
      21             :    We don't do this iteration in integer arithmetic directly because the
      22             :    iteration has two roundoff errors while the actual result only has
      23             :    one (the floor of the continuum value).  As such, even if it did
      24             :    converge in integer arithmetic, probably would not always converge
      25             :    exactly.
      26             : 
      27             :    We instead combine the two divisions into one, yielding single round
      28             :    off error iteration:
      29             : 
      30             :      y' = floor( (y^2 + x) / (2 y) )
      31             : 
      32             :    As y has about half the width of x given a good initial guess, 2 y
      33             :    will not overflow and y^2 + x will be ~2 x and thus any potential
      34             :    intermediate overflow issues are cheap to handle.  If this converges,
      35             :    at convergence:
      36             : 
      37             :         y = (y^2 + x - r) / (2 y)
      38             :      -> 2 y^2 = y^2 + x - r
      39             :      -> y^2 = x - r
      40             : 
      41             :    for some r in [0,2 y-1].  We note that if y = floor( sqrt( x ) )
      42             :    exactly though:
      43             : 
      44             :         y^2 <= x < (y+1)^2
      45             :      -> y^2 <= x < y^2 + 2 y + 1
      46             :      -> y^2  = x - r'
      47             : 
      48             :    for some r' in [0,2 y].  r' is r with the element 2 y added.  And it
      49             :    is possible to have y^2 = x - 2 y.  Namely if x+1 = z^2 for integer
      50             :    z, this becomes y^2 + 2 y + 1 = z^2 -> y = z-1.  That is, when
      51             :    x = z^2 - 1 for integral z, the relationship can never converge.  If
      52             :    we instead used a denominator of 2y+1 in the iteration, r would have
      53             :    the necessary range:
      54             : 
      55             :      y' = floor( (y^2 + x) / (2 y + 1) )
      56             : 
      57             :    At convergence we have:
      58             : 
      59             :         y = (y^2 + x - r) / (2 y+1)
      60             :      -> 2 y^2 + y = (y^2 + x - r)
      61             :      -> y^2 = x-y-r
      62             : 
      63             :    for some r in [0,2 y].  This isn't quite right but we change the
      64             :    recurrence numerator to compensate:
      65             : 
      66             :      y' = floor( (y^2 + y + x) / (2 y + 1) )
      67             :   
      68             :    At convergence we now have:
      69             : 
      70             :      y^2 = x-r
      71             : 
      72             :    for some r in [0,2 y].  That is, at convergence y = floor( sqrt(x) )
      73             :    exactly!  The addition of y to the numerator has not made
      74             :    intermediate overflow much more difficult to deal with either as y
      75             :    <<< x for large x.  So to compute this without intermediate overflow,
      76             :    we compute the terms individually and then combine the remainders
      77             :    appropriately.  x/(2y+1) term is trivial.  The other term,
      78             :    (y^2+y)/(2y+1) is asymptotically approximately y/2.  Breaking it into
      79             :    its asymptotic and residual:
      80             : 
      81             :       (y^2 + y) / (2y+1) = y/2 + ( y^2 + y - (y/2)(2y+1) ) / (2y+1)
      82             :                          = y/2 + ( y^2 + y - y^2 - y/2   ) / (2y+1)
      83             :                          = y/2 + (y/2) / (2y+1)
      84             : 
      85             :    For even y, y/2 = y>>1 = yh and we have the partial quotient yh and
      86             :    remainder yh.  For odd y, we have:
      87             : 
      88             :                          = yh + (1/2) + (yh+(1/2)) / (2y+1)
      89             :                          = yh + ((1/2)(2y+1)+yh+(1/2)) / (2y+1)
      90             :                          = yh + (y+yh+1) / (2y+1)
      91             : 
      92             :    with partial quotent yh and remainder y+yh+1.  This yields the
      93             :    iteration:
      94             : 
      95             :      y ~ sqrt(x)                               // <<< INT_MAX for all x
      96             :      for(;;) {
      97             :        d  = 2*y + 1;
      98             :        qx = x / d; rx = x - qx*d;              // Compute x  /(2y+1), rx in [0,2y]
      99             :        qy = y>>1;  ry = (y&1) ? (y+yh+1) : yh; // Compute y^2/(2y+1), ry in [0,2y]
     100             :        q  = qx+qy; r  = rx+ry;                 // Combine partials, r in [0,4y]
     101             :        if( r>=d ) q++, r-=d;                   // Handle carry (at most 1), r in [0,2y]
     102             :        if( y==q ) break;                       // At convergence y = floor(sqrt(x))
     103             :        y = q;
     104             :      }
     105             : 
     106             :    The better the initial guess, the faster this will converge.  Since
     107             :    convergence is still quadratic though, it will converge even given
     108             :    very simple guesses.  We use:
     109             : 
     110             :      y = sqrt(x) = sqrt( 2^n + d ) <~ 2^(n/2)
     111             :    
     112             :    where n is the index of the MSB and d is in [0,2^n) (i.e. is n bits
     113             :    wide).  Thus:
     114             : 
     115             :      y ~ 2^(n>>1) if n is even and 2^(n>>1) sqrt(2) if n is odd
     116             : 
     117             :    and we can do a simple fixed point calculation to compute this.
     118             : 
     119             :    For small values of x, we encode a 20 entry 3-bit wide lookup table
     120             :    in a 64-bit constant and just do a quick lookup.
     121             :    
     122             :    For types narrower than 64-bit, we can do the iteration portably in a
     123             :    wider type and simplify the operation.  We also do this if the
     124             :    underlying platform supports 128-bit wide types.
     125             : 
     126             :    FIXME: USE THE X86 FPU TO GET A REALLY GOOD INITIAL GUESS? */
     127             : 
     128             : FD_FN_CONST static inline uint
     129  2233602066 : fd_uint_sqrt( uint x ) {
     130  2233602066 :   if( x<21U ) return (uint)((0x49246db6da492248UL >> (3*(int)x)) & 7UL);
     131  1366708395 :   int  n = fd_uint_find_msb( x );
     132  1366708395 :   uint y = ( ((n & 1) ? 0xb504U /* floor( 2^15 sqrt(2) ) */ : 0x8000U /* 2^15 */) >> (15-(n>>1)) );
     133  1366708395 :   ulong _y = (ulong)y;
     134  1366708395 :   ulong _x = (ulong)x;
     135  3558165456 :   for(;;) {
     136  3558165456 :     ulong _z = (_y*_y + _y + _x) / ((_y<<1)+1UL);
     137  3558165456 :     if( _z==_y ) break;
     138  2191457061 :     _y = _z;
     139  2191457061 :   }
     140  1366708395 :   return (uint)_y;
     141  2233602066 : }
     142             : 
     143             : #if FD_HAS_INT128
     144             : 
     145             : FD_FN_CONST static inline ulong
     146  2578436145 : fd_ulong_sqrt( ulong x ) {
     147  2578436145 :   if( x<21UL ) return (0x49246db6da492248UL >> (3*(int)x)) & 7UL;
     148  2403729051 :   int   n = fd_ulong_find_msb( x );
     149  2403729051 :   ulong y = ((n & 1) ? 0xb504f333UL /* floor( 2^31 sqrt(2) ) */ : 0x80000000UL /* 2^31 */) >> (31-(n>>1));
     150  2403729051 :   uint128 _y = (uint128)y;
     151  2403729051 :   uint128 _x = (uint128)x;
     152 10700385480 :   for(;;) {
     153 10700385480 :     uint128 _z = (_y*_y + _y + _x) / ((_y<<1)+(uint128)1);
     154 10700385480 :     if( _z==_y ) break;
     155  8296656429 :     _y = _z;
     156  8296656429 :   }
     157  2403729051 :   return (ulong)_y;
     158  2578436145 : }
     159             : 
     160             : #else
     161             : 
     162             : FD_FN_CONST static inline ulong
     163             : fd_ulong_sqrt( ulong x ) {
     164             :   if( x<21UL ) return (0x49246db6da492248UL >> (3*(int)x)) & 7UL;
     165             :   int   n = fd_ulong_find_msb( x );
     166             :   ulong y = ((n & 1) ? 0xb504f333UL /* floor( 2^31 sqrt(2) ) */ : 0x80000000UL /* 2^31 */) >> (31-(n>>1));
     167             :   for(;;) {
     168             :     ulong d = (y<<1); d++;
     169             :     ulong qx = x / d; ulong rx = x - qx*d;
     170             :     ulong qy = y>>1;  ulong ry = fd_ulong_if( y & 1UL, y+qy+1UL, qy );
     171             :     ulong q  = qx+qy; ulong r  = rx+ry;
     172             :     q += (ulong)(r>=d);
     173             :     if( y==q ) break;
     174             :     y = q;
     175             :   }
     176             :   return y;
     177             : }
     178             : 
     179             : #endif
     180             : 
     181             : /* FIXME: CONSIDER USING A TABLE LOOKUP UCHAR AND, TO A LESSER EXTENT,
     182             :    USHORT FOR THESE */
     183             : 
     184   740625936 : FD_FN_CONST static inline uchar  fd_uchar_sqrt ( uchar  x ) { return (uchar )fd_uint_sqrt( (uint)x ); }
     185   745321893 : FD_FN_CONST static inline ushort fd_ushort_sqrt( ushort x ) { return (ushort)fd_uint_sqrt( (uint)x ); }
     186             : 
     187             : /* These return floor( sqrt( x ) ), undefined behavior for negative x. */
     188             : 
     189   140625936 : FD_FN_CONST static inline schar fd_schar_sqrt( schar x ) { return (schar)fd_uchar_sqrt ( (uchar )x ); }
     190   145321893 : FD_FN_CONST static inline short fd_short_sqrt( short x ) { return (short)fd_ushort_sqrt( (ushort)x ); }
     191   147654237 : FD_FN_CONST static inline int   fd_int_sqrt  ( int   x ) { return (int  )fd_uint_sqrt  ( (uint  )x ); }
     192   148829829 : FD_FN_CONST static inline long  fd_long_sqrt ( long  x ) { return (long )fd_ulong_sqrt ( (ulong )x ); }
     193             : 
     194             : /* These return the floor( re sqrt(x) ) */
     195             : 
     196   150000000 : FD_FN_CONST static inline schar fd_schar_re_sqrt( schar x ) { return fd_schar_if( x>(schar)0,  (schar)fd_uchar_sqrt ( (uchar )x ), (schar)0  ); }
     197   150000000 : FD_FN_CONST static inline short fd_short_re_sqrt( short x ) { return fd_short_if( x>(short)0,  (short)fd_ushort_sqrt( (ushort)x ), (short)0  ); }
     198   150000000 : FD_FN_CONST static inline int   fd_int_re_sqrt  ( int   x ) { return fd_int_if  ( x>       0,  (int  )fd_uint_sqrt  ( (uint  )x ),        0  ); }
     199   150000000 : FD_FN_CONST static inline long  fd_long_re_sqrt ( long  x ) { return fd_long_if ( x>       0L, (long )fd_ulong_sqrt ( (ulong )x ),        0L ); }
     200             : 
     201             : /* These return the floor( sqrt( |x| ) ) */
     202             : 
     203   150000000 : FD_FN_CONST static inline schar fd_schar_sqrt_abs( schar x ) { return (schar)fd_uchar_sqrt ( fd_schar_abs( x ) ); }
     204   150000000 : FD_FN_CONST static inline short fd_short_sqrt_abs( short x ) { return (short)fd_ushort_sqrt( fd_short_abs( x ) ); }
     205   150000000 : FD_FN_CONST static inline int   fd_int_sqrt_abs  ( int   x ) { return (int  )fd_uint_sqrt  ( fd_int_abs  ( x ) ); }
     206   150000000 : FD_FN_CONST static inline long  fd_long_sqrt_abs ( long  x ) { return (long )fd_ulong_sqrt ( fd_long_abs ( x ) ); }
     207             : 
     208             : FD_PROTOTYPES_END
     209             : 
     210             : #endif /* HEADER_fd_src_util_math_fd_sqrt_h */

Generated by: LCOV version 1.14