1use super::{
11 circuit_expr::evaluate_bookkeeping_tables_given_operation,
12 expr_errors::ExpressionError,
13 generic_expr::{Expression, ExpressionNode, ExpressionType},
14 verifier_expr::VerifierExpr,
15};
16use crate::{
17 layer::product::Product,
18 mle::{betavalues::BetaValues, dense::DenseMle, MleIndex},
19 sumcheck::{
20 apply_updated_beta_values_to_evals, beta_cascade, beta_cascade_no_independent_variable,
21 SumcheckEvals,
22 },
23};
24use crate::{
25 layer::{gate::BinaryOperation, product::PostSumcheckLayer},
26 mle::{verifier_mle::VerifierMle, Mle},
27};
28use itertools::{repeat_n, Itertools};
29use serde::{Deserialize, Serialize};
30use shared_types::Field;
31use std::{
32 cmp::max,
33 collections::HashSet,
34 fmt::Debug,
35 ops::{Add, Mul, Neg, Sub},
36};
37
38use anyhow::{anyhow, Ok, Result};
39
40#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
47pub struct MleVecIndex(usize);
48
49impl MleVecIndex {
50 pub fn new(index: usize) -> Self {
52 MleVecIndex(index)
53 }
54
55 pub fn index(&self) -> usize {
57 self.0
58 }
59
60 pub fn increment(&mut self, offset: usize) {
62 self.0 += offset;
63 }
64
65 pub fn get_mle<'a, F: Field>(&self, mle_vec: &'a [DenseMle<F>]) -> &'a DenseMle<F> {
67 &mle_vec[self.0]
68 }
69
70 pub fn get_mle_mut<'a, F: Field>(&self, mle_vec: &'a mut [DenseMle<F>]) -> &'a mut DenseMle<F> {
72 &mut mle_vec[self.0]
73 }
74}
75
76#[derive(Serialize, Deserialize, Clone, Debug)]
79pub struct ProverExpr;
80impl<F: Field> ExpressionType<F> for ProverExpr {
81 type MLENodeRepr = MleVecIndex;
82 type MleVec = Vec<DenseMle<F>>;
83}
84
85impl<F: Field> Expression<F, ProverExpr> {
88 pub fn select(self, mut rhs: Expression<F, ProverExpr>) -> Self {
91 let offset = self.num_mle();
92 rhs.increment_mle_vec_indices(offset);
93 let (lhs_node, lhs_mle_vec) = self.deconstruct();
94 let (rhs_node, rhs_mle_vec) = rhs.deconstruct();
95
96 let concat_node =
97 ExpressionNode::Selector(MleIndex::Free, Box::new(lhs_node), Box::new(rhs_node));
98
99 let concat_mle_vec = lhs_mle_vec.into_iter().chain(rhs_mle_vec).collect_vec();
100
101 Expression::new(concat_node, concat_mle_vec)
102 }
103
104 pub fn pow(pow: usize, mle: DenseMle<F>) -> Self {
106 let mle_vec_indices = (0..pow).map(|_index| MleVecIndex::new(0)).collect_vec();
107
108 let product_node = ExpressionNode::Product(mle_vec_indices);
109
110 Expression::new(product_node, vec![mle])
111 }
112
113 pub fn products(product_list: <ProverExpr as ExpressionType<F>>::MleVec) -> Self {
115 let mle_vec_indices = (0..product_list.len()).map(MleVecIndex::new).collect_vec();
116
117 let product_node = ExpressionNode::Product(mle_vec_indices);
118
119 Expression::new(product_node, product_list)
120 }
121
122 pub fn mle(mle: DenseMle<F>) -> Self {
124 let mle_node = ExpressionNode::Mle(MleVecIndex::new(0));
125
126 Expression::new(mle_node, [mle].to_vec())
127 }
128
129 pub fn constant(constant: F) -> Self {
131 let mle_node = ExpressionNode::Constant(constant);
132
133 Expression::new(mle_node, [].to_vec())
134 }
135
136 pub fn negated(expression: Self) -> Self {
138 let (node, mle_vec) = expression.deconstruct();
139
140 let mle_node = ExpressionNode::Scaled(Box::new(node), F::from(1).neg());
141
142 Expression::new(mle_node, mle_vec)
143 }
144
145 pub fn sum(lhs: Self, mut rhs: Self) -> Self {
147 let offset = lhs.num_mle();
148 rhs.increment_mle_vec_indices(offset);
149
150 let (lhs_node, lhs_mle_vec) = lhs.deconstruct();
151 let (rhs_node, rhs_mle_vec) = rhs.deconstruct();
152
153 let sum_node = ExpressionNode::Sum(Box::new(lhs_node), Box::new(rhs_node));
154 let sum_mle_vec = lhs_mle_vec.into_iter().chain(rhs_mle_vec).collect_vec();
155
156 Expression::new(sum_node, sum_mle_vec)
157 }
158
159 pub fn scaled(expression: Expression<F, ProverExpr>, scale: F) -> Self {
161 let (node, mle_vec) = expression.deconstruct();
162
163 Expression::new(ExpressionNode::Scaled(Box::new(node), scale), mle_vec)
164 }
165
166 pub fn num_mle(&self) -> usize {
168 self.mle_vec.len()
169 }
170
171 pub fn increment_mle_vec_indices(&mut self, offset: usize) {
173 let mut increment_closure = |expr: &mut ExpressionNode<F, ProverExpr>,
176 _mle_vec: &mut Vec<DenseMle<F>>|
177 -> Result<()> {
178 match expr {
179 ExpressionNode::Mle(mle_vec_index) => {
180 mle_vec_index.increment(offset);
181 Ok(())
182 }
183 ExpressionNode::Product(mle_indices) => {
184 for mle_vec_index in mle_indices {
185 mle_vec_index.increment(offset);
186 }
187 Ok(())
188 }
189 ExpressionNode::Constant(_)
190 | ExpressionNode::Scaled(_, _)
191 | ExpressionNode::Sum(_, _)
192 | ExpressionNode::Selector(_, _, _) => Ok(()),
193 }
194 };
195
196 self.traverse_mut(&mut increment_closure).unwrap();
197 }
198
199 pub fn transform_to_verifier_expression(self) -> Result<Expression<F, VerifierExpr>> {
209 let (mut expression_node, mle_vec) = self.deconstruct();
210 Ok(Expression::new(
211 expression_node
212 .transform_to_verifier_expression_node(&mle_vec)
213 .unwrap(),
214 (),
215 ))
216 }
217
218 pub fn fix_variable(&mut self, round_index: usize, challenge: F) {
220 let (expression_node, mle_vec) = self.deconstruct_mut();
221
222 expression_node.fix_variable_node(round_index, challenge, mle_vec)
223 }
224
225 pub fn fix_variable_at_index(&mut self, round_index: usize, challenge: F) {
227 let (expression_node, mle_vec) = self.deconstruct_mut();
228
229 expression_node.fix_variable_at_index_node(round_index, challenge, mle_vec)
230 }
231
232 pub fn evaluate_expr(&mut self, challenges: Vec<F>) -> Result<F> {
234 challenges
236 .iter()
237 .enumerate()
238 .for_each(|(round_idx, &challenge)| {
239 self.fix_variable(round_idx, challenge);
240 });
241
242 let mut observer_fn = |exp: &ExpressionNode<F, ProverExpr>,
244 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec|
245 -> Result<()> {
246 match exp {
247 ExpressionNode::Mle(mle_vec_idx) => {
248 let mle = mle_vec_idx.get_mle(mle_vec);
249 let indices = mle
250 .mle_indices()
251 .iter()
252 .filter_map(|index| match index {
253 MleIndex::Bound(chal, index) => Some((*chal, index)),
254 _ => None,
255 })
256 .collect_vec();
257
258 let start = *indices[0].1;
259 let end = *indices[indices.len() - 1].1;
260
261 let (indices, _): (Vec<_>, Vec<usize>) = indices.into_iter().unzip();
262
263 if indices.as_slice() == &challenges[start..=end] {
264 Ok(())
265 } else {
266 Err(anyhow!(ExpressionError::EvaluateBoundIndicesDontMatch))
267 }
268 }
269 ExpressionNode::Product(mle_vec_indices) => {
270 let mles = mle_vec_indices
271 .iter()
272 .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
273 .collect_vec();
274
275 mles.iter()
276 .map(|mle| {
277 let indices = mle
278 .mle_indices()
279 .iter()
280 .filter_map(|index| match index {
281 MleIndex::Bound(chal, index) => Some((*chal, index)),
282 _ => None,
283 })
284 .collect_vec();
285
286 let start = *indices[0].1;
287 let end = *indices[indices.len() - 1].1;
288
289 let (indices, _): (Vec<_>, Vec<usize>) = indices.into_iter().unzip();
290
291 if indices.as_slice() == &challenges[start..=end] {
292 Ok(())
293 } else {
294 Err(anyhow!(ExpressionError::EvaluateBoundIndicesDontMatch))
295 }
296 })
297 .try_collect()
298 }
299
300 _ => Ok(()),
301 }
302 };
303 self.traverse(&mut observer_fn)?;
304
305 self.clone()
307 .transform_to_verifier_expression()
308 .unwrap()
309 .evaluate()
310 }
311
312 #[allow(clippy::too_many_arguments)]
313 pub fn evaluate_sumcheck_beta_cascade(
317 &self,
318 beta: &[&BetaValues<F>],
319 random_coefficients: &[F],
320 round_index: usize,
321 degree: usize,
322 ) -> SumcheckEvals<F> {
323 self.expression_node.evaluate_sumcheck_node_beta_cascade(
324 beta,
325 &self.mle_vec,
326 random_coefficients,
327 round_index,
328 degree,
329 )
330 }
331
332 pub fn evaluate_sumcheck_node_beta_cascade_sum(
337 &self,
338 beta_values: &BetaValues<F>,
339 round_index: usize,
340 degree: usize,
341 ) -> SumcheckEvals<F> {
342 self.expression_node
343 .evaluate_sumcheck_node_beta_cascade_sum(
344 beta_values,
345 round_index,
346 degree,
347 &self.mle_vec,
348 )
349 }
350
351 pub fn get_all_rounds(&self) -> Vec<usize> {
354 let (expression_node, mle_vec) = self.deconstruct_ref();
355 let mut all_rounds = expression_node.get_all_rounds(mle_vec);
356 all_rounds.sort();
357 all_rounds
358 }
359
360 pub fn get_all_nonlinear_rounds(&self) -> Vec<usize> {
363 let (expression_node, mle_vec) = self.deconstruct_ref();
364 let mut nonlinear_rounds = expression_node.get_all_nonlinear_rounds(mle_vec);
365 nonlinear_rounds.sort();
366 nonlinear_rounds
367 }
368
369 pub fn get_all_linear_rounds(&self) -> Vec<usize> {
372 let (expression_node, mle_vec) = self.deconstruct_ref();
373 let mut linear_rounds = expression_node.get_all_linear_rounds(mle_vec);
374 linear_rounds.sort();
375 linear_rounds
376 }
377
378 pub fn index_mle_indices(&mut self, curr_index: usize) -> usize {
382 let (expression_node, mle_vec) = self.deconstruct_mut();
383 expression_node.index_mle_indices_node(curr_index, mle_vec)
384 }
385
386 pub fn get_expression_num_free_variables(&self) -> usize {
388 self.expression_node
389 .get_expression_num_free_variables_node(0, &self.mle_vec)
390 }
391
392 pub fn get_post_sumcheck_layer(&self, multiplier: F) -> PostSumcheckLayer<F, F> {
395 self.expression_node
396 .get_post_sumcheck_layer(multiplier, &self.mle_vec)
397 }
398
399 pub fn get_max_degree(&self) -> usize {
401 self.expression_node.get_max_degree(&self.mle_vec)
402 }
403}
404
405impl<F: Field> ExpressionNode<F, ProverExpr> {
406 pub fn transform_to_verifier_expression_node(
410 &mut self,
411 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
412 ) -> Result<ExpressionNode<F, VerifierExpr>> {
413 match self {
414 ExpressionNode::Constant(scalar) => Ok(ExpressionNode::Constant(*scalar)),
415 ExpressionNode::Selector(index, a, b) => Ok(ExpressionNode::Selector(
416 index.clone(),
417 Box::new(a.transform_to_verifier_expression_node(mle_vec)?),
418 Box::new(b.transform_to_verifier_expression_node(mle_vec)?),
419 )),
420 ExpressionNode::Mle(mle_vec_idx) => {
421 let mle = mle_vec_idx.get_mle(mle_vec);
422
423 if !mle.is_fully_bounded() {
424 return Err(anyhow!(ExpressionError::EvaluateNotFullyBoundError));
425 }
426
427 let layer_id = mle.layer_id();
428 let mle_indices = mle.mle_indices().to_vec();
429 let eval = mle.value();
430
431 Ok(ExpressionNode::Mle(VerifierMle::new(
432 layer_id,
433 mle_indices,
434 eval,
435 )))
436 }
437 ExpressionNode::Sum(a, b) => Ok(ExpressionNode::Sum(
438 Box::new(a.transform_to_verifier_expression_node(mle_vec)?),
439 Box::new(b.transform_to_verifier_expression_node(mle_vec)?),
440 )),
441 ExpressionNode::Product(mle_vec_indices) => {
442 let mles = mle_vec_indices
443 .iter_mut()
444 .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
445 .collect_vec();
446
447 for mle in mles.iter() {
448 if !mle.is_fully_bounded() {
449 return Err(anyhow!(ExpressionError::EvaluateNotFullyBoundError));
450 }
451 }
452
453 Ok(ExpressionNode::Product(
454 mles.into_iter()
455 .map(|mle| {
456 VerifierMle::new(
457 mle.layer_id(),
458 mle.mle_indices().to_vec(),
459 mle.value(),
460 )
461 })
462 .collect_vec(),
463 ))
464 }
465 ExpressionNode::Scaled(mle, scalar) => Ok(ExpressionNode::Scaled(
466 Box::new(mle.transform_to_verifier_expression_node(mle_vec)?),
467 *scalar,
468 )),
469 }
470 }
471
472 pub fn fix_variable_node(
474 &mut self,
475 round_index: usize,
476 challenge: F,
477 mle_vec: &mut <ProverExpr as ExpressionType<F>>::MleVec, ) {
479 match self {
480 ExpressionNode::Selector(index, a, b) => {
481 if *index == MleIndex::Indexed(round_index) {
482 index.bind_index(challenge);
483 } else {
484 a.fix_variable_node(round_index, challenge, mle_vec);
485 b.fix_variable_node(round_index, challenge, mle_vec);
486 }
487 }
488 ExpressionNode::Mle(mle_vec_idx) => {
489 let mle = mle_vec_idx.get_mle_mut(mle_vec);
490
491 if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
492 mle.fix_variable(round_index, challenge);
493 }
494 }
495 ExpressionNode::Sum(a, b) => {
496 a.fix_variable_node(round_index, challenge, mle_vec);
497 b.fix_variable_node(round_index, challenge, mle_vec);
498 }
499 ExpressionNode::Product(mle_vec_indices) => {
500 mle_vec_indices
501 .iter_mut()
502 .map(|mle_vec_index| {
503 let mle = mle_vec_index.get_mle_mut(mle_vec);
504
505 if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
506 mle.fix_variable(round_index, challenge);
507 }
508 })
509 .collect_vec();
510 }
511 ExpressionNode::Scaled(a, _) => {
512 a.fix_variable_node(round_index, challenge, mle_vec);
513 }
514 ExpressionNode::Constant(_) => (),
515 }
516 }
517
518 pub fn fix_variable_at_index_node(
520 &mut self,
521 round_index: usize,
522 challenge: F,
523 mle_vec: &mut <ProverExpr as ExpressionType<F>>::MleVec, ) {
525 match self {
526 ExpressionNode::Selector(index, a, b) => {
527 if *index == MleIndex::Indexed(round_index) {
528 index.bind_index(challenge);
529 } else {
530 a.fix_variable_at_index_node(round_index, challenge, mle_vec);
531 b.fix_variable_at_index_node(round_index, challenge, mle_vec);
532 }
533 }
534 ExpressionNode::Mle(mle_vec_idx) => {
535 let mle = mle_vec_idx.get_mle_mut(mle_vec);
536
537 if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
538 mle.fix_variable_at_index(round_index, challenge);
539 }
540 }
541 ExpressionNode::Sum(a, b) => {
542 a.fix_variable_at_index_node(round_index, challenge, mle_vec);
543 b.fix_variable_at_index_node(round_index, challenge, mle_vec);
544 }
545 ExpressionNode::Product(mle_vec_indices) => {
546 mle_vec_indices
547 .iter_mut()
548 .map(|mle_vec_index| {
549 let mle = mle_vec_index.get_mle_mut(mle_vec);
550
551 if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
552 mle.fix_variable_at_index(round_index, challenge);
553 }
554 })
555 .collect_vec();
556 }
557 ExpressionNode::Scaled(a, _) => {
558 a.fix_variable_at_index_node(round_index, challenge, mle_vec);
559 }
560 ExpressionNode::Constant(_) => (),
561 }
562 }
563
564 pub fn evaluate_sumcheck_node_beta_cascade_sum(
565 &self,
566 beta_values: &BetaValues<F>,
567 round_index: usize,
568 degree: usize,
569 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
570 ) -> SumcheckEvals<F> {
571 match self {
572 ExpressionNode::Constant(constant) => {
573 SumcheckEvals(repeat_n(*constant, degree + 1).collect())
574 }
575 ExpressionNode::Selector(selector_mle_index, lhs, rhs) => {
576 let lhs_eval = lhs.evaluate_sumcheck_node_beta_cascade_sum(
577 beta_values,
578 round_index,
579 degree,
580 mle_vec,
581 );
582 let rhs_eval = rhs.evaluate_sumcheck_node_beta_cascade_sum(
583 beta_values,
584 round_index,
585 degree,
586 mle_vec,
587 );
588 match selector_mle_index {
589 MleIndex::Indexed(var_number) => {
590 let index_claim = beta_values.get_unbound_value(*var_number).unwrap();
591 (lhs_eval * (F::ONE - index_claim)) + (rhs_eval * index_claim)
592 }
593 MleIndex::Bound(bound_value, var_number) => {
594 let identity = F::ONE;
595 let beta_bound = beta_values
596 .get_updated_value(*var_number)
597 .unwrap_or(identity);
598 ((lhs_eval * (F::ONE - bound_value)) + (rhs_eval * bound_value))
599 * beta_bound
600 }
601 _ => panic!("Invalid MLE Index for a selector bit, should be free or indexed"),
602 }
603 }
604 ExpressionNode::Mle(mle_idx) => {
605 let mle = mle_idx.get_mle(mle_vec);
606 let (unbound, bound) = beta_values.get_relevant_beta_unbound_and_bound(
607 mle.mle_indices(),
608 round_index,
609 false,
610 );
611 beta_cascade_no_independent_variable(mle.mle.to_vec(), &unbound, &bound, degree)
612 }
613 ExpressionNode::Sum(lhs, rhs) => {
614 let lhs_eval = lhs.evaluate_sumcheck_node_beta_cascade_sum(
615 beta_values,
616 round_index,
617 degree,
618 mle_vec,
619 );
620 let rhs_eval = rhs.evaluate_sumcheck_node_beta_cascade_sum(
621 beta_values,
622 round_index,
623 degree,
624 mle_vec,
625 );
626 lhs_eval + rhs_eval
627 }
628 ExpressionNode::Product(mle_idx_vec) => {
629 let (mles, mles_bookkeeping_tables): (Vec<&DenseMle<F>>, Vec<Vec<F>>) = mle_idx_vec
630 .iter()
631 .map(|mle_vec_index| {
632 let mle = mle_vec_index.get_mle(mle_vec);
633 (mle, mle.mle.to_vec())
634 })
635 .unzip();
636
637 let mut unique_mle_indices = HashSet::new();
638
639 let mle_indices_vec = mles
640 .iter()
641 .flat_map(|mle| mle.mle_indices.clone())
642 .filter(move |mle_index| unique_mle_indices.insert(mle_index.clone()))
643 .collect_vec();
644
645 let (unbound, bound) = beta_values.get_relevant_beta_unbound_and_bound(
646 &mle_indices_vec,
647 round_index,
648 false,
649 );
650 let evaluated_bookkeeping_tables = evaluate_bookkeeping_tables_given_operation(
651 &mles_bookkeeping_tables,
652 BinaryOperation::Mul,
653 );
654 beta_cascade_no_independent_variable(
655 evaluated_bookkeeping_tables.to_vec(),
656 &unbound,
657 &bound,
658 degree,
659 )
660 }
661 ExpressionNode::Scaled(expression_node, scale) => {
662 expression_node.evaluate_sumcheck_node_beta_cascade_sum(
663 beta_values,
664 round_index,
665 degree,
666 mle_vec,
667 ) * scale
668 }
669 }
670 }
671
672 #[allow(clippy::too_many_arguments)]
726 pub fn evaluate_sumcheck_node_beta_cascade(
727 &self,
728 beta_vec: &[&BetaValues<F>],
729 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
730 random_coefficients: &[F],
731 round_index: usize,
732 degree: usize,
733 ) -> SumcheckEvals<F> {
734 match self {
735 ExpressionNode::Constant(constant) => {
744 let sumcheck_eval_not_scaled_by_constant = beta_vec
745 .iter()
746 .zip(random_coefficients)
747 .map(|(beta_table, random_coeff)| {
748 let folded_updated_vals = beta_table.fold_updated_values();
749 let index_claim = beta_table.get_unbound_value(round_index).unwrap();
750 let one_minus_index_claim = F::ONE - index_claim;
751 let beta_step = index_claim - one_minus_index_claim;
752 let evals =
753 std::iter::successors(Some(one_minus_index_claim), move |item| {
754 Some(*item + beta_step)
755 })
756 .take(degree + 1)
757 .collect_vec();
758 apply_updated_beta_values_to_evals(evals, folded_updated_vals)
759 * random_coeff
760 })
761 .reduce(|acc, elem| acc + elem)
762 .unwrap();
763 sumcheck_eval_not_scaled_by_constant * constant
764 }
765
766 ExpressionNode::Selector(index, a, b) => {
774 match index {
775 MleIndex::Indexed(indexed_bit) => {
776 let (lhs_evals, rhs_evals) = (
777 beta_vec
778 .iter()
779 .map(|beta| {
780 a.evaluate_sumcheck_node_beta_cascade_sum(
781 beta,
782 round_index,
783 degree,
784 mle_vec,
785 )
786 })
787 .collect_vec(),
788 beta_vec
789 .iter()
790 .map(|beta| {
791 b.evaluate_sumcheck_node_beta_cascade_sum(
792 beta,
793 round_index,
794 degree,
795 mle_vec,
796 )
797 })
798 .collect_vec(),
799 );
800 match Ord::cmp(&round_index, indexed_bit) {
804 std::cmp::Ordering::Less => {
805 let sumcheck_eval = beta_vec
806 .iter()
807 .zip((lhs_evals.iter().zip(rhs_evals.iter())).zip(random_coefficients))
808 .map(|(beta_table, ((a, b), random_coeff))| {
809 let index_claim = beta_table.get_unbound_value(*indexed_bit).unwrap();
810 let a_eval: &SumcheckEvals<F> = a;
811 let b_eval: &SumcheckEvals<F> = b;
812 let a_with_sel: SumcheckEvals<F> =
820 a_eval.clone() * (F::ONE - index_claim);
821 let b_with_sel: SumcheckEvals<F> = b_eval.clone() * index_claim;
822 (a_with_sel + b_with_sel) * random_coeff
823 })
824 .reduce(|acc, elem| acc + elem)
825 .unwrap();
826 sumcheck_eval
827 }
828 std::cmp::Ordering::Equal => {
829 let sumcheck_eval = beta_vec
833 .iter()
834 .zip((lhs_evals.iter().zip(rhs_evals)).zip(random_coefficients))
835 .map(|(beta_table, ((a, b), random_coeff))| {
836 let SumcheckEvals(first_evals) = a;
837 let SumcheckEvals(second_evals) = b;
838 if first_evals.len() == second_evals.len() {
839 let bound_beta_values = beta_table.fold_updated_values();
840 let index_claim =
841 beta_table.get_unbound_value(*indexed_bit).unwrap();
842 let eval_len = first_evals.len();
847 let one_minus_index_claim = F::ONE - index_claim;
848 let beta_step = index_claim - one_minus_index_claim;
849 let beta_evals = std::iter::successors(
850 Some(one_minus_index_claim),
851 move |item| Some(*item + beta_step),
852 )
853 .take(eval_len)
854 .collect_vec();
855 let first_evals = SumcheckEvals(
859 first_evals
860 .clone()
861 .into_iter()
862 .enumerate()
863 .map(|(idx, first_eval)| {
864 first_eval
865 * (F::ONE - F::from(idx as u64))
866 * beta_evals[idx]
867 })
868 .collect(),
869 );
870 let second_evals = SumcheckEvals(
871 second_evals
872 .clone()
873 .into_iter()
874 .enumerate()
875 .map(|(idx, second_eval)| {
876 second_eval
877 * F::from(idx as u64) * beta_evals[idx]
878 })
879 .collect(),
880 );
881 (first_evals + second_evals) * random_coeff * bound_beta_values
882 } else {
883 panic!("Expression returns two evals that do not have the same length on a selector bit")
884 }
885 })
886 .reduce(|acc, elem| acc + elem)
887 .unwrap();
888 sumcheck_eval
889 }
890 std::cmp::Ordering::Greater => panic!(
894 "Invalid selector index, cannot be less than the current round index"
895 ),
896 }
897 }
898 MleIndex::Bound(coeff, _) => {
902 let (lhs_evals, rhs_evals) = (
903 beta_vec
904 .iter()
905 .map(|beta| {
906 a.evaluate_sumcheck_node_beta_cascade(
907 &[*beta],
908 mle_vec,
909 &[F::ONE],
910 round_index,
911 degree,
912 )
913 })
914 .collect_vec(),
915 beta_vec
916 .iter()
917 .map(|beta| {
918 b.evaluate_sumcheck_node_beta_cascade(
919 &[*beta],
920 mle_vec,
921 &[F::ONE],
922 round_index,
923 degree,
924 )
925 })
926 .collect_vec(),
927 );
928 let coeff_neg = F::ONE - coeff;
929 (lhs_evals.iter().zip(rhs_evals))
930 .zip(random_coefficients)
931 .map(|((a, b), random_coeff)| {
932 let a_eval = a;
933 let b_eval = b;
934 ((b_eval.clone() * coeff) + (a_eval.clone() * coeff_neg))
935 * random_coeff
936 })
937 .reduce(|acc, elem| acc + elem)
938 .unwrap()
939 }
940 _ => panic!("selector index should not be a Free or Fixed bit"),
941 }
942 }
943 ExpressionNode::Mle(mle_vec_idx) => {
946 let mle = mle_vec_idx.get_mle(mle_vec);
947 let (unbound_beta_vec, bound_beta_vec): (Vec<Vec<F>>, Vec<Vec<F>>) = beta_vec
948 .iter()
949 .map(|beta| {
950 beta.get_relevant_beta_unbound_and_bound(
951 mle.mle_indices(),
952 round_index,
953 true,
954 )
955 })
956 .unzip();
957
958 beta_cascade(
959 &[&mle.clone()],
960 degree,
961 round_index,
962 &unbound_beta_vec,
963 &bound_beta_vec,
964 random_coefficients,
965 )
966 }
967 ExpressionNode::Sum(a, b) => {
970 let a = a.evaluate_sumcheck_node_beta_cascade(
971 beta_vec,
972 mle_vec,
973 random_coefficients,
974 round_index,
975 degree,
976 );
977 let b = b.evaluate_sumcheck_node_beta_cascade(
978 beta_vec,
979 mle_vec,
980 random_coefficients,
981 round_index,
982 degree,
983 );
984 a + b
985 }
986 ExpressionNode::Product(mle_vec_indices) => {
990 let mles = mle_vec_indices
991 .iter()
992 .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
993 .collect_vec();
994
995 let mut unique_mle_indices = HashSet::new();
996
997 let mle_indices_vec = mles
998 .iter()
999 .flat_map(|mle| mle.mle_indices.clone())
1000 .filter(move |mle_index| unique_mle_indices.insert(mle_index.clone()))
1001 .collect_vec();
1002
1003 let (unbound_beta_vec, bound_beta_vec): (Vec<Vec<F>>, Vec<Vec<F>>) = beta_vec
1004 .iter()
1005 .map(|beta| {
1006 beta.get_relevant_beta_unbound_and_bound(
1007 &mle_indices_vec,
1008 round_index,
1009 true,
1010 )
1011 })
1012 .unzip();
1013
1014 beta_cascade(
1015 &mles,
1016 degree,
1017 round_index,
1018 &unbound_beta_vec,
1019 &bound_beta_vec,
1020 random_coefficients,
1021 )
1022 }
1023
1024 ExpressionNode::Scaled(a, scale) => {
1027 let a = a.evaluate_sumcheck_node_beta_cascade(
1028 beta_vec,
1029 mle_vec,
1030 random_coefficients,
1031 round_index,
1032 degree,
1033 );
1034 a * scale
1035 }
1036 }
1037 }
1038
1039 pub fn index_mle_indices_node(
1043 &mut self,
1044 curr_index: usize,
1045 mle_vec: &mut <ProverExpr as ExpressionType<F>>::MleVec,
1046 ) -> usize {
1047 match self {
1048 ExpressionNode::Selector(mle_index, a, b) => {
1049 let mut new_index = curr_index;
1050 if *mle_index == MleIndex::Free {
1051 *mle_index = MleIndex::Indexed(curr_index);
1052 new_index += 1;
1053 }
1054 let a_bits = a.index_mle_indices_node(new_index, mle_vec);
1055 let b_bits = b.index_mle_indices_node(new_index, mle_vec);
1056 max(a_bits, b_bits)
1057 }
1058 ExpressionNode::Mle(mle_vec_idx) => {
1059 let mle = mle_vec_idx.get_mle_mut(mle_vec);
1060 mle.index_mle_indices(curr_index)
1061 }
1062 ExpressionNode::Sum(a, b) => {
1063 let a_bits = a.index_mle_indices_node(curr_index, mle_vec);
1064 let b_bits = b.index_mle_indices_node(curr_index, mle_vec);
1065 max(a_bits, b_bits)
1066 }
1067 ExpressionNode::Product(mle_vec_indices) => mle_vec_indices
1068 .iter_mut()
1069 .map(|mle_vec_index| {
1070 let mle = mle_vec_index.get_mle_mut(mle_vec);
1071 mle.index_mle_indices(curr_index)
1072 })
1073 .reduce(max)
1074 .unwrap_or(curr_index),
1075 ExpressionNode::Scaled(a, _) => a.index_mle_indices_node(curr_index, mle_vec),
1076 ExpressionNode::Constant(_) => curr_index,
1077 }
1078 }
1079
1080 pub(crate) fn get_all_rounds(
1083 &self,
1084 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1085 ) -> Vec<usize> {
1086 let degree_per_index = self.get_rounds_helper(mle_vec);
1087 (0..degree_per_index.len())
1088 .filter(|&i| degree_per_index[i] > 0)
1089 .collect()
1090 }
1091
1092 pub fn get_all_nonlinear_rounds(
1094 &self,
1095 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1096 ) -> Vec<usize> {
1097 let degree_per_index = self.get_rounds_helper(mle_vec);
1098 (0..degree_per_index.len())
1099 .filter(|&i| degree_per_index[i] > 1)
1100 .collect()
1101 }
1102
1103 pub fn get_all_linear_rounds(
1105 &self,
1106 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1107 ) -> Vec<usize> {
1108 let degree_per_index = self.get_rounds_helper(mle_vec);
1109 (0..degree_per_index.len())
1110 .filter(|&i| degree_per_index[i] == 1)
1111 .collect()
1112 }
1113
1114 fn get_rounds_helper(&self, mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec) -> Vec<usize> {
1116 let mut degree_per_index = Vec::new();
1118 let max_degree = |degree_per_index: &mut Vec<usize>, index: usize, new_degree: usize| {
1120 if degree_per_index.len() <= index {
1121 degree_per_index.extend(vec![0; index + 1 - degree_per_index.len()]);
1122 }
1123 if degree_per_index[index] < new_degree {
1124 degree_per_index[index] = new_degree;
1125 }
1126 };
1127 let add_degree = |degree_per_index: &mut Vec<usize>, index: usize, new_degree: usize| {
1129 if degree_per_index.len() <= index {
1130 degree_per_index.extend(vec![0; index + 1 - degree_per_index.len()]);
1131 }
1132 degree_per_index[index] += new_degree;
1133 };
1134
1135 match self {
1136 ExpressionNode::Product(mle_vec_indices) => {
1138 let mles = mle_vec_indices
1139 .iter()
1140 .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
1141 .collect_vec();
1142 mles.into_iter().for_each(|mle| {
1143 mle.mle_indices.iter().for_each(|mle_index| {
1144 if let MleIndex::Indexed(i) = mle_index {
1145 add_degree(&mut degree_per_index, *i, 1);
1146 }
1147 })
1148 });
1149 }
1150 ExpressionNode::Mle(mle_vec_idx) => {
1152 let mle = mle_vec_idx.get_mle(mle_vec);
1153 mle.mle_indices.iter().for_each(|mle_index| {
1154 if let MleIndex::Indexed(i) = mle_index {
1155 max_degree(&mut degree_per_index, *i, 1);
1156 }
1157 });
1158 }
1159 ExpressionNode::Selector(sel_index, a, b) => {
1161 if let MleIndex::Indexed(i) = sel_index {
1162 add_degree(&mut degree_per_index, *i, 1);
1163 };
1164 let a_degree_per_index = a.get_rounds_helper(mle_vec);
1165 let b_degree_per_index = b.get_rounds_helper(mle_vec);
1166 for i in 0..max(a_degree_per_index.len(), b_degree_per_index.len()) {
1168 if let Some(a_degree) = a_degree_per_index.get(i) {
1169 max_degree(&mut degree_per_index, i, *a_degree);
1170 }
1171 if let Some(b_degree) = b_degree_per_index.get(i) {
1172 max_degree(&mut degree_per_index, i, *b_degree);
1173 }
1174 }
1175 }
1176 ExpressionNode::Sum(a, b) => {
1178 let a_degree_per_index = a.get_rounds_helper(mle_vec);
1179 let b_degree_per_index = b.get_rounds_helper(mle_vec);
1180 for i in 0..max(a_degree_per_index.len(), b_degree_per_index.len()) {
1182 if let Some(a_degree) = a_degree_per_index.get(i) {
1183 max_degree(&mut degree_per_index, i, *a_degree);
1184 }
1185 if let Some(b_degree) = b_degree_per_index.get(i) {
1186 max_degree(&mut degree_per_index, i, *b_degree);
1187 }
1188 }
1189 }
1190 ExpressionNode::Scaled(a, _) => {
1192 degree_per_index = a.get_rounds_helper(mle_vec);
1193 }
1194 ExpressionNode::Constant(_) => {}
1196 }
1197 degree_per_index
1198 }
1199
1200 pub fn get_expression_num_free_variables_node(
1202 &self,
1203 curr_size: usize,
1204 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1205 ) -> usize {
1206 match self {
1207 ExpressionNode::Selector(mle_index, a, b) => {
1208 let (a_bits, b_bits) = if matches!(mle_index, &MleIndex::Free) {
1209 (
1210 a.get_expression_num_free_variables_node(curr_size + 1, mle_vec),
1211 b.get_expression_num_free_variables_node(curr_size + 1, mle_vec),
1212 )
1213 } else {
1214 (
1215 a.get_expression_num_free_variables_node(curr_size, mle_vec),
1216 b.get_expression_num_free_variables_node(curr_size, mle_vec),
1217 )
1218 };
1219
1220 max(a_bits, b_bits)
1221 }
1222 ExpressionNode::Mle(mle_vec_idx) => {
1223 let mle = mle_vec_idx.get_mle(mle_vec);
1224
1225 mle.mle_indices()
1226 .iter()
1227 .filter(|item| matches!(item, &&MleIndex::Free))
1228 .collect_vec()
1229 .len()
1230 + curr_size
1231 }
1232 ExpressionNode::Sum(a, b) => {
1233 let a_bits = a.get_expression_num_free_variables_node(curr_size, mle_vec);
1234 let b_bits = b.get_expression_num_free_variables_node(curr_size, mle_vec);
1235 max(a_bits, b_bits)
1236 }
1237 ExpressionNode::Product(mle_vec_indices) => {
1238 let mles = mle_vec_indices
1239 .iter()
1240 .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
1241 .collect_vec();
1242
1243 mles.iter()
1244 .map(|mle| {
1245 mle.mle_indices()
1246 .iter()
1247 .filter(|item| matches!(item, &&MleIndex::Free))
1248 .collect_vec()
1249 .len()
1250 })
1251 .max()
1252 .unwrap_or(0)
1253 + curr_size
1254 }
1255 ExpressionNode::Scaled(a, _) => {
1256 a.get_expression_num_free_variables_node(curr_size, mle_vec)
1257 }
1258 ExpressionNode::Constant(_) => curr_size,
1259 }
1260 }
1261
1262 pub fn get_post_sumcheck_layer(
1266 &self,
1267 multiplier: F,
1268 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1269 ) -> PostSumcheckLayer<F, F> {
1270 let mut products: Vec<Product<F, F>> = vec![];
1271 match self {
1272 ExpressionNode::Selector(mle_index, a, b) => {
1273 let left_side_acc = multiplier * (F::ONE - mle_index.val().unwrap());
1274 let right_side_acc = multiplier * (mle_index.val().unwrap());
1275 products.extend(a.get_post_sumcheck_layer(left_side_acc, mle_vec).0);
1276 products.extend(b.get_post_sumcheck_layer(right_side_acc, mle_vec).0);
1277 }
1278 ExpressionNode::Sum(a, b) => {
1279 products.extend(a.get_post_sumcheck_layer(multiplier, mle_vec).0);
1280 products.extend(b.get_post_sumcheck_layer(multiplier, mle_vec).0);
1281 }
1282 ExpressionNode::Mle(mle_vec_idx) => {
1283 let mle = mle_vec_idx.get_mle(mle_vec);
1284 assert!(mle.is_fully_bounded());
1285 products.push(Product::<F, F>::new(std::slice::from_ref(mle), multiplier));
1286 }
1287 ExpressionNode::Product(mle_vec_indices) => {
1288 let mles = mle_vec_indices
1289 .iter()
1290 .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec).clone())
1291 .collect_vec();
1292 let product = Product::<F, F>::new(&mles, multiplier);
1293 products.push(product);
1294 }
1295 ExpressionNode::Scaled(a, scale_factor) => {
1296 let acc = multiplier * scale_factor;
1297 products.extend(a.get_post_sumcheck_layer(acc, mle_vec).0);
1298 }
1299 ExpressionNode::Constant(constant) => {
1300 products.push(Product::<F, F>::new(&[], *constant * multiplier));
1301 }
1302 }
1303 PostSumcheckLayer(products)
1304 }
1305
1306 fn get_max_degree(&self, _mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec) -> usize {
1308 match self {
1309 ExpressionNode::Selector(_, a, b) | ExpressionNode::Sum(a, b) => {
1310 let a_degree = a.get_max_degree(_mle_vec);
1311 let b_degree = b.get_max_degree(_mle_vec);
1312 max(a_degree, b_degree)
1313 }
1314 ExpressionNode::Mle(_) => {
1315 1
1317 }
1318 ExpressionNode::Product(mles) => {
1319 mles.len()
1321 }
1322 ExpressionNode::Scaled(a, _) => a.get_max_degree(_mle_vec),
1323 ExpressionNode::Constant(_) => 1,
1324 }
1325 }
1326}
1327
1328impl<F: Field> Neg for Expression<F, ProverExpr> {
1329 type Output = Expression<F, ProverExpr>;
1330 fn neg(self) -> Self::Output {
1331 Expression::<F, ProverExpr>::negated(self)
1332 }
1333}
1334
1335impl<F: Field> Add for Expression<F, ProverExpr> {
1337 type Output = Expression<F, ProverExpr>;
1338 fn add(self, rhs: Expression<F, ProverExpr>) -> Expression<F, ProverExpr> {
1339 Expression::<F, ProverExpr>::sum(self, rhs)
1340 }
1341}
1342
1343impl<F: Field> Sub for Expression<F, ProverExpr> {
1344 type Output = Expression<F, ProverExpr>;
1345 fn sub(self, rhs: Expression<F, ProverExpr>) -> Expression<F, ProverExpr> {
1346 self.add(rhs.neg())
1347 }
1348}
1349
1350impl<F: Field> Mul<F> for Expression<F, ProverExpr> {
1351 type Output = Expression<F, ProverExpr>;
1352 fn mul(self, rhs: F) -> Self::Output {
1353 Expression::<F, ProverExpr>::scaled(self, rhs)
1354 }
1355}
1356
1357impl<F: std::fmt::Debug + Field> std::fmt::Debug for Expression<F, ProverExpr> {
1359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360 f.debug_struct("Expression")
1361 .field("Expression_Node", &self.expression_node)
1362 .field("MleRef_Vec", &self.mle_vec)
1363 .finish()
1364 }
1365}
1366
1367impl<F: std::fmt::Debug + Field> std::fmt::Debug for ExpressionNode<F, ProverExpr> {
1369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1370 match self {
1371 ExpressionNode::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(),
1372 ExpressionNode::Selector(index, a, b) => f
1373 .debug_tuple("Selector")
1374 .field(index)
1375 .field(a)
1376 .field(b)
1377 .finish(),
1378 ExpressionNode::Mle(_mle) => f.debug_struct("Mle").field("mle", _mle).finish(),
1380 ExpressionNode::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
1381 ExpressionNode::Product(a) => f.debug_tuple("Product").field(a).finish(),
1382 ExpressionNode::Scaled(poly, scalar) => {
1383 f.debug_tuple("Scaled").field(poly).field(scalar).finish()
1384 }
1385 }
1386 }
1387}