remainder/expression/
prover_expr.rs

1//! The prover's view of an [Expression] -- see file-level documentation within
2//! [crate::expression] for more details.
3//!
4//! Conceptually, [ProverExpr] contains the "circuit structure" between the
5//! layer whose values are the output of the given expression and those whose
6//! values are the inputs to the given expression, i.e. the polynomial
7//! relationship between them, as well as (in an ownership sense) the actual
8//! data, stored in [DenseMle]s.
9
10use super::{
11    circuit_expr::evaluate_bookkeeping_tables_given_operation,
12    expr_errors::ExpressionError,
13    generic_expr::{Expression, ExpressionNode, ExpressionType},
14    verifier_expr::VerifierExpr,
15};
16use crate::{
17    layer::product::Product,
18    mle::{betavalues::BetaValues, dense::DenseMle, MleIndex},
19    sumcheck::{
20        apply_updated_beta_values_to_evals, beta_cascade, beta_cascade_no_independent_variable,
21        SumcheckEvals,
22    },
23};
24use crate::{
25    layer::{gate::BinaryOperation, product::PostSumcheckLayer},
26    mle::{verifier_mle::VerifierMle, Mle},
27};
28use itertools::{repeat_n, Itertools};
29use serde::{Deserialize, Serialize};
30use shared_types::Field;
31use std::{
32    cmp::max,
33    collections::HashSet,
34    fmt::Debug,
35    ops::{Add, Mul, Neg, Sub},
36};
37
38use anyhow::{anyhow, Ok, Result};
39
40/// mid-term solution for deduplication of DenseMleRefs
41/// basically a wrapper around usize, which denotes the index
42/// of the MleRef in an expression's MleRef list/// Generic Expressions
43///
44/// TODO(ryancao): We should deprecate this and instead just have
45/// references to the `DenseMLE<F>`s which are stored in the circuit_map.
46#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
47pub struct MleVecIndex(usize);
48
49impl MleVecIndex {
50    /// create a new MleRefIndex
51    pub fn new(index: usize) -> Self {
52        MleVecIndex(index)
53    }
54
55    /// returns the index
56    pub fn index(&self) -> usize {
57        self.0
58    }
59
60    /// add the index with an increment amount
61    pub fn increment(&mut self, offset: usize) {
62        self.0 += offset;
63    }
64
65    /// return the actual mle in the vec within the prover expression
66    pub fn get_mle<'a, F: Field>(&self, mle_vec: &'a [DenseMle<F>]) -> &'a DenseMle<F> {
67        &mle_vec[self.0]
68    }
69
70    /// return the actual mle in the vec within the prover expression
71    pub fn get_mle_mut<'a, F: Field>(&self, mle_vec: &'a mut [DenseMle<F>]) -> &'a mut DenseMle<F> {
72        &mut mle_vec[self.0]
73    }
74}
75
76/// Prover Expression
77/// the leaf nodes of the expression tree are DenseMleRefs
78#[derive(Serialize, Deserialize, Clone, Debug)]
79pub struct ProverExpr;
80impl<F: Field> ExpressionType<F> for ProverExpr {
81    type MLENodeRepr = MleVecIndex;
82    type MleVec = Vec<DenseMle<F>>;
83}
84
85/// this is what the prover manipulates to prove the correctness of the computation.
86/// Methods here include ones to fix bits, evaluate sumcheck messages, etc.
87impl<F: Field> Expression<F, ProverExpr> {
88    /// See documentation in [super::circuit_expr::ExprDescription]'s `select()`
89    /// function for more details!
90    pub fn select(self, mut rhs: Expression<F, ProverExpr>) -> Self {
91        let offset = self.num_mle();
92        rhs.increment_mle_vec_indices(offset);
93        let (lhs_node, lhs_mle_vec) = self.deconstruct();
94        let (rhs_node, rhs_mle_vec) = rhs.deconstruct();
95
96        let concat_node =
97            ExpressionNode::Selector(MleIndex::Free, Box::new(lhs_node), Box::new(rhs_node));
98
99        let concat_mle_vec = lhs_mle_vec.into_iter().chain(rhs_mle_vec).collect_vec();
100
101        Expression::new(concat_node, concat_mle_vec)
102    }
103
104    /// Create a product Expression that raises one MLE to a given power
105    pub fn pow(pow: usize, mle: DenseMle<F>) -> Self {
106        let mle_vec_indices = (0..pow).map(|_index| MleVecIndex::new(0)).collect_vec();
107
108        let product_node = ExpressionNode::Product(mle_vec_indices);
109
110        Expression::new(product_node, vec![mle])
111    }
112
113    /// Create a product Expression that multiplies many MLEs together
114    pub fn products(product_list: <ProverExpr as ExpressionType<F>>::MleVec) -> Self {
115        let mle_vec_indices = (0..product_list.len()).map(MleVecIndex::new).collect_vec();
116
117        let product_node = ExpressionNode::Product(mle_vec_indices);
118
119        Expression::new(product_node, product_list)
120    }
121
122    /// Create a mle Expression that contains one MLE
123    pub fn mle(mle: DenseMle<F>) -> Self {
124        let mle_node = ExpressionNode::Mle(MleVecIndex::new(0));
125
126        Expression::new(mle_node, [mle].to_vec())
127    }
128
129    /// Create a constant Expression that contains one field element
130    pub fn constant(constant: F) -> Self {
131        let mle_node = ExpressionNode::Constant(constant);
132
133        Expression::new(mle_node, [].to_vec())
134    }
135
136    /// negates an Expression
137    pub fn negated(expression: Self) -> Self {
138        let (node, mle_vec) = expression.deconstruct();
139
140        let mle_node = ExpressionNode::Scaled(Box::new(node), F::from(1).neg());
141
142        Expression::new(mle_node, mle_vec)
143    }
144
145    /// Create a Sum Expression that contains two MLEs
146    pub fn sum(lhs: Self, mut rhs: Self) -> Self {
147        let offset = lhs.num_mle();
148        rhs.increment_mle_vec_indices(offset);
149
150        let (lhs_node, lhs_mle_vec) = lhs.deconstruct();
151        let (rhs_node, rhs_mle_vec) = rhs.deconstruct();
152
153        let sum_node = ExpressionNode::Sum(Box::new(lhs_node), Box::new(rhs_node));
154        let sum_mle_vec = lhs_mle_vec.into_iter().chain(rhs_mle_vec).collect_vec();
155
156        Expression::new(sum_node, sum_mle_vec)
157    }
158
159    /// scales an Expression by a field element
160    pub fn scaled(expression: Expression<F, ProverExpr>, scale: F) -> Self {
161        let (node, mle_vec) = expression.deconstruct();
162
163        Expression::new(ExpressionNode::Scaled(Box::new(node), scale), mle_vec)
164    }
165
166    /// returns the number of MleRefs in the expression
167    pub fn num_mle(&self) -> usize {
168        self.mle_vec.len()
169    }
170
171    /// which increments all the MleVecIndex in the expression by *param* amount
172    pub fn increment_mle_vec_indices(&mut self, offset: usize) {
173        // define a closure that increments the MleVecIndex by the given amount
174        // use traverse_mut
175        let mut increment_closure = |expr: &mut ExpressionNode<F, ProverExpr>,
176                                     _mle_vec: &mut Vec<DenseMle<F>>|
177         -> Result<()> {
178            match expr {
179                ExpressionNode::Mle(mle_vec_index) => {
180                    mle_vec_index.increment(offset);
181                    Ok(())
182                }
183                ExpressionNode::Product(mle_indices) => {
184                    for mle_vec_index in mle_indices {
185                        mle_vec_index.increment(offset);
186                    }
187                    Ok(())
188                }
189                ExpressionNode::Constant(_)
190                | ExpressionNode::Scaled(_, _)
191                | ExpressionNode::Sum(_, _)
192                | ExpressionNode::Selector(_, _, _) => Ok(()),
193            }
194        };
195
196        self.traverse_mut(&mut increment_closure).unwrap();
197    }
198
199    /// Transforms the prover expression to a verifier expression.
200    ///
201    /// Should only be called when the entire expression is fully bound.
202    ///
203    /// Traverses the expression and changes the DenseMle to VerifierMle,
204    /// by grabbing their bookkeeping table's 1st and only element.
205    ///
206    /// If the bookkeeping table has more than 1 element, it
207    /// throws an ExpressionError::EvaluateNotFullyBoundError
208    pub fn transform_to_verifier_expression(self) -> Result<Expression<F, VerifierExpr>> {
209        let (mut expression_node, mle_vec) = self.deconstruct();
210        Ok(Expression::new(
211            expression_node
212                .transform_to_verifier_expression_node(&mle_vec)
213                .unwrap(),
214            (),
215        ))
216    }
217
218    /// fix the variable at a certain round index, always MSB index
219    pub fn fix_variable(&mut self, round_index: usize, challenge: F) {
220        let (expression_node, mle_vec) = self.deconstruct_mut();
221
222        expression_node.fix_variable_node(round_index, challenge, mle_vec)
223    }
224
225    /// fix the variable at a certain round index, arbitrary index
226    pub fn fix_variable_at_index(&mut self, round_index: usize, challenge: F) {
227        let (expression_node, mle_vec) = self.deconstruct_mut();
228
229        expression_node.fix_variable_at_index_node(round_index, challenge, mle_vec)
230    }
231
232    /// evaluates an expression on the given challenges points, by fixing the variables
233    pub fn evaluate_expr(&mut self, challenges: Vec<F>) -> Result<F> {
234        // It's as simple as fixing all variables
235        challenges
236            .iter()
237            .enumerate()
238            .for_each(|(round_idx, &challenge)| {
239                self.fix_variable(round_idx, challenge);
240            });
241
242        // ----- this is literally a check -----
243        let mut observer_fn = |exp: &ExpressionNode<F, ProverExpr>,
244                               mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec|
245         -> Result<()> {
246            match exp {
247                ExpressionNode::Mle(mle_vec_idx) => {
248                    let mle = mle_vec_idx.get_mle(mle_vec);
249                    let indices = mle
250                        .mle_indices()
251                        .iter()
252                        .filter_map(|index| match index {
253                            MleIndex::Bound(chal, index) => Some((*chal, index)),
254                            _ => None,
255                        })
256                        .collect_vec();
257
258                    let start = *indices[0].1;
259                    let end = *indices[indices.len() - 1].1;
260
261                    let (indices, _): (Vec<_>, Vec<usize>) = indices.into_iter().unzip();
262
263                    if indices.as_slice() == &challenges[start..=end] {
264                        Ok(())
265                    } else {
266                        Err(anyhow!(ExpressionError::EvaluateBoundIndicesDontMatch))
267                    }
268                }
269                ExpressionNode::Product(mle_vec_indices) => {
270                    let mles = mle_vec_indices
271                        .iter()
272                        .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
273                        .collect_vec();
274
275                    mles.iter()
276                        .map(|mle| {
277                            let indices = mle
278                                .mle_indices()
279                                .iter()
280                                .filter_map(|index| match index {
281                                    MleIndex::Bound(chal, index) => Some((*chal, index)),
282                                    _ => None,
283                                })
284                                .collect_vec();
285
286                            let start = *indices[0].1;
287                            let end = *indices[indices.len() - 1].1;
288
289                            let (indices, _): (Vec<_>, Vec<usize>) = indices.into_iter().unzip();
290
291                            if indices.as_slice() == &challenges[start..=end] {
292                                Ok(())
293                            } else {
294                                Err(anyhow!(ExpressionError::EvaluateBoundIndicesDontMatch))
295                            }
296                        })
297                        .try_collect()
298                }
299
300                _ => Ok(()),
301            }
302        };
303        self.traverse(&mut observer_fn)?;
304
305        // Traverse the expression and pick up all the evals
306        self.clone()
307            .transform_to_verifier_expression()
308            .unwrap()
309            .evaluate()
310    }
311
312    #[allow(clippy::too_many_arguments)]
313    /// This evaluates a sumcheck message using the beta cascade algorithm by calling it on the root
314    /// node of the expression tree. This assumes that there is an independent variable in the
315    /// expression, which is the `round_index`.
316    pub fn evaluate_sumcheck_beta_cascade(
317        &self,
318        beta: &[&BetaValues<F>],
319        random_coefficients: &[F],
320        round_index: usize,
321        degree: usize,
322    ) -> SumcheckEvals<F> {
323        self.expression_node.evaluate_sumcheck_node_beta_cascade(
324            beta,
325            &self.mle_vec,
326            random_coefficients,
327            round_index,
328            degree,
329        )
330    }
331
332    /// This evaluates a sumcheck message using the beta cascade algorithm, taking the sum
333    /// of the expression over all the variables `round_index` and after. For the variables
334    /// before, we compute the fully bound beta equality MLE and scale the rest of the sum
335    /// by this value.
336    pub fn evaluate_sumcheck_node_beta_cascade_sum(
337        &self,
338        beta_values: &BetaValues<F>,
339        round_index: usize,
340        degree: usize,
341    ) -> SumcheckEvals<F> {
342        self.expression_node
343            .evaluate_sumcheck_node_beta_cascade_sum(
344                beta_values,
345                round_index,
346                degree,
347                &self.mle_vec,
348            )
349    }
350
351    /// Traverses the expression tree to return all indices within the
352    /// expression. Can only be used after indexing the expression.
353    pub fn get_all_rounds(&self) -> Vec<usize> {
354        let (expression_node, mle_vec) = self.deconstruct_ref();
355        let mut all_rounds = expression_node.get_all_rounds(mle_vec);
356        all_rounds.sort();
357        all_rounds
358    }
359
360    /// this traverses the expression tree to get all of the nonlinear rounds. can only be used after indexing the expression.
361    /// returns the indices sorted.
362    pub fn get_all_nonlinear_rounds(&self) -> Vec<usize> {
363        let (expression_node, mle_vec) = self.deconstruct_ref();
364        let mut nonlinear_rounds = expression_node.get_all_nonlinear_rounds(mle_vec);
365        nonlinear_rounds.sort();
366        nonlinear_rounds
367    }
368
369    /// this traverses the expression tree to get all of the linear rounds. can only be used after indexing the expression.
370    /// returns the indices sorted.
371    pub fn get_all_linear_rounds(&self) -> Vec<usize> {
372        let (expression_node, mle_vec) = self.deconstruct_ref();
373        let mut linear_rounds = expression_node.get_all_linear_rounds(mle_vec);
374        linear_rounds.sort();
375        linear_rounds
376    }
377
378    /// Mutate the MLE indices that are [MleIndex::Free] in the expression and
379    /// turn them into [MleIndex::Indexed]. Returns the max number of bits
380    /// that are indexed.
381    pub fn index_mle_indices(&mut self, curr_index: usize) -> usize {
382        let (expression_node, mle_vec) = self.deconstruct_mut();
383        expression_node.index_mle_indices_node(curr_index, mle_vec)
384    }
385
386    /// Gets the number of free variables in an expression.
387    pub fn get_expression_num_free_variables(&self) -> usize {
388        self.expression_node
389            .get_expression_num_free_variables_node(0, &self.mle_vec)
390    }
391
392    /// Get the [PostSumcheckLayer] for this expression, which represents the fully bound values of the expression.
393    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
394    pub fn get_post_sumcheck_layer(&self, multiplier: F) -> PostSumcheckLayer<F, F> {
395        self.expression_node
396            .get_post_sumcheck_layer(multiplier, &self.mle_vec)
397    }
398
399    /// Get the maximum degree of any variable in htis expression
400    pub fn get_max_degree(&self) -> usize {
401        self.expression_node.get_max_degree(&self.mle_vec)
402    }
403}
404
405impl<F: Field> ExpressionNode<F, ProverExpr> {
406    /// Transforms the expression to a verifier expression
407    /// should only be called when no variables are bound in the expression.
408    /// Traverses the expression and changes the DenseMle to MleDescription.
409    pub fn transform_to_verifier_expression_node(
410        &mut self,
411        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
412    ) -> Result<ExpressionNode<F, VerifierExpr>> {
413        match self {
414            ExpressionNode::Constant(scalar) => Ok(ExpressionNode::Constant(*scalar)),
415            ExpressionNode::Selector(index, a, b) => Ok(ExpressionNode::Selector(
416                index.clone(),
417                Box::new(a.transform_to_verifier_expression_node(mle_vec)?),
418                Box::new(b.transform_to_verifier_expression_node(mle_vec)?),
419            )),
420            ExpressionNode::Mle(mle_vec_idx) => {
421                let mle = mle_vec_idx.get_mle(mle_vec);
422
423                if !mle.is_fully_bounded() {
424                    return Err(anyhow!(ExpressionError::EvaluateNotFullyBoundError));
425                }
426
427                let layer_id = mle.layer_id();
428                let mle_indices = mle.mle_indices().to_vec();
429                let eval = mle.value();
430
431                Ok(ExpressionNode::Mle(VerifierMle::new(
432                    layer_id,
433                    mle_indices,
434                    eval,
435                )))
436            }
437            ExpressionNode::Sum(a, b) => Ok(ExpressionNode::Sum(
438                Box::new(a.transform_to_verifier_expression_node(mle_vec)?),
439                Box::new(b.transform_to_verifier_expression_node(mle_vec)?),
440            )),
441            ExpressionNode::Product(mle_vec_indices) => {
442                let mles = mle_vec_indices
443                    .iter_mut()
444                    .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
445                    .collect_vec();
446
447                for mle in mles.iter() {
448                    if !mle.is_fully_bounded() {
449                        return Err(anyhow!(ExpressionError::EvaluateNotFullyBoundError));
450                    }
451                }
452
453                Ok(ExpressionNode::Product(
454                    mles.into_iter()
455                        .map(|mle| {
456                            VerifierMle::new(
457                                mle.layer_id(),
458                                mle.mle_indices().to_vec(),
459                                mle.value(),
460                            )
461                        })
462                        .collect_vec(),
463                ))
464            }
465            ExpressionNode::Scaled(mle, scalar) => Ok(ExpressionNode::Scaled(
466                Box::new(mle.transform_to_verifier_expression_node(mle_vec)?),
467                *scalar,
468            )),
469        }
470    }
471
472    /// fix the variable at a certain round index, always the most significant index.
473    pub fn fix_variable_node(
474        &mut self,
475        round_index: usize,
476        challenge: F,
477        mle_vec: &mut <ProverExpr as ExpressionType<F>>::MleVec, // remove all other cases other than selector, call mle.fix_variable on all mle_vec contents
478    ) {
479        match self {
480            ExpressionNode::Selector(index, a, b) => {
481                if *index == MleIndex::Indexed(round_index) {
482                    index.bind_index(challenge);
483                } else {
484                    a.fix_variable_node(round_index, challenge, mle_vec);
485                    b.fix_variable_node(round_index, challenge, mle_vec);
486                }
487            }
488            ExpressionNode::Mle(mle_vec_idx) => {
489                let mle = mle_vec_idx.get_mle_mut(mle_vec);
490
491                if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
492                    mle.fix_variable(round_index, challenge);
493                }
494            }
495            ExpressionNode::Sum(a, b) => {
496                a.fix_variable_node(round_index, challenge, mle_vec);
497                b.fix_variable_node(round_index, challenge, mle_vec);
498            }
499            ExpressionNode::Product(mle_vec_indices) => {
500                mle_vec_indices
501                    .iter_mut()
502                    .map(|mle_vec_index| {
503                        let mle = mle_vec_index.get_mle_mut(mle_vec);
504
505                        if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
506                            mle.fix_variable(round_index, challenge);
507                        }
508                    })
509                    .collect_vec();
510            }
511            ExpressionNode::Scaled(a, _) => {
512                a.fix_variable_node(round_index, challenge, mle_vec);
513            }
514            ExpressionNode::Constant(_) => (),
515        }
516    }
517
518    /// fix the variable at a certain round index, can be arbitrary indices.
519    pub fn fix_variable_at_index_node(
520        &mut self,
521        round_index: usize,
522        challenge: F,
523        mle_vec: &mut <ProverExpr as ExpressionType<F>>::MleVec, // remove all other cases other than selector, call mle.fix_variable on all mle_vec contents
524    ) {
525        match self {
526            ExpressionNode::Selector(index, a, b) => {
527                if *index == MleIndex::Indexed(round_index) {
528                    index.bind_index(challenge);
529                } else {
530                    a.fix_variable_at_index_node(round_index, challenge, mle_vec);
531                    b.fix_variable_at_index_node(round_index, challenge, mle_vec);
532                }
533            }
534            ExpressionNode::Mle(mle_vec_idx) => {
535                let mle = mle_vec_idx.get_mle_mut(mle_vec);
536
537                if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
538                    mle.fix_variable_at_index(round_index, challenge);
539                }
540            }
541            ExpressionNode::Sum(a, b) => {
542                a.fix_variable_at_index_node(round_index, challenge, mle_vec);
543                b.fix_variable_at_index_node(round_index, challenge, mle_vec);
544            }
545            ExpressionNode::Product(mle_vec_indices) => {
546                mle_vec_indices
547                    .iter_mut()
548                    .map(|mle_vec_index| {
549                        let mle = mle_vec_index.get_mle_mut(mle_vec);
550
551                        if mle.mle_indices().contains(&MleIndex::Indexed(round_index)) {
552                            mle.fix_variable_at_index(round_index, challenge);
553                        }
554                    })
555                    .collect_vec();
556            }
557            ExpressionNode::Scaled(a, _) => {
558                a.fix_variable_at_index_node(round_index, challenge, mle_vec);
559            }
560            ExpressionNode::Constant(_) => (),
561        }
562    }
563
564    pub fn evaluate_sumcheck_node_beta_cascade_sum(
565        &self,
566        beta_values: &BetaValues<F>,
567        round_index: usize,
568        degree: usize,
569        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
570    ) -> SumcheckEvals<F> {
571        match self {
572            ExpressionNode::Constant(constant) => {
573                SumcheckEvals(repeat_n(*constant, degree + 1).collect())
574            }
575            ExpressionNode::Selector(selector_mle_index, lhs, rhs) => {
576                let lhs_eval = lhs.evaluate_sumcheck_node_beta_cascade_sum(
577                    beta_values,
578                    round_index,
579                    degree,
580                    mle_vec,
581                );
582                let rhs_eval = rhs.evaluate_sumcheck_node_beta_cascade_sum(
583                    beta_values,
584                    round_index,
585                    degree,
586                    mle_vec,
587                );
588                match selector_mle_index {
589                    MleIndex::Indexed(var_number) => {
590                        let index_claim = beta_values.get_unbound_value(*var_number).unwrap();
591                        (lhs_eval * (F::ONE - index_claim)) + (rhs_eval * index_claim)
592                    }
593                    MleIndex::Bound(bound_value, var_number) => {
594                        let identity = F::ONE;
595                        let beta_bound = beta_values
596                            .get_updated_value(*var_number)
597                            .unwrap_or(identity);
598                        ((lhs_eval * (F::ONE - bound_value)) + (rhs_eval * bound_value))
599                            * beta_bound
600                    }
601                    _ => panic!("Invalid MLE Index for a selector bit, should be free or indexed"),
602                }
603            }
604            ExpressionNode::Mle(mle_idx) => {
605                let mle = mle_idx.get_mle(mle_vec);
606                let (unbound, bound) = beta_values.get_relevant_beta_unbound_and_bound(
607                    mle.mle_indices(),
608                    round_index,
609                    false,
610                );
611                beta_cascade_no_independent_variable(mle.mle.to_vec(), &unbound, &bound, degree)
612            }
613            ExpressionNode::Sum(lhs, rhs) => {
614                let lhs_eval = lhs.evaluate_sumcheck_node_beta_cascade_sum(
615                    beta_values,
616                    round_index,
617                    degree,
618                    mle_vec,
619                );
620                let rhs_eval = rhs.evaluate_sumcheck_node_beta_cascade_sum(
621                    beta_values,
622                    round_index,
623                    degree,
624                    mle_vec,
625                );
626                lhs_eval + rhs_eval
627            }
628            ExpressionNode::Product(mle_idx_vec) => {
629                let (mles, mles_bookkeeping_tables): (Vec<&DenseMle<F>>, Vec<Vec<F>>) = mle_idx_vec
630                    .iter()
631                    .map(|mle_vec_index| {
632                        let mle = mle_vec_index.get_mle(mle_vec);
633                        (mle, mle.mle.to_vec())
634                    })
635                    .unzip();
636
637                let mut unique_mle_indices = HashSet::new();
638
639                let mle_indices_vec = mles
640                    .iter()
641                    .flat_map(|mle| mle.mle_indices.clone())
642                    .filter(move |mle_index| unique_mle_indices.insert(mle_index.clone()))
643                    .collect_vec();
644
645                let (unbound, bound) = beta_values.get_relevant_beta_unbound_and_bound(
646                    &mle_indices_vec,
647                    round_index,
648                    false,
649                );
650                let evaluated_bookkeeping_tables = evaluate_bookkeeping_tables_given_operation(
651                    &mles_bookkeeping_tables,
652                    BinaryOperation::Mul,
653                );
654                beta_cascade_no_independent_variable(
655                    evaluated_bookkeeping_tables.to_vec(),
656                    &unbound,
657                    &bound,
658                    degree,
659                )
660            }
661            ExpressionNode::Scaled(expression_node, scale) => {
662                expression_node.evaluate_sumcheck_node_beta_cascade_sum(
663                    beta_values,
664                    round_index,
665                    degree,
666                    mle_vec,
667                ) * scale
668            }
669        }
670    }
671
672    /// This is the function to compute a single-round sumcheck message using the
673    /// beta cascade algorithm.
674    ///
675    /// # Arguments
676    ///
677    /// * `expr`: the Expression `P` defining a GKR layer. The caller is expected to
678    ///   have already fixed the variables of previous rounds.
679    /// * `round_index`: the MLE index corresponding to the variable that is going
680    ///   to be the independent variable for this round. The caller is expected to
681    ///   have already fixed variables `1 .. (round_index - 1)` in expression `P` to
682    ///   the verifier's challanges.
683    /// * `max_degree`: the degree of the polynomial to be exchanged in this round's
684    ///   sumcheck message.
685    /// * `beta_value`: the `beta` function associated with expression `exp`.  It is
686    ///   the caller's responsibility to keep this consistent with `expr`
687    ///   before/after each call.
688    ///
689    /// In particular, if `round_index == k`, and the current GKR layer expression
690    /// was originally on `n` variables, `expr` is expected to represent a
691    /// polynomial expression on `n - k + 1` variables: `P(r_1, r_2, ..., r_{k-1},
692    /// x_k, x_{k+1}, ..., x_n): F^{n - k + 1} -> F`, with the first `k - 1` free
693    /// variables already fixed to random challenges `r_1, ..., r_{k-1}`. Similarly,
694    /// `beta_values` should represent the polynomial: `\beta(r_1, ..., r_{k-1},
695    /// b_k, ..., b_n, g_1, ..., g_n)` whose unbound variables are `b_k, ..., b_n`.
696    ///
697    /// # Returns
698    ///
699    /// If successful, this functions returns a representation of the univariate
700    /// polynomial:
701    /// ```text
702    ///     g_{round_index}(x) =
703    ///         \sum_{b_{k+1} \in {0, 1}}
704    ///         \sum_{b_{k+2} \in {0, 1}}
705    ///             ...
706    ///         \sum_{b_{n} \in {0, 1}}
707    ///             \beta(r_1, ..., r_k, x, b_{k+1}, ..., b_{n}, g_1, ..., g_n)
708    ///                 * P(r_1, ..., r_k, x, b_{k+1}, ..., b_n)
709    /// ```
710    ///
711    /// 1. This function should be responsible for mutating `expr` and `beta_values`
712    ///    by fixing variables (if any) *after* the sumcheck round. It should
713    ///    maintain the invariant that `expr` and `beta_values` are consistent with
714    ///    each other!
715    /// 2. `max_degree` should NOT be the caller's responsibility to compute. The
716    ///    degree should be determined through `expr` and `round_index`.  It is
717    ///    error-prone to allow for sumcheck message to go through with an arbitrary
718    ///    degree.
719    ///
720    /// # Beta cascade
721    ///
722    /// Instead of using a beta table to linearize an expression, we utilize the
723    /// fact that for each specific node in an expression tree, we only need exactly
724    /// the beta values corresponding to the indices present in that node.
725    #[allow(clippy::too_many_arguments)]
726    pub fn evaluate_sumcheck_node_beta_cascade(
727        &self,
728        beta_vec: &[&BetaValues<F>],
729        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
730        random_coefficients: &[F],
731        round_index: usize,
732        degree: usize,
733    ) -> SumcheckEvals<F> {
734        match self {
735            // Each different type of expression node (constant, selector, product, sum,
736            // neg, scaled, mle) is treated differently, so we create closures for each
737            // which are then evaluated by the `evaluate_sumcheck_beta_cascade`
738            // function.
739
740            // A constant does not have any variables, so we do not need a beta table at
741            // all. Therefore we just repeat the constant evaluation for the `degree +
742            // 1` number of times as this is how many evaluations we need.
743            ExpressionNode::Constant(constant) => {
744                let sumcheck_eval_not_scaled_by_constant = beta_vec
745                    .iter()
746                    .zip(random_coefficients)
747                    .map(|(beta_table, random_coeff)| {
748                        let folded_updated_vals = beta_table.fold_updated_values();
749                        let index_claim = beta_table.get_unbound_value(round_index).unwrap();
750                        let one_minus_index_claim = F::ONE - index_claim;
751                        let beta_step = index_claim - one_minus_index_claim;
752                        let evals =
753                            std::iter::successors(Some(one_minus_index_claim), move |item| {
754                                Some(*item + beta_step)
755                            })
756                            .take(degree + 1)
757                            .collect_vec();
758                        apply_updated_beta_values_to_evals(evals, folded_updated_vals)
759                            * random_coeff
760                    })
761                    .reduce(|acc, elem| acc + elem)
762                    .unwrap();
763                sumcheck_eval_not_scaled_by_constant * constant
764            }
765
766            // the selector is split into three cases:
767            // - when the selector bit itself is not the independent variable and hasn't
768            //   been bound yet,
769            // - when the selector bit is the independent variable
770            // - when the selector bit has already been bound we determine which case we
771            // are in by comparing the round_index to the selector index which is an
772            // argument to the closure.
773            ExpressionNode::Selector(index, a, b) => {
774                match index {
775                    MleIndex::Indexed(indexed_bit) => {
776                        let (lhs_evals, rhs_evals) = (
777                            beta_vec
778                                .iter()
779                                .map(|beta| {
780                                    a.evaluate_sumcheck_node_beta_cascade_sum(
781                                        beta,
782                                        round_index,
783                                        degree,
784                                        mle_vec,
785                                    )
786                                })
787                                .collect_vec(),
788                            beta_vec
789                                .iter()
790                                .map(|beta| {
791                                    b.evaluate_sumcheck_node_beta_cascade_sum(
792                                        beta,
793                                        round_index,
794                                        degree,
795                                        mle_vec,
796                                    )
797                                })
798                                .collect_vec(),
799                        );
800                        // because the selector bit itself only has one variable (1 - b_i) *
801                        // (a) + b_i * b we only need one value within the beta table in
802                        // order to evaluate the selector at this point.
803                        match Ord::cmp(&round_index, indexed_bit) {
804                            std::cmp::Ordering::Less => {
805                                let sumcheck_eval = beta_vec
806                                    .iter()
807                                    .zip((lhs_evals.iter().zip(rhs_evals.iter())).zip(random_coefficients))
808                                    .map(|(beta_table, ((a, b), random_coeff))| {
809                                        let index_claim = beta_table.get_unbound_value(*indexed_bit).unwrap();
810                                        let a_eval: &SumcheckEvals<F> = a;
811                                        let b_eval: &SumcheckEvals<F> = b;
812                                        // when the selector bit is not the independent variable and
813                                        // has not been bound yet, we are simply summing over
814                                        // everything. in order to take the beta values into account
815                                        // this means for everything on the "left" side of the
816                                        // selector we want to multiply by (1 - g_i) and for
817                                        // everything on the "right" side of the selector we want to
818                                        // multiply by g_i. we can then add these!
819                                        let a_with_sel: SumcheckEvals<F> =
820                                            a_eval.clone() * (F::ONE - index_claim);
821                                        let b_with_sel: SumcheckEvals<F> = b_eval.clone() * index_claim;
822                                        (a_with_sel + b_with_sel) * random_coeff
823                                    })
824                                    .reduce(|acc, elem| acc + elem)
825                                    .unwrap();
826                                sumcheck_eval
827                            }
828                            std::cmp::Ordering::Equal => {
829                                // this is when the selector index is the independent
830                                // variable! this means the beta value at this index also
831                                // has an independent variable.
832                                let sumcheck_eval = beta_vec
833                                        .iter()
834                                        .zip((lhs_evals.iter().zip(rhs_evals)).zip(random_coefficients))
835                                        .map(|(beta_table, ((a, b), random_coeff))| {
836                                            let SumcheckEvals(first_evals) = a;
837                                            let SumcheckEvals(second_evals) = b;
838                                            if first_evals.len() == second_evals.len() {
839                                                let bound_beta_values = beta_table.fold_updated_values();
840                                                let index_claim =
841                                                    beta_table.get_unbound_value(*indexed_bit).unwrap();
842                                                // therefore we compute the successors of the beta
843                                                // values as well, as the successors correspond to
844                                                // evaluations at the points 0, 1, ... for the
845                                                // independent variable.
846                                                let eval_len = first_evals.len();
847                                                let one_minus_index_claim = F::ONE - index_claim;
848                                                let beta_step = index_claim - one_minus_index_claim;
849                                                let beta_evals = std::iter::successors(
850                                                    Some(one_minus_index_claim),
851                                                    move |item| Some(*item + beta_step),
852                                                )
853                                                .take(eval_len)
854                                                .collect_vec();
855                                                // the selector index also has an independent variable
856                                                // so we factor this as well as the corresponding beta
857                                                // successor at this index.
858                                                let first_evals = SumcheckEvals(
859                                                    first_evals
860                                                        .clone()
861                                                        .into_iter()
862                                                        .enumerate()
863                                                        .map(|(idx, first_eval)| {
864                                                            first_eval
865                                                                * (F::ONE - F::from(idx as u64))
866                                                                * beta_evals[idx]
867                                                        })
868                                                        .collect(),
869                                                );
870                                                let second_evals = SumcheckEvals(
871                                                    second_evals
872                                                        .clone()
873                                                        .into_iter()
874                                                        .enumerate()
875                                                        .map(|(idx, second_eval)| {
876                                                            second_eval
877                                                            * F::from(idx as u64) * beta_evals[idx]
878                                                        })
879                                                        .collect(),
880                                                );
881                                                (first_evals + second_evals) * random_coeff * bound_beta_values
882                                            } else {
883                                                panic!("Expression returns two evals that do not have the same length on a selector bit")
884                                            }
885                                        })
886                                        .reduce(|acc, elem| acc + elem)
887                                        .unwrap();
888                                sumcheck_eval
889                            }
890                            // we cannot have an indexed bit for the selector bit that is
891                            // less than the current sumcheck round. therefore this is an
892                            // error
893                            std::cmp::Ordering::Greater => panic!(
894                                "Invalid selector index, cannot be less than the current round index"
895                            ),
896                        }
897                    }
898                    // if the selector bit has already been bound, that means the beta value
899                    // at this index has also already been bound, if it exists! otherwise we
900                    // just treat it as the identity
901                    MleIndex::Bound(coeff, _) => {
902                        let (lhs_evals, rhs_evals) = (
903                            beta_vec
904                                .iter()
905                                .map(|beta| {
906                                    a.evaluate_sumcheck_node_beta_cascade(
907                                        &[*beta],
908                                        mle_vec,
909                                        &[F::ONE],
910                                        round_index,
911                                        degree,
912                                    )
913                                })
914                                .collect_vec(),
915                            beta_vec
916                                .iter()
917                                .map(|beta| {
918                                    b.evaluate_sumcheck_node_beta_cascade(
919                                        &[*beta],
920                                        mle_vec,
921                                        &[F::ONE],
922                                        round_index,
923                                        degree,
924                                    )
925                                })
926                                .collect_vec(),
927                        );
928                        let coeff_neg = F::ONE - coeff;
929                        (lhs_evals.iter().zip(rhs_evals))
930                            .zip(random_coefficients)
931                            .map(|((a, b), random_coeff)| {
932                                let a_eval = a;
933                                let b_eval = b;
934                                ((b_eval.clone() * coeff) + (a_eval.clone() * coeff_neg))
935                                    * random_coeff
936                            })
937                            .reduce(|acc, elem| acc + elem)
938                            .unwrap()
939                    }
940                    _ => panic!("selector index should not be a Free or Fixed bit"),
941                }
942            }
943            // the mle evaluation takes in the mle ref, and the corresponding unbound
944            // and bound beta values to pass into the `beta_cascade` function
945            ExpressionNode::Mle(mle_vec_idx) => {
946                let mle = mle_vec_idx.get_mle(mle_vec);
947                let (unbound_beta_vec, bound_beta_vec): (Vec<Vec<F>>, Vec<Vec<F>>) = beta_vec
948                    .iter()
949                    .map(|beta| {
950                        beta.get_relevant_beta_unbound_and_bound(
951                            mle.mle_indices(),
952                            round_index,
953                            true,
954                        )
955                    })
956                    .unzip();
957
958                beta_cascade(
959                    &[&mle.clone()],
960                    degree,
961                    round_index,
962                    &unbound_beta_vec,
963                    &bound_beta_vec,
964                    random_coefficients,
965                )
966            }
967            // when we have a sum, we can evaluate both parts of the expression
968            // separately and just add the evaluations
969            ExpressionNode::Sum(a, b) => {
970                let a = a.evaluate_sumcheck_node_beta_cascade(
971                    beta_vec,
972                    mle_vec,
973                    random_coefficients,
974                    round_index,
975                    degree,
976                );
977                let b = b.evaluate_sumcheck_node_beta_cascade(
978                    beta_vec,
979                    mle_vec,
980                    random_coefficients,
981                    round_index,
982                    degree,
983                );
984                a + b
985            }
986            // when we have a product, the node can only contain mle refs. therefore
987            // this is similar to the mle evaluation, but instead we have a list of mle
988            // refs, and the corresponding unbound and bound  beta values for that node.
989            ExpressionNode::Product(mle_vec_indices) => {
990                let mles = mle_vec_indices
991                    .iter()
992                    .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
993                    .collect_vec();
994
995                let mut unique_mle_indices = HashSet::new();
996
997                let mle_indices_vec = mles
998                    .iter()
999                    .flat_map(|mle| mle.mle_indices.clone())
1000                    .filter(move |mle_index| unique_mle_indices.insert(mle_index.clone()))
1001                    .collect_vec();
1002
1003                let (unbound_beta_vec, bound_beta_vec): (Vec<Vec<F>>, Vec<Vec<F>>) = beta_vec
1004                    .iter()
1005                    .map(|beta| {
1006                        beta.get_relevant_beta_unbound_and_bound(
1007                            &mle_indices_vec,
1008                            round_index,
1009                            true,
1010                        )
1011                    })
1012                    .unzip();
1013
1014                beta_cascade(
1015                    &mles,
1016                    degree,
1017                    round_index,
1018                    &unbound_beta_vec,
1019                    &bound_beta_vec,
1020                    random_coefficients,
1021                )
1022            }
1023
1024            // when the expression is scaled by a field element, we can scale the
1025            // evaluations by this element as well
1026            ExpressionNode::Scaled(a, scale) => {
1027                let a = a.evaluate_sumcheck_node_beta_cascade(
1028                    beta_vec,
1029                    mle_vec,
1030                    random_coefficients,
1031                    round_index,
1032                    degree,
1033                );
1034                a * scale
1035            }
1036        }
1037    }
1038
1039    /// Mutate the MLE indices that are [MleIndex::Free] in the expression and
1040    /// turn them into [MleIndex::Indexed]. Returns the max number of bits
1041    /// that are indexed.
1042    pub fn index_mle_indices_node(
1043        &mut self,
1044        curr_index: usize,
1045        mle_vec: &mut <ProverExpr as ExpressionType<F>>::MleVec,
1046    ) -> usize {
1047        match self {
1048            ExpressionNode::Selector(mle_index, a, b) => {
1049                let mut new_index = curr_index;
1050                if *mle_index == MleIndex::Free {
1051                    *mle_index = MleIndex::Indexed(curr_index);
1052                    new_index += 1;
1053                }
1054                let a_bits = a.index_mle_indices_node(new_index, mle_vec);
1055                let b_bits = b.index_mle_indices_node(new_index, mle_vec);
1056                max(a_bits, b_bits)
1057            }
1058            ExpressionNode::Mle(mle_vec_idx) => {
1059                let mle = mle_vec_idx.get_mle_mut(mle_vec);
1060                mle.index_mle_indices(curr_index)
1061            }
1062            ExpressionNode::Sum(a, b) => {
1063                let a_bits = a.index_mle_indices_node(curr_index, mle_vec);
1064                let b_bits = b.index_mle_indices_node(curr_index, mle_vec);
1065                max(a_bits, b_bits)
1066            }
1067            ExpressionNode::Product(mle_vec_indices) => mle_vec_indices
1068                .iter_mut()
1069                .map(|mle_vec_index| {
1070                    let mle = mle_vec_index.get_mle_mut(mle_vec);
1071                    mle.index_mle_indices(curr_index)
1072                })
1073                .reduce(max)
1074                .unwrap_or(curr_index),
1075            ExpressionNode::Scaled(a, _) => a.index_mle_indices_node(curr_index, mle_vec),
1076            ExpressionNode::Constant(_) => curr_index,
1077        }
1078    }
1079
1080    /// this traverses the expression to get all of the rounds, in total. requires going through each of the nodes
1081    /// and collecting the leaf node indices.
1082    pub(crate) fn get_all_rounds(
1083        &self,
1084        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1085    ) -> Vec<usize> {
1086        let degree_per_index = self.get_rounds_helper(mle_vec);
1087        (0..degree_per_index.len())
1088            .filter(|&i| degree_per_index[i] > 0)
1089            .collect()
1090    }
1091
1092    /// traverse an expression tree in order to get all of the nonlinear rounds in an expression.
1093    pub fn get_all_nonlinear_rounds(
1094        &self,
1095        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1096    ) -> Vec<usize> {
1097        let degree_per_index = self.get_rounds_helper(mle_vec);
1098        (0..degree_per_index.len())
1099            .filter(|&i| degree_per_index[i] > 1)
1100            .collect()
1101    }
1102
1103    /// get all of the linear rounds from an expression tree
1104    pub fn get_all_linear_rounds(
1105        &self,
1106        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1107    ) -> Vec<usize> {
1108        let degree_per_index = self.get_rounds_helper(mle_vec);
1109        (0..degree_per_index.len())
1110            .filter(|&i| degree_per_index[i] == 1)
1111            .collect()
1112    }
1113
1114    // a recursive helper for get_all_rounds, get_all_nonlinear_rounds, and get_all_linear_rounds
1115    fn get_rounds_helper(&self, mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec) -> Vec<usize> {
1116        // degree of each index
1117        let mut degree_per_index = Vec::new();
1118        // set the degree of the corresponding index to max(OLD_DEGREE, NEW_DEGREE)
1119        let max_degree = |degree_per_index: &mut Vec<usize>, index: usize, new_degree: usize| {
1120            if degree_per_index.len() <= index {
1121                degree_per_index.extend(vec![0; index + 1 - degree_per_index.len()]);
1122            }
1123            if degree_per_index[index] < new_degree {
1124                degree_per_index[index] = new_degree;
1125            }
1126        };
1127        // set the degree of the corresponding index to OLD_DEGREE + NEW_DEGREE
1128        let add_degree = |degree_per_index: &mut Vec<usize>, index: usize, new_degree: usize| {
1129            if degree_per_index.len() <= index {
1130                degree_per_index.extend(vec![0; index + 1 - degree_per_index.len()]);
1131            }
1132            degree_per_index[index] += new_degree;
1133        };
1134
1135        match self {
1136            // in a product, we need the union of all the indices in each of the individual mle refs.
1137            ExpressionNode::Product(mle_vec_indices) => {
1138                let mles = mle_vec_indices
1139                    .iter()
1140                    .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
1141                    .collect_vec();
1142                mles.into_iter().for_each(|mle| {
1143                    mle.mle_indices.iter().for_each(|mle_index| {
1144                        if let MleIndex::Indexed(i) = mle_index {
1145                            add_degree(&mut degree_per_index, *i, 1);
1146                        }
1147                    })
1148                });
1149            }
1150            // in an mle, we need all of the mle indices in the mle.
1151            ExpressionNode::Mle(mle_vec_idx) => {
1152                let mle = mle_vec_idx.get_mle(mle_vec);
1153                mle.mle_indices.iter().for_each(|mle_index| {
1154                    if let MleIndex::Indexed(i) = mle_index {
1155                        max_degree(&mut degree_per_index, *i, 1);
1156                    }
1157                });
1158            }
1159            // in selector, take the max degree of each children, and add 1 degree to the selector itself
1160            ExpressionNode::Selector(sel_index, a, b) => {
1161                if let MleIndex::Indexed(i) = sel_index {
1162                    add_degree(&mut degree_per_index, *i, 1);
1163                };
1164                let a_degree_per_index = a.get_rounds_helper(mle_vec);
1165                let b_degree_per_index = b.get_rounds_helper(mle_vec);
1166                // linear operator -- take the max degree
1167                for i in 0..max(a_degree_per_index.len(), b_degree_per_index.len()) {
1168                    if let Some(a_degree) = a_degree_per_index.get(i) {
1169                        max_degree(&mut degree_per_index, i, *a_degree);
1170                    }
1171                    if let Some(b_degree) = b_degree_per_index.get(i) {
1172                        max_degree(&mut degree_per_index, i, *b_degree);
1173                    }
1174                }
1175            }
1176            // in sum, take the max degree of each children
1177            ExpressionNode::Sum(a, b) => {
1178                let a_degree_per_index = a.get_rounds_helper(mle_vec);
1179                let b_degree_per_index = b.get_rounds_helper(mle_vec);
1180                // linear operator -- take the max degree
1181                for i in 0..max(a_degree_per_index.len(), b_degree_per_index.len()) {
1182                    if let Some(a_degree) = a_degree_per_index.get(i) {
1183                        max_degree(&mut degree_per_index, i, *a_degree);
1184                    }
1185                    if let Some(b_degree) = b_degree_per_index.get(i) {
1186                        max_degree(&mut degree_per_index, i, *b_degree);
1187                    }
1188                }
1189            }
1190            // scaled and negated, does not affect degree
1191            ExpressionNode::Scaled(a, _) => {
1192                degree_per_index = a.get_rounds_helper(mle_vec);
1193            }
1194            // for a constant there are no new indices.
1195            ExpressionNode::Constant(_) => {}
1196        }
1197        degree_per_index
1198    }
1199
1200    /// Gets the number of free variables in an expression.
1201    pub fn get_expression_num_free_variables_node(
1202        &self,
1203        curr_size: usize,
1204        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1205    ) -> usize {
1206        match self {
1207            ExpressionNode::Selector(mle_index, a, b) => {
1208                let (a_bits, b_bits) = if matches!(mle_index, &MleIndex::Free) {
1209                    (
1210                        a.get_expression_num_free_variables_node(curr_size + 1, mle_vec),
1211                        b.get_expression_num_free_variables_node(curr_size + 1, mle_vec),
1212                    )
1213                } else {
1214                    (
1215                        a.get_expression_num_free_variables_node(curr_size, mle_vec),
1216                        b.get_expression_num_free_variables_node(curr_size, mle_vec),
1217                    )
1218                };
1219
1220                max(a_bits, b_bits)
1221            }
1222            ExpressionNode::Mle(mle_vec_idx) => {
1223                let mle = mle_vec_idx.get_mle(mle_vec);
1224
1225                mle.mle_indices()
1226                    .iter()
1227                    .filter(|item| matches!(item, &&MleIndex::Free))
1228                    .collect_vec()
1229                    .len()
1230                    + curr_size
1231            }
1232            ExpressionNode::Sum(a, b) => {
1233                let a_bits = a.get_expression_num_free_variables_node(curr_size, mle_vec);
1234                let b_bits = b.get_expression_num_free_variables_node(curr_size, mle_vec);
1235                max(a_bits, b_bits)
1236            }
1237            ExpressionNode::Product(mle_vec_indices) => {
1238                let mles = mle_vec_indices
1239                    .iter()
1240                    .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec))
1241                    .collect_vec();
1242
1243                mles.iter()
1244                    .map(|mle| {
1245                        mle.mle_indices()
1246                            .iter()
1247                            .filter(|item| matches!(item, &&MleIndex::Free))
1248                            .collect_vec()
1249                            .len()
1250                    })
1251                    .max()
1252                    .unwrap_or(0)
1253                    + curr_size
1254            }
1255            ExpressionNode::Scaled(a, _) => {
1256                a.get_expression_num_free_variables_node(curr_size, mle_vec)
1257            }
1258            ExpressionNode::Constant(_) => curr_size,
1259        }
1260    }
1261
1262    /// Recursively get the [PostSumcheckLayer] for an Expression node, which is the fully bound
1263    /// representation of an expression.
1264    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
1265    pub fn get_post_sumcheck_layer(
1266        &self,
1267        multiplier: F,
1268        mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec,
1269    ) -> PostSumcheckLayer<F, F> {
1270        let mut products: Vec<Product<F, F>> = vec![];
1271        match self {
1272            ExpressionNode::Selector(mle_index, a, b) => {
1273                let left_side_acc = multiplier * (F::ONE - mle_index.val().unwrap());
1274                let right_side_acc = multiplier * (mle_index.val().unwrap());
1275                products.extend(a.get_post_sumcheck_layer(left_side_acc, mle_vec).0);
1276                products.extend(b.get_post_sumcheck_layer(right_side_acc, mle_vec).0);
1277            }
1278            ExpressionNode::Sum(a, b) => {
1279                products.extend(a.get_post_sumcheck_layer(multiplier, mle_vec).0);
1280                products.extend(b.get_post_sumcheck_layer(multiplier, mle_vec).0);
1281            }
1282            ExpressionNode::Mle(mle_vec_idx) => {
1283                let mle = mle_vec_idx.get_mle(mle_vec);
1284                assert!(mle.is_fully_bounded());
1285                products.push(Product::<F, F>::new(std::slice::from_ref(mle), multiplier));
1286            }
1287            ExpressionNode::Product(mle_vec_indices) => {
1288                let mles = mle_vec_indices
1289                    .iter()
1290                    .map(|mle_vec_index| mle_vec_index.get_mle(mle_vec).clone())
1291                    .collect_vec();
1292                let product = Product::<F, F>::new(&mles, multiplier);
1293                products.push(product);
1294            }
1295            ExpressionNode::Scaled(a, scale_factor) => {
1296                let acc = multiplier * scale_factor;
1297                products.extend(a.get_post_sumcheck_layer(acc, mle_vec).0);
1298            }
1299            ExpressionNode::Constant(constant) => {
1300                products.push(Product::<F, F>::new(&[], *constant * multiplier));
1301            }
1302        }
1303        PostSumcheckLayer(products)
1304    }
1305
1306    /// Get the maximum degree of an ExpressionNode, recursively.
1307    fn get_max_degree(&self, _mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec) -> usize {
1308        match self {
1309            ExpressionNode::Selector(_, a, b) | ExpressionNode::Sum(a, b) => {
1310                let a_degree = a.get_max_degree(_mle_vec);
1311                let b_degree = b.get_max_degree(_mle_vec);
1312                max(a_degree, b_degree)
1313            }
1314            ExpressionNode::Mle(_) => {
1315                // 1 for the current MLE
1316                1
1317            }
1318            ExpressionNode::Product(mles) => {
1319                // max degree is the number of MLEs in a product
1320                mles.len()
1321            }
1322            ExpressionNode::Scaled(a, _) => a.get_max_degree(_mle_vec),
1323            ExpressionNode::Constant(_) => 1,
1324        }
1325    }
1326}
1327
1328impl<F: Field> Neg for Expression<F, ProverExpr> {
1329    type Output = Expression<F, ProverExpr>;
1330    fn neg(self) -> Self::Output {
1331        Expression::<F, ProverExpr>::negated(self)
1332    }
1333}
1334
1335/// implement the Add, Sub, and Mul traits for the Expression
1336impl<F: Field> Add for Expression<F, ProverExpr> {
1337    type Output = Expression<F, ProverExpr>;
1338    fn add(self, rhs: Expression<F, ProverExpr>) -> Expression<F, ProverExpr> {
1339        Expression::<F, ProverExpr>::sum(self, rhs)
1340    }
1341}
1342
1343impl<F: Field> Sub for Expression<F, ProverExpr> {
1344    type Output = Expression<F, ProverExpr>;
1345    fn sub(self, rhs: Expression<F, ProverExpr>) -> Expression<F, ProverExpr> {
1346        self.add(rhs.neg())
1347    }
1348}
1349
1350impl<F: Field> Mul<F> for Expression<F, ProverExpr> {
1351    type Output = Expression<F, ProverExpr>;
1352    fn mul(self, rhs: F) -> Self::Output {
1353        Expression::<F, ProverExpr>::scaled(self, rhs)
1354    }
1355}
1356
1357// defines how the Expressions are printed and displayed
1358impl<F: std::fmt::Debug + Field> std::fmt::Debug for Expression<F, ProverExpr> {
1359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360        f.debug_struct("Expression")
1361            .field("Expression_Node", &self.expression_node)
1362            .field("MleRef_Vec", &self.mle_vec)
1363            .finish()
1364    }
1365}
1366
1367// defines how the ExpressionNodes are printed and displayed
1368impl<F: std::fmt::Debug + Field> std::fmt::Debug for ExpressionNode<F, ProverExpr> {
1369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1370        match self {
1371            ExpressionNode::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(),
1372            ExpressionNode::Selector(index, a, b) => f
1373                .debug_tuple("Selector")
1374                .field(index)
1375                .field(a)
1376                .field(b)
1377                .finish(),
1378            // Skip enum variant and print query struct directly to maintain backwards compatibility.
1379            ExpressionNode::Mle(_mle) => f.debug_struct("Mle").field("mle", _mle).finish(),
1380            ExpressionNode::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
1381            ExpressionNode::Product(a) => f.debug_tuple("Product").field(a).finish(),
1382            ExpressionNode::Scaled(poly, scalar) => {
1383                f.debug_tuple("Scaled").field(poly).field(scalar).finish()
1384            }
1385        }
1386    }
1387}