remainder/
sumcheck.rs

1//! Contains cryptographic algorithms for going through the sumcheck protocol in
2//! the context of a GKR prover.
3//!
4//! Let `P: F^n -> F` denote the polynomial [Expression] used to define some GKR
5//! Layer. This means that the value at a certain index `b \in {0, 1}^n` of the
6//! layer is given by `P(b)`. Denote by `V: {0, 1}^n -> F` the restriction of
7//! `P` on the hypercube.
8//!
9//! As part of the GKR protocol, the prover needs to assert the following
10//! statement about the multilinear extention `\tilde{V}: F^n -> F` of `V`:
11//! ```text
12//!     \tilde{V}(g_1, ..., g_n) = r \in F`,
13//!         for some challenges g_1, ..., g_n \in F                    (1)
14//! ```
15//! (Note that, in general, `P` and `\tilde{V}` are different functions. They
16//!  are both extensions of `V`, but `\tilde{V}` is a linear polynomial on each
17//!  of it's variables).
18//!
19//! The left-hand side of (1) can be expressed as a sum over the hypercube as
20//! follows:
21//! ```text
22//!     \sum_{b_1 \in {0, 1}}
23//!     \sum_{b_2 \in {0, 1}}
24//!         ...
25//!     \sum_{b_n \in {0, 1}}
26//!        \beta(b_1, ..., b_n, g_1, ..., g_n) * P(b_1, b_2, ..., b_n) = r  (2)
27//! ```
28//! where `\beta` is the following polynomial extending the equality predicate:
29//! ```text
30//!     \beta(b_1, ..., b_n, g_1, ..., g_n) =
31//!         \prod_{i = 1}^n [ b_i * g_i + (1 - b_i) * (1 - g_i) ]
32//! ```
33//!
34//! The functions in this module run the sumcheck protocol on expressions of the
35//! form described in equation (2). See the documentation of
36//! `compute_sumcheck_message_beta_cascade` for more information.
37
38use std::{
39    iter::{repeat, successors},
40    ops::{Add, Mul, Neg},
41};
42
43/// Tests for sumcheck with various expressions.
44#[cfg(test)]
45pub mod tests;
46
47use anyhow::{anyhow, Result};
48use ark_std::{cfg_chunks, cfg_into_iter};
49use itertools::{repeat_n, Itertools};
50use thiserror::Error;
51
52use crate::{
53    expression::{
54        generic_expr::{Expression, ExpressionNode, ExpressionType},
55        prover_expr::ProverExpr,
56    },
57    mle::{Mle, MleIndex},
58};
59#[cfg(feature = "parallel")]
60use rayon::iter::{IntoParallelIterator, ParallelIterator};
61#[cfg(feature = "parallel")]
62use rayon::prelude::ParallelSlice;
63
64use shared_types::Field;
65
66/// Errors to do with the evaluation of MleRefs.
67#[derive(Error, Debug, Clone, PartialEq)]
68pub enum MleError {
69    /// Passed list of Mles is empty.
70    #[error("Passed list of Mles is empty")]
71    EmptyMleList,
72
73    /// Beta table not yet initialized for Mle.
74    #[error("Beta table not yet initialized for Mle")]
75    NoBetaTable,
76
77    /// Layer does not have claims yet.
78    #[error("Layer does not have claims yet")]
79    NoClaim,
80
81    /// Unable to eval beta.
82    #[error("Unable to eval beta")]
83    BetaEvalError,
84
85    /// Cannot compute sumcheck message on un-indexed MLE.
86    #[error("Cannot compute sumcheck message on un-indexed MLE")]
87    NotIndexedError,
88}
89
90/// Verification error.
91#[derive(Error, Debug, Clone)]
92pub enum VerifyError {
93    /// Failed sumcheck round.
94    #[error("Failed sumcheck round")]
95    SumcheckBad,
96}
97
98/// Error when Interpolating a univariate polynomial.
99#[derive(Error, Debug, Clone)]
100pub enum InterpError {
101    /// Too few evaluation points.
102    #[error("Too few evaluation points")]
103    EvalLessThanDegree,
104
105    /// No possible polynomial.
106    #[error("No possible polynomial")]
107    NoInverse,
108}
109
110/// A type representing the univariate polynomial `g_i: F -> F` which the prover
111/// sends to the verifier in each round of sumcheck. Note that we are using an
112/// evaluation representation of polynomials, which means this type just holds
113/// the evaluations: `[g_i(0), g_i(1), ..., g_i(d)]`, where `d` is the degree of
114/// `g_i`.
115#[derive(PartialEq, Debug, Clone)]
116pub struct SumcheckEvals<F: Field>(pub Vec<F>);
117
118impl<F: Field> Neg for SumcheckEvals<F> {
119    type Output = Self;
120    fn neg(self) -> Self::Output {
121        // Negation for a bunch of eval points is just element-wise negation
122        SumcheckEvals(self.0.into_iter().map(|eval| eval.neg()).collect_vec())
123    }
124}
125
126impl<F: Field> Add for SumcheckEvals<F> {
127    type Output = Self;
128    fn add(self, rhs: Self) -> Self {
129        SumcheckEvals(
130            self.0
131                .into_iter()
132                .zip(rhs.0)
133                .map(|(lhs, rhs)| lhs + rhs)
134                .collect_vec(),
135        )
136    }
137}
138
139impl<F: Field> Mul<F> for SumcheckEvals<F> {
140    type Output = Self;
141    fn mul(self, rhs: F) -> Self {
142        SumcheckEvals(
143            self.0
144                .into_iter()
145                .zip(repeat(rhs))
146                .map(|(lhs, rhs)| lhs * rhs)
147                .collect_vec(),
148        )
149    }
150}
151
152impl<F: Field> Mul<&F> for SumcheckEvals<F> {
153    type Output = Self;
154    fn mul(self, rhs: &F) -> Self {
155        SumcheckEvals(
156            self.0
157                .into_iter()
158                .zip(repeat(rhs))
159                .map(|(lhs, rhs)| lhs * rhs)
160                .collect_vec(),
161        )
162    }
163}
164
165/// this function will take a list of mle refs, and compute the element-wise
166/// product of all of their bookkeeping tables along with the "successors."
167///
168/// for example, if we have two bookkeeping tables [a_1, a_2, a_3, a_4] and
169/// [c_1, c_2, c_3, c_4] and the degree of our expression at this index is 3, we
170/// need 4 evaluations for a unique curve. therefore first we will compute [a_1,
171/// a_2, (1-2)a_1 + 2a_2, (1-3)a_1 + 3a_2, a_3, a_4, (1-2)a_3 + 2a_4, (1-3)a_3 +
172/// 3a_4] and the same thing for the other mle and element-wise multiply both
173/// results. the resulting vector will always be size (degree + 1) * (2 ^
174/// (max_num_vars - 1))
175///
176/// this function assumes that the first variable is an independent variable.
177pub fn successors_from_mle_product<F: Field>(
178    mles: &[&impl Mle<F>],
179    degree: usize,
180    round_index: usize,
181) -> Result<Vec<Vec<F>>> {
182    // Gets the total number of free variables across all MLEs within this
183    // product
184    let mut max_num_vars = mles
185        .iter()
186        .map(|mle| mle.num_free_vars())
187        .max()
188        .ok_or(MleError::EmptyMleList)?;
189
190    let mles_have_independent_variable = mles
191        .iter()
192        .map(|mle| mle.mle_indices().contains(&MleIndex::Indexed(round_index)))
193        .reduce(|acc, item| acc | item)
194        .unwrap();
195
196    // We add 1 to the max number of variables if there is no independent variable
197    // to account for the independent variable contained within the beta.
198    if !mles_have_independent_variable {
199        max_num_vars += 1;
200    }
201
202    let successors_vec = cfg_into_iter!((0..1 << (max_num_vars - 1)))
203        .map(|mle_index| {
204            mles.iter()
205                .map(|mle| {
206                    let num_coefficients_in_mle = 1 << mle.num_free_vars();
207
208                    // Over here, we perform the wrap-around functionality if we
209                    // are multiplying two mles with different number of
210                    // variables. for example if we are multiplying V(b_1, b_2)
211                    // * V(b_1), and summing over b_2, then the overall sum is
212                    // V(b_1, 0) * V(b_1) + V(b_1, 1) * V(b_1). it can be seen
213                    // that the "smaller" mle (the one over less variables) has
214                    // to repeat itself an according number of times when the
215                    // sum is over a variable it does not contain. the
216                    // appropriate index is therefore determined as follows.
217                    let mle_index = if mle.num_free_vars() < max_num_vars {
218                        // If we have less than the max number of variables,
219                        // then we perform this variable-repeat functionality by
220                        // first rounding to the nearest power of 2, and then
221                        // taking the floor of the index divided by the difference in
222                        // power of 2.
223                        let multiple = (1 << max_num_vars) / num_coefficients_in_mle;
224                        mle_index / multiple
225                    } else {
226                        mle_index
227                    };
228                    // Over here, we get the elements in the pair so when index
229                    // = 0, it's [0] and [1], if index = 1, it's [2] and [3],
230                    // etc. because we are extending a function that was
231                    // originally defined over the hypercube, each pair
232                    // corresponds to two points on a line. we grab these two
233                    // points here
234                    let first = mle.get(mle_index).unwrap_or(F::ZERO);
235                    let second = if mle.num_free_vars() != 0 {
236                        mle.get(mle_index + (num_coefficients_in_mle / 2))
237                            .unwrap_or(F::ZERO)
238                    } else {
239                        first
240                    };
241                    let step = second - first;
242
243                    // creating the successors representing the evaluations for
244                    // \pi_{i = 1}^n f_i(X, b_2, ..., b_n) across X = 0,1, 2,
245                    // ... for a specific set of b_2, ..., b_n.
246                    Box::new(successors(Some(first), move |item| Some(*item + step)))
247                        as Box<dyn Iterator<Item = F> + Send>
248                })
249                .reduce(|mut a, mut b| {
250                    Box::new(
251                        successors(Some((a.next().unwrap(), b.next().unwrap())), move |_| {
252                            Some((a.next().unwrap(), b.next().unwrap()))
253                        })
254                        .map(|(a_val, b_val)| a_val * b_val),
255                    ) as Box<dyn Iterator<Item = F> + Send>
256                })
257                .unwrap()
258                .take(degree + 1)
259                .collect()
260        })
261        .collect();
262    Ok(successors_vec)
263}
264
265/// This is one step of the beta cascade algorithm, performing `(1 - beta_val) *
266/// mle[index] + beta_val * mle[index + 1]`
267pub(crate) fn beta_cascade_step<F: Field>(
268    mle_successor_vec: &[Vec<F>],
269    beta_val: F,
270) -> Vec<Vec<F>> {
271    let (one_minus_beta_val, beta_val) = (F::ONE - beta_val, beta_val);
272
273    cfg_chunks!(mle_successor_vec, 2)
274        .map(|successor_pair| {
275            let first_evals = &successor_pair[0];
276            let second_evals = &successor_pair[1];
277            let mut inner_result = Vec::with_capacity(first_evals.len());
278            inner_result.extend(
279                first_evals
280                    .iter()
281                    .zip(second_evals)
282                    .map(|(fold_a, fold_b)| one_minus_beta_val * fold_a + beta_val * fold_b),
283            );
284            inner_result
285        })
286        .collect()
287}
288
289/// This is the final step of beta cascade, where we take all the "bound" beta
290/// values and scale all of the evaluations by the product of all of these
291/// values.
292pub(crate) fn apply_updated_beta_values_to_evals<F: Field>(
293    evals: Vec<F>,
294    folded_updated_vals: F,
295) -> SumcheckEvals<F> {
296    let evals = evals
297        .iter()
298        .map(|elem| folded_updated_vals * elem)
299        .collect_vec();
300
301    SumcheckEvals(evals)
302}
303
304/// this is how we compute the evaluations of a product of mle refs along with a
305/// beta table. rather than using the full expanded version of a beta table, we
306/// instead just use the beta values vectors (which are unbound beta values, and
307/// the bound beta values) there are (degree + 1) evaluations that are returned
308/// which are the evaluations of the univariate polynomial where the
309/// "round_index"-th bit is the independent variable.
310pub fn beta_cascade<F: Field>(
311    mles: &[&impl Mle<F>],
312    degree: usize,
313    round_index: usize,
314    beta_vals_vec: &[Vec<F>],
315    beta_updated_vals_vec: &[Vec<F>],
316    random_coefficients: &[F],
317) -> SumcheckEvals<F> {
318    // Check that the number of beta values that we have is equal to the number
319    // of random coefficients, which must be the same because these are the
320    // number of claims we are aggregating over.
321    assert_eq!(beta_vals_vec.len(), beta_updated_vals_vec.len());
322    assert_eq!(beta_vals_vec.len(), random_coefficients.len());
323
324    let mle_successor_vec = successors_from_mle_product(mles, degree, round_index).unwrap();
325    // We compute the sumcheck evaluations using beta cascade for the same
326    // set of MLE successors, but different beta values. All of these are
327    // stored in the iterator.
328    let evals_iter = (beta_vals_vec.iter().zip(beta_updated_vals_vec))
329        .zip(random_coefficients)
330        .map(|((beta_vals, beta_updated_vals), random_coeff)| {
331            // Apply beta cascade steps, reducing `mle_successor_vec` size
332            // progressively.
333            let final_successor_vec = if beta_vals.len() > 1 {
334                let mut current_successor_vec =
335                    beta_cascade_step(&mle_successor_vec, *beta_vals.last().unwrap());
336                // All the skips, a really gross way of making sure we
337                // don't clone all of mle_successor_vec each time.
338                for val in beta_vals.iter().skip(1).rev().skip(1) {
339                    // Apply beta cascade step and return the new vector, replacing
340                    // the previous one
341                    current_successor_vec = beta_cascade_step(&current_successor_vec, *val);
342                }
343                current_successor_vec
344            } else {
345                // Only clone if this is going to be the final one we fold to get evaluations.
346                mle_successor_vec.clone()
347            };
348
349            // Check that mle_successor_vec now contains only one element after
350            // cascading
351            assert_eq!(final_successor_vec.len(), 1);
352
353            // Extract the remaining iterator from mle_successor_vec by popping it
354            let folded_mle_successors = &final_successor_vec[0];
355            // for the MSB of the beta value, this must be
356            // the independent variable. otherwise it would already be bound.
357            // therefore we need to compute the successors of this value in order to
358            // get its evaluations.
359            let evals = if !beta_vals.is_empty() {
360                let second_beta_successor = beta_vals[0];
361                let first_beta_successor = F::ONE - second_beta_successor;
362                let step = second_beta_successor - first_beta_successor;
363                let beta_successors =
364                    std::iter::successors(Some(first_beta_successor), move |item| {
365                        Some(*item + step)
366                    });
367                // the length of the mle successor vec before this last step must be
368                // degree + 1. therefore we can just do a zip with the beta
369                // successors to get the final degree + 1 evaluations.
370                beta_successors
371                    .zip(folded_mle_successors)
372                    .map(|(beta_succ, mle_succ)| beta_succ * mle_succ)
373                    .take(degree + 1)
374                    .collect_vec()
375            } else {
376                vec![F::ONE]
377            };
378            // apply the bound beta values as a scalar factor to each of the
379            // evaluations Multiply by the random coefficient to get the
380            // random linear combination by summing at the end.
381            apply_updated_beta_values_to_evals(evals, beta_updated_vals.iter().product())
382                * random_coeff
383        });
384    // Combine all the evaluations using a random linear combination. We
385    // simply sum because all evaluations are already multiplied by their
386    // random coefficient.
387    evals_iter.reduce(|acc, elem| acc + elem).unwrap()
388}
389
390/// Similar to [beta_cascade], but does not compute any evaluations and
391/// simply multiplies the appropriate beta values by the evaluated bookkeeping table.
392pub fn beta_cascade_no_independent_variable<F: Field>(
393    mut evaluated_bookkeeping_table: Vec<F>,
394    beta_vals: &[F],
395    beta_updated_vals: &[F],
396    degree: usize,
397) -> SumcheckEvals<F> {
398    if evaluated_bookkeeping_table.len() > 1 {
399        beta_vals.iter().rev().for_each(|beta_val| {
400            let (one_minus_beta_val, beta_val) = (F::ONE - beta_val, beta_val);
401            evaluated_bookkeeping_table = evaluated_bookkeeping_table
402                .chunks(2)
403                .map(|bits| bits[0] * one_minus_beta_val + bits[1] * beta_val)
404                .collect_vec();
405        });
406    }
407
408    assert_eq!(evaluated_bookkeeping_table.len(), 1);
409    let eval_vec: Vec<F> = repeat_n(evaluated_bookkeeping_table[0], degree + 1).collect();
410
411    apply_updated_beta_values_to_evals(eval_vec, beta_updated_vals.iter().product())
412}
413
414/// Returns the maximum degree of b_{curr_round} within an expression (and
415/// therefore the number of prover messages we need to send)
416pub(crate) fn get_round_degree<F: Field>(
417    expr: &Expression<F, ProverExpr>,
418    curr_round: usize,
419) -> usize {
420    // By default, all rounds have degree at least 2 (beta table included)
421    let mut round_degree = 1;
422
423    let mut get_degree_closure = |expr: &ExpressionNode<F, ProverExpr>,
424                                  mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec|
425     -> Result<()> {
426        let round_degree = &mut round_degree;
427
428        // The only exception is within a product of MLEs
429        if let ExpressionNode::Product(mle_vec_indices) = expr {
430            let mut product_round_degree: usize = 0;
431            for mle_vec_index in mle_vec_indices {
432                let mle = mle_vec_index.get_mle(mle_vec);
433
434                let mle_indices = mle.mle_indices();
435                for mle_index in mle_indices {
436                    if *mle_index == MleIndex::Indexed(curr_round) {
437                        product_round_degree += 1;
438                        break;
439                    }
440                }
441            }
442            if *round_degree < product_round_degree {
443                *round_degree = product_round_degree;
444            }
445        }
446        Ok(())
447    };
448
449    expr.traverse(&mut get_degree_closure).unwrap();
450    // add 1 cuz beta table but idk if we would ever use this without a beta
451    // table
452    round_degree + 1
453}
454
455/// Use degree + 1 evaluations to figure out the evaluation at some arbitrary
456/// point
457pub fn evaluate_at_a_point<F: Field>(given_evals: &[F], point: F) -> Result<F> {
458    // Special case for the constant polynomial.
459    if given_evals.len() == 1 {
460        return Ok(given_evals[0]);
461    }
462
463    debug_assert!(given_evals.len() > 1);
464
465    // Special cases for `point == 0` and `point == 1`.
466    if point == F::ZERO {
467        return Ok(given_evals[0]);
468    }
469    if point == F::ONE {
470        return Ok(*given_evals.get(1).unwrap_or(&given_evals[0]));
471    }
472
473    // Need degree + 1 evaluations to interpolate
474    let eval = (0..given_evals.len())
475        .map(
476            // Create an iterator of everything except current value
477            |x| {
478                (0..x)
479                    .chain(x + 1..given_evals.len())
480                    .map(|x| F::from(x as u64))
481                    .fold(
482                        // Compute vector of (numerator, denominator)
483                        (F::ONE, F::ONE),
484                        |(num, denom), val| {
485                            (num * (point - val), denom * (F::from(x as u64) - val))
486                        },
487                    )
488            },
489        )
490        .enumerate()
491        .map(
492            // Add up barycentric weight * current eval at point
493            |(x, (num, denom))| given_evals[x] * num * denom.invert().unwrap(),
494        )
495        .reduce(|x, y| x + y);
496    eval.ok_or(anyhow!("Interpretation Error: No Inverse"))
497}