remainder/sumcheck.rs
1//! Contains cryptographic algorithms for going through the sumcheck protocol in
2//! the context of a GKR prover.
3//!
4//! Let `P: F^n -> F` denote the polynomial [Expression] used to define some GKR
5//! Layer. This means that the value at a certain index `b \in {0, 1}^n` of the
6//! layer is given by `P(b)`. Denote by `V: {0, 1}^n -> F` the restriction of
7//! `P` on the hypercube.
8//!
9//! As part of the GKR protocol, the prover needs to assert the following
10//! statement about the multilinear extention `\tilde{V}: F^n -> F` of `V`:
11//! ```text
12//! \tilde{V}(g_1, ..., g_n) = r \in F`,
13//! for some challenges g_1, ..., g_n \in F (1)
14//! ```
15//! (Note that, in general, `P` and `\tilde{V}` are different functions. They
16//! are both extensions of `V`, but `\tilde{V}` is a linear polynomial on each
17//! of it's variables).
18//!
19//! The left-hand side of (1) can be expressed as a sum over the hypercube as
20//! follows:
21//! ```text
22//! \sum_{b_1 \in {0, 1}}
23//! \sum_{b_2 \in {0, 1}}
24//! ...
25//! \sum_{b_n \in {0, 1}}
26//! \beta(b_1, ..., b_n, g_1, ..., g_n) * P(b_1, b_2, ..., b_n) = r (2)
27//! ```
28//! where `\beta` is the following polynomial extending the equality predicate:
29//! ```text
30//! \beta(b_1, ..., b_n, g_1, ..., g_n) =
31//! \prod_{i = 1}^n [ b_i * g_i + (1 - b_i) * (1 - g_i) ]
32//! ```
33//!
34//! The functions in this module run the sumcheck protocol on expressions of the
35//! form described in equation (2). See the documentation of
36//! `compute_sumcheck_message_beta_cascade` for more information.
37
38use std::{
39 iter::{repeat, successors},
40 ops::{Add, Mul, Neg},
41};
42
43/// Tests for sumcheck with various expressions.
44#[cfg(test)]
45pub mod tests;
46
47use anyhow::{anyhow, Result};
48use ark_std::{cfg_chunks, cfg_into_iter};
49use itertools::{repeat_n, Itertools};
50use thiserror::Error;
51
52use crate::{
53 expression::{
54 generic_expr::{Expression, ExpressionNode, ExpressionType},
55 prover_expr::ProverExpr,
56 },
57 mle::{Mle, MleIndex},
58};
59#[cfg(feature = "parallel")]
60use rayon::iter::{IntoParallelIterator, ParallelIterator};
61#[cfg(feature = "parallel")]
62use rayon::prelude::ParallelSlice;
63
64use shared_types::Field;
65
66/// Errors to do with the evaluation of MleRefs.
67#[derive(Error, Debug, Clone, PartialEq)]
68pub enum MleError {
69 /// Passed list of Mles is empty.
70 #[error("Passed list of Mles is empty")]
71 EmptyMleList,
72
73 /// Beta table not yet initialized for Mle.
74 #[error("Beta table not yet initialized for Mle")]
75 NoBetaTable,
76
77 /// Layer does not have claims yet.
78 #[error("Layer does not have claims yet")]
79 NoClaim,
80
81 /// Unable to eval beta.
82 #[error("Unable to eval beta")]
83 BetaEvalError,
84
85 /// Cannot compute sumcheck message on un-indexed MLE.
86 #[error("Cannot compute sumcheck message on un-indexed MLE")]
87 NotIndexedError,
88}
89
90/// Verification error.
91#[derive(Error, Debug, Clone)]
92pub enum VerifyError {
93 /// Failed sumcheck round.
94 #[error("Failed sumcheck round")]
95 SumcheckBad,
96}
97
98/// Error when Interpolating a univariate polynomial.
99#[derive(Error, Debug, Clone)]
100pub enum InterpError {
101 /// Too few evaluation points.
102 #[error("Too few evaluation points")]
103 EvalLessThanDegree,
104
105 /// No possible polynomial.
106 #[error("No possible polynomial")]
107 NoInverse,
108}
109
110/// A type representing the univariate polynomial `g_i: F -> F` which the prover
111/// sends to the verifier in each round of sumcheck. Note that we are using an
112/// evaluation representation of polynomials, which means this type just holds
113/// the evaluations: `[g_i(0), g_i(1), ..., g_i(d)]`, where `d` is the degree of
114/// `g_i`.
115#[derive(PartialEq, Debug, Clone)]
116pub struct SumcheckEvals<F: Field>(pub Vec<F>);
117
118impl<F: Field> Neg for SumcheckEvals<F> {
119 type Output = Self;
120 fn neg(self) -> Self::Output {
121 // Negation for a bunch of eval points is just element-wise negation
122 SumcheckEvals(self.0.into_iter().map(|eval| eval.neg()).collect_vec())
123 }
124}
125
126impl<F: Field> Add for SumcheckEvals<F> {
127 type Output = Self;
128 fn add(self, rhs: Self) -> Self {
129 SumcheckEvals(
130 self.0
131 .into_iter()
132 .zip(rhs.0)
133 .map(|(lhs, rhs)| lhs + rhs)
134 .collect_vec(),
135 )
136 }
137}
138
139impl<F: Field> Mul<F> for SumcheckEvals<F> {
140 type Output = Self;
141 fn mul(self, rhs: F) -> Self {
142 SumcheckEvals(
143 self.0
144 .into_iter()
145 .zip(repeat(rhs))
146 .map(|(lhs, rhs)| lhs * rhs)
147 .collect_vec(),
148 )
149 }
150}
151
152impl<F: Field> Mul<&F> for SumcheckEvals<F> {
153 type Output = Self;
154 fn mul(self, rhs: &F) -> Self {
155 SumcheckEvals(
156 self.0
157 .into_iter()
158 .zip(repeat(rhs))
159 .map(|(lhs, rhs)| lhs * rhs)
160 .collect_vec(),
161 )
162 }
163}
164
165/// this function will take a list of mle refs, and compute the element-wise
166/// product of all of their bookkeeping tables along with the "successors."
167///
168/// for example, if we have two bookkeeping tables [a_1, a_2, a_3, a_4] and
169/// [c_1, c_2, c_3, c_4] and the degree of our expression at this index is 3, we
170/// need 4 evaluations for a unique curve. therefore first we will compute [a_1,
171/// a_2, (1-2)a_1 + 2a_2, (1-3)a_1 + 3a_2, a_3, a_4, (1-2)a_3 + 2a_4, (1-3)a_3 +
172/// 3a_4] and the same thing for the other mle and element-wise multiply both
173/// results. the resulting vector will always be size (degree + 1) * (2 ^
174/// (max_num_vars - 1))
175///
176/// this function assumes that the first variable is an independent variable.
177pub fn successors_from_mle_product<F: Field>(
178 mles: &[&impl Mle<F>],
179 degree: usize,
180 round_index: usize,
181) -> Result<Vec<Vec<F>>> {
182 // Gets the total number of free variables across all MLEs within this
183 // product
184 let mut max_num_vars = mles
185 .iter()
186 .map(|mle| mle.num_free_vars())
187 .max()
188 .ok_or(MleError::EmptyMleList)?;
189
190 let mles_have_independent_variable = mles
191 .iter()
192 .map(|mle| mle.mle_indices().contains(&MleIndex::Indexed(round_index)))
193 .reduce(|acc, item| acc | item)
194 .unwrap();
195
196 // We add 1 to the max number of variables if there is no independent variable
197 // to account for the independent variable contained within the beta.
198 if !mles_have_independent_variable {
199 max_num_vars += 1;
200 }
201
202 let successors_vec = cfg_into_iter!((0..1 << (max_num_vars - 1)))
203 .map(|mle_index| {
204 mles.iter()
205 .map(|mle| {
206 let num_coefficients_in_mle = 1 << mle.num_free_vars();
207
208 // Over here, we perform the wrap-around functionality if we
209 // are multiplying two mles with different number of
210 // variables. for example if we are multiplying V(b_1, b_2)
211 // * V(b_1), and summing over b_2, then the overall sum is
212 // V(b_1, 0) * V(b_1) + V(b_1, 1) * V(b_1). it can be seen
213 // that the "smaller" mle (the one over less variables) has
214 // to repeat itself an according number of times when the
215 // sum is over a variable it does not contain. the
216 // appropriate index is therefore determined as follows.
217 let mle_index = if mle.num_free_vars() < max_num_vars {
218 // If we have less than the max number of variables,
219 // then we perform this variable-repeat functionality by
220 // first rounding to the nearest power of 2, and then
221 // taking the floor of the index divided by the difference in
222 // power of 2.
223 let multiple = (1 << max_num_vars) / num_coefficients_in_mle;
224 mle_index / multiple
225 } else {
226 mle_index
227 };
228 // Over here, we get the elements in the pair so when index
229 // = 0, it's [0] and [1], if index = 1, it's [2] and [3],
230 // etc. because we are extending a function that was
231 // originally defined over the hypercube, each pair
232 // corresponds to two points on a line. we grab these two
233 // points here
234 let first = mle.get(mle_index).unwrap_or(F::ZERO);
235 let second = if mle.num_free_vars() != 0 {
236 mle.get(mle_index + (num_coefficients_in_mle / 2))
237 .unwrap_or(F::ZERO)
238 } else {
239 first
240 };
241 let step = second - first;
242
243 // creating the successors representing the evaluations for
244 // \pi_{i = 1}^n f_i(X, b_2, ..., b_n) across X = 0,1, 2,
245 // ... for a specific set of b_2, ..., b_n.
246 Box::new(successors(Some(first), move |item| Some(*item + step)))
247 as Box<dyn Iterator<Item = F> + Send>
248 })
249 .reduce(|mut a, mut b| {
250 Box::new(
251 successors(Some((a.next().unwrap(), b.next().unwrap())), move |_| {
252 Some((a.next().unwrap(), b.next().unwrap()))
253 })
254 .map(|(a_val, b_val)| a_val * b_val),
255 ) as Box<dyn Iterator<Item = F> + Send>
256 })
257 .unwrap()
258 .take(degree + 1)
259 .collect()
260 })
261 .collect();
262 Ok(successors_vec)
263}
264
265/// This is one step of the beta cascade algorithm, performing `(1 - beta_val) *
266/// mle[index] + beta_val * mle[index + 1]`
267pub(crate) fn beta_cascade_step<F: Field>(
268 mle_successor_vec: &[Vec<F>],
269 beta_val: F,
270) -> Vec<Vec<F>> {
271 let (one_minus_beta_val, beta_val) = (F::ONE - beta_val, beta_val);
272
273 cfg_chunks!(mle_successor_vec, 2)
274 .map(|successor_pair| {
275 let first_evals = &successor_pair[0];
276 let second_evals = &successor_pair[1];
277 let mut inner_result = Vec::with_capacity(first_evals.len());
278 inner_result.extend(
279 first_evals
280 .iter()
281 .zip(second_evals)
282 .map(|(fold_a, fold_b)| one_minus_beta_val * fold_a + beta_val * fold_b),
283 );
284 inner_result
285 })
286 .collect()
287}
288
289/// This is the final step of beta cascade, where we take all the "bound" beta
290/// values and scale all of the evaluations by the product of all of these
291/// values.
292pub(crate) fn apply_updated_beta_values_to_evals<F: Field>(
293 evals: Vec<F>,
294 folded_updated_vals: F,
295) -> SumcheckEvals<F> {
296 let evals = evals
297 .iter()
298 .map(|elem| folded_updated_vals * elem)
299 .collect_vec();
300
301 SumcheckEvals(evals)
302}
303
304/// this is how we compute the evaluations of a product of mle refs along with a
305/// beta table. rather than using the full expanded version of a beta table, we
306/// instead just use the beta values vectors (which are unbound beta values, and
307/// the bound beta values) there are (degree + 1) evaluations that are returned
308/// which are the evaluations of the univariate polynomial where the
309/// "round_index"-th bit is the independent variable.
310pub fn beta_cascade<F: Field>(
311 mles: &[&impl Mle<F>],
312 degree: usize,
313 round_index: usize,
314 beta_vals_vec: &[Vec<F>],
315 beta_updated_vals_vec: &[Vec<F>],
316 random_coefficients: &[F],
317) -> SumcheckEvals<F> {
318 // Check that the number of beta values that we have is equal to the number
319 // of random coefficients, which must be the same because these are the
320 // number of claims we are aggregating over.
321 assert_eq!(beta_vals_vec.len(), beta_updated_vals_vec.len());
322 assert_eq!(beta_vals_vec.len(), random_coefficients.len());
323
324 let mle_successor_vec = successors_from_mle_product(mles, degree, round_index).unwrap();
325 // We compute the sumcheck evaluations using beta cascade for the same
326 // set of MLE successors, but different beta values. All of these are
327 // stored in the iterator.
328 let evals_iter = (beta_vals_vec.iter().zip(beta_updated_vals_vec))
329 .zip(random_coefficients)
330 .map(|((beta_vals, beta_updated_vals), random_coeff)| {
331 // Apply beta cascade steps, reducing `mle_successor_vec` size
332 // progressively.
333 let final_successor_vec = if beta_vals.len() > 1 {
334 let mut current_successor_vec =
335 beta_cascade_step(&mle_successor_vec, *beta_vals.last().unwrap());
336 // All the skips, a really gross way of making sure we
337 // don't clone all of mle_successor_vec each time.
338 for val in beta_vals.iter().skip(1).rev().skip(1) {
339 // Apply beta cascade step and return the new vector, replacing
340 // the previous one
341 current_successor_vec = beta_cascade_step(¤t_successor_vec, *val);
342 }
343 current_successor_vec
344 } else {
345 // Only clone if this is going to be the final one we fold to get evaluations.
346 mle_successor_vec.clone()
347 };
348
349 // Check that mle_successor_vec now contains only one element after
350 // cascading
351 assert_eq!(final_successor_vec.len(), 1);
352
353 // Extract the remaining iterator from mle_successor_vec by popping it
354 let folded_mle_successors = &final_successor_vec[0];
355 // for the MSB of the beta value, this must be
356 // the independent variable. otherwise it would already be bound.
357 // therefore we need to compute the successors of this value in order to
358 // get its evaluations.
359 let evals = if !beta_vals.is_empty() {
360 let second_beta_successor = beta_vals[0];
361 let first_beta_successor = F::ONE - second_beta_successor;
362 let step = second_beta_successor - first_beta_successor;
363 let beta_successors =
364 std::iter::successors(Some(first_beta_successor), move |item| {
365 Some(*item + step)
366 });
367 // the length of the mle successor vec before this last step must be
368 // degree + 1. therefore we can just do a zip with the beta
369 // successors to get the final degree + 1 evaluations.
370 beta_successors
371 .zip(folded_mle_successors)
372 .map(|(beta_succ, mle_succ)| beta_succ * mle_succ)
373 .take(degree + 1)
374 .collect_vec()
375 } else {
376 vec![F::ONE]
377 };
378 // apply the bound beta values as a scalar factor to each of the
379 // evaluations Multiply by the random coefficient to get the
380 // random linear combination by summing at the end.
381 apply_updated_beta_values_to_evals(evals, beta_updated_vals.iter().product())
382 * random_coeff
383 });
384 // Combine all the evaluations using a random linear combination. We
385 // simply sum because all evaluations are already multiplied by their
386 // random coefficient.
387 evals_iter.reduce(|acc, elem| acc + elem).unwrap()
388}
389
390/// Similar to [beta_cascade], but does not compute any evaluations and
391/// simply multiplies the appropriate beta values by the evaluated bookkeeping table.
392pub fn beta_cascade_no_independent_variable<F: Field>(
393 mut evaluated_bookkeeping_table: Vec<F>,
394 beta_vals: &[F],
395 beta_updated_vals: &[F],
396 degree: usize,
397) -> SumcheckEvals<F> {
398 if evaluated_bookkeeping_table.len() > 1 {
399 beta_vals.iter().rev().for_each(|beta_val| {
400 let (one_minus_beta_val, beta_val) = (F::ONE - beta_val, beta_val);
401 evaluated_bookkeeping_table = evaluated_bookkeeping_table
402 .chunks(2)
403 .map(|bits| bits[0] * one_minus_beta_val + bits[1] * beta_val)
404 .collect_vec();
405 });
406 }
407
408 assert_eq!(evaluated_bookkeeping_table.len(), 1);
409 let eval_vec: Vec<F> = repeat_n(evaluated_bookkeeping_table[0], degree + 1).collect();
410
411 apply_updated_beta_values_to_evals(eval_vec, beta_updated_vals.iter().product())
412}
413
414/// Returns the maximum degree of b_{curr_round} within an expression (and
415/// therefore the number of prover messages we need to send)
416pub(crate) fn get_round_degree<F: Field>(
417 expr: &Expression<F, ProverExpr>,
418 curr_round: usize,
419) -> usize {
420 // By default, all rounds have degree at least 2 (beta table included)
421 let mut round_degree = 1;
422
423 let mut get_degree_closure = |expr: &ExpressionNode<F, ProverExpr>,
424 mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec|
425 -> Result<()> {
426 let round_degree = &mut round_degree;
427
428 // The only exception is within a product of MLEs
429 if let ExpressionNode::Product(mle_vec_indices) = expr {
430 let mut product_round_degree: usize = 0;
431 for mle_vec_index in mle_vec_indices {
432 let mle = mle_vec_index.get_mle(mle_vec);
433
434 let mle_indices = mle.mle_indices();
435 for mle_index in mle_indices {
436 if *mle_index == MleIndex::Indexed(curr_round) {
437 product_round_degree += 1;
438 break;
439 }
440 }
441 }
442 if *round_degree < product_round_degree {
443 *round_degree = product_round_degree;
444 }
445 }
446 Ok(())
447 };
448
449 expr.traverse(&mut get_degree_closure).unwrap();
450 // add 1 cuz beta table but idk if we would ever use this without a beta
451 // table
452 round_degree + 1
453}
454
455/// Use degree + 1 evaluations to figure out the evaluation at some arbitrary
456/// point
457pub fn evaluate_at_a_point<F: Field>(given_evals: &[F], point: F) -> Result<F> {
458 // Special case for the constant polynomial.
459 if given_evals.len() == 1 {
460 return Ok(given_evals[0]);
461 }
462
463 debug_assert!(given_evals.len() > 1);
464
465 // Special cases for `point == 0` and `point == 1`.
466 if point == F::ZERO {
467 return Ok(given_evals[0]);
468 }
469 if point == F::ONE {
470 return Ok(*given_evals.get(1).unwrap_or(&given_evals[0]));
471 }
472
473 // Need degree + 1 evaluations to interpolate
474 let eval = (0..given_evals.len())
475 .map(
476 // Create an iterator of everything except current value
477 |x| {
478 (0..x)
479 .chain(x + 1..given_evals.len())
480 .map(|x| F::from(x as u64))
481 .fold(
482 // Compute vector of (numerator, denominator)
483 (F::ONE, F::ONE),
484 |(num, denom), val| {
485 (num * (point - val), denom * (F::from(x as u64) - val))
486 },
487 )
488 },
489 )
490 .enumerate()
491 .map(
492 // Add up barycentric weight * current eval at point
493 |(x, (num, denom))| given_evals[x] * num * denom.invert().unwrap(),
494 )
495 .reduce(|x, y| x + y);
496 eval.ok_or(anyhow!("Interpretation Error: No Inverse"))
497}