remainder/layer/
regular_layer.rs

1//! The implementation of `RegularLayer`
2
3#[cfg(test)]
4mod tests;
5
6use std::collections::{HashMap, HashSet};
7
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use shared_types::{
11    config::{global_config::global_claim_agg_strategy, ClaimAggregationStrategy},
12    transcript::{ProverTranscript, VerifierTranscript},
13    Field,
14};
15use tracing::info;
16
17use crate::{
18    circuit_layout::{CircuitEvalMap, CircuitLocation},
19    claims::{Claim, ClaimError, RawClaim},
20    expression::{
21        circuit_expr::{filter_bookkeeping_table, ExprDescription},
22        generic_expr::{Expression, ExpressionNode, ExpressionType},
23        prover_expr::ProverExpr,
24        verifier_expr::VerifierExpr,
25    },
26    layer::{Layer, LayerId, VerificationError},
27    mle::{betavalues::BetaValues, dense::DenseMle, mle_description::MleDescription, Mle},
28    sumcheck::{evaluate_at_a_point, get_round_degree},
29};
30
31use super::{
32    layer_enum::{LayerEnum, VerifierLayerEnum},
33    product::PostSumcheckLayer,
34};
35
36use super::{LayerDescription, VerifierLayer};
37
38use anyhow::{anyhow, Ok, Result};
39
40/// The most common implementation of [crate::layer::Layer].
41///
42/// A regular layer is made up of a structured polynomial relationship between
43/// MLEs of previous layers.
44///
45/// Proofs are generated with the Sumcheck protocol.
46#[derive(Serialize, Deserialize, Clone, Debug)]
47#[serde(bound = "F: Field")]
48pub struct RegularLayer<F: Field> {
49    /// This layer's ID.
50    id: LayerId,
51
52    /// The polynomial expression defining this layer.
53    /// It includes information on how this layer relates to the others.
54    pub(crate) expression: Expression<F, ProverExpr>,
55
56    /// Stores the indices of the sumcheck rounds in this GKR layer so we
57    /// only produce sumcheck proofs over those. When we use interpolative
58    /// claim aggregation, this is all of the nonlinear variables in the
59    /// expression. When we use RLC claim aggregation, this is all of the
60    /// variables in the expression.
61    sumcheck_rounds: Vec<usize>,
62
63    /// Stores the beta values associated with the `expression`.
64    /// Initially set to `None`. Computed during initialization.
65    beta_vals_vec: Option<Vec<BetaValues<F>>>,
66}
67
68impl<F: Field> RegularLayer<F> {
69    /// Creates a new `RegularLayer` from an `Expression` and a `LayerId`
70    ///
71    /// The `Expression` is the relationship this `Layer` proves
72    /// and the `LayerId` is the location of this `Layer` in the overall circuit
73    pub fn new_raw(id: LayerId, mut expression: Expression<F, ProverExpr>) -> Self {
74        // Compute nonlinear rounds from `expression`
75        expression.index_mle_indices(0);
76        let sumcheck_rounds = match global_claim_agg_strategy() {
77            ClaimAggregationStrategy::Interpolative => expression.get_all_nonlinear_rounds(),
78            ClaimAggregationStrategy::RLC => expression.get_all_rounds(),
79        };
80        RegularLayer {
81            id,
82            expression,
83            sumcheck_rounds,
84            beta_vals_vec: None,
85        }
86    }
87
88    /// Returns a reference to the expression that this layer is proving.
89    pub fn get_expression(&self) -> &Expression<F, ProverExpr> {
90        &self.expression
91    }
92
93    /// Traverse the fully-bound `self.expression` and append all MLE values
94    /// to the trascript.
95    pub fn append_leaf_mles_to_transcript(
96        &self,
97        transcript_writer: &mut impl ProverTranscript<F>,
98    ) -> Result<()> {
99        let mut observer_fn = |expr_node: &ExpressionNode<F, ProverExpr>,
100                               mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec|
101         -> Result<()> {
102            match expr_node {
103                ExpressionNode::Mle(mle_vec_index) => {
104                    let mle: &DenseMle<F> = &mle_vec[mle_vec_index.index()];
105                    let val = mle.mle.value();
106                    transcript_writer.append("Fully bound MLE evaluation", val);
107                    Ok(())
108                }
109                ExpressionNode::Product(mle_vec_indices) => {
110                    for mle_vec_index in mle_vec_indices {
111                        let mle = &mle_vec[mle_vec_index.index()];
112                        let eval = mle.mle.value();
113                        transcript_writer.append("Fully bound MLE evaluation", eval);
114                    }
115                    Ok(())
116                }
117                ExpressionNode::Constant(_)
118                | ExpressionNode::Scaled(_, _)
119                | ExpressionNode::Sum(_, _)
120                | ExpressionNode::Selector(_, _, _) => Ok(()),
121            }
122        };
123
124        let _ = self.expression.traverse(&mut observer_fn);
125
126        Ok(())
127    }
128}
129
130impl<F: Field> Layer<F> for RegularLayer<F> {
131    fn layer_id(&self) -> LayerId {
132        self.id
133    }
134
135    fn prove(
136        &mut self,
137        claims: &[&RawClaim<F>],
138        transcript_writer: &mut impl ProverTranscript<F>,
139    ) -> Result<()> {
140        info!("Proving a GKR Layer.");
141
142        // Initialize tables and pre-fix variables.
143        let random_coefficients = match global_claim_agg_strategy() {
144            ClaimAggregationStrategy::Interpolative => {
145                assert_eq!(claims.len(), 1);
146                self.initialize(claims[0].get_point())?;
147                vec![F::ONE]
148            }
149            ClaimAggregationStrategy::RLC => {
150                let random_coefficients =
151                    transcript_writer.get_challenges("RLC Claim Agg Coefficients", claims.len());
152                self.initialize_rlc(&random_coefficients, claims);
153                random_coefficients
154            }
155        };
156
157        let mut previous_round_message = vec![claims
158            .iter()
159            .zip(&random_coefficients)
160            .fold(F::ZERO, |acc, (claim, random_coeff)| {
161                acc + claim.get_eval() * random_coeff
162            })];
163        let mut previous_challenge = F::ZERO;
164
165        let layer_id = self.layer_id();
166        for round_index in self.sumcheck_rounds.clone() {
167            // First compute the appropriate number of univariate evaluations for this round.
168            let prover_sumcheck_message =
169                self.compute_round_sumcheck_message(round_index, &random_coefficients)?;
170            // In debug mode, catch sumcheck round errors from the prover side.
171            debug_assert_eq!(
172                evaluate_at_a_point(&previous_round_message, previous_challenge).unwrap(),
173                prover_sumcheck_message[0] + prover_sumcheck_message[1],
174                "failed at round {round_index}, layer {layer_id}",
175            );
176            // Append the evaluations to the transcript.
177            // Since the verifier can deduce g_i(0) by computing claim - g_i(1), the prover does not send g_i(0)
178            transcript_writer.append_elements(
179                "Sumcheck round univariate evaluations",
180                &prover_sumcheck_message[1..],
181            );
182            // Sample the challenge
183            let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
184            // "Bind" the challenge to the expression at this point.
185            self.bind_round_variable(round_index, challenge)?;
186            // For debug mode, update the previous message and challenge for the purpose
187            // of checking whether these still pass the sumcheck round checks.
188            previous_round_message = prover_sumcheck_message;
189            previous_challenge = challenge;
190        }
191
192        // By now, `self.expression` should be fully bound.
193        assert_eq!(self.expression.get_expression_num_free_variables(), 0);
194
195        // Append the values of the leaf MLEs to the transcript.
196        self.append_leaf_mles_to_transcript(transcript_writer)?;
197
198        Ok(())
199    }
200
201    /// Initialize all necessary information in order to start sumcheck within a
202    /// layer of GKR. This includes pre-fixing all of the rounds within the
203    /// layer which are linear, and then appropriately initializing the
204    /// necessary beta values over the nonlinear rounds.
205    fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
206        let expression = &mut self.expression;
207        let expression_nonlinear_indices = expression.get_all_nonlinear_rounds();
208        let expression_linear_indices = expression.get_all_linear_rounds();
209
210        // For each of the linear indices in the expression, we can fix the
211        // variable at that index of the expression, so that now the only
212        // unbound indices are the nonlinear indices.
213        expression_linear_indices
214            .iter()
215            .sorted()
216            .for_each(|round_idx| {
217                expression.fix_variable_at_index(*round_idx, claim_point[*round_idx]);
218            });
219
220        // We need the beta values over the nonlinear indices of the expression,
221        // so we grab the claim points that are over these nonlinear indices and
222        // then initialize the betavalues struct over them.
223        let betavec = expression_nonlinear_indices
224            .iter()
225            .map(|idx| (*idx, claim_point[*idx]))
226            .collect_vec();
227        let newbeta = BetaValues::new(betavec);
228        self.beta_vals_vec = Some(vec![newbeta]);
229
230        Ok(())
231    }
232
233    fn initialize_rlc(&mut self, _random_coefficients: &[F], claims: &[&RawClaim<F>]) {
234        // We need the beta values over all the indices of the expression, as we
235        // cannot perform the linear round optimization with RLC claim agg since
236        // we have multiple points to bind each MLE to.
237        let expression = &mut self.expression;
238        let expression_all_indices = expression.get_all_rounds();
239
240        let beta_vals_vec = claims
241            .iter()
242            .map(|claim| {
243                let claim_point = claim.get_point();
244                let betavec = expression_all_indices
245                    .iter()
246                    .map(|idx| (*idx, claim_point[*idx]))
247                    .collect_vec();
248                BetaValues::new(betavec)
249            })
250            .collect();
251        self.beta_vals_vec = Some(beta_vals_vec);
252    }
253
254    fn compute_round_sumcheck_message(
255        &mut self,
256        round_index: usize,
257        random_coefficients: &[F],
258    ) -> Result<Vec<F>> {
259        // Grabs the expression/beta table.
260        let expression = &self.expression;
261        let newbeta = &self.beta_vals_vec;
262
263        // Grabs the degree of univariate polynomial we are sending over.
264        let degree = get_round_degree(expression, round_index);
265
266        // Computes the sumcheck message using the beta cascade algorithm.
267        let prover_sumcheck_message = expression.evaluate_sumcheck_beta_cascade(
268            &newbeta.as_ref().unwrap().iter().collect_vec(),
269            random_coefficients,
270            round_index,
271            degree,
272        );
273
274        Ok(prover_sumcheck_message.0)
275    }
276
277    fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
278        // Grabs the expression/beta table.
279        let expression = &mut self.expression;
280        let beta_vals_vec = &mut self.beta_vals_vec;
281
282        // Update the bookkeeping tables as necessary.
283        expression.fix_variable(round_index, challenge);
284        beta_vals_vec
285            .as_mut()
286            .unwrap()
287            .iter_mut()
288            .for_each(|beta_vals| {
289                beta_vals.beta_update(round_index, challenge);
290            });
291
292        Ok(())
293    }
294
295    /// Returns the round indices (with respect to the indices of all relevant
296    /// variables within the layer) which are nonlinear. For example, if the
297    /// current layer's [Expression] looks something like
298    /// V_{i + 1}(x_1, x_2, x_3) * V_{i + 2}(x_1, x_2)
299    /// then the `sumcheck_round_indices` of this layer would be [1, 2].
300    fn sumcheck_round_indices(&self) -> Vec<usize> {
301        self.sumcheck_rounds.clone()
302    }
303
304    fn max_degree(&self) -> usize {
305        &self.expression.get_max_degree() + 1
306    }
307
308    /// Get the [PostSumcheckLayer] for a regular layer, which represents the fully bound expression.
309    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
310    fn get_post_sumcheck_layer(
311        &self,
312        round_challenges: &[F],
313        claim_challenges: &[&[F]],
314        random_coefficients: &[F],
315    ) -> PostSumcheckLayer<F, F> {
316        let sumcheck_round_indices = self.sumcheck_round_indices();
317        // Filter the claim to get the values of the claim pertaining to the nonlinear rounds.
318        let sumcheck_claim_points_vec = claim_challenges
319            .iter()
320            .map(|claim_challenge| {
321                claim_challenge
322                    .iter()
323                    .enumerate()
324                    .filter_map(|(idx, point)| {
325                        if sumcheck_round_indices.contains(&idx) {
326                            Some(*point)
327                        } else {
328                            None
329                        }
330                    })
331                    .collect_vec()
332            })
333            .collect_vec();
334
335        // Compute beta over these and the sumcheck challenges.
336        let rlc_beta = sumcheck_claim_points_vec
337            .iter()
338            .zip(random_coefficients)
339            .fold(F::ZERO, |acc, (elem, random_coeff)| {
340                assert_eq!(round_challenges.len(), elem.len());
341                let fully_bound_beta =
342                    BetaValues::compute_beta_over_two_challenges(round_challenges, elem);
343                acc + fully_bound_beta * random_coeff
344            });
345
346        self.expression.get_post_sumcheck_layer(rlc_beta)
347    }
348
349    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
350        // First off, parse the expression that is associated with the layer.
351        // Next, get to the actual claims that are generated by each expression and grab them
352        // Return basically a list of (usize, Claim)
353        let layerwise_expr = &self.expression;
354
355        let mut claims: Vec<Claim<F>> = Vec::new();
356
357        // Define how to parse the expression tree.
358        // Basically we just want to go down it and pass up claims.
359        // We can only add a new claim if we see an MLE with all its indices
360        // bound.
361        let mut observer_fn = |expr: &ExpressionNode<F, ProverExpr>,
362                               mle_vec: &<ProverExpr as ExpressionType<F>>::MleVec|
363         -> Result<()> {
364            match expr {
365                ExpressionNode::Mle(mle_vec_idx) => {
366                    let mle = mle_vec_idx.get_mle(mle_vec);
367
368                    let fixed_mle_indices = mle
369                        .mle_indices
370                        .iter()
371                        .map(|index| index.val().ok_or(anyhow!(ClaimError::MleRefMleError)))
372                        .collect::<Result<Vec<_>>>()?;
373
374                    // Grab the layer ID (i.e. MLE index) which this mle refers to
375                    let mle_layer_id = mle.layer_id();
376
377                    let claimed_value = mle.value();
378
379                    // Note: No need to append claim values here.
380                    // We already appended them when evaluating the
381                    // expression for sumcheck.
382
383                    // Construct the claim
384                    let claim = Claim::new(
385                        fixed_mle_indices,
386                        claimed_value,
387                        self.layer_id(),
388                        mle_layer_id,
389                    );
390
391                    // Push it into the list of claims
392                    claims.push(claim);
393                }
394                ExpressionNode::Product(mle_vec_indices) => {
395                    for mle_vec_index in mle_vec_indices {
396                        let mle = mle_vec_index.get_mle(mle_vec);
397                        let fixed_mle_indices = mle
398                            .mle_indices
399                            .iter()
400                            .map(|index| index.val().ok_or(anyhow!(ClaimError::MleRefMleError)))
401                            .collect::<Result<Vec<_>>>()?;
402
403                        // Grab the layer ID (i.e. MLE index) which this mle refers to
404                        let mle_layer_id = mle.layer_id();
405
406                        let claimed_value = mle.value();
407
408                        // Note: No need to append the claim value to the transcript here. We
409                        // already appended when evaluating the expression for sumcheck.
410
411                        // Construct the claim
412                        // need to populate the claim with the mle ref we are grabbing the claim from
413                        let claim = Claim::new(
414                            fixed_mle_indices,
415                            claimed_value,
416                            self.layer_id(),
417                            mle_layer_id,
418                        );
419
420                        // Push it into the list of claims
421                        claims.push(claim);
422                    }
423                }
424                _ => {}
425            }
426            Ok(())
427        };
428
429        // Apply the observer function from above onto the expression
430        layerwise_expr.traverse(&mut observer_fn)?;
431
432        Ok(claims)
433    }
434}
435
436/// The circuit description counterpart of a [RegularLayer].
437#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
438#[serde(bound = "F: Field")]
439pub struct RegularLayerDescription<F: Field> {
440    /// This layer's ID.
441    id: LayerId,
442
443    /// A structural description of the polynomial expression defining this
444    /// layer. The leaves of the expression describe the MLE characteristics
445    /// without storing any values.
446    expression: Expression<F, ExprDescription>,
447}
448
449impl<F: Field> RegularLayerDescription<F> {
450    /// Generates a new [RegularLayerDescription] given raw data.
451    pub fn new_raw(id: LayerId, expression: Expression<F, ExprDescription>) -> Self {
452        Self { id, expression }
453    }
454
455    /// Get the number of variables in the underlying expression of the layer.
456    pub fn get_num_vars(&self) -> usize {
457        self.expression.num_vars()
458    }
459}
460
461/// The verifier's counterpart of a [RegularLayer].
462#[derive(Serialize, Deserialize, Clone, Debug)]
463#[serde(bound = "F: Field")]
464pub struct VerifierRegularLayer<F: Field> {
465    /// This layer's ID.
466    id: LayerId,
467
468    /// A fully-bound expression defining the layer.
469    expression: Expression<F, VerifierExpr>,
470}
471
472impl<F: Field> VerifierRegularLayer<F> {
473    /// Generates a new [VerifierRegularLayer] given raw data.
474    pub(crate) fn new_raw(id: LayerId, expression: Expression<F, VerifierExpr>) -> Self {
475        Self { id, expression }
476    }
477}
478
479impl<F: Field> LayerDescription<F> for RegularLayerDescription<F> {
480    type VerifierLayer = VerifierRegularLayer<F>;
481
482    fn layer_id(&self) -> LayerId {
483        self.id
484    }
485
486    fn compute_data_outputs(
487        &self,
488        mle_outputs_necessary: &HashSet<&MleDescription<F>>,
489        circuit_map: &mut CircuitEvalMap<F>,
490    ) {
491        let mut expression_nodes_to_compile =
492            HashMap::<&ExpressionNode<F, ExprDescription>, Vec<(Vec<bool>, Vec<bool>)>>::new();
493
494        mle_outputs_necessary
495            .iter()
496            .for_each(|mle_output_necessary| {
497                let prefix_bits = mle_output_necessary.prefix_bits();
498                let mut unfiltered_prefix_bits: Vec<bool> = vec![];
499                let expression_node_to_compile = prefix_bits.iter().fold(
500                    &self.expression.expression_node,
501                    |acc, bit| match acc {
502                        ExpressionNode::Selector(_mle_index, lhs, rhs) => {
503                            if *bit {
504                                rhs
505                            } else {
506                                lhs
507                            }
508                        }
509                        _ => {
510                            unfiltered_prefix_bits.push(*bit);
511                            acc
512                        }
513                    },
514                );
515                if expression_nodes_to_compile.contains_key(expression_node_to_compile) {
516                    expression_nodes_to_compile
517                        .get_mut(expression_node_to_compile)
518                        .unwrap()
519                        .push((unfiltered_prefix_bits.clone(), prefix_bits));
520                } else {
521                    expression_nodes_to_compile.insert(
522                        expression_node_to_compile,
523                        vec![(unfiltered_prefix_bits.clone(), prefix_bits)],
524                    );
525                }
526            });
527
528        expression_nodes_to_compile
529            .iter()
530            .for_each(|(expression_node, prefix_bit_vec)| {
531                let full_bookkeeping_table = expression_node
532                    .compute_bookkeeping_table(circuit_map)
533                    .unwrap();
534                prefix_bit_vec
535                    .iter()
536                    .for_each(|(unfiltered_prefix_bits, prefix_bits)| {
537                        let filtered_table = filter_bookkeeping_table(
538                            &full_bookkeeping_table,
539                            unfiltered_prefix_bits,
540                        );
541                        circuit_map.add_node(
542                            CircuitLocation::new(self.layer_id(), prefix_bits.clone()),
543                            filtered_table,
544                        );
545                    });
546            });
547    }
548
549    fn verify_rounds(
550        &self,
551        claims: &[&RawClaim<F>],
552        transcript_reader: &mut impl VerifierTranscript<F>,
553    ) -> Result<VerifierLayerEnum<F>> {
554        let rounds_sumchecked_over = match global_claim_agg_strategy() {
555            ClaimAggregationStrategy::Interpolative => self.expression.get_all_nonlinear_rounds(),
556            ClaimAggregationStrategy::RLC => self.expression.get_all_rounds(),
557        };
558
559        // Keeps track of challenges `r_1, ..., r_n` sent by the verifier.
560        let mut challenges = vec![];
561
562        // Random coefficients depending on claim aggregation strategy.
563        let random_coefficients = match global_claim_agg_strategy() {
564            ClaimAggregationStrategy::Interpolative => {
565                assert_eq!(claims.len(), 1);
566                vec![F::ONE]
567            }
568            ClaimAggregationStrategy::RLC => {
569                transcript_reader.get_challenges("RLC Claim Agg Coefficients", claims.len())?
570            }
571        };
572
573        // Represents `g_{i-1}(x)` of the previous round.
574        // This is initialized to the constant polynomial `g_0(x)` which evaluates
575        // to the claim result for any `x`.
576        let mut g_prev_round = match global_claim_agg_strategy() {
577            ClaimAggregationStrategy::Interpolative => {
578                vec![claims[0].get_eval()]
579            }
580            ClaimAggregationStrategy::RLC => vec![random_coefficients
581                .iter()
582                .zip(claims)
583                .fold(F::ZERO, |acc, (rlc_val, claim)| {
584                    acc + *rlc_val * claim.get_eval()
585                })],
586        };
587
588        // Previous round's challege: r_{i-1}.
589        let mut prev_challenge = F::ZERO;
590
591        // For round 1 <= i <= n, perform the check:
592        for round_index in &rounds_sumchecked_over {
593            let degree = self.expression.get_round_degree(*round_index);
594
595            // Receive `g_i(x)` from the Prover.
596            // Since we are using an evaluation representation for polynomials,
597            // the degree check is implicit: the verifier is requesting
598            // `degree + 1` evaluations, ensuring that `g_i` is of degree
599            // at most `degree`. If the prover appended more evaluations,
600            // there will be a transcript read error later on in the proving
601            // process which will result in the proof not verifying.
602            // Furthermore, since the verifier can deduce g_i(0) by computing `claim - g_i(1)`,
603            // the prover does not include g_i(0) in the message. Instead, the verifier
604            // reserves the spot of g_i(0) when reading from the transcript, and compute g_i(0)
605            // afterwards.
606            // TODO(Makis):
607            //   1. Modify the Transcript interface to catch any errors sooner.
608            //   2. This line is assuming a representation for the polynomial!
609            //   We should hide that under another function whose job is to take
610            //   the trascript reader and read the polynomial in whatever
611            //   representation is being used.
612            let mut g_cur_round: Vec<_> = [Ok(F::from(0))]
613                .into_iter()
614                .chain((0..degree).map(|_| {
615                    transcript_reader.consume_element("Sumcheck round univariate evaluations")
616                }))
617                .collect::<Result<_, _>>()?;
618
619            // Sample random challenge `r_i`.
620            let challenge = transcript_reader.get_challenge("Sumcheck round challenge")?;
621
622            // TODO(Makis): After refactoring `SumcheckEvals` to be a
623            // representation of a univariate polynomial, `evaluate_at_a_point`
624            // should just be a method.
625            // Compute:
626            //       `g_i(0) = g_{i - 1}(r_{i-1}) - g_i(1)`
627            let g_prev_r_prev = evaluate_at_a_point(&g_prev_round, prev_challenge).unwrap();
628            let g_i_one = evaluate_at_a_point(&g_cur_round, F::ONE).unwrap();
629            g_cur_round[0] = g_prev_r_prev - g_i_one;
630
631            g_prev_round = g_cur_round;
632            prev_challenge = challenge;
633            challenges.push(challenge);
634        }
635
636        // TODO(Makis): Add check that `expr` is on the same number of total vars.
637        let num_vars = claims[0].get_num_vars();
638
639        // Build an indicator vector for linear indices.
640        let mut var_is_linear: Vec<bool> = vec![true; num_vars];
641        if global_claim_agg_strategy() == ClaimAggregationStrategy::Interpolative {
642            for idx in &rounds_sumchecked_over {
643                var_is_linear[*idx] = false;
644            }
645        }
646        // Build point interlacing linear-round challenges with nonlinear-round
647        // challenges.
648        let mut nonlinear_idx = 0;
649        let point: &Vec<F> = match global_claim_agg_strategy() {
650            ClaimAggregationStrategy::Interpolative => &(0..num_vars)
651                .map(|idx| {
652                    if var_is_linear[idx] {
653                        claims[0].get_point()[idx]
654                    } else {
655                        let r = challenges[nonlinear_idx];
656                        nonlinear_idx += 1;
657                        r
658                    }
659                })
660                .collect(),
661            ClaimAggregationStrategy::RLC => &challenges,
662        };
663
664        let verifier_layer = self
665            .convert_into_verifier_layer(
666                point,
667                &claims.iter().map(|claim| claim.get_point()).collect_vec(),
668                transcript_reader,
669            )
670            .unwrap();
671
672        // Compute `P(r_1, ..., r_n)` over all challenge points (linear and
673        // non-linear).
674        // The MLE values are retrieved from the transcript.
675        let expr_value_at_challenge_point = verifier_layer.expression.evaluate()?;
676
677        let beta_fn_evaluated_at_challenge_point = match global_claim_agg_strategy() {
678            ClaimAggregationStrategy::Interpolative => {
679                // Compute `\beta((r_1, ..., r_n), (u_1, ..., u_n))`.
680                let claim_nonlinear_vals: Vec<F> = rounds_sumchecked_over
681                    .iter()
682                    .map(|idx| claims[0].get_point()[*idx])
683                    .collect();
684                debug_assert_eq!(claim_nonlinear_vals.len(), challenges.len());
685                BetaValues::compute_beta_over_two_challenges(&claim_nonlinear_vals, &challenges)
686            }
687            ClaimAggregationStrategy::RLC => random_coefficients.iter().zip(claims).fold(
688                F::ZERO,
689                |acc, (random_coeff, claim)| {
690                    acc + *random_coeff
691                        * BetaValues::compute_beta_over_two_challenges(
692                            claim.get_point(),
693                            &challenges,
694                        )
695                },
696            ),
697        };
698
699        // Evalute `g_n(r_n)`.
700        // Note: If there were no nonlinear rounds, this value reduces to
701        // `claim.get_result()` due to how we initialized `g_prev_round`.
702        let g_final_r_final = evaluate_at_a_point(&g_prev_round, prev_challenge)?;
703
704        // Final check:
705        // `\sum_{b_2} \sum_{b_4} P(g_1, b_2, g_3, b_4) * \beta( (b_2, b_4), (g_2, g_4) )`.
706        // P(g_1, challenge[0], g_3, challenge[0]) * \beta( challenge, (g_2, g_4) )
707        // `g_n(r_n) == P(r_1, ..., r_n) * \beta(r_1, ..., r_n, g_1, ..., g_n)`.
708        if g_final_r_final != expr_value_at_challenge_point * beta_fn_evaluated_at_challenge_point {
709            return Err(anyhow!(VerificationError::SumcheckFailed));
710        }
711
712        Ok(VerifierLayerEnum::Regular(verifier_layer))
713    }
714
715    fn sumcheck_round_indices(&self) -> Vec<usize> {
716        match global_claim_agg_strategy() {
717            ClaimAggregationStrategy::Interpolative => self.expression.get_all_nonlinear_rounds(),
718            ClaimAggregationStrategy::RLC => self.expression.get_all_rounds(),
719        }
720    }
721
722    fn convert_into_verifier_layer(
723        &self,
724        sumcheck_challenges: &[F],
725        _claim_point: &[&[F]],
726        transcript_reader: &mut impl VerifierTranscript<F>,
727    ) -> Result<Self::VerifierLayer> {
728        let verifier_expr = self
729            .expression
730            .bind(sumcheck_challenges, transcript_reader)?;
731
732        let verifier_layer = VerifierRegularLayer::new_raw(self.layer_id(), verifier_expr);
733        Ok(verifier_layer)
734    }
735
736    /// Get the [PostSumcheckLayer] for a [RegularLayerDescription], which represents the description of a fully bound expression.
737    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
738    fn get_post_sumcheck_layer(
739        &self,
740        round_challenges: &[F],
741        claim_challenges: &[&[F]],
742        random_coefficients: &[F],
743    ) -> PostSumcheckLayer<F, Option<F>> {
744        let sumcheck_round_indices = self.sumcheck_round_indices();
745        // Filter the claim to get the values of the claim pertaining to the nonlinear rounds.
746        let sumcheck_claim_points_vec = claim_challenges
747            .iter()
748            .map(|claim_challenge| {
749                claim_challenge
750                    .iter()
751                    .enumerate()
752                    .filter_map(|(idx, point)| {
753                        if sumcheck_round_indices.contains(&idx) {
754                            Some(*point)
755                        } else {
756                            None
757                        }
758                    })
759                    .collect_vec()
760            })
761            .collect_vec();
762
763        // Compute beta over these and the sumcheck challenges.
764        let rlc_beta = sumcheck_claim_points_vec
765            .iter()
766            .zip(random_coefficients)
767            .fold(F::ZERO, |acc, (elem, random_coeff)| {
768                assert_eq!(round_challenges.len(), elem.len());
769                let fully_bound_beta =
770                    BetaValues::compute_beta_over_two_challenges(round_challenges, elem);
771                acc + fully_bound_beta * random_coeff
772            });
773
774        // Compute the fully bound challenges, which include those pre-fixed for linear rounds
775        // and the sumcheck rounds.
776
777        let all_bound_challenges = match global_claim_agg_strategy() {
778            ClaimAggregationStrategy::Interpolative => {
779                assert_eq!(claim_challenges.len(), 1);
780                let mut sumcheck_round_index_counter = 0;
781                let all_chals = (0..claim_challenges[0].len())
782                    .map(|idx| {
783                        if sumcheck_round_indices.contains(&idx) {
784                            let chal = round_challenges[sumcheck_round_index_counter];
785                            sumcheck_round_index_counter += 1;
786                            chal
787                        } else {
788                            claim_challenges[0][idx]
789                        }
790                    })
791                    .collect_vec();
792                assert_eq!(sumcheck_round_index_counter, sumcheck_round_indices.len());
793                all_chals
794            }
795            ClaimAggregationStrategy::RLC => round_challenges.to_vec(),
796        };
797
798        self.expression
799            .get_post_sumcheck_layer(rlc_beta, &all_bound_challenges)
800    }
801
802    fn max_degree(&self) -> usize {
803        self.expression.get_max_degree() + 1
804    }
805
806    fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
807        self.expression.get_circuit_mles()
808    }
809
810    fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
811        let prover_expr = self.expression.into_prover_expression(circuit_map);
812        let regular_layer = RegularLayer::new_raw(self.layer_id(), prover_expr);
813        regular_layer.into()
814    }
815
816    fn index_mle_indices(&mut self, start_index: usize) {
817        self.expression.index_mle_vars(start_index);
818    }
819}
820
821impl<F: Field> VerifierLayer<F> for VerifierRegularLayer<F> {
822    fn layer_id(&self) -> LayerId {
823        self.id
824    }
825
826    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
827        let expr = &self.expression;
828
829        // Define how to parse the expression tree
830        // - Basically we just want to go down it and pass up claims
831        // - We can only add a new claim if we see an MLE with all its indices bound
832
833        let mut claims: Vec<Claim<F>> = Vec::new();
834
835        let mut observer_fn = |exp: &ExpressionNode<F, VerifierExpr>,
836                               _mle_vec: &<VerifierExpr as ExpressionType<F>>::MleVec|
837         -> Result<()> {
838            match exp {
839                ExpressionNode::Mle(verifier_mle) => {
840                    let fixed_mle_indices = verifier_mle
841                        .var_indices()
842                        .iter()
843                        .map(|index| index.val().ok_or(anyhow!(ClaimError::MleRefMleError)))
844                        .collect::<Result<Vec<_>>>()?;
845
846                    // Grab the layer ID (i.e. MLE index) which this mle refers to
847                    let mle_layer_id = verifier_mle.layer_id();
848
849                    // Grab the actual value that the claim is supposed to evaluate to
850                    let claimed_value = verifier_mle.value();
851
852                    // Construct the claim
853                    let claim: Claim<F> = Claim::new(
854                        fixed_mle_indices,
855                        claimed_value,
856                        self.layer_id(),
857                        mle_layer_id,
858                    );
859
860                    // Push it into the list of claims
861                    claims.push(claim);
862                }
863                ExpressionNode::Product(verifier_mle_vec) => {
864                    for verifier_mle in verifier_mle_vec {
865                        let fixed_mle_indices = verifier_mle
866                            .var_indices()
867                            .iter()
868                            .map(|index| index.val().ok_or(anyhow!(ClaimError::MleRefMleError)))
869                            .collect::<Result<Vec<_>>>()?;
870
871                        // Grab the layer ID (i.e. MLE index) which this mle refers to
872                        let mle_layer_id = verifier_mle.layer_id();
873
874                        let claimed_value = verifier_mle.value();
875
876                        // Construct the claim
877                        let claim: Claim<F> = Claim::new(
878                            fixed_mle_indices,
879                            claimed_value,
880                            self.layer_id(),
881                            mle_layer_id,
882                        );
883
884                        // Push it into the list of claims
885                        claims.push(claim);
886                    }
887                }
888                _ => {}
889            }
890            Ok(())
891        };
892
893        // Apply the observer function from above onto the expression
894        expr.traverse(&mut observer_fn)?;
895
896        Ok(claims)
897    }
898}