remainder/utils/
mle.rs

1//! Module for generating and manipulating mles.
2
3use std::iter::repeat_with;
4
5use itertools::{repeat_n, FoldWhile, Itertools};
6use rand::Rng;
7use rayon::prelude::*;
8use shared_types::Field;
9
10use crate::{
11    claims::RawClaim,
12    layer::LayerId,
13    mle::{betavalues::BetaValues, dense::DenseMle, evals::MultilinearExtension, MleIndex},
14};
15
16/// Return a vector containing a padded version of the input data, with the
17/// padding value at the end of the vector, such that the length is
18/// `data.len().next_power_of_two()`.  This is a no-op if the length is already
19/// a power of two.
20/// # Examples:
21/// ```
22/// use remainder::utils::mle::pad_with;
23/// let data = vec![1, 2, 3];
24/// let padded_data = pad_with(0, &data);
25/// assert_eq!(padded_data, vec![1, 2, 3, 0]);
26/// assert_eq!(pad_with(0, &padded_data), vec![1, 2, 3, 0]); // length is already a power of two.
27/// ```
28pub fn pad_with<F: Clone>(padding_value: F, data: &[F]) -> Vec<F> {
29    let padded_length = data.len().checked_next_power_of_two().unwrap();
30    let mut padded_data = Vec::with_capacity(padded_length);
31    padded_data.extend_from_slice(data);
32    padded_data.extend(std::iter::repeat_n(
33        padding_value,
34        padded_length - data.len(),
35    ));
36    padded_data
37}
38
39/// Returns the argsort (i.e. indices) of the given vector slice.
40///
41/// ## Example:
42/// ```
43/// use remainder::utils::mle::argsort;
44/// let data = vec![3, 1, 4, 1, 5, 9, 2];
45///
46/// // Ascending order
47/// let indices = argsort(&data, false);
48/// assert_eq!(indices, vec![1, 3, 6, 0, 2, 4, 5]);
49/// ```
50pub fn argsort<T: Ord>(slice: &[T], invert: bool) -> Vec<usize> {
51    let mut indices: Vec<usize> = (0..slice.len()).collect();
52
53    indices.sort_by(|&i, &j| {
54        if invert {
55            slice[j].cmp(&slice[i])
56        } else {
57            slice[i].cmp(&slice[j])
58        }
59    });
60
61    indices
62}
63
64/// Helper function to create random MLE with specific number of vars
65pub fn get_random_mle<F: Field>(num_vars: usize, rng: &mut impl Rng) -> DenseMle<F> {
66    let capacity = 2_u32.pow(num_vars as u32);
67    let bookkeeping_table = repeat_with(|| F::from(rng.gen::<u64>()) * F::from(rng.gen::<u64>()))
68        .take(capacity as usize)
69        .collect_vec();
70    DenseMle::new_from_raw(bookkeeping_table, LayerId::Input(0))
71}
72
73/// Helper function to create random MLE with specific number of vars
74pub fn get_random_mle_from_capacity<F: Field>(capacity: usize, rng: &mut impl Rng) -> DenseMle<F> {
75    let bookkeeping_table = repeat_with(|| F::from(rng.gen::<u64>()) * F::from(rng.gen::<u64>()))
76        .take(capacity)
77        .collect_vec();
78    DenseMle::new_from_raw(bookkeeping_table, LayerId::Input(0))
79}
80
81/// Returns a vector of MLEs for dataparallel testing according to the number of
82/// variables and number of dataparallel bits.
83pub fn get_dummy_random_mle_vec<F: Field>(
84    num_vars: usize,
85    num_dataparallel_bits: usize,
86    rng: &mut impl Rng,
87) -> Vec<DenseMle<F>> {
88    (0..(1 << num_dataparallel_bits))
89        .map(|_| {
90            let mle_vec = (0..(1 << num_vars))
91                .map(|_| F::from(rng.gen::<u64>()))
92                .collect_vec();
93            DenseMle::new_from_raw(mle_vec, LayerId::Input(0))
94        })
95        .collect_vec()
96}
97
98/// Returns the specific bit decomp for a given index, using `num_bits` bits.
99/// Note that this returns the decomposition in BIG ENDIAN!
100pub fn get_mle_idx_decomp_for_idx<F: Field>(idx: usize, num_bits: usize) -> Vec<MleIndex<F>> {
101    (0..(num_bits))
102        .rev()
103        .map(|cur_num_bits| {
104            let is_one =
105                (idx % 2_usize.pow(cur_num_bits as u32 + 1)) >= 2_usize.pow(cur_num_bits as u32);
106            MleIndex::Fixed(is_one)
107        })
108        .collect_vec()
109}
110
111/// Returns the total MLE indices given a `Vec<bool>`. for the prefix bits and
112/// then the number of free bits after.
113pub fn get_total_mle_indices<F: Field>(
114    prefix_bits: &[bool],
115    num_free_bits: usize,
116) -> Vec<MleIndex<F>> {
117    prefix_bits
118        .iter()
119        .map(|bit| MleIndex::Fixed(*bit))
120        .chain(repeat_n(MleIndex::Free, num_free_bits))
121        .collect()
122}
123
124/// Construct a parent MLE for the given MLEs and prefix bits, where the prefix
125/// bits of each MLE specify how it should be inserted into the parent. Entries
126/// left unspecified are filled with `F::ZERO`.
127/// # Requires:
128/// * that the number of variables in each MLE, plus the number of its prefix
129///   bits, is the same across all pairs; this will be the number of variables
130///   in the returned MLE.
131/// * the slice is non-empty.
132///
133/// # Example:
134/// ```
135/// use remainder::utils::mle::build_composite_mle;
136/// use remainder::mle::evals::MultilinearExtension;
137/// use shared_types::Fr;
138/// use itertools::{Itertools};
139/// let mle1 = MultilinearExtension::new(vec![Fr::from(2)]);
140/// let mle2 = MultilinearExtension::new(vec![Fr::from(1), Fr::from(3)]);
141/// let result = build_composite_mle(&[(&mle1, vec![false, true]), (&mle2, vec![true])]);
142/// assert_eq!(*result.f.iter().collect_vec().clone(), vec![Fr::from(0), Fr::from(2), Fr::from(1), Fr::from(3)]);
143/// ```
144pub fn build_composite_mle<F: Field>(
145    mles: &[(&MultilinearExtension<F>, Vec<bool>)],
146) -> MultilinearExtension<F> {
147    assert!(!mles.is_empty());
148    let out_num_vars = mles[0].0.num_vars() + mles[0].1.len();
149    // Check that all (MLE, prefix bit) pairs require the same total number of
150    // variables.
151
152    mles.iter().for_each(|(mle, prefix_bits)| {
153        assert_eq!(mle.num_vars() + prefix_bits.len(), out_num_vars);
154    });
155    let mut out = vec![F::ZERO; 1 << out_num_vars];
156    for (mle, prefix_bits) in mles {
157        let mut current_window = 1 << out_num_vars;
158        let starting_index = prefix_bits.iter().fold(0, |acc_index, bit| {
159            let starting_index_acc = if *bit {
160                acc_index + current_window / 2
161            } else {
162                acc_index
163            };
164            current_window /= 2;
165            starting_index_acc
166        });
167        // REMARK: Modifying this check so that the input mles can be non-
168        // powers of 2.
169        assert_eq!(current_window, mle.len().next_power_of_two());
170        (starting_index..(starting_index + current_window))
171            .enumerate()
172            .for_each(|(mle_idx, out_idx)| {
173                out[out_idx] = mle.get(mle_idx).unwrap_or(F::ZERO);
174            });
175    }
176    MultilinearExtension::new(out)
177}
178
179/// Verifies a claim by evaluating the MLE at the challenge point and checking
180/// that the result.
181pub fn verify_claim<F: Field>(mle_unpadded_evaluations: &[F], claim: &RawClaim<F>) {
182    let mle_evaluations = claim
183        .get_point()
184        .iter()
185        .fold_while(mle_unpadded_evaluations, |acc, elem| {
186            if elem == &F::ZERO {
187                let sliced_acc = &acc[..(acc.len() / 2)];
188                FoldWhile::Continue(sliced_acc)
189            } else if elem == &F::ONE {
190                let sliced_acc = &acc[(acc.len() / 2)..];
191                FoldWhile::Continue(sliced_acc)
192            } else {
193                FoldWhile::Done(acc)
194            }
195        })
196        .into_inner();
197    let filtered_claim = claim
198        .get_point()
199        .iter()
200        .skip_while(|x| x == &&F::ZERO || x == &&F::ONE)
201        .copied()
202        .collect_vec();
203    let mle = MultilinearExtension::new(mle_evaluations.to_vec());
204    assert_eq!(mle.num_vars(), filtered_claim.len());
205    let eval = evaluate_mle_at_a_point_gray_codes(&mle, &filtered_claim);
206    assert_eq!(eval, claim.get_eval());
207}
208
209/// A struct representing an iterator that iterates through the range
210/// (1..2^{`num_bits`}) but in the ordering of a Gray Code, which means that the
211/// Hamming distance between the bit representation of any consecutive indices
212/// is only 1.
213///
214/// The iterator is of the type (u32, (u32, bool)) which represents: (index,
215/// (index of the bit that was flipped, the previous value of the flipped bit.))
216
217#[derive(Debug)]
218pub struct GrayCodeIterator {
219    num_bits: usize,
220    current_iteration: u32,
221    end_iteration: Option<u32>,
222}
223
224impl GrayCodeIterator {
225    /// Note: `num_bits` cannot be more than 31 because we work with u32s in
226    /// this iterator.
227    pub fn new(num_bits: usize) -> Self {
228        assert!(num_bits < 32);
229        Self {
230            num_bits,
231            current_iteration: 0,
232            end_iteration: None,
233        }
234    }
235
236    pub(crate) fn new_at_index(
237        num_bits: usize,
238        current_iteration: u32,
239        end_iteration: Option<u32>,
240    ) -> Self {
241        Self {
242            num_bits,
243            current_iteration,
244            end_iteration,
245        }
246    }
247
248    pub(crate) fn get_gray_index(num_bits: usize, index: u32) -> u32 {
249        let mask = (1 << num_bits) - 1;
250
251        (index ^ (index >> 1)) & mask
252    }
253}
254
255impl Iterator for GrayCodeIterator {
256    type Item = (u32, (u32, bool));
257
258    fn next(&mut self) -> Option<Self::Item> {
259        if self.current_iteration >= ((1 << self.num_bits) - 1) {
260            return None;
261        }
262
263        if self.end_iteration.is_some()
264            && self.current_iteration >= (self.end_iteration.unwrap() - 1)
265        {
266            return None;
267        }
268
269        if self.end_iteration.is_some()
270            && self.current_iteration >= (self.end_iteration.unwrap() - 1)
271        {
272            return None;
273        }
274
275        // Mask current value to ensure we only get num_bits number of bits per
276        // result.
277        let mask = (1 << self.num_bits) - 1;
278
279        // Because we don't store the previous value, just calculate it using
280        // the current_val that's stored.
281        let prev_gray = (self.current_iteration ^ (self.current_iteration >> 1)) & mask;
282        // The next value is simply XOR of the counter incremented by 1 and
283        // itself right-shifted. (source: algorithm in Wikipedia and ChatGPT
284        // hehe)
285        self.current_iteration += 1;
286        let new_gray = (self.current_iteration ^ (self.current_iteration >> 1)) & mask;
287
288        // Internally, the bits are stored in little-endian. NOTE: Our
289        // bookkeeping tables are in "big-endian", so we need to take this into
290        // account when evaluating an MLE.
291        Some((
292            new_gray,
293            compute_flipped_bit_idx_and_value_graycode(prev_gray, new_gray),
294        ))
295    }
296}
297
298/// A struct representing lexicographic bit-order in little-endian.
299///
300/// The iterator returns elements of the form (u32, Vec<(u32, bool)>) where the
301/// first u32 is the current index, and the Vec<(u32, bool)> represents the bits
302/// that flipped and what they used to be. As opposed to [GrayCodeIterator],
303/// these don't have Hamming distance 1, so the flipped bits go in a Vec.
304pub struct LexicographicLE {
305    num_bits: usize,
306    current_val: u32,
307}
308
309impl LexicographicLE {
310    fn new(num_bits: usize) -> Self {
311        Self {
312            num_bits,
313            current_val: 0,
314        }
315    }
316}
317
318impl Iterator for LexicographicLE {
319    type Item = (u32, Vec<(u32, bool)>);
320
321    fn next(&mut self) -> Option<Self::Item> {
322        if self.current_val >= ((1 << self.num_bits) - 1) {
323            return None;
324        }
325
326        let prev_val = self.current_val;
327        self.current_val += 1;
328
329        let flipped_bit_idx_and_values =
330            compute_flipped_bit_idx_and_values_lexicographic(prev_val, self.current_val);
331
332        Some((self.current_val, flipped_bit_idx_and_values))
333    }
334}
335
336/// Compute the single flipped bit and its previous value for the gray codes
337/// iterator.
338pub fn compute_flipped_bit_idx_and_value_graycode(curr_val: u32, next_val: u32) -> (u32, bool) {
339    let flipped_bit = (curr_val ^ next_val).trailing_zeros();
340    let previous_value = (curr_val & (1 << flipped_bit)) != 0;
341    (flipped_bit, previous_value)
342}
343
344/// Compute the inverses and one minus the elem inverted for a vec of claim challenges.
345pub fn compute_inverses_vec_and_one_minus_inverted_vec<F: Field>(
346    claim_points: &[&[F]],
347) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
348    let inverses_vec = claim_points
349        .iter()
350        .map(|claim_point| {
351            claim_point
352                .iter()
353                .map(|elem| elem.invert().unwrap())
354                .collect_vec()
355        })
356        .collect_vec();
357    let one_minus_inverses_vec = claim_points
358        .iter()
359        .map(|claim_point| {
360            claim_point
361                .iter()
362                .map(|elem| (F::ONE - elem).invert().unwrap())
363                .collect_vec()
364        })
365        .collect_vec();
366    (inverses_vec, one_minus_inverses_vec)
367}
368
369/// Compute the flipped bits between a previous value and a current value, and
370/// return each of the flipped bits' indices and previous value.
371pub fn compute_flipped_bit_idx_and_values_lexicographic(
372    curr_val: u32,
373    next_val: u32,
374) -> Vec<(u32, bool)> {
375    let flipped_bits = curr_val ^ next_val;
376    let mut flipped_bit_idx_and_values = Vec::<(u32, bool)>::new();
377    (0..32).for_each(|idx| {
378        if (flipped_bits & (1 << idx)) != 0 {
379            // NOTE: Our bookkeeping tables are in "big-endian", so we need
380            // to take this into account when evaluating an MLE.
381            flipped_bit_idx_and_values.push((idx, (curr_val & (1 << idx)) != 0))
382        }
383    });
384    flipped_bit_idx_and_values
385}
386
387/// Compute the next beta values from the previous by multiplying by appropriate
388/// inverses and challenge points given the flipped bits and their previous
389/// values. This is for when we have multiple claims to compute the beta over.
390pub fn compute_next_beta_values_vec_from_current<F: Field>(
391    current_beta_values: &[F],
392    inverses_vec: &[Vec<F>],
393    one_minus_elem_inverted_vec: &[Vec<F>],
394    claim_points: &[&[F]],
395    flipped_bit_idx_and_values: &[(u32, bool)],
396) -> Vec<F> {
397    current_beta_values
398        .iter()
399        .zip(inverses_vec.iter().zip(one_minus_elem_inverted_vec))
400        .zip(claim_points)
401        .map(
402            |((current_beta_value, (inverses, one_minus_inverses)), claim_point)| {
403                compute_next_beta_value_from_current(
404                    current_beta_value,
405                    inverses,
406                    one_minus_inverses,
407                    claim_point,
408                    flipped_bit_idx_and_values,
409                )
410            },
411        )
412        .collect_vec()
413}
414
415/// Compute the next beta value from the previous by multiplying by appropriate
416/// inverses and challenge points given the flipped bits and their previous
417/// values. This is for when we have a single claim to compute the beta over.
418pub fn compute_next_beta_value_from_current<F: Field>(
419    current_beta_value: &F,
420    inverses: &[F],
421    one_minus_elem_inverted: &[F],
422    claim_point: &[F],
423    flipped_bit_idx_and_values: &[(u32, bool)],
424) -> F {
425    let n = claim_point.len();
426    flipped_bit_idx_and_values.iter().fold(
427        // For each of the flipped bits, multiply by the
428        // appropriate inverse depending on if the value was
429        // previously 0 or 1.
430        *current_beta_value,
431        |acc, (idx, value)| {
432            if *value {
433                acc * inverses[n - 1 - *idx as usize]
434                    * (F::ONE - claim_point[n - 1 - *idx as usize])
435            } else {
436                acc * (one_minus_elem_inverted[n - 1 - *idx as usize])
437                    * claim_point[n - 1 - *idx as usize]
438            }
439        },
440    )
441}
442
443/// This function non-destructively evaluates an MLE at a given point using the
444/// [LexicographicLE] iterator.
445pub fn evaluate_mle_at_a_point_lexicographic_order<F: Field>(
446    mle: &MultilinearExtension<F>,
447    point: &[F],
448) -> F {
449    let n = point.len();
450    let mle_num_vars = mle.num_vars();
451    assert_eq!(n, mle_num_vars);
452
453    let starting_beta_value =
454        BetaValues::compute_beta_over_two_challenges(&vec![F::ZERO; mle_num_vars], point);
455
456    let starting_evaluation_acc = starting_beta_value * mle.first();
457    let lexicographic_le = LexicographicLE::new(mle_num_vars);
458    let inverses = point
459        .iter()
460        .map(|elem| elem.invert().unwrap())
461        .collect_vec();
462    let one_minus_inverses = point
463        .iter()
464        .map(|elem| (F::ONE - elem).invert().unwrap())
465        .collect_vec();
466
467    let (_final_beta_value, evaluation) = lexicographic_le.fold(
468        (starting_beta_value, starting_evaluation_acc),
469        |(prev_beta_value, evaluation_acc), (index, flipped_bit_indices_and_values)| {
470            let next_beta_value = flipped_bit_indices_and_values.iter().fold(
471                prev_beta_value,
472                |acc, (flipped_bit_index, flipped_bit_value)| {
473                    // For every bit i that is flipped, if it used to be a 1,
474                    // then we multiply by r_i^{-1} and multiply by (1 - r_i) to
475                    // account for this bit flip. NOTE: we subtract from n - 1
476                    // to account for the fact that internally, these u32s are
477                    // stored in little endian, but our bookkeeping tables are
478                    // stored in "big endian" indexing.
479                    if *flipped_bit_value {
480                        acc * inverses[n - 1 - *flipped_bit_index as usize]
481                            * (F::ONE - point[n - 1 - *flipped_bit_index as usize])
482                    }
483                    // For every bit i that is flipped, if it used to be a 0,
484                    // then we multiply by (1 - r_i)^{-1} and multiply by r_i to
485                    // account for this bit flip.
486                    else {
487                        acc * (one_minus_inverses[n - 1 - *flipped_bit_index as usize])
488                            * point[n - 1 - *flipped_bit_index as usize]
489                    }
490                },
491            );
492
493            // Multiply this by the appropriate MLE coefficient.
494            let next_evaluation_acc = next_beta_value * mle.get(index as usize).unwrap();
495            (next_beta_value, evaluation_acc + next_evaluation_acc)
496        },
497    );
498    evaluation
499}
500
501/// This function non-destructively evaluates an MLE at a given point using the
502/// gray codes iterator. Optimized version that uses 2 multiplications instead
503/// of 1.
504///
505/// Currently does not support for when the value in the point is either 0 or 1.
506pub fn evaluate_mle_at_a_point_gray_codes<F: Field>(
507    mle: &MultilinearExtension<F>,
508    point: &[F],
509) -> F {
510    let n = point.len();
511    let mle_num_vars = mle.num_vars();
512    assert_eq!(n, mle_num_vars);
513    // The gray codes start at index 1, so we start with the first value which
514    // is \widetilde{\beta}(\vec{0}, point).
515    let starting_beta_value =
516        BetaValues::compute_beta_over_two_challenges(&vec![F::ZERO; mle_num_vars], point);
517    // This is the value that gets multiplied to the first MLE coefficient,
518    // which is (1 - r_1) * (1 - r_2) * ... * (1 - r_n) where (r_1, ..., r_n) is
519    // the point.
520    let starting_evaluation_acc = starting_beta_value * mle.first();
521    let gray_code = GrayCodeIterator::new(mle_num_vars);
522    let inverses = point
523        .iter()
524        .map(|elem| elem.invert().unwrap())
525        .collect_vec();
526    let one_minus_inverses = point
527        .iter()
528        .map(|elem| (F::ONE - elem).invert().unwrap())
529        .collect_vec();
530
531    // We simply compute the correct inverse and new multiplicative term for
532    // each bit that is flipped in the beta value, and accumulate these by doing
533    // an element-wise multiplication with the correct index of the MLE
534    // coefficients. We precompute these so that during the scanning we only
535    // need to do one multiplication instead of two
536    let multiplier_if_flipped_bit_is_one = inverses
537        .iter()
538        .zip(point.iter())
539        .map(|(inverse, point_elem)| *inverse * (F::ONE - point_elem))
540        .collect_vec();
541    let multiplier_if_flipped_bit_is_zero = one_minus_inverses
542        .iter()
543        .zip(point.iter())
544        .map(|(one_minus_inverse, point_elem)| *one_minus_inverse * point_elem)
545        .collect_vec();
546
547    let (_final_beta_value, evaluation) = gray_code.fold(
548        (starting_beta_value, starting_evaluation_acc),
549        |(prev_beta_value, evaluation_acc), (index, (flipped_bit_index, flipped_bit_value))| {
550            // For every bit i that is flipped, if it used to be a 1, then we
551            // multiply by r_i^{-1} and multiply by (1 - r_i) to account for
552            // this bit flip. NOTE: we subtract from n - 1 to account for the
553            // fact that internally, these u32s are stored in little endian, but
554            // our bookkeeping tables are stored in "big endian" indexing.
555            let next_beta_value = if flipped_bit_value {
556                prev_beta_value
557                    * multiplier_if_flipped_bit_is_one[n - 1 - flipped_bit_index as usize]
558            }
559            // For every bit i that is flipped, if it used to be a 0, then we
560            // multiply by (1 - r_i)^{-1} and multiply by r_i to account for
561            // this bit flip.
562            else {
563                prev_beta_value
564                    * multiplier_if_flipped_bit_is_zero[n - 1 - flipped_bit_index as usize]
565            };
566            // Multiply this by the appropriate MLE coefficient.
567            let next_evaluation_acc = next_beta_value * mle.get(index as usize).unwrap();
568            (next_beta_value, evaluation_acc + next_evaluation_acc)
569        },
570    );
571    evaluation
572}
573
574/// This function non-destructively evaluates an MLE at a given point using the
575/// gray codes iterator. Parallelized version that uses K threads.
576pub fn evaluate_mle_at_a_point_gray_codes_parallel<F: Field, const K: usize>(
577    mle: &MultilinearExtension<F>,
578    point: &[F],
579) -> F {
580    let n = point.len();
581    let mle_num_vars = mle.num_vars();
582    assert_eq!(n, mle_num_vars);
583    assert!(
584        (1 << mle_num_vars) >= K,
585        "cannot have more partitions than the length of MLE"
586    );
587
588    let starting_indices = (0..K)
589        .map(|partition| partition * ((1 << mle_num_vars) / K))
590        .collect_vec();
591    let starting_gray_code_indices = starting_indices
592        .iter()
593        .map(|idx| GrayCodeIterator::get_gray_index(mle_num_vars, *idx as u32))
594        .collect_vec();
595    let starting_beta_values = starting_gray_code_indices
596        .iter()
597        .map(|gray_code| {
598            BetaValues::compute_beta_over_challenge_and_index(point, *gray_code as usize)
599        })
600        .collect_vec();
601
602    // This is the value that gets multiplied to the first MLE coefficient,
603    // which is (1 - r_1) * (1 - r_2) * ... * (1 - r_n) where (r_1, ..., r_n) is
604    // the point.
605    let starting_evaluation_accs = starting_beta_values
606        .iter()
607        .zip(starting_gray_code_indices.iter())
608        .map(|(beta_value, gray_code)| *beta_value * mle.get(*gray_code as usize).unwrap())
609        .collect_vec();
610
611    let gray_codes = starting_indices
612        .iter()
613        .enumerate()
614        .map(|(partition, &starting_index)| {
615            let end_iteration = if partition == K - 1 {
616                None
617            } else {
618                Some(starting_indices[partition + 1] as u32)
619            };
620            GrayCodeIterator::new_at_index(mle_num_vars, starting_index as u32, end_iteration)
621        })
622        .collect_vec();
623
624    let inverses = point
625        .iter()
626        .map(|elem| elem.invert().unwrap())
627        .collect_vec();
628
629    let one_minus_inverses = point
630        .iter()
631        .map(|elem| (F::ONE - elem).invert().unwrap())
632        .collect_vec();
633
634    // We simply compute the correct inverse and new multiplicative term for
635    // each bit that is flipped in the beta value, and accumulate these by doing
636    // an element-wise multiplication with the correct index of the MLE
637    // coefficients. We precompute these so that during the scanning we only
638    // need to do one multiplication instead of two
639    let multiplier_if_flipped_bit_is_one = inverses
640        .iter()
641        .zip(point.iter())
642        .map(|(inverse, point_elem)| *inverse * (F::ONE - point_elem))
643        .collect_vec();
644
645    let multiplier_if_flipped_bit_is_zero = one_minus_inverses
646        .iter()
647        .zip(point.iter())
648        .map(|(one_minus_inverse, point_elem)| *one_minus_inverse * point_elem)
649        .collect_vec();
650
651    (0..K)
652        .into_par_iter()
653        .zip(gray_codes.into_par_iter())
654        .map(|(partition, gray_code)| {
655            let starting_beta_value = starting_beta_values[partition];
656            let starting_evaluation_acc = starting_evaluation_accs[partition];
657            let (_final_beta_value, evaluation) = gray_code.fold(
658                (starting_beta_value, starting_evaluation_acc),
659                |(prev_beta_value, evaluation_acc),
660                 (index, (flipped_bit_index, flipped_bit_value))| {
661                    // For every bit i that is flipped, if it used to be a 1,
662                    // then we multiply by r_i^{-1} and multiply by (1 - r_i) to
663                    // account for this bit flip. NOTE: we subtract from n - 1
664                    // to account for the fact that internally, these u32s are
665                    // stored in little endian, but our bookkeeping tables are
666                    // stored in "big endian" indexing.
667                    let next_beta_value = if flipped_bit_value {
668                        prev_beta_value
669                            * multiplier_if_flipped_bit_is_one[n - 1 - flipped_bit_index as usize]
670                    }
671                    // For every bit i that is flipped, if it used to be a 0,
672                    // then we multiply by (1 - r_i)^{-1} and multiply by r_i to
673                    // account for this bit flip.
674                    else {
675                        prev_beta_value
676                            * multiplier_if_flipped_bit_is_zero[n - 1 - flipped_bit_index as usize]
677                    };
678                    // Multiply this by the appropriate MLE coefficient.
679                    let next_evaluation_acc = next_beta_value * mle.get(index as usize).unwrap();
680                    (next_beta_value, evaluation_acc + next_evaluation_acc)
681                },
682            );
683            evaluation
684        })
685        .reduce(|| F::ZERO, |a, b| a + b)
686}
687
688/// Destructively evaluate an MLE at a point by using the `fix_variable`
689/// algorithm iteratively until all of the variables have been bound.
690///
691pub fn evaluate_mle_destructive<F: Field>(mle: &mut MultilinearExtension<F>, point: &[F]) -> F {
692    point.iter().for_each(|challenge| {
693        mle.fix_variable(*challenge);
694    });
695    assert!(mle.is_fully_bound());
696    mle.first()
697}
698
699#[cfg(test)]
700mod tests {
701    use ark_std::test_rng;
702    use itertools::Itertools;
703    use shared_types::{ff_field, Fr};
704
705    use crate::{
706        mle::evals::MultilinearExtension,
707        utils::mle::{
708            evaluate_mle_at_a_point_gray_codes_parallel,
709            evaluate_mle_at_a_point_lexicographic_order, evaluate_mle_destructive,
710            GrayCodeIterator,
711        },
712    };
713
714    use super::evaluate_mle_at_a_point_gray_codes;
715
716    #[test]
717    fn test_gray_code_0_vars() {
718        let mut gray_code_iterator = GrayCodeIterator::new(0);
719
720        assert_eq!(gray_code_iterator.next(), None);
721    }
722
723    #[test]
724    fn test_gray_code_iterator_len() {
725        for n in 1..16 {
726            assert_eq!(GrayCodeIterator::new(n).count(), (1 << n) - 1);
727        }
728    }
729
730    // Note that this is the only test we have here that verifies that we're
731    // actually using Gray Codes and not any of the other codes that could
732    // satisfy the properties listed on the `test_gray_code_property` test.
733    #[test]
734    fn test_gray_code_3_vars() {
735        let mut gray_code_iterator = GrayCodeIterator::new(3);
736
737        assert_eq!(gray_code_iterator.next(), Some((1, (0, false))));
738        assert_eq!(gray_code_iterator.next(), Some((3, (1, false))));
739        assert_eq!(gray_code_iterator.next(), Some((2, (0, true))));
740        assert_eq!(gray_code_iterator.next(), Some((6, (2, false))));
741        assert_eq!(gray_code_iterator.next(), Some((7, (0, false))));
742        assert_eq!(gray_code_iterator.next(), Some((5, (1, true))));
743        assert_eq!(gray_code_iterator.next(), Some((4, (0, true))));
744        assert_eq!(gray_code_iterator.next(), None);
745    }
746
747    // Property testing of `GrayCode` for values of `num_bits` of up to 15 (for
748    // efficient testing). This test ensures that:
749    //   1. The Hamming distance between consecutive codes is exactly 1.
750    //   2. The index of the flipped bit correct.
751    //   3. Each of the `2^num_codes` appear exactly once.
752    //
753    // Note that there are multiple codes that satisfy the above properties. Any
754    // code with those properties can be used for computing MLEs in linear time.
755    #[test]
756    fn test_gray_code_property() {
757        for n in 1..16 {
758            let gray_code_iterator = GrayCodeIterator::new(n);
759
760            let mut seen: Vec<bool> = vec![false; 1 << n];
761
762            // Assume we're starting from 0.
763            seen[0] = true;
764
765            gray_code_iterator.fold(0, |prev, (cur, (idx, val))| {
766                // Ensure `cur` has NOT been seen before.
767                assert!(!seen[cur as usize]);
768                seen[cur as usize] = true;
769
770                let mask: u32 = 1 << idx;
771
772                // This ensures that:
773                //   1. The Hamming Distance between `prev` and `cur` is exactly
774                //      1, because `mask` contains exactly one bit set to 1,
775                //      and,
776                //   2. The flipped bit is indeed in the `idx`-th position.
777                assert_eq!(prev ^ cur, mask);
778
779                // Ensures `val` is the previous value of the flipped bit.
780                assert_eq!((prev & mask) >> idx, val as u32);
781
782                cur
783            });
784
785            // Ensure all codes have been encountered during the iteration.
786            assert!(seen.iter().all(|x| *x))
787        }
788    }
789
790    /// there cannot be more threads than the length of the MLE
791    #[test]
792    #[should_panic]
793    fn test_evaluate_mle_at_a_point_1_variable_gray_codes_parallel_more_threads_than_mle_length() {
794        const K: usize = 3;
795        let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
796        let point = &[Fr::from(2)];
797        let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
798        let expected_evaluation =
799            evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mut mle, point);
800        assert_eq!(computed_evaluation, expected_evaluation);
801    }
802
803    #[test]
804    fn test_evaluate_mle_at_a_point_1_variable_gray_codes_parallel_1_thread() {
805        const K: usize = 1;
806        let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
807        let point = &[Fr::from(2)];
808        let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
809        let expected_evaluation =
810            evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mut mle, point);
811        assert_eq!(computed_evaluation, expected_evaluation);
812    }
813
814    #[test]
815    fn test_evaluate_mle_at_a_point_1_variable_gray_codes_parallel_2_threads() {
816        const K: usize = 2;
817        let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
818        let point = &[Fr::from(2)];
819        let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
820        let expected_evaluation =
821            evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mut mle, point);
822        assert_eq!(computed_evaluation, expected_evaluation);
823    }
824
825    #[test]
826    fn test_evaluate_mle_at_a_point_1_variable_gray_codes() {
827        let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
828        let point = &[Fr::from(2)];
829        let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
830        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
831        assert_eq!(computed_evaluation, expected_evaluation);
832    }
833
834    #[test]
835    fn test_evaluate_mle_at_a_point_2_variable_gray_codes() {
836        let mut mle: MultilinearExtension<Fr> = vec![1, 2, 1, 2].into();
837        let point = &[Fr::from(2), Fr::from(3)];
838        let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
839        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
840        assert_eq!(computed_evaluation, expected_evaluation);
841    }
842
843    #[test]
844    fn test_evaluate_mle_at_a_point_3_variable_gray_codes_random_parallel() {
845        const K: usize = 5;
846        let mut rng = test_rng();
847        let mut mle = MultilinearExtension::new((0..8).map(|_| Fr::random(&mut rng)).collect());
848        let point = &(0..3).map(|_| Fr::random(&mut rng)).collect_vec();
849        let computed_evaluation =
850            evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mle, point);
851        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
852        assert_eq!(computed_evaluation, expected_evaluation);
853    }
854
855    #[test]
856    fn test_evaluate_mle_at_a_point_3_variable_gray_codes_random() {
857        let mut rng = test_rng();
858        let mut mle = MultilinearExtension::new((0..8).map(|_| Fr::random(&mut rng)).collect());
859        let point = &(0..3).map(|_| Fr::random(&mut rng)).collect_vec();
860        let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
861        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
862        assert_eq!(computed_evaluation, expected_evaluation);
863    }
864
865    #[test]
866    fn test_evaluate_mle_at_a_point_1_variable_lexicographic() {
867        let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
868        let point = &[Fr::from(2)];
869        let computed_evaluation = evaluate_mle_at_a_point_lexicographic_order(&mle, point);
870        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
871        assert_eq!(computed_evaluation, expected_evaluation);
872    }
873
874    #[test]
875    fn test_evaluate_mle_at_a_point_2_variable_lexicographic() {
876        let mut mle: MultilinearExtension<Fr> = vec![1, 2, 1, 2].into();
877        let point = &[Fr::from(2), Fr::from(3)];
878        let computed_evaluation = evaluate_mle_at_a_point_lexicographic_order(&mle, point);
879        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
880        assert_eq!(computed_evaluation, expected_evaluation);
881    }
882
883    #[test]
884    fn test_evaluate_mle_at_a_point_3_variable_lexicographic_random() {
885        let mut rng = test_rng();
886        let mut mle = MultilinearExtension::new((0..8).map(|_| Fr::random(&mut rng)).collect());
887        let point = &(0..3).map(|_| Fr::random(&mut rng)).collect_vec();
888        let computed_evaluation = evaluate_mle_at_a_point_lexicographic_order(&mle, point);
889        let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
890        assert_eq!(computed_evaluation, expected_evaluation);
891    }
892
893    /// Ensure that all three methods of computing MLEs at a point produce the
894    /// same result for random MLEs and points of sizes up to 15 bits.
895    #[test]
896    fn test_evaluation_equivalence() {
897        for n in 1..16 {
898            let num_vars = n;
899            let num_evals = 1 << num_vars;
900
901            let mut rng = test_rng();
902            let mut mle =
903                MultilinearExtension::new((0..num_evals).map(|_| Fr::random(&mut rng)).collect());
904            let point = (0..num_vars).map(|_| Fr::random(&mut rng)).collect_vec();
905
906            let gray_code_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, &point);
907            let lexicographic_evaluation =
908                evaluate_mle_at_a_point_lexicographic_order(&mle, &point);
909            let destructive_evaluation = evaluate_mle_destructive(&mut mle, &point);
910
911            assert!(
912                gray_code_evaluation == lexicographic_evaluation
913                    && lexicographic_evaluation == destructive_evaluation
914            );
915        }
916    }
917}