remainder/expression/
verifier_expr.rs

1//! The verifier's "view" of a "fully-bound" polynomial relationship between
2//! an output layer and those of its inputs (see documentation within
3//! [crate::expression] for more general information).
4//!
5//! Specifically, the [VerifierExpr] is exactly the struct a GKR verifier
6//! uses to compute its "oracle query" at the end of the sumcheck protocol,
7//! wherein
8//! * The prover has sent over the last sumcheck message, i.e. the univariate
9//!   polynomial f_{n - 1}(x) = \sum_{b_1, ..., b_{n - 1}} f(b_1, ..., b_{n - 1}, x)
10//! * The verifier samples r_n uniformly from \mathbb{F} and wishes to check that
11//!   Expr(r_1, ..., r_n) = f_{n - 1}(r_n).
12//!   Specifically, the verifier wishes to "plug in" r_1, ..., r_n to Expr(x_1, ..., x_n)
13//!   and additionally add in prover-claimed values for each of the MLEs at the
14//!   leaves of Expr(x_1, ..., x_n) to check the above. The [VerifierExpr] allows
15//!   the verifier to do exactly this, as the conversion from a
16//!   [super::circuit_expr::ExprDescription] to a [VerifierExpr] involves exactly the
17//!   process of "binding" the sumcheck challenges and "populating" each leaf
18//!   MLE with the prover-claimed value for the evaluation of that MLE at the
19//!   bound sumcheck challenge points.
20
21use crate::mle::{verifier_mle::VerifierMle, MleIndex};
22use serde::{Deserialize, Serialize};
23use std::{
24    collections::{HashMap, HashSet},
25    fmt::Debug,
26};
27
28use shared_types::Field;
29
30use super::{
31    expr_errors::ExpressionError,
32    generic_expr::{Expression, ExpressionNode, ExpressionType},
33};
34
35use anyhow::{anyhow, Ok, Result};
36
37/// Placeholder type for defining `Expression<F, VerifierExpr>`, the type used
38/// for representing expressions for the Verifier.
39#[derive(Serialize, Deserialize, Clone, Debug)]
40pub struct VerifierExpr;
41
42// The leaves of an expression of this type contain a [VerifierMle], an analogue
43// of [crate::mle::dense::DenseMle], storing fully bound MLEs.
44// TODO(Makis): Consider allowing for re-use of MLEs, like in a [ProverExpr]:
45// ```ignore
46//     type MLENodeRepr = usize,
47//     type MleVec = Vec<VerifierMle<F>>,
48// ```
49impl<F: Field> ExpressionType<F> for VerifierExpr {
50    type MLENodeRepr = VerifierMle<F>;
51    type MleVec = ();
52}
53
54impl<F: Field> Expression<F, VerifierExpr> {
55    /// Create a mle Expression that contains one MLE
56    pub fn mle(mle: VerifierMle<F>) -> Self {
57        let mle_node = ExpressionNode::Mle(mle);
58
59        Expression::new(mle_node, ())
60    }
61
62    /// Evaluate this fully bound expression.
63    pub fn evaluate(&self) -> Result<F> {
64        let constant = |c| Ok(c);
65        let selector_column = |idx: &MleIndex<F>, lhs: Result<F>, rhs: Result<F>| -> Result<F> {
66            // Selector bit must be bound
67            if let MleIndex::Bound(val, _) = idx {
68                return Ok(*val * rhs? + (F::ONE - val) * lhs?);
69            }
70            Err(anyhow!(ExpressionError::SelectorBitNotBoundError))
71        };
72        let mle_eval = |verifier_mle: &VerifierMle<F>| -> Result<F> { Ok(verifier_mle.value()) };
73        let sum = |lhs: Result<F>, rhs: Result<F>| Ok(lhs? + rhs?);
74        let product = |verifier_mles: &[VerifierMle<F>]| -> Result<F> {
75            verifier_mles
76                .iter()
77                .try_fold(F::ONE, |acc, verifier_mle| Ok(acc * verifier_mle.value()))
78        };
79        let scaled = |val: Result<F>, scalar: F| Ok(val? * scalar);
80
81        self.expression_node.reduce(
82            &constant,
83            &selector_column,
84            &mle_eval,
85            &sum,
86            &product,
87            &scaled,
88        )
89    }
90
91    /// Traverses the expression tree to get the indices of all the nonlinear
92    /// rounds. Returns a sorted vector of indices.
93    pub fn get_all_nonlinear_rounds(&mut self) -> Vec<usize> {
94        let (expression_node, mle_vec) = self.deconstruct_mut();
95        let mut nonlinear_rounds: Vec<usize> =
96            expression_node.get_all_nonlinear_rounds(&mut vec![], mle_vec);
97        nonlinear_rounds.sort();
98        nonlinear_rounds
99    }
100}
101
102impl<F: Field> ExpressionNode<F, VerifierExpr> {
103    /// Evaluate the polynomial using the provided closures to perform the
104    /// operations.
105    #[allow(clippy::too_many_arguments)]
106    pub fn reduce<T>(
107        &self,
108        constant: &impl Fn(F) -> T,
109        selector_column: &impl Fn(&MleIndex<F>, T, T) -> T,
110        mle_eval: &impl Fn(&<VerifierExpr as ExpressionType<F>>::MLENodeRepr) -> T,
111        sum: &impl Fn(T, T) -> T,
112        product: &impl Fn(&[<VerifierExpr as ExpressionType<F>>::MLENodeRepr]) -> T,
113        scaled: &impl Fn(T, F) -> T,
114    ) -> T {
115        match self {
116            ExpressionNode::Constant(scalar) => constant(*scalar),
117            ExpressionNode::Selector(index, a, b) => {
118                let lhs = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
119                let rhs = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
120                selector_column(index, lhs, rhs)
121            }
122            ExpressionNode::Mle(query) => mle_eval(query),
123            ExpressionNode::Sum(a, b) => {
124                let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
125                let b = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
126                sum(a, b)
127            }
128            ExpressionNode::Product(queries) => product(queries),
129            ExpressionNode::Scaled(a, f) => {
130                let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
131                scaled(a, *f)
132            }
133        }
134    }
135
136    /// Traverse an expression tree in order and returns a vector of indices of
137    /// all the nonlinear rounds in an expression (in no particular order).
138    pub fn get_all_nonlinear_rounds(
139        &self,
140        curr_nonlinear_indices: &mut Vec<usize>,
141        _mle_vec: &<VerifierExpr as ExpressionType<F>>::MleVec,
142    ) -> Vec<usize> {
143        let nonlinear_indices_in_node = {
144            match self {
145                // The only case where an index is nonlinear is if it is present in multiple mle
146                // refs that are part of a product. We iterate through all the indices in the
147                // product nodes to look for repeated indices within a single node.
148                ExpressionNode::Product(verifier_mles) => {
149                    let mut product_nonlinear_indices: HashSet<usize> = HashSet::new();
150                    let mut product_indices_counts: HashMap<MleIndex<F>, usize> = HashMap::new();
151
152                    verifier_mles.iter().for_each(|verifier_mle| {
153                        verifier_mle.var_indices().iter().for_each(|mle_index| {
154                            let curr_count = {
155                                if product_indices_counts.contains_key(mle_index) {
156                                    product_indices_counts.get(mle_index).unwrap()
157                                } else {
158                                    &0
159                                }
160                            };
161                            product_indices_counts.insert(mle_index.clone(), curr_count + 1);
162                        })
163                    });
164
165                    product_indices_counts
166                        .into_iter()
167                        .for_each(|(mle_index, count)| {
168                            if count > 1 {
169                                if let MleIndex::Indexed(i) = mle_index {
170                                    product_nonlinear_indices.insert(i);
171                                }
172                            }
173                        });
174
175                    product_nonlinear_indices
176                }
177                // for the rest of the types of expressions, we simply traverse through the expression node to look
178                // for more leaves which are specifically product nodes.
179                ExpressionNode::Selector(_sel_index, a, b) => {
180                    let mut sel_nonlinear_indices: HashSet<usize> = HashSet::new();
181                    let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
182                    let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
183                    a_indices
184                        .into_iter()
185                        .zip(b_indices)
186                        .for_each(|(a_mle_idx, b_mle_idx)| {
187                            sel_nonlinear_indices.insert(a_mle_idx);
188                            sel_nonlinear_indices.insert(b_mle_idx);
189                        });
190                    sel_nonlinear_indices
191                }
192                ExpressionNode::Sum(a, b) => {
193                    let mut sum_nonlinear_indices: HashSet<usize> = HashSet::new();
194                    let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
195                    let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
196                    a_indices
197                        .into_iter()
198                        .zip(b_indices)
199                        .for_each(|(a_mle_idx, b_mle_idx)| {
200                            sum_nonlinear_indices.insert(a_mle_idx);
201                            sum_nonlinear_indices.insert(b_mle_idx);
202                        });
203                    sum_nonlinear_indices
204                }
205                ExpressionNode::Scaled(a, _) => a
206                    .get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec)
207                    .into_iter()
208                    .collect(),
209                ExpressionNode::Constant(_) | ExpressionNode::Mle(_) => HashSet::new(),
210            }
211        };
212        // we grab all of the indices and take the union of all of them to return all nonlinear rounds in an expression tree.
213        nonlinear_indices_in_node.into_iter().for_each(|index| {
214            if !curr_nonlinear_indices.contains(&index) {
215                curr_nonlinear_indices.push(index);
216            }
217        });
218        curr_nonlinear_indices.clone()
219    }
220}
221
222impl<F: std::fmt::Debug + Field> std::fmt::Debug for Expression<F, VerifierExpr> {
223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        f.debug_struct("Circuit Expression")
225            .field("Expression_Node", &self.expression_node)
226            .finish()
227    }
228}
229
230impl<F: std::fmt::Debug + Field> std::fmt::Debug for ExpressionNode<F, VerifierExpr> {
231    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232        match self {
233            ExpressionNode::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(),
234            ExpressionNode::Selector(index, a, b) => f
235                .debug_tuple("Selector")
236                .field(index)
237                .field(a)
238                .field(b)
239                .finish(),
240            // Skip enum variant and print query struct directly to maintain backwards compatibility.
241            ExpressionNode::Mle(mle) => f.debug_struct("Circuit Mle").field("mle", mle).finish(),
242            ExpressionNode::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
243            ExpressionNode::Product(a) => f.debug_tuple("Product").field(a).finish(),
244            ExpressionNode::Scaled(poly, scalar) => {
245                f.debug_tuple("Scaled").field(poly).field(scalar).finish()
246            }
247        }
248    }
249}