1use std::iter::repeat_with;
4
5use itertools::{repeat_n, FoldWhile, Itertools};
6use rand::Rng;
7use rayon::prelude::*;
8use shared_types::Field;
9
10use crate::{
11 claims::RawClaim,
12 layer::LayerId,
13 mle::{betavalues::BetaValues, dense::DenseMle, evals::MultilinearExtension, MleIndex},
14};
15
16pub fn pad_with<F: Clone>(padding_value: F, data: &[F]) -> Vec<F> {
29 let padded_length = data.len().checked_next_power_of_two().unwrap();
30 let mut padded_data = Vec::with_capacity(padded_length);
31 padded_data.extend_from_slice(data);
32 padded_data.extend(std::iter::repeat_n(
33 padding_value,
34 padded_length - data.len(),
35 ));
36 padded_data
37}
38
39pub fn argsort<T: Ord>(slice: &[T], invert: bool) -> Vec<usize> {
51 let mut indices: Vec<usize> = (0..slice.len()).collect();
52
53 indices.sort_by(|&i, &j| {
54 if invert {
55 slice[j].cmp(&slice[i])
56 } else {
57 slice[i].cmp(&slice[j])
58 }
59 });
60
61 indices
62}
63
64pub fn get_random_mle<F: Field>(num_vars: usize, rng: &mut impl Rng) -> DenseMle<F> {
66 let capacity = 2_u32.pow(num_vars as u32);
67 let bookkeeping_table = repeat_with(|| F::from(rng.gen::<u64>()) * F::from(rng.gen::<u64>()))
68 .take(capacity as usize)
69 .collect_vec();
70 DenseMle::new_from_raw(bookkeeping_table, LayerId::Input(0))
71}
72
73pub fn get_random_mle_from_capacity<F: Field>(capacity: usize, rng: &mut impl Rng) -> DenseMle<F> {
75 let bookkeeping_table = repeat_with(|| F::from(rng.gen::<u64>()) * F::from(rng.gen::<u64>()))
76 .take(capacity)
77 .collect_vec();
78 DenseMle::new_from_raw(bookkeeping_table, LayerId::Input(0))
79}
80
81pub fn get_dummy_random_mle_vec<F: Field>(
84 num_vars: usize,
85 num_dataparallel_bits: usize,
86 rng: &mut impl Rng,
87) -> Vec<DenseMle<F>> {
88 (0..(1 << num_dataparallel_bits))
89 .map(|_| {
90 let mle_vec = (0..(1 << num_vars))
91 .map(|_| F::from(rng.gen::<u64>()))
92 .collect_vec();
93 DenseMle::new_from_raw(mle_vec, LayerId::Input(0))
94 })
95 .collect_vec()
96}
97
98pub fn get_mle_idx_decomp_for_idx<F: Field>(idx: usize, num_bits: usize) -> Vec<MleIndex<F>> {
101 (0..(num_bits))
102 .rev()
103 .map(|cur_num_bits| {
104 let is_one =
105 (idx % 2_usize.pow(cur_num_bits as u32 + 1)) >= 2_usize.pow(cur_num_bits as u32);
106 MleIndex::Fixed(is_one)
107 })
108 .collect_vec()
109}
110
111pub fn get_total_mle_indices<F: Field>(
114 prefix_bits: &[bool],
115 num_free_bits: usize,
116) -> Vec<MleIndex<F>> {
117 prefix_bits
118 .iter()
119 .map(|bit| MleIndex::Fixed(*bit))
120 .chain(repeat_n(MleIndex::Free, num_free_bits))
121 .collect()
122}
123
124pub fn build_composite_mle<F: Field>(
145 mles: &[(&MultilinearExtension<F>, Vec<bool>)],
146) -> MultilinearExtension<F> {
147 assert!(!mles.is_empty());
148 let out_num_vars = mles[0].0.num_vars() + mles[0].1.len();
149 mles.iter().for_each(|(mle, prefix_bits)| {
153 assert_eq!(mle.num_vars() + prefix_bits.len(), out_num_vars);
154 });
155 let mut out = vec![F::ZERO; 1 << out_num_vars];
156 for (mle, prefix_bits) in mles {
157 let mut current_window = 1 << out_num_vars;
158 let starting_index = prefix_bits.iter().fold(0, |acc_index, bit| {
159 let starting_index_acc = if *bit {
160 acc_index + current_window / 2
161 } else {
162 acc_index
163 };
164 current_window /= 2;
165 starting_index_acc
166 });
167 assert_eq!(current_window, mle.len().next_power_of_two());
170 (starting_index..(starting_index + current_window))
171 .enumerate()
172 .for_each(|(mle_idx, out_idx)| {
173 out[out_idx] = mle.get(mle_idx).unwrap_or(F::ZERO);
174 });
175 }
176 MultilinearExtension::new(out)
177}
178
179pub fn verify_claim<F: Field>(mle_unpadded_evaluations: &[F], claim: &RawClaim<F>) {
182 let mle_evaluations = claim
183 .get_point()
184 .iter()
185 .fold_while(mle_unpadded_evaluations, |acc, elem| {
186 if elem == &F::ZERO {
187 let sliced_acc = &acc[..(acc.len() / 2)];
188 FoldWhile::Continue(sliced_acc)
189 } else if elem == &F::ONE {
190 let sliced_acc = &acc[(acc.len() / 2)..];
191 FoldWhile::Continue(sliced_acc)
192 } else {
193 FoldWhile::Done(acc)
194 }
195 })
196 .into_inner();
197 let filtered_claim = claim
198 .get_point()
199 .iter()
200 .skip_while(|x| x == &&F::ZERO || x == &&F::ONE)
201 .copied()
202 .collect_vec();
203 let mle = MultilinearExtension::new(mle_evaluations.to_vec());
204 assert_eq!(mle.num_vars(), filtered_claim.len());
205 let eval = evaluate_mle_at_a_point_gray_codes(&mle, &filtered_claim);
206 assert_eq!(eval, claim.get_eval());
207}
208
209#[derive(Debug)]
218pub struct GrayCodeIterator {
219 num_bits: usize,
220 current_iteration: u32,
221 end_iteration: Option<u32>,
222}
223
224impl GrayCodeIterator {
225 pub fn new(num_bits: usize) -> Self {
228 assert!(num_bits < 32);
229 Self {
230 num_bits,
231 current_iteration: 0,
232 end_iteration: None,
233 }
234 }
235
236 pub(crate) fn new_at_index(
237 num_bits: usize,
238 current_iteration: u32,
239 end_iteration: Option<u32>,
240 ) -> Self {
241 Self {
242 num_bits,
243 current_iteration,
244 end_iteration,
245 }
246 }
247
248 pub(crate) fn get_gray_index(num_bits: usize, index: u32) -> u32 {
249 let mask = (1 << num_bits) - 1;
250
251 (index ^ (index >> 1)) & mask
252 }
253}
254
255impl Iterator for GrayCodeIterator {
256 type Item = (u32, (u32, bool));
257
258 fn next(&mut self) -> Option<Self::Item> {
259 if self.current_iteration >= ((1 << self.num_bits) - 1) {
260 return None;
261 }
262
263 if self.end_iteration.is_some()
264 && self.current_iteration >= (self.end_iteration.unwrap() - 1)
265 {
266 return None;
267 }
268
269 if self.end_iteration.is_some()
270 && self.current_iteration >= (self.end_iteration.unwrap() - 1)
271 {
272 return None;
273 }
274
275 let mask = (1 << self.num_bits) - 1;
278
279 let prev_gray = (self.current_iteration ^ (self.current_iteration >> 1)) & mask;
282 self.current_iteration += 1;
286 let new_gray = (self.current_iteration ^ (self.current_iteration >> 1)) & mask;
287
288 Some((
292 new_gray,
293 compute_flipped_bit_idx_and_value_graycode(prev_gray, new_gray),
294 ))
295 }
296}
297
298pub struct LexicographicLE {
305 num_bits: usize,
306 current_val: u32,
307}
308
309impl LexicographicLE {
310 fn new(num_bits: usize) -> Self {
311 Self {
312 num_bits,
313 current_val: 0,
314 }
315 }
316}
317
318impl Iterator for LexicographicLE {
319 type Item = (u32, Vec<(u32, bool)>);
320
321 fn next(&mut self) -> Option<Self::Item> {
322 if self.current_val >= ((1 << self.num_bits) - 1) {
323 return None;
324 }
325
326 let prev_val = self.current_val;
327 self.current_val += 1;
328
329 let flipped_bit_idx_and_values =
330 compute_flipped_bit_idx_and_values_lexicographic(prev_val, self.current_val);
331
332 Some((self.current_val, flipped_bit_idx_and_values))
333 }
334}
335
336pub fn compute_flipped_bit_idx_and_value_graycode(curr_val: u32, next_val: u32) -> (u32, bool) {
339 let flipped_bit = (curr_val ^ next_val).trailing_zeros();
340 let previous_value = (curr_val & (1 << flipped_bit)) != 0;
341 (flipped_bit, previous_value)
342}
343
344pub fn compute_inverses_vec_and_one_minus_inverted_vec<F: Field>(
346 claim_points: &[&[F]],
347) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
348 let inverses_vec = claim_points
349 .iter()
350 .map(|claim_point| {
351 claim_point
352 .iter()
353 .map(|elem| elem.invert().unwrap())
354 .collect_vec()
355 })
356 .collect_vec();
357 let one_minus_inverses_vec = claim_points
358 .iter()
359 .map(|claim_point| {
360 claim_point
361 .iter()
362 .map(|elem| (F::ONE - elem).invert().unwrap())
363 .collect_vec()
364 })
365 .collect_vec();
366 (inverses_vec, one_minus_inverses_vec)
367}
368
369pub fn compute_flipped_bit_idx_and_values_lexicographic(
372 curr_val: u32,
373 next_val: u32,
374) -> Vec<(u32, bool)> {
375 let flipped_bits = curr_val ^ next_val;
376 let mut flipped_bit_idx_and_values = Vec::<(u32, bool)>::new();
377 (0..32).for_each(|idx| {
378 if (flipped_bits & (1 << idx)) != 0 {
379 flipped_bit_idx_and_values.push((idx, (curr_val & (1 << idx)) != 0))
382 }
383 });
384 flipped_bit_idx_and_values
385}
386
387pub fn compute_next_beta_values_vec_from_current<F: Field>(
391 current_beta_values: &[F],
392 inverses_vec: &[Vec<F>],
393 one_minus_elem_inverted_vec: &[Vec<F>],
394 claim_points: &[&[F]],
395 flipped_bit_idx_and_values: &[(u32, bool)],
396) -> Vec<F> {
397 current_beta_values
398 .iter()
399 .zip(inverses_vec.iter().zip(one_minus_elem_inverted_vec))
400 .zip(claim_points)
401 .map(
402 |((current_beta_value, (inverses, one_minus_inverses)), claim_point)| {
403 compute_next_beta_value_from_current(
404 current_beta_value,
405 inverses,
406 one_minus_inverses,
407 claim_point,
408 flipped_bit_idx_and_values,
409 )
410 },
411 )
412 .collect_vec()
413}
414
415pub fn compute_next_beta_value_from_current<F: Field>(
419 current_beta_value: &F,
420 inverses: &[F],
421 one_minus_elem_inverted: &[F],
422 claim_point: &[F],
423 flipped_bit_idx_and_values: &[(u32, bool)],
424) -> F {
425 let n = claim_point.len();
426 flipped_bit_idx_and_values.iter().fold(
427 *current_beta_value,
431 |acc, (idx, value)| {
432 if *value {
433 acc * inverses[n - 1 - *idx as usize]
434 * (F::ONE - claim_point[n - 1 - *idx as usize])
435 } else {
436 acc * (one_minus_elem_inverted[n - 1 - *idx as usize])
437 * claim_point[n - 1 - *idx as usize]
438 }
439 },
440 )
441}
442
443pub fn evaluate_mle_at_a_point_lexicographic_order<F: Field>(
446 mle: &MultilinearExtension<F>,
447 point: &[F],
448) -> F {
449 let n = point.len();
450 let mle_num_vars = mle.num_vars();
451 assert_eq!(n, mle_num_vars);
452
453 let starting_beta_value =
454 BetaValues::compute_beta_over_two_challenges(&vec![F::ZERO; mle_num_vars], point);
455
456 let starting_evaluation_acc = starting_beta_value * mle.first();
457 let lexicographic_le = LexicographicLE::new(mle_num_vars);
458 let inverses = point
459 .iter()
460 .map(|elem| elem.invert().unwrap())
461 .collect_vec();
462 let one_minus_inverses = point
463 .iter()
464 .map(|elem| (F::ONE - elem).invert().unwrap())
465 .collect_vec();
466
467 let (_final_beta_value, evaluation) = lexicographic_le.fold(
468 (starting_beta_value, starting_evaluation_acc),
469 |(prev_beta_value, evaluation_acc), (index, flipped_bit_indices_and_values)| {
470 let next_beta_value = flipped_bit_indices_and_values.iter().fold(
471 prev_beta_value,
472 |acc, (flipped_bit_index, flipped_bit_value)| {
473 if *flipped_bit_value {
480 acc * inverses[n - 1 - *flipped_bit_index as usize]
481 * (F::ONE - point[n - 1 - *flipped_bit_index as usize])
482 }
483 else {
487 acc * (one_minus_inverses[n - 1 - *flipped_bit_index as usize])
488 * point[n - 1 - *flipped_bit_index as usize]
489 }
490 },
491 );
492
493 let next_evaluation_acc = next_beta_value * mle.get(index as usize).unwrap();
495 (next_beta_value, evaluation_acc + next_evaluation_acc)
496 },
497 );
498 evaluation
499}
500
501pub fn evaluate_mle_at_a_point_gray_codes<F: Field>(
507 mle: &MultilinearExtension<F>,
508 point: &[F],
509) -> F {
510 let n = point.len();
511 let mle_num_vars = mle.num_vars();
512 assert_eq!(n, mle_num_vars);
513 let starting_beta_value =
516 BetaValues::compute_beta_over_two_challenges(&vec![F::ZERO; mle_num_vars], point);
517 let starting_evaluation_acc = starting_beta_value * mle.first();
521 let gray_code = GrayCodeIterator::new(mle_num_vars);
522 let inverses = point
523 .iter()
524 .map(|elem| elem.invert().unwrap())
525 .collect_vec();
526 let one_minus_inverses = point
527 .iter()
528 .map(|elem| (F::ONE - elem).invert().unwrap())
529 .collect_vec();
530
531 let multiplier_if_flipped_bit_is_one = inverses
537 .iter()
538 .zip(point.iter())
539 .map(|(inverse, point_elem)| *inverse * (F::ONE - point_elem))
540 .collect_vec();
541 let multiplier_if_flipped_bit_is_zero = one_minus_inverses
542 .iter()
543 .zip(point.iter())
544 .map(|(one_minus_inverse, point_elem)| *one_minus_inverse * point_elem)
545 .collect_vec();
546
547 let (_final_beta_value, evaluation) = gray_code.fold(
548 (starting_beta_value, starting_evaluation_acc),
549 |(prev_beta_value, evaluation_acc), (index, (flipped_bit_index, flipped_bit_value))| {
550 let next_beta_value = if flipped_bit_value {
556 prev_beta_value
557 * multiplier_if_flipped_bit_is_one[n - 1 - flipped_bit_index as usize]
558 }
559 else {
563 prev_beta_value
564 * multiplier_if_flipped_bit_is_zero[n - 1 - flipped_bit_index as usize]
565 };
566 let next_evaluation_acc = next_beta_value * mle.get(index as usize).unwrap();
568 (next_beta_value, evaluation_acc + next_evaluation_acc)
569 },
570 );
571 evaluation
572}
573
574pub fn evaluate_mle_at_a_point_gray_codes_parallel<F: Field, const K: usize>(
577 mle: &MultilinearExtension<F>,
578 point: &[F],
579) -> F {
580 let n = point.len();
581 let mle_num_vars = mle.num_vars();
582 assert_eq!(n, mle_num_vars);
583 assert!(
584 (1 << mle_num_vars) >= K,
585 "cannot have more partitions than the length of MLE"
586 );
587
588 let starting_indices = (0..K)
589 .map(|partition| partition * ((1 << mle_num_vars) / K))
590 .collect_vec();
591 let starting_gray_code_indices = starting_indices
592 .iter()
593 .map(|idx| GrayCodeIterator::get_gray_index(mle_num_vars, *idx as u32))
594 .collect_vec();
595 let starting_beta_values = starting_gray_code_indices
596 .iter()
597 .map(|gray_code| {
598 BetaValues::compute_beta_over_challenge_and_index(point, *gray_code as usize)
599 })
600 .collect_vec();
601
602 let starting_evaluation_accs = starting_beta_values
606 .iter()
607 .zip(starting_gray_code_indices.iter())
608 .map(|(beta_value, gray_code)| *beta_value * mle.get(*gray_code as usize).unwrap())
609 .collect_vec();
610
611 let gray_codes = starting_indices
612 .iter()
613 .enumerate()
614 .map(|(partition, &starting_index)| {
615 let end_iteration = if partition == K - 1 {
616 None
617 } else {
618 Some(starting_indices[partition + 1] as u32)
619 };
620 GrayCodeIterator::new_at_index(mle_num_vars, starting_index as u32, end_iteration)
621 })
622 .collect_vec();
623
624 let inverses = point
625 .iter()
626 .map(|elem| elem.invert().unwrap())
627 .collect_vec();
628
629 let one_minus_inverses = point
630 .iter()
631 .map(|elem| (F::ONE - elem).invert().unwrap())
632 .collect_vec();
633
634 let multiplier_if_flipped_bit_is_one = inverses
640 .iter()
641 .zip(point.iter())
642 .map(|(inverse, point_elem)| *inverse * (F::ONE - point_elem))
643 .collect_vec();
644
645 let multiplier_if_flipped_bit_is_zero = one_minus_inverses
646 .iter()
647 .zip(point.iter())
648 .map(|(one_minus_inverse, point_elem)| *one_minus_inverse * point_elem)
649 .collect_vec();
650
651 (0..K)
652 .into_par_iter()
653 .zip(gray_codes.into_par_iter())
654 .map(|(partition, gray_code)| {
655 let starting_beta_value = starting_beta_values[partition];
656 let starting_evaluation_acc = starting_evaluation_accs[partition];
657 let (_final_beta_value, evaluation) = gray_code.fold(
658 (starting_beta_value, starting_evaluation_acc),
659 |(prev_beta_value, evaluation_acc),
660 (index, (flipped_bit_index, flipped_bit_value))| {
661 let next_beta_value = if flipped_bit_value {
668 prev_beta_value
669 * multiplier_if_flipped_bit_is_one[n - 1 - flipped_bit_index as usize]
670 }
671 else {
675 prev_beta_value
676 * multiplier_if_flipped_bit_is_zero[n - 1 - flipped_bit_index as usize]
677 };
678 let next_evaluation_acc = next_beta_value * mle.get(index as usize).unwrap();
680 (next_beta_value, evaluation_acc + next_evaluation_acc)
681 },
682 );
683 evaluation
684 })
685 .reduce(|| F::ZERO, |a, b| a + b)
686}
687
688pub fn evaluate_mle_destructive<F: Field>(mle: &mut MultilinearExtension<F>, point: &[F]) -> F {
692 point.iter().for_each(|challenge| {
693 mle.fix_variable(*challenge);
694 });
695 assert!(mle.is_fully_bound());
696 mle.first()
697}
698
699#[cfg(test)]
700mod tests {
701 use ark_std::test_rng;
702 use itertools::Itertools;
703 use shared_types::{ff_field, Fr};
704
705 use crate::{
706 mle::evals::MultilinearExtension,
707 utils::mle::{
708 evaluate_mle_at_a_point_gray_codes_parallel,
709 evaluate_mle_at_a_point_lexicographic_order, evaluate_mle_destructive,
710 GrayCodeIterator,
711 },
712 };
713
714 use super::evaluate_mle_at_a_point_gray_codes;
715
716 #[test]
717 fn test_gray_code_0_vars() {
718 let mut gray_code_iterator = GrayCodeIterator::new(0);
719
720 assert_eq!(gray_code_iterator.next(), None);
721 }
722
723 #[test]
724 fn test_gray_code_iterator_len() {
725 for n in 1..16 {
726 assert_eq!(GrayCodeIterator::new(n).count(), (1 << n) - 1);
727 }
728 }
729
730 #[test]
734 fn test_gray_code_3_vars() {
735 let mut gray_code_iterator = GrayCodeIterator::new(3);
736
737 assert_eq!(gray_code_iterator.next(), Some((1, (0, false))));
738 assert_eq!(gray_code_iterator.next(), Some((3, (1, false))));
739 assert_eq!(gray_code_iterator.next(), Some((2, (0, true))));
740 assert_eq!(gray_code_iterator.next(), Some((6, (2, false))));
741 assert_eq!(gray_code_iterator.next(), Some((7, (0, false))));
742 assert_eq!(gray_code_iterator.next(), Some((5, (1, true))));
743 assert_eq!(gray_code_iterator.next(), Some((4, (0, true))));
744 assert_eq!(gray_code_iterator.next(), None);
745 }
746
747 #[test]
756 fn test_gray_code_property() {
757 for n in 1..16 {
758 let gray_code_iterator = GrayCodeIterator::new(n);
759
760 let mut seen: Vec<bool> = vec![false; 1 << n];
761
762 seen[0] = true;
764
765 gray_code_iterator.fold(0, |prev, (cur, (idx, val))| {
766 assert!(!seen[cur as usize]);
768 seen[cur as usize] = true;
769
770 let mask: u32 = 1 << idx;
771
772 assert_eq!(prev ^ cur, mask);
778
779 assert_eq!((prev & mask) >> idx, val as u32);
781
782 cur
783 });
784
785 assert!(seen.iter().all(|x| *x))
787 }
788 }
789
790 #[test]
792 #[should_panic]
793 fn test_evaluate_mle_at_a_point_1_variable_gray_codes_parallel_more_threads_than_mle_length() {
794 const K: usize = 3;
795 let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
796 let point = &[Fr::from(2)];
797 let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
798 let expected_evaluation =
799 evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mut mle, point);
800 assert_eq!(computed_evaluation, expected_evaluation);
801 }
802
803 #[test]
804 fn test_evaluate_mle_at_a_point_1_variable_gray_codes_parallel_1_thread() {
805 const K: usize = 1;
806 let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
807 let point = &[Fr::from(2)];
808 let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
809 let expected_evaluation =
810 evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mut mle, point);
811 assert_eq!(computed_evaluation, expected_evaluation);
812 }
813
814 #[test]
815 fn test_evaluate_mle_at_a_point_1_variable_gray_codes_parallel_2_threads() {
816 const K: usize = 2;
817 let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
818 let point = &[Fr::from(2)];
819 let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
820 let expected_evaluation =
821 evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mut mle, point);
822 assert_eq!(computed_evaluation, expected_evaluation);
823 }
824
825 #[test]
826 fn test_evaluate_mle_at_a_point_1_variable_gray_codes() {
827 let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
828 let point = &[Fr::from(2)];
829 let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
830 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
831 assert_eq!(computed_evaluation, expected_evaluation);
832 }
833
834 #[test]
835 fn test_evaluate_mle_at_a_point_2_variable_gray_codes() {
836 let mut mle: MultilinearExtension<Fr> = vec![1, 2, 1, 2].into();
837 let point = &[Fr::from(2), Fr::from(3)];
838 let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
839 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
840 assert_eq!(computed_evaluation, expected_evaluation);
841 }
842
843 #[test]
844 fn test_evaluate_mle_at_a_point_3_variable_gray_codes_random_parallel() {
845 const K: usize = 5;
846 let mut rng = test_rng();
847 let mut mle = MultilinearExtension::new((0..8).map(|_| Fr::random(&mut rng)).collect());
848 let point = &(0..3).map(|_| Fr::random(&mut rng)).collect_vec();
849 let computed_evaluation =
850 evaluate_mle_at_a_point_gray_codes_parallel::<shared_types::Fr, K>(&mle, point);
851 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
852 assert_eq!(computed_evaluation, expected_evaluation);
853 }
854
855 #[test]
856 fn test_evaluate_mle_at_a_point_3_variable_gray_codes_random() {
857 let mut rng = test_rng();
858 let mut mle = MultilinearExtension::new((0..8).map(|_| Fr::random(&mut rng)).collect());
859 let point = &(0..3).map(|_| Fr::random(&mut rng)).collect_vec();
860 let computed_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, point);
861 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
862 assert_eq!(computed_evaluation, expected_evaluation);
863 }
864
865 #[test]
866 fn test_evaluate_mle_at_a_point_1_variable_lexicographic() {
867 let mut mle: MultilinearExtension<Fr> = vec![1, 2].into();
868 let point = &[Fr::from(2)];
869 let computed_evaluation = evaluate_mle_at_a_point_lexicographic_order(&mle, point);
870 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
871 assert_eq!(computed_evaluation, expected_evaluation);
872 }
873
874 #[test]
875 fn test_evaluate_mle_at_a_point_2_variable_lexicographic() {
876 let mut mle: MultilinearExtension<Fr> = vec![1, 2, 1, 2].into();
877 let point = &[Fr::from(2), Fr::from(3)];
878 let computed_evaluation = evaluate_mle_at_a_point_lexicographic_order(&mle, point);
879 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
880 assert_eq!(computed_evaluation, expected_evaluation);
881 }
882
883 #[test]
884 fn test_evaluate_mle_at_a_point_3_variable_lexicographic_random() {
885 let mut rng = test_rng();
886 let mut mle = MultilinearExtension::new((0..8).map(|_| Fr::random(&mut rng)).collect());
887 let point = &(0..3).map(|_| Fr::random(&mut rng)).collect_vec();
888 let computed_evaluation = evaluate_mle_at_a_point_lexicographic_order(&mle, point);
889 let expected_evaluation = evaluate_mle_destructive(&mut mle, point);
890 assert_eq!(computed_evaluation, expected_evaluation);
891 }
892
893 #[test]
896 fn test_evaluation_equivalence() {
897 for n in 1..16 {
898 let num_vars = n;
899 let num_evals = 1 << num_vars;
900
901 let mut rng = test_rng();
902 let mut mle =
903 MultilinearExtension::new((0..num_evals).map(|_| Fr::random(&mut rng)).collect());
904 let point = (0..num_vars).map(|_| Fr::random(&mut rng)).collect_vec();
905
906 let gray_code_evaluation = evaluate_mle_at_a_point_gray_codes(&mle, &point);
907 let lexicographic_evaluation =
908 evaluate_mle_at_a_point_lexicographic_order(&mle, &point);
909 let destructive_evaluation = evaluate_mle_destructive(&mut mle, &point);
910
911 assert!(
912 gray_code_evaluation == lexicographic_evaluation
913 && lexicographic_evaluation == destructive_evaluation
914 );
915 }
916 }
917}