LCOV - code coverage report
Current view: top level - flamenco/runtime/program/zksdk/rangeproofs - fd_rangeproofs.c (source / functions) Hit Total Coverage
Test: cov.lcov Lines: 180 211 85.3 %
Date: 2024-11-13 11:58:15 Functions: 3 3 100.0 %

          Line data    Source code
       1             : #include "fd_rangeproofs.h"
       2             : 
       3             : static inline int
       4        1128 : batched_range_proof_validate_bits( ulong bit_length ) {
       5        1128 :   if ( FD_LIKELY(
       6        1128 :     bit_length==1  || bit_length==2  || bit_length==4  || bit_length==8 ||
       7        1128 :     bit_length==16 || bit_length==32 || bit_length==64 || bit_length==128
       8        1128 :   ) ) {
       9        1128 :     return FD_RANGEPROOFS_SUCCESS;
      10        1128 :   }
      11           0 :   return FD_RANGEPROOFS_ERROR;
      12        1128 : }
      13             : 
      14             : void
      15             : fd_rangeproofs_delta(
      16             :   uchar       delta[ 32 ],
      17             :   ulong const nm,
      18             :   uchar const y[ 32 ],
      19             :   uchar const z[ 32 ],
      20             :   uchar const zz[ 32 ],
      21             :   uchar const bit_lengths[ 1 ],
      22             :   uchar const batch_len
      23         123 : ) {
      24         123 :   uchar exp_y[ 32 ];
      25         123 :   uchar sum_of_powers_y[ 32 ];
      26         123 :   fd_memcpy( exp_y, y, 32 );
      27         123 :   fd_curve25519_scalar_add( sum_of_powers_y, y, fd_curve25519_scalar_one );
      28         855 :   for( ulong i=nm; i>2; i/=2 ) {
      29         732 :     fd_curve25519_scalar_mul   ( exp_y, exp_y, exp_y );
      30         732 :     fd_curve25519_scalar_muladd( sum_of_powers_y, exp_y, sum_of_powers_y, sum_of_powers_y );
      31         732 :   }
      32         123 :   fd_curve25519_scalar_sub( delta, z, zz );
      33         123 :   fd_curve25519_scalar_mul( delta, delta, sum_of_powers_y );
      34             : 
      35         123 :   uchar neg_exp_z[ 32 ];
      36         123 :   uchar sum_2[ 32 ];
      37         123 :   fd_curve25519_scalar_neg( neg_exp_z, zz );
      38        1107 :   for( ulong i=0; i<batch_len; i++ ) {
      39         984 :     fd_memset( sum_2, 0, 32 );
      40             :     //TODO currently assuming that bit_length[i] is multiple of 8 - need to fix cases: 1, 2, 4
      41         984 :     fd_memset( sum_2, 0xFF, bit_lengths[i] / 8 );
      42         984 :     fd_curve25519_scalar_mul   ( neg_exp_z, neg_exp_z, z );
      43         984 :     fd_curve25519_scalar_muladd( delta, neg_exp_z, sum_2, delta );
      44         984 :   }
      45         123 : }
      46             : 
      47             : int
      48             : fd_rangeproofs_verify(
      49             :   fd_rangeproofs_range_proof_t const * range_proof,
      50             :   fd_rangeproofs_ipp_proof_t const *   ipp_proof,
      51             :   uchar const                          commitments [ 32 ],
      52             :   uchar const                          bit_lengths [ 1 ],
      53             :   uchar const                          batch_len,
      54         141 :   fd_merlin_transcript_t *             transcript ) {
      55             : 
      56             :   /* https://github.com/anza-xyz/agave/blob/v2.0.1/zk-sdk/src/range_proof/mod.rs#L288
      57             : 
      58             :     We need to verify a range proof, by computing a large MSM.
      59             : 
      60             :     We store points in the following array.
      61             :     Indexes are the common example of u128 batch range proof with batch_len==4,
      62             :     used in SPL confidential transfers.
      63             : 
      64             :            points
      65             :       0    G
      66             :       1    H
      67             :       2    S
      68             :       3    T_1
      69             :       4    T_2
      70             :       5    commitments[ 0 ]
      71             :            ...
      72             :       8    commitments[ 3 ]    // 4 == batch_len (example)
      73             :       9    L_vec[ 0 ]
      74             :            ...
      75             :      15    L_vec[ 6 ]          // 7 == log2( 128 )
      76             :      16    R_vec[ 0 ]
      77             :            ...
      78             :      22    R_vec[ 6 ]          // 7 == log2( 128 )
      79             :      23    generators_H[ 0 ]
      80             :            ...
      81             :     150    generators_H[ 127 ] // 128 generators
      82             :     151    generators_G[ 0 ]
      83             :            ...
      84             :     278    generators_G[ 127 ] // 128 generators
      85             :     ------------------------------------------------------ MSM
      86             :            A
      87             : 
      88             :     As final check we test that the result of the MSM == -A.
      89             :     We could negate all scalars, but that'd make it more complex to debug
      90             :     against Rust rangeproofs / Solana, in case of issues, and the marginal
      91             :     cost of negating A is negligible.
      92             : 
      93             :     This implementation has a few differences compared to the Rust implementation.
      94             : 
      95             :     - We need to support batched range proofs for u64, u128 and u256.
      96             :       Rust does dynamic allocations. This implementation statically allocates
      97             :       for u256 (a total of <64kB) and dynamically handles u64 and u128.
      98             : 
      99             :     - This implementation limits memory copies.
     100             :       Input data arrives from the Solana tx in a certain order and essentially
     101             :       includes compressed points and scalars.
     102             :       We allocate enough scalars and (uncompressed) points for the MSM.
     103             :       As we parse input data, we compute scalars and decompress points
     104             :       directly into the memory region used by MSM (layout shown above).
     105             : 
     106             :     - Points and scalars are in a different order compared to Rust,
     107             :       but their value is the same. The order has no particular meaning,
     108             :       it just seemed more convenient.
     109             : 
     110             :     - Range proof depends interally on innerproduct proof (ipp).
     111             :       ipp needs to invert logn elements (called u_i).
     112             :       range proof, in addition, needs to invert y.
     113             :       Rust uses batch inversion to invert all u_i more efficiently.
     114             :       We also include y in the batch, to save 1 inversion (~300 mul).
     115             : 
     116             :     - ipp generates n scalars s_i, from which range proof derives 2n scalars
     117             :       for generators_G and generators_H.
     118             :       The scalars for generators_G are just a rescaling of s_i,
     119             :       while the scalars for generators_H are a bit more complex.
     120             :       We store s_i in the same memory region of generators_G scalars,
     121             :       then use them to compute generators_H scalars, and finally we do
     122             :       the rescaling. This saves 8kB of stack.
     123             :   */
     124             : 
     125             :   /* Capital LOGN, N are used to allocate memory.
     126             :      Lowercase logn, n are used at runtime.
     127             :      This implementation allocates memory to support u256, and
     128             :      at runtime can verify u64, u128 and u256 range proofs. */
     129         141 : #define LOGN 8
     130         141 : #define N (1 << LOGN)
     131         141 : #define MAX (2*N + 2*LOGN + 5 + FD_RANGEPROOFS_MAX_COMMITMENTS)
     132             : 
     133         141 :   const ulong logn = ipp_proof->logn;
     134         141 :   const ulong n = 1UL << logn;
     135             : 
     136             :   /* https://github.com/anza-xyz/agave/blob/v2.0.1/zk-sdk/src/range_proof/mod.rs#L294-L306
     137             :      total bit length (nm) should be a power of 2, and <= 256 == size of our generators table. */
     138         141 :   ulong nm = 0;
     139        1269 :   for( uchar i=0; i<batch_len; i++ ) {
     140        1128 :     if( FD_UNLIKELY( batched_range_proof_validate_bits( bit_lengths[i] ) != FD_RANGEPROOFS_SUCCESS ) ) {
     141           0 :       return FD_RANGEPROOFS_ERROR;
     142           0 :     }
     143        1128 :     nm += bit_lengths[i];
     144        1128 :   }
     145         141 :   if( FD_UNLIKELY( nm != n ) ) {
     146           0 :     return FD_RANGEPROOFS_ERROR;
     147           0 :   }
     148             : 
     149             :   /* Validate all inputs */
     150         141 :   uchar scalars[ MAX*32 ];
     151         141 :   fd_ristretto255_point_t points[ MAX ];
     152         141 :   fd_ristretto255_point_t a_res[ 1 ];
     153         141 :   fd_ristretto255_point_t res[ 1 ];
     154             : 
     155         141 :   if( FD_UNLIKELY( fd_curve25519_scalar_validate( range_proof->tx )==NULL ) ) {
     156           0 :     return FD_RANGEPROOFS_ERROR;
     157           0 :   }
     158         141 :   if( FD_UNLIKELY( fd_curve25519_scalar_validate( range_proof->tx_blinding )==NULL ) ) {
     159           0 :     return FD_RANGEPROOFS_ERROR;
     160           0 :   }
     161         141 :   if( FD_UNLIKELY( fd_curve25519_scalar_validate( range_proof->e_blinding )==NULL ) ) {
     162           0 :     return FD_RANGEPROOFS_ERROR;
     163           0 :   }
     164         141 :   if( FD_UNLIKELY( fd_curve25519_scalar_validate( ipp_proof->a )==NULL ) ) {
     165           0 :     return FD_RANGEPROOFS_ERROR;
     166           0 :   }
     167         141 :   if( FD_UNLIKELY( fd_curve25519_scalar_validate( ipp_proof->b )==NULL ) ) {
     168           0 :     return FD_RANGEPROOFS_ERROR;
     169           0 :   }
     170             : 
     171         141 :   fd_ristretto255_point_set( &points[0], fd_rangeproofs_basepoint_G );
     172         141 :   fd_ristretto255_point_set( &points[1], fd_rangeproofs_basepoint_H );
     173         141 :   if( FD_UNLIKELY( fd_ristretto255_point_decompress( a_res, range_proof->a )==NULL ) ) {
     174           0 :     return FD_RANGEPROOFS_ERROR;
     175           0 :   }
     176         141 :   if( FD_UNLIKELY( fd_ristretto255_point_decompress( &points[2], range_proof->s )==NULL ) ) {
     177           0 :     return FD_RANGEPROOFS_ERROR;
     178           0 :   }
     179         141 :   if( FD_UNLIKELY( fd_ristretto255_point_decompress( &points[3], range_proof->t1 )==NULL ) ) {
     180           0 :     return FD_RANGEPROOFS_ERROR;
     181           0 :   }
     182         141 :   if( FD_UNLIKELY( fd_ristretto255_point_decompress( &points[4], range_proof->t2 )==NULL ) ) {
     183           0 :     return FD_RANGEPROOFS_ERROR;
     184           0 :   }
     185         141 :   ulong idx = 5;
     186        1125 :   for( ulong i=0; i<batch_len; i++, idx++ ) {
     187        1002 :     if( FD_UNLIKELY( fd_ristretto255_point_decompress( &points[ idx ], &commitments[ i*32 ] )==NULL ) ) {
     188          18 :       return FD_RANGEPROOFS_ERROR;
     189          18 :     }
     190        1002 :   }
     191         978 :   for( ulong i=0; i<logn; i++, idx++ ) {
     192         855 :     if( FD_UNLIKELY( fd_ristretto255_point_decompress( &points[ idx ], ipp_proof->vecs[ i ].l )==NULL ) ) {
     193           0 :       return FD_RANGEPROOFS_ERROR;
     194           0 :     }
     195         855 :   }
     196         978 :   for( ulong i=0; i<logn; i++, idx++ ) {
     197         855 :     if( FD_UNLIKELY( fd_ristretto255_point_decompress( &points[ idx ], ipp_proof->vecs[ i ].r )==NULL ) ) {
     198           0 :       return FD_RANGEPROOFS_ERROR;
     199           0 :     }
     200         855 :   }
     201         123 :   fd_memcpy( &points[ idx ],   fd_rangeproofs_generators_H, n*sizeof(fd_ristretto255_point_t) );
     202         123 :   fd_memcpy( &points[ idx+n ], fd_rangeproofs_generators_G, n*sizeof(fd_ristretto255_point_t) );
     203             : 
     204             :   /* Finalize transcript and extract challenges */
     205         123 :   int val = FD_TRANSCRIPT_SUCCESS;
     206         123 :   fd_rangeproofs_transcript_domsep_range_proof( transcript, nm );
     207             : 
     208         123 :   val |= fd_rangeproofs_transcript_validate_and_append_point( transcript, FD_TRANSCRIPT_LITERAL("A"), range_proof->a);
     209         123 :   val |= fd_rangeproofs_transcript_validate_and_append_point( transcript, FD_TRANSCRIPT_LITERAL("S"), range_proof->s);
     210             : 
     211         123 :   uchar batchinv_in [ 32*(1+LOGN) ];
     212         123 :   uchar batchinv_out[ 32*(1+LOGN) ];
     213         123 :   uchar allinv[ 32 ];
     214         123 :   uchar *y = batchinv_in;
     215         123 :   uchar *y_inv = batchinv_out;
     216         123 :   uchar z[ 32 ];
     217         123 :   fd_rangeproofs_transcript_challenge_scalar( y, transcript, FD_TRANSCRIPT_LITERAL("y") );
     218         123 :   fd_rangeproofs_transcript_challenge_scalar( z, transcript, FD_TRANSCRIPT_LITERAL("z") );
     219             : 
     220         123 :   val |= fd_rangeproofs_transcript_validate_and_append_point( transcript, FD_TRANSCRIPT_LITERAL("T_1"), range_proof->t1);
     221         123 :   val |= fd_rangeproofs_transcript_validate_and_append_point( transcript, FD_TRANSCRIPT_LITERAL("T_2"), range_proof->t2);
     222         123 :   if( FD_UNLIKELY( val != FD_TRANSCRIPT_SUCCESS ) ) {
     223           0 :     return FD_RANGEPROOFS_ERROR;
     224           0 :   }
     225             : 
     226         123 :   uchar x[ 32 ];
     227         123 :   fd_rangeproofs_transcript_challenge_scalar( x, transcript, FD_TRANSCRIPT_LITERAL("x") );
     228             : 
     229         123 :   fd_rangeproofs_transcript_append_scalar( transcript, FD_TRANSCRIPT_LITERAL("t_x"), range_proof->tx);
     230         123 :   fd_rangeproofs_transcript_append_scalar( transcript, FD_TRANSCRIPT_LITERAL("t_x_blinding"), range_proof->tx_blinding);
     231         123 :   fd_rangeproofs_transcript_append_scalar( transcript, FD_TRANSCRIPT_LITERAL("e_blinding"), range_proof->e_blinding);
     232             : 
     233         123 :   uchar w[ 32 ];
     234         123 :   uchar c[ 32 ];
     235         123 :   fd_rangeproofs_transcript_challenge_scalar( w, transcript, FD_TRANSCRIPT_LITERAL("w") );
     236         123 :   fd_rangeproofs_transcript_challenge_scalar( c, transcript, FD_TRANSCRIPT_LITERAL("c") );
     237             : 
     238             :   /* Inner Product (sub)Proof */
     239         123 :   fd_rangeproofs_transcript_domsep_inner_product( transcript, nm );
     240             : 
     241         123 :   uchar *u =     &batchinv_in [ 32 ]; // skip y
     242         123 :   uchar *u_inv = &batchinv_out[ 32 ]; // skip y_inv
     243         978 :   for( ulong i=0; i<logn; i++ ) {
     244         855 :     val |= fd_rangeproofs_transcript_validate_and_append_point( transcript, FD_TRANSCRIPT_LITERAL("L"), ipp_proof->vecs[ i ].l);
     245         855 :     val |= fd_rangeproofs_transcript_validate_and_append_point( transcript, FD_TRANSCRIPT_LITERAL("R"), ipp_proof->vecs[ i ].r);
     246         855 :     if( FD_UNLIKELY( val != FD_TRANSCRIPT_SUCCESS ) ) {
     247           0 :       return FD_RANGEPROOFS_ERROR;
     248           0 :     }
     249         855 :     fd_rangeproofs_transcript_challenge_scalar( &u[ i*32 ], transcript, FD_TRANSCRIPT_LITERAL("u") );
     250         855 :   }
     251         123 :   fd_curve25519_scalar_batch_inv( batchinv_out, allinv, batchinv_in, logn+1 );
     252             : 
     253             :   /* Compute scalars */
     254             : 
     255             :   // H: - ( eb + c t_xb )
     256         123 :   uchar const *eb = range_proof->e_blinding;
     257         123 :   uchar const *txb = range_proof->tx_blinding;
     258         123 :   fd_curve25519_scalar_muladd( &scalars[ 1*32 ], c, txb, eb );
     259         123 :   fd_curve25519_scalar_neg(    &scalars[ 1*32 ], &scalars[ 1*32 ] );
     260             : 
     261             :   // S:   x
     262             :   // T_1: c x
     263             :   // T_2: c x^2
     264         123 :   fd_curve25519_scalar_set(    &scalars[ 2*32 ], x );
     265         123 :   fd_curve25519_scalar_mul(    &scalars[ 3*32 ], c, x );
     266         123 :   fd_curve25519_scalar_mul(    &scalars[ 4*32 ], &scalars[ 3*32 ], x );
     267             : 
     268             :   // commitments: c z^2, c z^3 ...
     269         123 :   uchar zz[ 32 ];
     270         123 :   fd_curve25519_scalar_mul(    zz, z, z );
     271         123 :   fd_curve25519_scalar_mul(    &scalars[ 5*32 ], zz, c );
     272         123 :   idx = 6;
     273         984 :   for( ulong i=1; i<batch_len; i++, idx++ ) {
     274         861 :     fd_curve25519_scalar_mul(  &scalars[ idx*32 ], &scalars[ (idx-1)*32 ], z );
     275         861 :   }
     276             : 
     277             :   // L_vec: u0^2, u1^2...
     278             :   // R_vec: 1/u0^2, 1/u1^2...
     279         123 :   uchar *u_sq = &scalars[ idx*32 ];
     280         978 :   for( ulong i=0; i<logn; i++, idx++ ) {
     281         855 :     fd_curve25519_scalar_mul(  &scalars[ idx*32 ], &u[ i*32 ], &u[ i*32 ] );
     282         855 :   }
     283         978 :   for( ulong i=0; i<logn; i++, idx++ ) {
     284         855 :     fd_curve25519_scalar_mul(  &scalars[ idx*32 ], &u_inv[ i*32 ], &u_inv[ i*32 ] );
     285         855 :   }
     286             : 
     287             :   // s_i for generators_G, generators_H
     288         123 :   uchar *s = &scalars[ (idx+n)*32 ];
     289         123 :   fd_curve25519_scalar_mul( &s[ 0*32 ], allinv, y ); // allinv also contains 1/y
     290             :   // s[i] = s[ i-k ] * u[ k+1 ]^2   (k the "next power of 2" wrt i)
     291         978 :   for( ulong k=0; k<logn; k++ ) {
     292         855 :     ulong powk = (1UL << k);
     293       18396 :     for( ulong j=0; j<powk; j++ ) {
     294       17541 :       ulong i = powk + j;
     295       17541 :       fd_curve25519_scalar_mul( &s[ i*32 ], &s[ j*32 ], &u_sq[ (logn-1-k)*32 ] );
     296       17541 :     }
     297         855 :   }
     298             : 
     299             :   // generators_H: (-a * s_i) + (-z)
     300         123 :   uchar const *a = ipp_proof->a;
     301         123 :   uchar const *b = ipp_proof->b;
     302         123 :   uchar minus_b[ 32 ];
     303         123 :   uchar exp_z[ 32 ];
     304         123 :   uchar exp_y_inv[ 32 ];
     305         123 :   uchar z_and_2[ 32 ];
     306         123 :   fd_curve25519_scalar_neg( minus_b, b );
     307         123 :   fd_memcpy( exp_z, zz, 32 );
     308         123 :   fd_memcpy( z_and_2, exp_z, 32 );
     309         123 :   fd_memcpy( exp_y_inv, y, 32 ); //TODO: remove 2 unnecessary muls
     310       17787 :   for( ulong i=0, j=0, m=0; i<n; i++, j++, idx++ ) {
     311       17664 :     if( j == bit_lengths[m] ) {
     312         861 :       j = 0;
     313         861 :       m++;
     314         861 :       fd_curve25519_scalar_mul ( exp_z, exp_z, z );
     315         861 :       fd_memcpy( z_and_2, exp_z, 32 );
     316         861 :     }
     317       17664 :     if( j != 0 ) {
     318       16680 :       fd_curve25519_scalar_add ( z_and_2, z_and_2, z_and_2 );
     319       16680 :     }
     320       17664 :     fd_curve25519_scalar_mul   ( exp_y_inv, exp_y_inv, y_inv );
     321       17664 :     fd_curve25519_scalar_muladd( &scalars[ idx*32 ], &s[ (n-1-i)*32 ], minus_b, z_and_2 );
     322       17664 :     fd_curve25519_scalar_muladd( &scalars[ idx*32 ], &scalars[ idx*32 ], exp_y_inv, z );
     323       17664 :   }
     324             : 
     325             :   // generators_G: (-a * s_i) + (-z)
     326         123 :   uchar minus_z[ 32 ];
     327         123 :   uchar minus_a[ 32 ];
     328         123 :   fd_curve25519_scalar_neg( minus_z, z );
     329         123 :   fd_curve25519_scalar_neg( minus_a, a );
     330       17787 :   for( ulong i=0; i<n; i++, idx++ ) {
     331       17664 :     fd_curve25519_scalar_muladd( &scalars[ idx*32 ], &s[ i*32 ], minus_a, minus_z );
     332       17664 :   }
     333             : 
     334             :   // G
     335             :   // w * (self.t_x - a * b) + c * (delta(&bit_lengths, &y, &z) - self.t_x)
     336         123 :   uchar delta[ 32 ];
     337         123 :   fd_rangeproofs_delta( delta, nm, y, z, zz, bit_lengths, batch_len );
     338         123 :   fd_curve25519_scalar_muladd(  &scalars[ 0 ], minus_a, b, range_proof->tx );
     339         123 :   fd_curve25519_scalar_sub(     delta, delta, range_proof->tx );
     340         123 :   fd_curve25519_scalar_mul(     delta, delta, c );
     341         123 :   fd_curve25519_scalar_muladd(  &scalars[ 0 ], &scalars[ 0 ], w, delta );
     342             : 
     343             :   /* Compute the final MSM */
     344         123 :   fd_ristretto255_multi_scalar_mul( res, scalars, points, idx );
     345             : 
     346         123 :   if( FD_LIKELY( fd_ristretto255_point_eq_neg( res, a_res ) ) ) {
     347          69 :     return FD_RANGEPROOFS_SUCCESS;
     348          69 :   }
     349             : 
     350          54 : #undef LOGN
     351          54 : #undef N
     352          54 : #undef MAX
     353          54 :   return FD_RANGEPROOFS_ERROR;
     354         123 : }

Generated by: LCOV version 1.14