remainder/expression/
circuit_expr.rs

1//! The "pure" polynomial relationship description between the MLE representing
2//! a single layer of a "structured" circuit and those representing data from
3//! other layers. See documentation in [crate::expression] for more details.
4
5use crate::{
6    circuit_layout::CircuitEvalMap,
7    layer::{
8        gate::BinaryOperation,
9        product::{PostSumcheckLayer, Product},
10    },
11    mle::{
12        evals::MultilinearExtension, mle_description::MleDescription, verifier_mle::VerifierMle,
13        MleIndex,
14    },
15};
16use ark_std::log2;
17use itertools::Itertools;
18use serde::{Deserialize, Serialize};
19use std::{
20    cmp::max,
21    collections::{HashMap, HashSet},
22    fmt::Debug,
23    ops::{Add, Mul, Neg, Sub},
24};
25
26use shared_types::{transcript::VerifierTranscript, Field};
27
28use super::{
29    expr_errors::ExpressionError,
30    generic_expr::{Expression, ExpressionNode, ExpressionType},
31    prover_expr::ProverExpr,
32    verifier_expr::VerifierExpr,
33};
34
35use anyhow::{anyhow, Result};
36
37/// Type for defining [Expression<F, ExprDescription>], the type used
38/// for representing expressions in the circuit description.
39#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
40pub struct ExprDescription;
41
42// The leaves of an expression of this type contain a [MleDescription], an analogue
43// of [crate::mle::dense::DenseMle], storing only metadata related to the MLE,
44// without any evaluations.
45impl<F: Field> ExpressionType<F> for ExprDescription {
46    type MLENodeRepr = MleDescription<F>;
47    type MleVec = ();
48}
49
50impl<F: Field> Expression<F, ExprDescription> {
51    /// Binds the variables of this expression to `point`, and retrieves the
52    /// leaf MLE values from the `transcript_reader`.  Returns a `Expression<F,
53    /// VerifierExpr>` version of `self`.
54    pub fn bind(
55        &self,
56        point: &[F],
57        transcript_reader: &mut impl VerifierTranscript<F>,
58    ) -> Result<Expression<F, VerifierExpr>> {
59        Ok(Expression::new(
60            self.expression_node
61                .into_verifier_node(point, transcript_reader)?,
62            (),
63        ))
64    }
65
66    /// Convenience function which creates a trivial [Expression<F, ExprDescription>]
67    /// referring to a single MLE.
68    pub fn from_mle_desc(mle_desc: MleDescription<F>) -> Self {
69        Self {
70            expression_node: ExpressionNode::<F, ExprDescription>::Mle(mle_desc),
71            mle_vec: (),
72        }
73    }
74
75    /// Traverses the expression tree to get the indices of all the nonlinear
76    /// rounds. Returns a sorted vector of indices.
77    pub fn get_all_nonlinear_rounds(&self) -> Vec<usize> {
78        self.expression_node
79            .get_all_nonlinear_rounds(&mut vec![], &self.mle_vec)
80            .into_iter()
81            .sorted()
82            .collect()
83    }
84
85    /// Traverses the expression tree to return all indices within the
86    /// expression. Can only be used after indexing the expression.
87    pub fn get_all_rounds(&self) -> Vec<usize> {
88        self.expression_node
89            .get_all_rounds(&mut vec![], &self.mle_vec)
90            .into_iter()
91            .sorted()
92            .collect()
93    }
94
95    /// Get the [MleDescription]s for this expression, which are at the leaves of the expression.
96    pub fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
97        let circuit_mles = self.expression_node.get_circuit_mles();
98        circuit_mles
99    }
100
101    /// Label the free variables in an expression.
102    pub fn index_mle_vars(&mut self, start_index: usize) {
103        self.expression_node.index_mle_vars(start_index);
104    }
105
106    /// Get the [Expression<F, ProverExpr>] corresponding to this [Expression<F, ExprDescription>] using the
107    /// associated data in the [CircuitEvalMap].
108    pub fn into_prover_expression(
109        &self,
110        circuit_map: &CircuitEvalMap<F>,
111    ) -> Expression<F, ProverExpr> {
112        self.expression_node.into_prover_expression(circuit_map)
113    }
114
115    /// Get the [PostSumcheckLayer] for this expression, which represents the fully bound values of the expression.
116    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
117    pub fn get_post_sumcheck_layer(
118        &self,
119        multiplier: F,
120        challenges: &[F],
121    ) -> PostSumcheckLayer<F, Option<F>> {
122        self.expression_node
123            .get_post_sumcheck_layer(multiplier, challenges, &self.mle_vec)
124    }
125
126    /// Get the maximum degree of any variable in this expression.
127    pub fn get_max_degree(&self) -> usize {
128        self.expression_node.get_max_degree(&self.mle_vec)
129    }
130
131    /// Returns the maximum degree of b_{curr_round} within an expression
132    /// (and therefore the number of prover messages we need to send)
133    pub fn get_round_degree(&self, curr_round: usize) -> usize {
134        // By default, all rounds have degree at least 2 (beta table included)
135        let mut round_degree = 1;
136
137        let mut get_degree_closure = |expr: &ExpressionNode<F, ExprDescription>,
138                                      _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec|
139         -> Result<()> {
140            let round_degree = &mut round_degree;
141
142            // The only exception is within a product of MLEs
143            if let ExpressionNode::Product(circuit_mles) = expr {
144                let mut product_round_degree: usize = 0;
145                for circuit_mle in circuit_mles {
146                    let mle_indices = circuit_mle.var_indices();
147                    for mle_index in mle_indices {
148                        if *mle_index == MleIndex::Indexed(curr_round) {
149                            product_round_degree += 1;
150                            break;
151                        }
152                    }
153                }
154                if *round_degree < product_round_degree {
155                    *round_degree = product_round_degree;
156                }
157            }
158            Ok(())
159        };
160
161        self.traverse(&mut get_degree_closure).unwrap();
162        // add 1 cuz beta table but idk if we would ever use this without a beta table
163        round_degree + 1
164    }
165}
166
167impl<F: Field> ExpressionNode<F, ExprDescription> {
168    /// Turn this expression into a [VerifierExpr] which represents a fully bound expression.
169    /// Should only be applicable after a full layer of sumcheck.
170    pub fn into_verifier_node(
171        &self,
172        point: &[F],
173        transcript_reader: &mut impl VerifierTranscript<F>,
174    ) -> Result<ExpressionNode<F, VerifierExpr>> {
175        match self {
176            ExpressionNode::Constant(scalar) => Ok(ExpressionNode::Constant(*scalar)),
177            ExpressionNode::Selector(index, lhs, rhs) => match index {
178                MleIndex::Indexed(idx) => Ok(ExpressionNode::Selector(
179                    MleIndex::Bound(point[*idx], *idx),
180                    Box::new(lhs.into_verifier_node(point, transcript_reader)?),
181                    Box::new(rhs.into_verifier_node(point, transcript_reader)?),
182                )),
183                _ => Err(anyhow!(ExpressionError::SelectorBitNotBoundError)),
184            },
185            ExpressionNode::Mle(circuit_mle) => Ok(ExpressionNode::Mle(
186                circuit_mle.into_verifier_mle(point, transcript_reader)?,
187            )),
188            ExpressionNode::Sum(lhs, rhs) => Ok(ExpressionNode::Sum(
189                Box::new(lhs.into_verifier_node(point, transcript_reader)?),
190                Box::new(rhs.into_verifier_node(point, transcript_reader)?),
191            )),
192            ExpressionNode::Product(circuit_mles) => {
193                let verifier_mles: Vec<VerifierMle<F>> = circuit_mles
194                    .iter()
195                    .map(|circuit_mle| circuit_mle.into_verifier_mle(point, transcript_reader))
196                    .collect::<Result<Vec<VerifierMle<F>>>>()?;
197
198                Ok(ExpressionNode::Product(verifier_mles))
199            }
200            ExpressionNode::Scaled(circuit_mle, scalar) => Ok(ExpressionNode::Scaled(
201                Box::new(circuit_mle.into_verifier_node(point, transcript_reader)?),
202                *scalar,
203            )),
204        }
205    }
206
207    /// Compute the expression-wise bookkeeping table (coefficients of the MLE representing the expression)
208    /// for a given [ExprDescription]. This uses a [CircuitEvalMap] in order to grab the correct data
209    /// corresponding to the [MleDescription].
210    pub fn compute_bookkeeping_table(
211        &self,
212        circuit_map: &CircuitEvalMap<F>,
213    ) -> Option<MultilinearExtension<F>> {
214        let output_data: Option<MultilinearExtension<F>> = match self {
215            ExpressionNode::Mle(circuit_mle) => {
216                let maybe_mle = circuit_map.get_data_from_circuit_mle(circuit_mle);
217                if let Ok(mle) = maybe_mle {
218                    Some(mle.clone())
219                } else {
220                    return None;
221                }
222            }
223            ExpressionNode::Product(circuit_mles) => {
224                let mle_bookkeeping_tables = circuit_mles
225                    .iter()
226                    .map(|circuit_mle| {
227                        circuit_map
228                            .get_data_from_circuit_mle(circuit_mle) // Returns Result
229                            .map(|data| data.to_vec()) // Map Ok value to slice
230                    })
231                    .collect::<Result<Vec<Vec<F>>>>() // Collect all into a Result
232                    .ok()?;
233                Some(evaluate_bookkeeping_tables_given_operation(
234                    &mle_bookkeeping_tables,
235                    BinaryOperation::Mul,
236                ))
237            }
238            ExpressionNode::Sum(a, b) => {
239                let a_bookkeeping_table = a.compute_bookkeeping_table(circuit_map)?;
240                let b_bookkeeping_table = b.compute_bookkeeping_table(circuit_map)?;
241                Some(evaluate_bookkeeping_tables_given_operation(
242                    &[
243                        (a_bookkeeping_table.to_vec()),
244                        (b_bookkeeping_table.to_vec()),
245                    ],
246                    BinaryOperation::Add,
247                ))
248            }
249            ExpressionNode::Scaled(a, scale) => {
250                let a_bookkeeping_table = a.compute_bookkeeping_table(circuit_map)?;
251                Some(MultilinearExtension::new(
252                    a_bookkeeping_table
253                        .iter()
254                        .map(|elem| elem * scale)
255                        .collect_vec(),
256                ))
257            }
258            ExpressionNode::Selector(_mle_index, a, b) => {
259                let a_bookkeeping_table = a.compute_bookkeeping_table(circuit_map)?;
260                let b_bookkeeping_table = b.compute_bookkeeping_table(circuit_map)?;
261                assert_eq!(
262                    a_bookkeeping_table.num_vars(),
263                    b_bookkeeping_table.num_vars()
264                );
265                Some(MultilinearExtension::new(
266                    a_bookkeeping_table
267                        .iter()
268                        .chain(b_bookkeeping_table.iter())
269                        .collect_vec(),
270                ))
271            }
272            ExpressionNode::Constant(value) => Some(MultilinearExtension::new(vec![*value])),
273        };
274
275        output_data
276    }
277
278    /// Evaluate the polynomial using the provided closures to perform the
279    /// operations.
280    #[allow(clippy::too_many_arguments)]
281    pub fn reduce<T>(
282        &self,
283        constant: &mut impl FnMut(F) -> T,
284        selector_column: &mut impl FnMut(&MleIndex<F>, T, T) -> T,
285        mle_eval: &mut impl FnMut(&<ExprDescription as ExpressionType<F>>::MLENodeRepr) -> T,
286        sum: &mut impl FnMut(T, T) -> T,
287        product: &mut impl FnMut(&[<ExprDescription as ExpressionType<F>>::MLENodeRepr]) -> T,
288        scaled: &mut impl FnMut(T, F) -> T,
289    ) -> T {
290        match self {
291            ExpressionNode::Constant(scalar) => constant(*scalar),
292            ExpressionNode::Selector(index, a, b) => {
293                let lhs = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
294                let rhs = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
295                selector_column(index, lhs, rhs)
296            }
297            ExpressionNode::Mle(query) => mle_eval(query),
298            ExpressionNode::Sum(a, b) => {
299                let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
300                let b = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
301                sum(a, b)
302            }
303            ExpressionNode::Product(queries) => product(queries),
304            ExpressionNode::Scaled(a, f) => {
305                let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
306                scaled(a, *f)
307            }
308        }
309    }
310
311    /// Traverse an expression tree in order and returns a vector of indices of
312    /// all the nonlinear rounds in an expression (in no particular order).
313    pub fn get_all_nonlinear_rounds(
314        &self,
315        curr_nonlinear_indices: &mut Vec<usize>,
316        _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec,
317    ) -> Vec<usize> {
318        let nonlinear_indices_in_node = {
319            match self {
320                // The only case where an index is nonlinear is if it is present in multiple mle
321                // refs that are part of a product. We iterate through all the indices in the
322                // product nodes to look for repeated indices within a single node.
323                ExpressionNode::Product(verifier_mles) => {
324                    let mut product_nonlinear_indices: HashSet<usize> = HashSet::new();
325                    let mut product_indices_counts: HashMap<MleIndex<F>, usize> = HashMap::new();
326
327                    verifier_mles.iter().for_each(|verifier_mle| {
328                        verifier_mle.var_indices().iter().for_each(|mle_index| {
329                            let curr_count = {
330                                if product_indices_counts.contains_key(mle_index) {
331                                    product_indices_counts.get(mle_index).unwrap()
332                                } else {
333                                    &0
334                                }
335                            };
336                            product_indices_counts.insert(mle_index.clone(), curr_count + 1);
337                        })
338                    });
339
340                    product_indices_counts
341                        .into_iter()
342                        .for_each(|(mle_index, count)| {
343                            if count > 1 {
344                                if let MleIndex::Indexed(i) = mle_index {
345                                    product_nonlinear_indices.insert(i);
346                                } else if let MleIndex::Bound(_, i) = mle_index {
347                                    product_nonlinear_indices.insert(i);
348                                }
349                            }
350                        });
351
352                    product_nonlinear_indices
353                }
354                // for the rest of the types of expressions, we simply traverse through the expression node to look
355                // for more leaves which are specifically product nodes.
356                ExpressionNode::Selector(_sel_index, a, b) => {
357                    let mut sel_nonlinear_indices: HashSet<usize> = HashSet::new();
358                    let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
359                    let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
360                    a_indices
361                        .into_iter()
362                        .zip(b_indices)
363                        .for_each(|(a_mle_idx, b_mle_idx)| {
364                            sel_nonlinear_indices.insert(a_mle_idx);
365                            sel_nonlinear_indices.insert(b_mle_idx);
366                        });
367                    sel_nonlinear_indices
368                }
369                ExpressionNode::Sum(a, b) => {
370                    let mut sum_nonlinear_indices: HashSet<usize> = HashSet::new();
371                    let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
372                    let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
373                    a_indices
374                        .into_iter()
375                        .zip(b_indices)
376                        .for_each(|(a_mle_idx, b_mle_idx)| {
377                            sum_nonlinear_indices.insert(a_mle_idx);
378                            sum_nonlinear_indices.insert(b_mle_idx);
379                        });
380                    sum_nonlinear_indices
381                }
382                ExpressionNode::Scaled(a, _) => a
383                    .get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec)
384                    .into_iter()
385                    .collect(),
386                ExpressionNode::Constant(_) | ExpressionNode::Mle(_) => HashSet::new(),
387            }
388        };
389        // we grab all of the indices and take the union of all of them to return all nonlinear rounds in an expression tree.
390        nonlinear_indices_in_node.into_iter().for_each(|index| {
391            if !curr_nonlinear_indices.contains(&index) {
392                curr_nonlinear_indices.push(index);
393            }
394        });
395        curr_nonlinear_indices.clone()
396    }
397
398    /// This function traverses an expression tree in order to determine what are
399    /// the labels for the variables in the expression.
400    pub(crate) fn get_all_rounds(
401        &self,
402        curr_indices: &mut Vec<usize>,
403        _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec,
404    ) -> Vec<usize> {
405        let indices_in_node = {
406            match self {
407                // In a product, we need the union of all the labels of the variables in each
408                // of the MLEs.
409                ExpressionNode::Product(verifier_mles) => {
410                    let mut product_indices: HashSet<usize> = HashSet::new();
411                    verifier_mles.iter().for_each(|mle| {
412                        mle.var_indices().iter().for_each(|mle_index| {
413                            if let MleIndex::Indexed(i) = mle_index {
414                                product_indices.insert(*i);
415                            }
416                        })
417                    });
418                    product_indices
419                }
420                // In an mle, all the variable labels are relevant in the expression.
421                ExpressionNode::Mle(verifier_mle) => verifier_mle
422                    .var_indices()
423                    .iter()
424                    .filter_map(|mle_index| match mle_index {
425                        MleIndex::Indexed(i) => Some(*i),
426                        _ => None,
427                    })
428                    .collect(),
429                // In a selector, we traverse each parts of the selector while adding the selector index
430                // itself to the total set of all variable labels in an expression.
431                ExpressionNode::Selector(sel_index, a, b) => {
432                    let mut sel_indices: HashSet<usize> = HashSet::new();
433                    if let MleIndex::Indexed(i) = sel_index {
434                        sel_indices.insert(*i);
435                    };
436
437                    let a_indices = a.get_all_rounds(curr_indices, _mle_vec);
438                    let b_indices = b.get_all_rounds(curr_indices, _mle_vec);
439                    a_indices
440                        .into_iter()
441                        .zip(b_indices)
442                        .for_each(|(a_mle_idx, b_mle_idx)| {
443                            sel_indices.insert(a_mle_idx);
444                            sel_indices.insert(b_mle_idx);
445                        });
446                    sel_indices
447                }
448                // We add the variable labels in each of the parts of the sum.
449                ExpressionNode::Sum(a, b) => {
450                    let mut sum_indices: HashSet<usize> = HashSet::new();
451                    let a_indices = a.get_all_rounds(curr_indices, _mle_vec);
452                    let b_indices = b.get_all_rounds(curr_indices, _mle_vec);
453                    a_indices
454                        .into_iter()
455                        .zip(b_indices)
456                        .for_each(|(a_mle_idx, b_mle_idx)| {
457                            sum_indices.insert(a_mle_idx);
458                            sum_indices.insert(b_mle_idx);
459                        });
460                    sum_indices
461                }
462                // For scaled, we can add all of variable labels found in the expression being scaled.
463                ExpressionNode::Scaled(a, _) => a
464                    .get_all_rounds(curr_indices, _mle_vec)
465                    .into_iter()
466                    .collect(),
467                // for a constant there are no new indices.
468                ExpressionNode::Constant(_) => HashSet::new(),
469            }
470        };
471        // Once all of them have been collected, we can take the union of all
472        // of them to grab the variable labels of an expression.
473        indices_in_node.into_iter().for_each(|index| {
474            if !curr_indices.contains(&index) {
475                curr_indices.push(index);
476            }
477        });
478        curr_indices.clone()
479    }
480
481    /// Get all the [MleDescription]s, recursively, for this expression by adding the MLEs in the leaves into the vector of MleDescriptions.
482    pub fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
483        let mut circuit_mles: Vec<&MleDescription<F>> = vec![];
484        match self {
485            ExpressionNode::Selector(_mle_index, a, b) => {
486                circuit_mles.extend(a.get_circuit_mles());
487                circuit_mles.extend(b.get_circuit_mles());
488            }
489            ExpressionNode::Sum(a, b) => {
490                circuit_mles.extend(a.get_circuit_mles());
491                circuit_mles.extend(b.get_circuit_mles());
492            }
493            ExpressionNode::Mle(mle) => {
494                circuit_mles.push(mle);
495            }
496            ExpressionNode::Product(mles) => mles.iter().for_each(|mle| circuit_mles.push(mle)),
497            ExpressionNode::Scaled(a, _scale_factor) => {
498                circuit_mles.extend(a.get_circuit_mles());
499            }
500            ExpressionNode::Constant(_constant) => {}
501        }
502        circuit_mles
503    }
504
505    /// Label the MLE indices of an expression, starting from the `start_index`.
506    pub fn index_mle_vars(&mut self, start_index: usize) {
507        match self {
508            ExpressionNode::Selector(mle_index, a, b) => {
509                match mle_index {
510                    MleIndex::Free => *mle_index = MleIndex::Indexed(start_index),
511                    MleIndex::Fixed(_bit) => {}
512                    _ => panic!("should not have indexed or bound bits at this point!"),
513                };
514                a.index_mle_vars(start_index + 1);
515                b.index_mle_vars(start_index + 1);
516            }
517            ExpressionNode::Sum(a, b) => {
518                a.index_mle_vars(start_index);
519                b.index_mle_vars(start_index);
520            }
521            ExpressionNode::Mle(mle) => {
522                mle.index_mle_indices(start_index);
523            }
524            ExpressionNode::Product(mles) => {
525                mles.iter_mut()
526                    .for_each(|mle| mle.index_mle_indices(start_index));
527            }
528            ExpressionNode::Scaled(a, _scale_factor) => {
529                a.index_mle_vars(start_index);
530            }
531            ExpressionNode::Constant(_constant) => {}
532        }
533    }
534
535    /// Get the [ExpressionNode<F, ProverExpr>] recursively, for this expression.
536    pub fn into_prover_expression(
537        &self,
538        circuit_map: &CircuitEvalMap<F>,
539    ) -> Expression<F, ProverExpr> {
540        match self {
541            ExpressionNode::Selector(_mle_index, a, b) => a
542                .into_prover_expression(circuit_map)
543                .select(b.into_prover_expression(circuit_map)),
544            ExpressionNode::Sum(a, b) => {
545                a.into_prover_expression(circuit_map) + b.into_prover_expression(circuit_map)
546            }
547            ExpressionNode::Mle(mle) => {
548                let prover_mle = mle.into_dense_mle(circuit_map);
549                prover_mle.expression()
550            }
551            ExpressionNode::Product(mles) => {
552                let dense_mles = mles
553                    .iter()
554                    .map(|mle| mle.into_dense_mle(circuit_map))
555                    .collect_vec();
556                Expression::<F, ProverExpr>::products(dense_mles)
557            }
558            ExpressionNode::Scaled(a, scale_factor) => {
559                a.into_prover_expression(circuit_map) * *scale_factor
560            }
561            ExpressionNode::Constant(constant) => Expression::<F, ProverExpr>::constant(*constant),
562        }
563    }
564
565    /// Recursively get the [PostSumcheckLayer] for an Expression node, which is the fully bound
566    /// representation of an expression.
567    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
568    pub fn get_post_sumcheck_layer(
569        &self,
570        multiplier: F,
571        challenges: &[F],
572        _mle_vec: &<VerifierExpr as ExpressionType<F>>::MleVec,
573    ) -> PostSumcheckLayer<F, Option<F>> {
574        let mut products: Vec<Product<F, Option<F>>> = vec![];
575        match self {
576            ExpressionNode::Selector(mle_index, a, b) => {
577                let idx_val = match mle_index {
578                    MleIndex::Indexed(idx) => challenges[*idx],
579                    MleIndex::Bound(chal, _idx) => *chal,
580                    // TODO(vishady): actually we should just have an assertion that circuit description only
581                    // contains indexed bits
582                    _ => panic!("should not have any other index here"),
583                };
584                let left_side_acc = multiplier * (F::ONE - idx_val);
585                let right_side_acc = multiplier * (idx_val);
586                products.extend(
587                    a.get_post_sumcheck_layer(left_side_acc, challenges, _mle_vec)
588                        .0,
589                );
590                products.extend(
591                    b.get_post_sumcheck_layer(right_side_acc, challenges, _mle_vec)
592                        .0,
593                );
594            }
595            ExpressionNode::Sum(a, b) => {
596                products.extend(
597                    a.get_post_sumcheck_layer(multiplier, challenges, _mle_vec)
598                        .0,
599                );
600                products.extend(
601                    b.get_post_sumcheck_layer(multiplier, challenges, _mle_vec)
602                        .0,
603                );
604            }
605            ExpressionNode::Mle(mle) => {
606                products.push(Product::<F, Option<F>>::new(
607                    std::slice::from_ref(mle),
608                    multiplier,
609                    challenges,
610                ));
611            }
612            ExpressionNode::Product(mles) => {
613                let product = Product::<F, Option<F>>::new(mles, multiplier, challenges);
614                products.push(product);
615            }
616            ExpressionNode::Scaled(a, scale_factor) => {
617                let acc = multiplier * scale_factor;
618                products.extend(a.get_post_sumcheck_layer(acc, challenges, _mle_vec).0);
619            }
620            ExpressionNode::Constant(constant) => {
621                products.push(Product::<F, Option<F>>::new(
622                    &[],
623                    *constant * multiplier,
624                    challenges,
625                ));
626            }
627        }
628        PostSumcheckLayer(products)
629    }
630
631    /// Get the maximum degree of an ExpressionNode, recursively.
632    fn get_max_degree(&self, _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec) -> usize {
633        match self {
634            ExpressionNode::Selector(_, a, b) | ExpressionNode::Sum(a, b) => {
635                let a_degree = a.get_max_degree(_mle_vec);
636                let b_degree = b.get_max_degree(_mle_vec);
637                max(a_degree, b_degree)
638            }
639            ExpressionNode::Mle(_) => {
640                // 1 for the current MLE
641                1
642            }
643            ExpressionNode::Product(mles) => {
644                // max degree is the number of MLEs in a product
645                mles.len()
646            }
647            ExpressionNode::Scaled(a, _) => a.get_max_degree(_mle_vec),
648            ExpressionNode::Constant(_) => 1,
649        }
650    }
651
652    /// Returns the total number of variables (i.e. number of rounds of sumcheck) within
653    /// the MLE representing the output "data" of this particular expression.
654    ///
655    /// Note that unlike within the `AbstractExpr` case, we don't need to return
656    /// a `Result` since all MLEs within a `ExprDescription` are instantiated with their
657    /// appropriate number of variables.
658    fn get_num_vars(&self) -> usize {
659        match self {
660            ExpressionNode::Constant(_) => 0,
661            ExpressionNode::Selector(_, lhs, rhs) => {
662                max(lhs.get_num_vars() + 1, rhs.get_num_vars() + 1)
663            }
664            ExpressionNode::Mle(circuit_mle_desc) => circuit_mle_desc.num_free_vars(),
665            ExpressionNode::Sum(lhs, rhs) => max(lhs.get_num_vars(), rhs.get_num_vars()),
666            ExpressionNode::Product(nodes) => nodes.iter().fold(0, |cur_max, circuit_mle_desc| {
667                max(cur_max, circuit_mle_desc.num_free_vars())
668            }),
669            ExpressionNode::Scaled(expr, _) => expr.get_num_vars(),
670        }
671    }
672}
673
674impl<F: Field> Expression<F, ExprDescription> {
675    /// Returns the total number of variables (i.e. number of rounds of sumcheck)
676    /// within the MLE representing the output "data" of this particular expression.
677    ///
678    /// Note that unlike within the AbstractExpr case, we don't need to return
679    /// a `Result` since all MLEs within a `ExprDescription` are instantiated with their appropriate number of variables.
680    pub fn num_vars(&self) -> usize {
681        self.expression_node.get_num_vars()
682    }
683
684    /// Creates an `Expression<F, ExprDescription>` which describes the polynomial relationship
685    ///
686    /// `circuit_mle_descs[0](x_1, ..., x_{n_0}) * circuit_mle_descs[1](x_1, ..., x_{n_1}) * ...`
687    pub fn products(circuit_mle_descs: Vec<MleDescription<F>>) -> Self {
688        let product_node = ExpressionNode::Product(circuit_mle_descs);
689
690        Expression::new(product_node, ())
691    }
692
693    /// Creates an [Expression<F, ExprDescription>] which describes the polynomial relationship
694    /// `(1 - x_0) * Self(x_1, ..., x_{n_lhs}) + b_0 * rhs(x_1, ..., x_{n_rhs})`
695    ///
696    /// NOTE that by default, performing a `select()` over an LHS and an RHS
697    /// with different numbers of variables will create a selector tree such that
698    /// the side with fewer variables always falls down the left-most side of
699    /// that subtree.
700    ///
701    /// For example, if we are calling `select()` on two MLEs,
702    /// V_i(x_0, ..., x_4) and V_i(x_0, ..., x_6)
703    /// then the resulting expression will have a single top-level selector, and
704    /// will forcibly move the first MLE (with two fewer variables) to the left-most
705    /// subtree with 5 variables:
706    /// (1 - x_0) * (1 - x_1) * (1 - x_2) * V_i(x_3, ..., x_7) +
707    /// x_0 * V_i(x_1, ..., x_7)
708    pub fn select(self, rhs: Expression<F, ExprDescription>) -> Self {
709        let (lhs_node, _) = self.deconstruct();
710        let (rhs_node, _) = rhs.deconstruct();
711
712        // Compute the difference in number of free variables, to add the appropriate number of selectors
713        let num_left_selectors = max(0, rhs_node.get_num_vars() - lhs_node.get_num_vars());
714        let num_right_selectors = max(0, lhs_node.get_num_vars() - rhs_node.get_num_vars());
715
716        let lhs_subtree = if num_left_selectors > 0 {
717            // Always "go left" and "select" against a constant zero
718            (0..num_left_selectors).fold(lhs_node, |cur_subtree, _| {
719                ExpressionNode::Selector(
720                    MleIndex::Free,
721                    Box::new(cur_subtree),
722                    Box::new(ExpressionNode::Constant(F::ZERO)),
723                )
724            })
725        } else {
726            lhs_node
727        };
728
729        let rhs_subtree = if num_right_selectors > 0 {
730            // Always "go left" and "select" against a constant zero
731            (0..num_right_selectors).fold(rhs_node, |cur_subtree, _| {
732                ExpressionNode::Selector(
733                    MleIndex::Free,
734                    Box::new(cur_subtree),
735                    Box::new(ExpressionNode::Constant(F::ZERO)),
736                )
737            })
738        } else {
739            rhs_node
740        };
741
742        // Sanitycheck
743        debug_assert_eq!(lhs_subtree.get_num_vars(), rhs_subtree.get_num_vars());
744
745        // Finally, a selector against the two (equal-num-vars) sides!
746        let concat_node =
747            ExpressionNode::Selector(MleIndex::Free, Box::new(lhs_subtree), Box::new(rhs_subtree));
748
749        Expression::new(concat_node, ())
750    }
751
752    /// Create a nested selector Expression that selects between 2^k Expressions
753    /// by creating a binary tree of Selector Expressions.
754    /// The order of the leaves is the order of the input expressions.
755    /// (Note that this is very different from calling [Self::select] consecutively.)
756    pub fn binary_tree_selector(expressions: Vec<Self>) -> Self {
757        // Ensure length is a power of two
758        assert!(expressions.len().is_power_of_two());
759        let mut expressions = expressions;
760        while expressions.len() > 1 {
761            // Iterate over consecutive pairs of expressions, creating a new expression that selects between them
762            expressions = expressions
763                .into_iter()
764                .tuples()
765                .map(|(lhs, rhs)| {
766                    let (lhs_node, _) = lhs.deconstruct();
767                    let (rhs_node, _) = rhs.deconstruct();
768
769                    let selector_node = ExpressionNode::Selector(
770                        MleIndex::Free,
771                        Box::new(lhs_node),
772                        Box::new(rhs_node),
773                    );
774
775                    Expression::new(selector_node, ())
776                })
777                .collect();
778        }
779        expressions[0].clone()
780    }
781
782    /// Literally just `constant` as a term, but as an "`Expression`"
783    pub fn constant(constant: F) -> Self {
784        let mle_node = ExpressionNode::Constant(constant);
785
786        Expression::new(mle_node, ())
787    }
788
789    /// Literally just `-expression`, as an "`Expression`"
790    pub fn negated(expression: Self) -> Self {
791        let (node, _) = expression.deconstruct();
792
793        let mle_node = ExpressionNode::Scaled(Box::new(node), F::from(1).neg());
794
795        Expression::new(mle_node, ())
796    }
797
798    /// Literally just `lhs` + `rhs`, as an "`Expression`"
799    pub fn sum(lhs: Self, rhs: Self) -> Self {
800        let (lhs_node, _) = lhs.deconstruct();
801        let (rhs_node, _) = rhs.deconstruct();
802
803        let sum_node = ExpressionNode::Sum(Box::new(lhs_node), Box::new(rhs_node));
804
805        Expression::new(sum_node, ())
806    }
807
808    /// scales an Expression by a field element
809    pub fn scaled(expression: Expression<F, ExprDescription>, scale: F) -> Self {
810        let (node, _) = expression.deconstruct();
811
812        Expression::new(ExpressionNode::Scaled(Box::new(node), scale), ())
813    }
814}
815
816/// Given a bookkeeping table, use the according prefix bits in order
817/// to filter it to the correct "view" that we want to see, assuming
818/// that the prefix bits are the most significant bits, and that
819/// the bookkeeping tables are stored in "big endian" format.
820pub fn filter_bookkeeping_table<F: Field>(
821    bookkeeping_table: &MultilinearExtension<F>,
822    unfiltered_prefix_bits: &[bool],
823) -> MultilinearExtension<F> {
824    let current_table = bookkeeping_table.to_vec();
825    let mut current_table_len = current_table.len();
826    let filtered_table = unfiltered_prefix_bits
827        .iter()
828        .fold(current_table, |acc, bit| {
829            let acc = if *bit {
830                acc.into_iter().skip(current_table_len / 2).collect_vec()
831            } else {
832                acc.into_iter().take(current_table_len / 2).collect_vec()
833            };
834            current_table_len /= 2;
835            acc
836        });
837    MultilinearExtension::new(filtered_table)
838}
839
840/// Evaluate the bookkeeping tables by applying the element-wise operation,
841/// which can either be addition or multiplication.
842pub(crate) fn evaluate_bookkeeping_tables_given_operation<F: Field>(
843    mle_bookkeeping_tables: &[Vec<F>],
844    binary_operation: BinaryOperation,
845) -> MultilinearExtension<F> {
846    let max_num_vars = mle_bookkeeping_tables
847        .iter()
848        .map(|bookkeeping_table| log2(bookkeeping_table.len()))
849        .max()
850        .unwrap();
851
852    let mut output_table = vec![F::ZERO; 1 << max_num_vars];
853    (0..1 << (max_num_vars)).for_each(|index| {
854        let evaluated_data_point = mle_bookkeeping_tables
855            .iter()
856            .map(|mle_bookkeeping_table| {
857                let zero = F::ZERO;
858                let index = if log2(mle_bookkeeping_table.len()) < max_num_vars {
859                    let max = 1 << log2(mle_bookkeeping_table.len());
860                    let multiple = (1 << max_num_vars) / max;
861                    index / multiple
862                } else {
863                    index
864                };
865                let value = *mle_bookkeeping_table.get(index).unwrap_or(&zero);
866                value
867            })
868            .reduce(|acc, value| binary_operation.perform_operation(acc, value))
869            .unwrap();
870        output_table[index] = evaluated_data_point;
871    });
872    MultilinearExtension::new(output_table)
873}
874
875impl<F: Field> Neg for Expression<F, ExprDescription> {
876    type Output = Expression<F, ExprDescription>;
877    fn neg(self) -> Self::Output {
878        Expression::<F, ExprDescription>::negated(self)
879    }
880}
881
882/// implement the Add, Sub, and Mul traits for the Expression
883impl<F: Field> Add for Expression<F, ExprDescription> {
884    type Output = Expression<F, ExprDescription>;
885    fn add(self, rhs: Expression<F, ExprDescription>) -> Expression<F, ExprDescription> {
886        Expression::<F, ExprDescription>::sum(self, rhs)
887    }
888}
889
890impl<F: Field> Sub for Expression<F, ExprDescription> {
891    type Output = Expression<F, ExprDescription>;
892    fn sub(self, rhs: Expression<F, ExprDescription>) -> Expression<F, ExprDescription> {
893        self.add(rhs.neg())
894    }
895}
896
897impl<F: Field> Mul<F> for Expression<F, ExprDescription> {
898    type Output = Expression<F, ExprDescription>;
899    fn mul(self, rhs: F) -> Self::Output {
900        Expression::<F, ExprDescription>::scaled(self, rhs)
901    }
902}
903
904impl<F: std::fmt::Debug + Field> std::fmt::Debug for Expression<F, ExprDescription> {
905    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
906        f.debug_struct("Circuit Expression")
907            .field("Expression_Node", &self.expression_node)
908            .finish()
909    }
910}
911
912impl<F: std::fmt::Debug + Field> std::fmt::Debug for ExpressionNode<F, ExprDescription> {
913    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
914        match self {
915            ExpressionNode::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(),
916            ExpressionNode::Selector(index, a, b) => f
917                .debug_tuple("Selector")
918                .field(index)
919                .field(a)
920                .field(b)
921                .finish(),
922            // Skip enum variant and print query struct directly to maintain backwards compatibility.
923            ExpressionNode::Mle(mle) => f.debug_struct("Circuit Mle").field("mle", mle).finish(),
924            ExpressionNode::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
925            ExpressionNode::Product(a) => f.debug_tuple("Product").field(a).finish(),
926            ExpressionNode::Scaled(poly, scalar) => {
927                f.debug_tuple("Scaled").field(poly).field(scalar).finish()
928            }
929        }
930    }
931}