remainder/layer/
gate.rs

1//! module for defining the gate layer, uses the libra trick
2//! to reduce the number of rounds for gate layers (with binary operations)
3
4/// Helper functions used in the gate sumcheck algorithms.
5pub mod gate_helpers;
6use std::{cmp::max, collections::HashSet};
7
8use gate_helpers::{
9    compute_fully_bound_binary_gate_function, fold_binary_gate_wiring_into_mles_phase_1,
10    fold_binary_gate_wiring_into_mles_phase_2,
11};
12use itertools::Itertools;
13use serde::{Deserialize, Serialize};
14use shared_types::{
15    config::{global_config::global_claim_agg_strategy, ClaimAggregationStrategy},
16    transcript::{ProverTranscript, VerifierTranscript},
17    Field,
18};
19
20use crate::{
21    circuit_layout::{CircuitEvalMap, CircuitLocation},
22    claims::{Claim, ClaimError, RawClaim},
23    layer::{
24        product::{PostSumcheckLayer, Product},
25        Layer, LayerError, LayerId, VerificationError,
26    },
27    mle::{
28        betavalues::BetaValues, dense::DenseMle, evals::MultilinearExtension,
29        mle_description::MleDescription, verifier_mle::VerifierMle, Mle, MleIndex,
30    },
31    sumcheck::{evaluate_at_a_point, SumcheckEvals},
32};
33
34use anyhow::{anyhow, Ok, Result};
35
36pub use self::gate_helpers::{
37    compute_sumcheck_message_data_parallel_gate, compute_sumcheck_message_no_beta_table,
38    index_mle_indices_gate, GateError,
39};
40
41use super::{
42    layer_enum::{LayerEnum, VerifierLayerEnum},
43    LayerDescription, VerifierLayer,
44};
45
46#[derive(PartialEq, Serialize, Deserialize, Clone, Debug, Copy)]
47
48/// Operations that are currently supported by the gate. Binary because these
49/// are fan-in-two gates.
50#[derive(Hash)]
51pub enum BinaryOperation {
52    /// An addition gate.
53    Add,
54
55    /// A multiplication gate.
56    Mul,
57}
58
59impl BinaryOperation {
60    /// Method to perform the respective operation.
61    pub fn perform_operation<F: Field>(&self, a: F, b: F) -> F {
62        match self {
63            BinaryOperation::Add => a + b,
64            BinaryOperation::Mul => a * b,
65        }
66    }
67}
68
69/// Generic gate struct -- the binary operation performed by the gate is specified by
70/// the `gate_operation` parameter. Additionally, the number of dataparallel variables
71/// is specified by `num_dataparallel_vars` in order to account for batched and un-batched
72/// gates.
73#[derive(Serialize, Deserialize, Clone, Debug)]
74#[serde(bound = "F: Field")]
75pub struct GateLayer<F: Field> {
76    /// The layer id associated with this gate layer.
77    pub layer_id: LayerId,
78    /// The number of bits representing the number of "dataparallel" copies of the circuit.
79    pub num_dataparallel_vars: usize,
80    /// A vector of tuples representing the "nonzero" gates, especially useful in the sparse case
81    /// the format is (z, x, y) where the gate at label z is the output of performing an operation
82    /// on gates with labels x and y.
83    pub nonzero_gates: Vec<(u32, u32, u32)>,
84    /// The left side of the expression, i.e. the mle that makes up the "x" variables.
85    pub lhs: DenseMle<F>,
86    /// The right side of the expression, i.e. the mle that makes up the "y" variables.
87    pub rhs: DenseMle<F>,
88    /// The mles that are constructed when initializing phase 1 (binding the x variables).
89    pub phase_1_mles: Option<Vec<Vec<DenseMle<F>>>>,
90    /// The mles that are constructed when initializing phase 2 (binding the y variables).
91    pub phase_2_mles: Option<Vec<Vec<DenseMle<F>>>>,
92    /// The gate operation representing the fan-in-two relationship.
93    pub gate_operation: BinaryOperation,
94    /// the beta table which enumerates the incoming claim's challenge points on the
95    /// dataparallel vars of the MLE
96    beta_g2_vec: Option<Vec<BetaValues<F>>>,
97    /// The incoming claim's challenge points.
98    g_vec: Option<Vec<Vec<F>>>,
99    /// the number of rounds in phase 1
100    num_rounds_phase1: usize,
101}
102
103impl<F: Field> Layer<F> for GateLayer<F> {
104    /// Gets this layer's id.
105    fn layer_id(&self) -> LayerId {
106        self.layer_id
107    }
108
109    fn prove(
110        &mut self,
111        claims: &[&RawClaim<F>],
112        transcript_writer: &mut impl ProverTranscript<F>,
113    ) -> Result<()> {
114        let original_lhs_num_free_vars = self.lhs.num_free_vars();
115        let original_rhs_num_free_vars = self.rhs.num_free_vars();
116        let random_coefficients = match global_claim_agg_strategy() {
117            ClaimAggregationStrategy::Interpolative => {
118                assert_eq!(claims.len(), 1);
119                self.initialize(claims[0].get_point())?;
120                vec![F::ONE]
121            }
122            ClaimAggregationStrategy::RLC => {
123                let random_coefficients =
124                    transcript_writer.get_challenges("RLC Claim Agg Coefficients", claims.len());
125                self.initialize_rlc(&random_coefficients, claims);
126                random_coefficients
127            }
128        };
129        let sumcheck_indices = self.sumcheck_round_indices();
130        (sumcheck_indices.iter()).for_each(|round_idx| {
131            let sumcheck_message = self
132                .compute_round_sumcheck_message(*round_idx, &random_coefficients)
133                .unwrap();
134            transcript_writer
135                .append_elements("Sumcheck round univariate evaluations", &sumcheck_message);
136            let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
137            self.bind_round_variable(*round_idx, challenge).unwrap();
138        });
139
140        // Edge case for if the LHS or RHS have 0 variables.
141        if original_lhs_num_free_vars - self.num_dataparallel_vars == 0 {
142            match global_claim_agg_strategy() {
143                ClaimAggregationStrategy::Interpolative => {
144                    self.init_phase_1(
145                        self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec(),
146                    );
147                }
148                ClaimAggregationStrategy::RLC => {
149                    self.init_phase_1_rlc(
150                        &self
151                            .g_vec
152                            .as_ref()
153                            .unwrap()
154                            .clone()
155                            .iter()
156                            .map(|challenge| &challenge[self.num_dataparallel_vars..])
157                            .collect_vec(),
158                        &random_coefficients,
159                    );
160                }
161            }
162        }
163        if original_rhs_num_free_vars - self.num_dataparallel_vars == 0 {
164            let f2 = &self.phase_1_mles.as_ref().unwrap()[0][1];
165            let f2_at_u = f2.value();
166            let u_challenges = &f2
167                .mle_indices()
168                .iter()
169                .filter_map(|mle_index| match mle_index {
170                    MleIndex::Bound(value, _idx) => Some(*value),
171                    MleIndex::Fixed(_) => None,
172                    _ => panic!("Should not have any unbound values"),
173                })
174                .collect_vec()[self.num_dataparallel_vars..];
175
176            match global_claim_agg_strategy() {
177                ClaimAggregationStrategy::Interpolative => {
178                    let g_challenges =
179                        self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec();
180                    self.init_phase_2(u_challenges, f2_at_u, &g_challenges);
181                }
182                ClaimAggregationStrategy::RLC => {
183                    self.init_phase_2_rlc(
184                        u_challenges,
185                        f2_at_u,
186                        &self
187                            .g_vec
188                            .as_ref()
189                            .unwrap()
190                            .clone()
191                            .iter()
192                            .map(|claim| &claim[self.num_dataparallel_vars..])
193                            .collect_vec(),
194                        &random_coefficients,
195                    );
196                }
197            }
198        }
199
200        // Finally, send the claimed values for each of the bound MLEs to the verifier
201        // First, send the claimed value of V_{i + 1}(g_2, u)
202        let lhs_reduced = &self.phase_1_mles.as_ref().unwrap()[0][1];
203        let rhs_reduced = &self.phase_2_mles.as_ref().unwrap()[0][1];
204        transcript_writer.append("Fully bound MLE evaluation", lhs_reduced.value());
205        // Next, send the claimed value of V_{i + 1}(g_2, v)
206        transcript_writer.append("Fully bound MLE evaluation", rhs_reduced.value());
207
208        Ok(())
209    }
210
211    fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
212        self.beta_g2_vec = Some(vec![BetaValues::new(
213            claim_point[..self.num_dataparallel_vars]
214                .iter()
215                .copied()
216                .enumerate()
217                .collect(),
218        )]);
219        self.g_vec = Some(vec![claim_point.to_vec()]);
220        self.lhs.index_mle_indices(0);
221        self.rhs.index_mle_indices(0);
222
223        Ok(())
224    }
225
226    fn initialize_rlc(&mut self, _random_coefficients: &[F], claims: &[&RawClaim<F>]) {
227        self.lhs.index_mle_indices(0);
228        self.rhs.index_mle_indices(0);
229        let (g_vec, beta_g2_vec): (Vec<Vec<F>>, Vec<BetaValues<F>>) = claims
230            .iter()
231            .map(|claim| {
232                (
233                    claim.get_point().to_vec(),
234                    BetaValues::new(
235                        claim.get_point()[..self.num_dataparallel_vars]
236                            .iter()
237                            .copied()
238                            .enumerate()
239                            .collect(),
240                    ),
241                )
242            })
243            .unzip();
244        self.g_vec = Some(g_vec);
245        self.beta_g2_vec = Some(beta_g2_vec);
246    }
247
248    fn compute_round_sumcheck_message(
249        &mut self,
250        round_index: usize,
251        random_coefficients: &[F],
252    ) -> Result<Vec<F>> {
253        let rounds_before_phase_2 = self.num_dataparallel_vars + self.num_rounds_phase1;
254
255        if round_index < self.num_dataparallel_vars {
256            // dataparallel phase
257            Ok(compute_sumcheck_message_data_parallel_gate(
258                &self.lhs,
259                &self.rhs,
260                self.gate_operation,
261                &self.nonzero_gates,
262                self.num_dataparallel_vars - round_index,
263                &self
264                    .g_vec
265                    .as_ref()
266                    .unwrap()
267                    .iter()
268                    .map(|challenge| &challenge[round_index..])
269                    .collect_vec(),
270                &self
271                    .beta_g2_vec
272                    .as_ref()
273                    .unwrap()
274                    .iter()
275                    .zip(random_coefficients)
276                    .map(|(beta_values, random_coeff)| {
277                        *random_coeff * beta_values.fold_updated_values()
278                    })
279                    .collect_vec(),
280            )
281            .unwrap())
282        } else if round_index < rounds_before_phase_2 {
283            if round_index == self.num_dataparallel_vars {
284                match global_claim_agg_strategy() {
285                    ClaimAggregationStrategy::Interpolative => {
286                        self.init_phase_1(
287                            self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec(),
288                        );
289                    }
290                    ClaimAggregationStrategy::RLC => {
291                        self.init_phase_1_rlc(
292                            &self
293                                .g_vec
294                                .as_ref()
295                                .unwrap()
296                                .clone()
297                                .iter()
298                                .map(|challenge| &challenge[self.num_dataparallel_vars..])
299                                .collect_vec(),
300                            random_coefficients,
301                        );
302                    }
303                }
304            }
305            let max_deg = self
306                .phase_1_mles
307                .as_ref()
308                .unwrap()
309                .iter()
310                .fold(0, |acc, elem| max(acc, elem.len()));
311
312            let init_mles: Vec<Vec<&DenseMle<F>>> = self
313                .phase_1_mles
314                .as_ref()
315                .unwrap()
316                .iter()
317                .map(|mle_vec| {
318                    let mle_reference: Vec<&DenseMle<F>> = mle_vec.iter().collect();
319                    mle_reference
320                })
321                .collect();
322            let evals_vec = init_mles
323                .iter()
324                .map(|mle_vec| {
325                    compute_sumcheck_message_no_beta_table(mle_vec, round_index, max_deg).unwrap()
326                })
327                .collect_vec();
328            let final_evals = evals_vec
329                .clone()
330                .into_iter()
331                .skip(1)
332                .fold(SumcheckEvals(evals_vec[0].clone()), |acc, elem| {
333                    acc + SumcheckEvals(elem)
334                });
335
336            Ok(final_evals.0)
337        } else {
338            if round_index == rounds_before_phase_2 {
339                let f2 = &self.phase_1_mles.as_ref().unwrap()[0][1];
340                let f2_at_u = f2.value();
341                let u_challenges = &f2
342                    .mle_indices()
343                    .iter()
344                    .filter_map(|mle_index| match mle_index {
345                        MleIndex::Bound(value, _idx) => Some(*value),
346                        MleIndex::Fixed(_) => None,
347                        _ => panic!("Should not have any unbound values"),
348                    })
349                    .collect_vec()[self.num_dataparallel_vars..];
350
351                match global_claim_agg_strategy() {
352                    ClaimAggregationStrategy::Interpolative => {
353                        let g_challenges =
354                            self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec();
355                        self.init_phase_2(u_challenges, f2_at_u, &g_challenges);
356                    }
357                    ClaimAggregationStrategy::RLC => {
358                        self.init_phase_2_rlc(
359                            u_challenges,
360                            f2_at_u,
361                            &self
362                                .g_vec
363                                .as_ref()
364                                .unwrap()
365                                .clone()
366                                .iter()
367                                .map(|claim| &claim[self.num_dataparallel_vars..])
368                                .collect_vec(),
369                            random_coefficients,
370                        );
371                    }
372                }
373            }
374            if self.phase_2_mles.as_ref().unwrap()[0][1].num_free_vars() > 0 {
375                // Return the first sumcheck message of this phase.
376                let max_deg = self
377                    .phase_2_mles
378                    .as_ref()
379                    .unwrap()
380                    .iter()
381                    .fold(0, |acc, elem| max(acc, elem.len()));
382
383                let init_mles: Vec<Vec<&DenseMle<F>>> = self
384                    .phase_2_mles
385                    .as_ref()
386                    .unwrap()
387                    .iter()
388                    .map(|mle_vec| {
389                        let mle_references: Vec<&DenseMle<F>> = mle_vec.iter().collect();
390                        mle_references
391                    })
392                    .collect();
393                let evals_vec = init_mles
394                    .iter()
395                    .map(|mle_vec| {
396                        compute_sumcheck_message_no_beta_table(
397                            mle_vec,
398                            round_index - self.num_rounds_phase1,
399                            max_deg,
400                        )
401                        .unwrap()
402                    })
403                    .collect_vec();
404                let final_evals = evals_vec
405                    .clone()
406                    .into_iter()
407                    .skip(1)
408                    .fold(SumcheckEvals(evals_vec[0].clone()), |acc, elem| {
409                        acc + SumcheckEvals(elem)
410                    });
411                Ok(final_evals.0)
412            } else {
413                Ok(vec![])
414            }
415        }
416    }
417
418    fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
419        if round_index < self.num_dataparallel_vars {
420            self.beta_g2_vec
421                .as_mut()
422                .unwrap()
423                .iter_mut()
424                .for_each(|beta| beta.beta_update(round_index, challenge));
425            self.lhs.fix_variable(round_index, challenge);
426            self.rhs.fix_variable(round_index, challenge);
427
428            Ok(())
429        } else if round_index < self.num_rounds_phase1 + self.num_dataparallel_vars {
430            let mles = self.phase_1_mles.as_mut().unwrap();
431            mles.iter_mut().for_each(|mle_vec| {
432                mle_vec.iter_mut().for_each(|mle| {
433                    mle.fix_variable(round_index, challenge);
434                })
435            });
436            Ok(())
437        } else {
438            let round_index = round_index - self.num_rounds_phase1;
439            let mles = self.phase_2_mles.as_mut().unwrap();
440            mles.iter_mut().for_each(|mle_vec| {
441                mle_vec.iter_mut().for_each(|mle| {
442                    mle.fix_variable(round_index, challenge);
443                })
444            });
445            Ok(())
446        }
447    }
448
449    fn sumcheck_round_indices(&self) -> Vec<usize> {
450        let num_u = self.lhs.mle_indices().iter().fold(0_usize, |acc, idx| {
451            acc + match idx {
452                MleIndex::Fixed(_) => 0,
453                _ => 1,
454            }
455        }) - self.num_dataparallel_vars;
456        let num_v = self.rhs.mle_indices().iter().fold(0_usize, |acc, idx| {
457            acc + match idx {
458                MleIndex::Fixed(_) => 0,
459                _ => 1,
460            }
461        }) - self.num_dataparallel_vars;
462
463        (0..num_u + num_v + self.num_dataparallel_vars).collect_vec()
464    }
465
466    fn max_degree(&self) -> usize {
467        match self.gate_operation {
468            BinaryOperation::Add => 2,
469            BinaryOperation::Mul => {
470                if self.num_dataparallel_vars != 0 {
471                    3
472                } else {
473                    2
474                }
475            }
476        }
477    }
478
479    fn get_post_sumcheck_layer(
480        &self,
481        round_challenges: &[F],
482        claim_challenges: &[&[F]],
483        random_coefficients: &[F],
484    ) -> super::product::PostSumcheckLayer<F, F> {
485        assert_eq!(claim_challenges.len(), random_coefficients.len());
486        let lhs_mle = &self.phase_1_mles.as_ref().unwrap()[0][1];
487        let rhs_mle = &self.phase_2_mles.as_ref().unwrap()[0][1];
488
489        let g2_challenges_vec = claim_challenges
490            .iter()
491            .map(|claim_chal| &claim_chal[..self.num_dataparallel_vars])
492            .collect_vec();
493        let g1_challenges_vec = claim_challenges
494            .iter()
495            .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
496            .collect_vec();
497
498        let dataparallel_sumcheck_challenges =
499            round_challenges[..self.num_dataparallel_vars].to_vec();
500        let first_u_challenges = round_challenges
501            [self.num_dataparallel_vars..self.num_dataparallel_vars + self.num_rounds_phase1]
502            .to_vec();
503        let last_v_challenges =
504            round_challenges[self.num_dataparallel_vars + self.num_rounds_phase1..].to_vec();
505        let random_coefficients_scaled_by_beta_bound = g2_challenges_vec
506            .iter()
507            .zip(random_coefficients)
508            .map(|(g2_challenges, random_coeff)| {
509                let beta_bound = if self.num_dataparallel_vars != 0 {
510                    BetaValues::compute_beta_over_two_challenges(
511                        g2_challenges,
512                        &dataparallel_sumcheck_challenges,
513                    )
514                } else {
515                    F::ONE
516                };
517                beta_bound * random_coeff
518            })
519            .collect_vec();
520
521        let f_1_uv = compute_fully_bound_binary_gate_function(
522            &first_u_challenges,
523            &last_v_challenges,
524            &g1_challenges_vec,
525            &self.nonzero_gates,
526            &random_coefficients_scaled_by_beta_bound,
527        );
528
529        match self.gate_operation {
530            BinaryOperation::Add => PostSumcheckLayer(vec![
531                Product::<F, F>::new(std::slice::from_ref(lhs_mle), f_1_uv),
532                Product::<F, F>::new(std::slice::from_ref(rhs_mle), f_1_uv),
533            ]),
534            BinaryOperation::Mul => PostSumcheckLayer(vec![Product::<F, F>::new(
535                &[lhs_mle.clone(), rhs_mle.clone()],
536                f_1_uv,
537            )]),
538        }
539    }
540
541    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
542        let lhs_reduced = self.phase_1_mles.clone().unwrap()[0][1].clone();
543        let rhs_reduced = self.phase_2_mles.clone().unwrap()[0][1].clone();
544
545        let mut claims = vec![];
546
547        // Grab the claim on the left side.
548        let mut fixed_mle_indices_u: Vec<F> = vec![];
549        for index in lhs_reduced.mle_indices() {
550            fixed_mle_indices_u.push(
551                index
552                    .val()
553                    .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))?,
554            );
555        }
556        let val = lhs_reduced.value();
557        let claim: Claim<F> = Claim::new(
558            fixed_mle_indices_u,
559            val,
560            self.layer_id(),
561            self.lhs.layer_id(),
562        );
563        claims.push(claim);
564
565        // Grab the claim on the right side.
566        let mut fixed_mle_indices_v: Vec<F> = vec![];
567        for index in rhs_reduced.mle_indices() {
568            fixed_mle_indices_v.push(
569                index
570                    .val()
571                    .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))?,
572            );
573        }
574        let val = rhs_reduced.value();
575        let claim: Claim<F> = Claim::new(
576            fixed_mle_indices_v,
577            val,
578            self.layer_id(),
579            self.rhs.layer_id(),
580        );
581        claims.push(claim);
582
583        Ok(claims)
584    }
585}
586
587/// The circuit-description counterpart of a Gate layer description.
588#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
589#[serde(bound = "F: Field")]
590pub struct GateLayerDescription<F: Field> {
591    /// The layer id associated with this gate layer.
592    id: LayerId,
593
594    /// The gate operation representing the fan-in-two relationship.
595    gate_operation: BinaryOperation,
596
597    /// A vector of tuples representing the "nonzero" gates, especially useful
598    /// in the sparse case the format is (z, x, y) where the gate at label z is
599    /// the output of performing an operation on gates with labels x and y.
600    nonzero_gates: Vec<(u32, u32, u32)>,
601
602    /// The left side of the expression, i.e. the mle that makes up the "x"
603    /// variables.
604    lhs_mle: MleDescription<F>,
605
606    /// The mles that are constructed when initializing phase 2 (binding the y
607    /// variables).
608    rhs_mle: MleDescription<F>,
609
610    /// The number of bits representing the number of "dataparallel" copies of
611    /// the circuit.
612    num_dataparallel_vars: usize,
613}
614
615impl<F: Field> GateLayerDescription<F> {
616    /// Constructor for a [GateLayerDescription].
617    pub fn new(
618        num_dataparallel_vars: Option<usize>,
619        wiring: Vec<(u32, u32, u32)>,
620        lhs_circuit_mle: MleDescription<F>,
621        rhs_circuit_mle: MleDescription<F>,
622        gate_layer_id: LayerId,
623        gate_operation: BinaryOperation,
624    ) -> Self {
625        GateLayerDescription {
626            id: gate_layer_id,
627            gate_operation,
628            nonzero_gates: wiring,
629            lhs_mle: lhs_circuit_mle,
630            rhs_mle: rhs_circuit_mle,
631            num_dataparallel_vars: num_dataparallel_vars.unwrap_or(0),
632        }
633    }
634}
635
636/// Degree of independent variable is cubic for mul dataparallel binding and
637/// quadratic for all other bindings (see below expression to verify for yourself!)
638///
639/// V_i(g_2, g_1) = \sum_{p_2, x, y} \beta(g_2, p_2) f_1(g_1, x, y) (V_{i + 1}(p_2, x) \op V_{i + 1}(p_2, y))
640const DATAPARALLEL_ROUND_MUL_NUM_EVALS: usize = 4;
641const DATAPARALLEL_ROUND_ADD_NUM_EVALS: usize = 3;
642const NON_DATAPARALLEL_ROUND_MUL_NUM_EVALS: usize = 3;
643const NON_DATAPARALLEL_ROUND_ADD_NUM_EVALS: usize = 3;
644
645impl<F: Field> LayerDescription<F> for GateLayerDescription<F> {
646    type VerifierLayer = VerifierGateLayer<F>;
647
648    /// Gets this layer's id.
649    fn layer_id(&self) -> LayerId {
650        self.id
651    }
652
653    fn verify_rounds(
654        &self,
655        claims: &[&RawClaim<F>],
656        transcript_reader: &mut impl VerifierTranscript<F>,
657    ) -> Result<VerifierLayerEnum<F>> {
658        // Storing challenges for the sake of claim generation later
659        let mut challenges = vec![];
660
661        // Random coefficients depending on claim aggregation strategy.
662        let random_coefficients = match global_claim_agg_strategy() {
663            ClaimAggregationStrategy::Interpolative => {
664                assert_eq!(claims.len(), 1);
665                vec![F::ONE]
666            }
667            ClaimAggregationStrategy::RLC => {
668                transcript_reader.get_challenges("RLC Claim Agg Coefficients", claims.len())?
669            }
670        };
671
672        // WARNING: WE ARE ASSUMING HERE THAT MLE INDICES INCLUDE DATAPARALLEL
673        // INDICES AND MAKE NO DISTINCTION BETWEEN THOSE AND REGULAR FREE/INDEXED
674        // BITS
675        let num_u = self.lhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
676            acc + match idx {
677                MleIndex::Fixed(_) => 0,
678                _ => 1,
679            }
680        }) - self.num_dataparallel_vars;
681        let num_v = self.rhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
682            acc + match idx {
683                MleIndex::Fixed(_) => 0,
684                _ => 1,
685            }
686        }) - self.num_dataparallel_vars;
687
688        // Store all prover sumcheck messages to check against
689        let mut sumcheck_messages: Vec<Vec<F>> = vec![];
690
691        // First round check against the claim.
692        let first_round_num_evals = match (self.gate_operation, self.num_dataparallel_vars) {
693            (BinaryOperation::Add, 0) => NON_DATAPARALLEL_ROUND_ADD_NUM_EVALS,
694            (BinaryOperation::Mul, 0) => NON_DATAPARALLEL_ROUND_MUL_NUM_EVALS,
695            (BinaryOperation::Add, _) => DATAPARALLEL_ROUND_ADD_NUM_EVALS,
696            (BinaryOperation::Mul, _) => DATAPARALLEL_ROUND_MUL_NUM_EVALS,
697        };
698        let first_round_sumcheck_messages = transcript_reader.consume_elements(
699            "Sumcheck round univariate evaluations",
700            first_round_num_evals,
701        )?;
702        sumcheck_messages.push(first_round_sumcheck_messages.clone());
703
704        match global_claim_agg_strategy() {
705            ClaimAggregationStrategy::Interpolative => {
706                // Check: V_i(g_2, g_1) =? g_1(0) + g_1(1)
707                // TODO(ryancao): SUPER overloaded notation (in e.g. above comments); fix across the board
708                if first_round_sumcheck_messages[0] + first_round_sumcheck_messages[1]
709                    != claims[0].get_eval()
710                {
711                    return Err(anyhow!(VerificationError::SumcheckStartFailed));
712                }
713            }
714            ClaimAggregationStrategy::RLC => {
715                let rlc_claim_eval = random_coefficients
716                    .iter()
717                    .zip(claims)
718                    .fold(F::ZERO, |acc, (rlc_val, claim)| {
719                        acc + *rlc_val * claim.get_eval()
720                    });
721                if first_round_sumcheck_messages[0] + first_round_sumcheck_messages[1]
722                    != rlc_claim_eval
723                {
724                    return Err(anyhow!(VerificationError::SumcheckStartFailed));
725                }
726            }
727        }
728
729        // Check each of the messages -- note that here the verifier doesn't actually see the difference
730        // between dataparallel rounds, phase 1 rounds, and phase 2 rounds; instead, the prover's proof reads
731        // as a single continuous proof.
732        for sumcheck_round_idx in 1..self.num_dataparallel_vars + num_u + num_v {
733            // Read challenge r_{i - 1} from transcript
734            let challenge = transcript_reader
735                .get_challenge("Sumcheck round challenge")
736                .unwrap();
737            let g_i_minus_1_evals = sumcheck_messages[sumcheck_messages.len() - 1].clone();
738
739            // Evaluate g_{i - 1}(r_{i - 1})
740            let prev_at_r = evaluate_at_a_point(&g_i_minus_1_evals, challenge).unwrap();
741
742            // Read off g_i(0), g_i(1), ..., g_i(d) from transcript
743            let univariate_num_evals = match (
744                sumcheck_round_idx < self.num_dataparallel_vars, // 0-indexed, so strictly less-than is correct
745                self.gate_operation,
746            ) {
747                (true, BinaryOperation::Add) => DATAPARALLEL_ROUND_ADD_NUM_EVALS,
748                (true, BinaryOperation::Mul) => DATAPARALLEL_ROUND_MUL_NUM_EVALS,
749                (false, BinaryOperation::Add) => NON_DATAPARALLEL_ROUND_ADD_NUM_EVALS,
750                (false, BinaryOperation::Mul) => NON_DATAPARALLEL_ROUND_MUL_NUM_EVALS,
751            };
752
753            let curr_evals = transcript_reader
754                .consume_elements(
755                    "Sumcheck round univariate evaluations",
756                    univariate_num_evals,
757                )
758                .unwrap();
759
760            // Check: g_i(0) + g_i(1) =? g_{i - 1}(r_{i - 1})
761            if prev_at_r != curr_evals[0] + curr_evals[1] {
762                dbg!(&sumcheck_round_idx);
763                return Err(anyhow!(VerificationError::SumcheckFailed));
764            };
765
766            // Add the prover message to the sumcheck messages
767            sumcheck_messages.push(curr_evals);
768            // Add the challenge.
769            challenges.push(challenge);
770        }
771
772        // Final round of sumcheck -- sample r_n from transcript.
773        let final_chal = transcript_reader
774            .get_challenge("Sumcheck round challenge")
775            .unwrap();
776        challenges.push(final_chal);
777
778        // Create the resulting verifier layer for claim tracking
779        // TODO(ryancao): This is not necessary; we only need to pass back the actual claims
780        let verifier_gate_layer = self
781            .convert_into_verifier_layer(
782                &challenges,
783                &claims.iter().map(|claim| claim.get_point()).collect_vec(),
784                transcript_reader,
785            )
786            .unwrap();
787        let final_result = verifier_gate_layer.evaluate(
788            &claims.iter().map(|claim| claim.get_point()).collect_vec(),
789            &random_coefficients,
790        );
791
792        // Finally, compute g_n(r_n).
793        let g_n_evals = sumcheck_messages[sumcheck_messages.len() - 1].clone();
794        let prev_at_r = evaluate_at_a_point(&g_n_evals, final_chal).unwrap();
795
796        // Final check in sumcheck.
797        if final_result != prev_at_r {
798            return Err(anyhow!(VerificationError::FinalSumcheckFailed));
799        }
800
801        Ok(VerifierLayerEnum::Gate(verifier_gate_layer))
802    }
803
804    fn sumcheck_round_indices(&self) -> Vec<usize> {
805        let num_u = self.lhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
806            acc + match idx {
807                MleIndex::Fixed(_) => 0,
808                _ => 1,
809            }
810        }) - self.num_dataparallel_vars;
811        let num_v = self.rhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
812            acc + match idx {
813                MleIndex::Fixed(_) => 0,
814                _ => 1,
815            }
816        }) - self.num_dataparallel_vars;
817        (0..num_u + num_v + self.num_dataparallel_vars).collect_vec()
818    }
819
820    fn convert_into_verifier_layer(
821        &self,
822        sumcheck_bindings: &[F],
823        claim_points: &[&[F]],
824        transcript_reader: &mut impl VerifierTranscript<F>,
825    ) -> Result<Self::VerifierLayer> {
826        // WARNING: WE ARE ASSUMING HERE THAT MLE INDICES INCLUDE DATAPARALLEL
827        // INDICES AND MAKE NO DISTINCTION BETWEEN THOSE AND REGULAR FREE/INDEXED
828        // BITS
829        let num_u = self.lhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
830            acc + match idx {
831                MleIndex::Fixed(_) => 0,
832                _ => 1,
833            }
834        }) - self.num_dataparallel_vars;
835        let num_v = self.rhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
836            acc + match idx {
837                MleIndex::Fixed(_) => 0,
838                _ => 1,
839            }
840        }) - self.num_dataparallel_vars;
841
842        // We want to separate the challenges into which ones are from the dataparallel bits, which ones
843        // are for binding x (phase 1), and which are for binding y (phase 2).
844        let mut sumcheck_bindings_vec = sumcheck_bindings.to_vec();
845        let last_v_challenges = sumcheck_bindings_vec.split_off(self.num_dataparallel_vars + num_u);
846        let first_u_challenges = sumcheck_bindings_vec.split_off(self.num_dataparallel_vars);
847        let dataparallel_challenges = sumcheck_bindings_vec;
848
849        assert_eq!(last_v_challenges.len(), num_v);
850
851        // Since the original mles are dataparallel, the challenges are the concat of the copy bits and the variable bound bits.
852        let lhs_challenges = dataparallel_challenges
853            .iter()
854            .chain(first_u_challenges.iter())
855            .copied()
856            .collect_vec();
857        let rhs_challenges = dataparallel_challenges
858            .iter()
859            .chain(last_v_challenges.iter())
860            .copied()
861            .collect_vec();
862
863        let lhs_verifier_mle = self
864            .lhs_mle
865            .into_verifier_mle(&lhs_challenges, transcript_reader)
866            .unwrap();
867        let rhs_verifier_mle = self
868            .rhs_mle
869            .into_verifier_mle(&rhs_challenges, transcript_reader)
870            .unwrap();
871
872        // Create the resulting verifier layer for claim tracking
873        // TODO(ryancao): This is not necessary; we only need to pass back the actual claims
874        let verifier_gate_layer = VerifierGateLayer {
875            layer_id: self.layer_id(),
876            gate_operation: self.gate_operation,
877            wiring: self.nonzero_gates.clone(),
878            lhs_mle: lhs_verifier_mle,
879            rhs_mle: rhs_verifier_mle,
880            num_dataparallel_rounds: self.num_dataparallel_vars,
881            claim_challenge_points: claim_points
882                .iter()
883                .cloned()
884                .map(|claim| claim.to_vec())
885                .collect_vec(),
886            dataparallel_sumcheck_challenges: dataparallel_challenges,
887            first_u_challenges,
888            last_v_challenges,
889        };
890
891        Ok(verifier_gate_layer)
892    }
893
894    fn get_post_sumcheck_layer(
895        &self,
896        round_challenges: &[F],
897        claim_challenges: &[&[F]],
898        random_coefficients: &[F],
899    ) -> super::product::PostSumcheckLayer<F, Option<F>> {
900        let num_rounds_phase1 = self.lhs_mle.num_free_vars() - self.num_dataparallel_vars;
901
902        let g2_challenges_vec = claim_challenges
903            .iter()
904            .map(|claim_chal| &claim_chal[..self.num_dataparallel_vars])
905            .collect_vec();
906        let g1_challenges_vec = claim_challenges
907            .iter()
908            .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
909            .collect_vec();
910
911        let dataparallel_sumcheck_challenges =
912            round_challenges[..self.num_dataparallel_vars].to_vec();
913        let first_u_challenges = round_challenges
914            [self.num_dataparallel_vars..self.num_dataparallel_vars + num_rounds_phase1]
915            .to_vec();
916        let last_v_challenges =
917            round_challenges[self.num_dataparallel_vars + num_rounds_phase1..].to_vec();
918        let random_coefficients_scaled_by_beta_bound = g2_challenges_vec
919            .iter()
920            .zip(random_coefficients)
921            .map(|(g2_challenges, random_coeff)| {
922                let beta_bound = if self.num_dataparallel_vars != 0 {
923                    BetaValues::compute_beta_over_two_challenges(
924                        g2_challenges,
925                        &dataparallel_sumcheck_challenges,
926                    )
927                } else {
928                    F::ONE
929                };
930                beta_bound * random_coeff
931            })
932            .collect_vec();
933
934        let f_1_uv = compute_fully_bound_binary_gate_function(
935            &first_u_challenges,
936            &last_v_challenges,
937            &g1_challenges_vec,
938            &self.nonzero_gates,
939            &random_coefficients_scaled_by_beta_bound,
940        );
941        let lhs_challenges = &round_challenges[..self.num_dataparallel_vars + num_rounds_phase1];
942        let rhs_challenges = &round_challenges[..self.num_dataparallel_vars]
943            .iter()
944            .copied()
945            .chain(round_challenges[self.num_dataparallel_vars + num_rounds_phase1..].to_vec())
946            .collect_vec();
947
948        match self.gate_operation {
949            BinaryOperation::Add => PostSumcheckLayer(vec![
950                Product::<F, Option<F>>::new(
951                    std::slice::from_ref(&self.lhs_mle),
952                    f_1_uv,
953                    lhs_challenges,
954                ),
955                Product::<F, Option<F>>::new(
956                    std::slice::from_ref(&self.rhs_mle),
957                    f_1_uv,
958                    rhs_challenges,
959                ),
960            ]),
961            BinaryOperation::Mul => {
962                PostSumcheckLayer(vec![Product::<F, Option<F>>::new_from_mul_gate(
963                    &[self.lhs_mle.clone(), self.rhs_mle.clone()],
964                    f_1_uv,
965                    &[lhs_challenges, rhs_challenges],
966                )])
967            }
968        }
969    }
970
971    fn max_degree(&self) -> usize {
972        match self.gate_operation {
973            BinaryOperation::Add => 2,
974            BinaryOperation::Mul => {
975                if self.num_dataparallel_vars != 0 {
976                    3
977                } else {
978                    2
979                }
980            }
981        }
982    }
983
984    fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
985        vec![&self.lhs_mle, &self.rhs_mle]
986    }
987
988    fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
989        let lhs_mle = self.lhs_mle.into_dense_mle(circuit_map);
990        let rhs_mle = self.rhs_mle.into_dense_mle(circuit_map);
991        let num_dataparallel_vars = if self.num_dataparallel_vars == 0 {
992            None
993        } else {
994            Some(self.num_dataparallel_vars)
995        };
996        let gate_layer = GateLayer::new(
997            num_dataparallel_vars,
998            self.nonzero_gates.clone(),
999            lhs_mle,
1000            rhs_mle,
1001            self.gate_operation,
1002            self.layer_id(),
1003        );
1004        gate_layer.into()
1005    }
1006
1007    fn index_mle_indices(&mut self, start_index: usize) {
1008        self.lhs_mle.index_mle_indices(start_index);
1009        self.rhs_mle.index_mle_indices(start_index);
1010    }
1011
1012    fn compute_data_outputs(
1013        &self,
1014        mle_outputs_necessary: &HashSet<&MleDescription<F>>,
1015        circuit_map: &mut CircuitEvalMap<F>,
1016    ) {
1017        assert_eq!(mle_outputs_necessary.len(), 1);
1018        let mle_output_necessary = mle_outputs_necessary.iter().next().unwrap();
1019
1020        let max_gate_val = self
1021            .nonzero_gates
1022            .iter()
1023            .fold(&0, |acc, (z, _, _)| std::cmp::max(acc, z));
1024
1025        // number of entries in the resulting table is the max gate z value * 2 to the power of the number of dataparallel bits, as we are
1026        // evaluating over all values in the boolean hypercube which includes dataparallel bits
1027        let num_dataparallel_vals = 1 << (self.num_dataparallel_vars);
1028        let res_table_num_entries =
1029            ((max_gate_val + 1) * num_dataparallel_vals).next_power_of_two();
1030
1031        let lhs_data = circuit_map
1032            .get_data_from_circuit_mle(&self.lhs_mle)
1033            .unwrap();
1034        let rhs_data = circuit_map
1035            .get_data_from_circuit_mle(&self.rhs_mle)
1036            .unwrap();
1037
1038        let num_gate_outputs_per_dataparallel_instance = (max_gate_val + 1).next_power_of_two();
1039        let mut res_table = vec![F::ZERO; res_table_num_entries as usize];
1040        (0..num_dataparallel_vals).for_each(|idx| {
1041            self.nonzero_gates.iter().for_each(|(z_ind, x_ind, y_ind)| {
1042                let zero = F::ZERO;
1043                let f2_val = lhs_data
1044                    .f
1045                    .get(
1046                        (idx * (1 << (lhs_data.num_vars() - self.num_dataparallel_vars)) + x_ind)
1047                            as usize,
1048                    )
1049                    .unwrap_or(zero);
1050                let f3_val = rhs_data
1051                    .f
1052                    .get(
1053                        (idx * (1 << (rhs_data.num_vars() - self.num_dataparallel_vars)) + y_ind)
1054                            as usize,
1055                    )
1056                    .unwrap_or(zero);
1057                res_table[(num_gate_outputs_per_dataparallel_instance * idx + z_ind) as usize] +=
1058                    self.gate_operation.perform_operation(f2_val, f3_val);
1059            });
1060        });
1061
1062        let output_data = MultilinearExtension::new(res_table);
1063        assert_eq!(
1064            output_data.num_vars(),
1065            mle_output_necessary.var_indices().len()
1066        );
1067
1068        circuit_map.add_node(CircuitLocation::new(self.layer_id(), vec![]), output_data);
1069    }
1070}
1071
1072impl<F: Field> VerifierGateLayer<F> {
1073    /// Computes the oracle query's value for a given [VerifierGateLayer].
1074    pub fn evaluate(&self, claims: &[&[F]], random_coefficients: &[F]) -> F {
1075        assert_eq!(random_coefficients.len(), claims.len());
1076        let scaled_random_coeffs = claims
1077            .iter()
1078            .zip(random_coefficients)
1079            .map(|(claim, random_coeff)| {
1080                let beta_bound = BetaValues::compute_beta_over_two_challenges(
1081                    &claim[..self.num_dataparallel_rounds],
1082                    &self.dataparallel_sumcheck_challenges,
1083                );
1084                beta_bound * random_coeff
1085            })
1086            .collect_vec();
1087
1088        let f_1_uv = compute_fully_bound_binary_gate_function(
1089            &self.first_u_challenges,
1090            &self.last_v_challenges,
1091            &claims
1092                .iter()
1093                .map(|claim| &claim[self.num_dataparallel_rounds..])
1094                .collect_vec(),
1095            &self.wiring,
1096            &scaled_random_coeffs,
1097        );
1098
1099        // Compute the final result of the bound expression (this is the oracle query).
1100        f_1_uv
1101            * self
1102                .gate_operation
1103                .perform_operation(self.lhs_mle.value(), self.rhs_mle.value())
1104    }
1105}
1106
1107/// The verifier's counterpart of a Gate layer.
1108#[derive(Serialize, Deserialize, Clone, Debug)]
1109#[serde(bound = "F: Field")]
1110pub struct VerifierGateLayer<F: Field> {
1111    /// The layer id associated with this gate layer.
1112    layer_id: LayerId,
1113
1114    /// The gate operation representing the fan-in-two relationship.
1115    gate_operation: BinaryOperation,
1116
1117    /// A vector of tuples representing the "nonzero" gates, especially useful
1118    /// in the sparse case the format is (z, x, y) where the gate at label z is
1119    /// the output of performing an operation on gates with labels x and y.
1120    wiring: Vec<(u32, u32, u32)>,
1121
1122    /// The left side of the expression, i.e. the mle that makes up the "x"
1123    /// variables.
1124    lhs_mle: VerifierMle<F>,
1125
1126    /// The mles that are constructed when initializing phase 2 (binding the y
1127    /// variables).
1128    rhs_mle: VerifierMle<F>,
1129
1130    /// The challenge points for the claim on the [VerifierGateLayer].
1131    claim_challenge_points: Vec<Vec<F>>,
1132
1133    /// The number of dataparallel rounds.
1134    num_dataparallel_rounds: usize,
1135
1136    /// The challenges for `p_2`, as derived from sumcheck.
1137    dataparallel_sumcheck_challenges: Vec<F>,
1138
1139    /// The challenges for `x`, as derived from sumcheck.
1140    first_u_challenges: Vec<F>,
1141
1142    /// The challenges for `y`, as derived from sumcheck.
1143    last_v_challenges: Vec<F>,
1144}
1145
1146impl<F: Field> VerifierLayer<F> for VerifierGateLayer<F> {
1147    fn layer_id(&self) -> LayerId {
1148        self.layer_id
1149    }
1150
1151    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
1152        // Grab the claim on the left side.
1153        // TODO!(ryancao): Do error handling here!
1154        let lhs_vars = self.lhs_mle.var_indices();
1155        let lhs_point = lhs_vars
1156            .iter()
1157            .map(|idx| match idx {
1158                MleIndex::Bound(chal, _bit_idx) => *chal,
1159                MleIndex::Fixed(val) => {
1160                    if *val {
1161                        F::ONE
1162                    } else {
1163                        F::ZERO
1164                    }
1165                }
1166                _ => panic!("Error: Not fully bound"),
1167            })
1168            .collect_vec();
1169        let lhs_val = self.lhs_mle.value();
1170
1171        let lhs_claim: Claim<F> =
1172            Claim::new(lhs_point, lhs_val, self.layer_id(), self.lhs_mle.layer_id());
1173
1174        // Grab the claim on the right side.
1175        // TODO!(ryancao): Do error handling here!
1176        let rhs_vars: &[MleIndex<F>] = self.rhs_mle.var_indices();
1177        let rhs_point = rhs_vars
1178            .iter()
1179            .map(|idx| match idx {
1180                MleIndex::Bound(chal, _bit_idx) => *chal,
1181                MleIndex::Fixed(val) => {
1182                    if *val {
1183                        F::ONE
1184                    } else {
1185                        F::ZERO
1186                    }
1187                }
1188                _ => panic!("Error: Not fully bound"),
1189            })
1190            .collect_vec();
1191        let rhs_val = self.rhs_mle.value();
1192
1193        let rhs_claim: Claim<F> =
1194            Claim::new(rhs_point, rhs_val, self.layer_id(), self.rhs_mle.layer_id());
1195
1196        Ok(vec![lhs_claim, rhs_claim])
1197    }
1198}
1199
1200impl<F: Field> GateLayer<F> {
1201    /// Construct a new gate layer
1202    ///
1203    /// # Arguments
1204    /// * `num_dataparallel_vars`: an optional representing the number of bits representing the circuit copy we are looking at.
1205    ///
1206    /// None if this is not dataparallel, otherwise specify the number of bits
1207    /// * `nonzero_gates`: the gate wiring between single-copy circuit (as the wiring for each circuit remains the same)
1208    ///
1209    /// x is the label on the batched mle `lhs`, y is the label on the batched mle `rhs`, and z is the label on the next layer, batched
1210    /// * `lhs`: the flattened mle representing the left side of the summation
1211    /// * `rhs`: the flattened mle representing the right side of the summation
1212    /// * `gate_operation`: which operation the gate is performing. right now, can either be an 'add' or 'mul' gate
1213    /// * `layer_id`: the id representing which current layer this is
1214    ///
1215    /// # Returns
1216    /// A `Gate` struct that can now prove and verify rounds
1217    pub fn new(
1218        num_dataparallel_vars: Option<usize>,
1219        nonzero_gates: Vec<(u32, u32, u32)>,
1220        lhs: DenseMle<F>,
1221        rhs: DenseMle<F>,
1222        gate_operation: BinaryOperation,
1223        layer_id: LayerId,
1224    ) -> Self {
1225        let num_dataparallel_vars = num_dataparallel_vars.unwrap_or(0);
1226        let num_rounds_phase1 = lhs.num_free_vars() - num_dataparallel_vars;
1227
1228        GateLayer {
1229            num_dataparallel_vars,
1230            nonzero_gates,
1231            lhs,
1232            rhs,
1233            layer_id,
1234            phase_1_mles: None,
1235            phase_2_mles: None,
1236            gate_operation,
1237            beta_g2_vec: None,
1238            g_vec: None,
1239            num_rounds_phase1,
1240        }
1241    }
1242
1243    /// Initialize phase 1, or the necessary mles in order to bind the variables in the `lhs` of the
1244    /// expression. Once this phase is initialized, the sumcheck rounds binding the "x" variables can
1245    /// be performed.
1246    fn init_phase_1(&mut self, challenges: Vec<F>) {
1247        let beta_g2_fully_bound = self.beta_g2_vec.as_ref().unwrap()[0].fold_updated_values();
1248
1249        let (a_hg_lhs_vec, a_hg_rhs_vec) = fold_binary_gate_wiring_into_mles_phase_1(
1250            &self.nonzero_gates,
1251            &[&challenges],
1252            &self.lhs,
1253            &self.rhs,
1254            &[beta_g2_fully_bound],
1255            self.gate_operation,
1256        );
1257
1258        // The actual mles differ based on whether we are doing a add gate or a mul gate, because
1259        // in the case of an add gate, we distribute the gate function whereas in the case of the
1260        // mul gate, we simply take the product over all three mles.
1261        let mut phase_1_mles = match self.gate_operation {
1262            BinaryOperation::Add => {
1263                vec![
1264                    vec![
1265                        DenseMle::new_from_raw(a_hg_lhs_vec, LayerId::Input(0)),
1266                        self.lhs.clone(),
1267                    ],
1268                    vec![DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0))],
1269                ]
1270            }
1271            BinaryOperation::Mul => {
1272                vec![vec![
1273                    DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0)),
1274                    self.lhs.clone(),
1275                ]]
1276            }
1277        };
1278
1279        phase_1_mles.iter_mut().for_each(|mle_vec| {
1280            index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1281        });
1282        self.phase_1_mles = Some(phase_1_mles);
1283    }
1284
1285    fn init_phase_1_rlc(&mut self, challenges: &[&[F]], random_coefficients: &[F]) {
1286        let random_coefficients_scaled_by_beta_g2 = self
1287            .beta_g2_vec
1288            .as_ref()
1289            .unwrap()
1290            .iter()
1291            .zip(random_coefficients)
1292            .map(|(beta_values, random_coeff)| {
1293                assert!(beta_values.is_fully_bounded());
1294                beta_values.fold_updated_values() * random_coeff
1295            })
1296            .collect_vec();
1297
1298        let (a_hg_lhs_vec, a_hg_rhs_vec) = fold_binary_gate_wiring_into_mles_phase_1(
1299            &self.nonzero_gates,
1300            challenges,
1301            &self.lhs,
1302            &self.rhs,
1303            &random_coefficients_scaled_by_beta_g2,
1304            self.gate_operation,
1305        );
1306
1307        // The actual mles differ based on whether we are doing a add gate or a mul gate, because
1308        // in the case of an add gate, we distribute the gate function whereas in the case of the
1309        // mul gate, we simply take the product over all three mles.
1310        let mut phase_1_mles = match self.gate_operation {
1311            BinaryOperation::Add => {
1312                vec![
1313                    vec![
1314                        DenseMle::new_from_raw(a_hg_lhs_vec, LayerId::Input(0)),
1315                        self.lhs.clone(),
1316                    ],
1317                    vec![DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0))],
1318                ]
1319            }
1320            BinaryOperation::Mul => {
1321                vec![vec![
1322                    DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0)),
1323                    self.lhs.clone(),
1324                ]]
1325            }
1326        };
1327
1328        phase_1_mles.iter_mut().for_each(|mle_vec| {
1329            index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1330        });
1331        self.phase_1_mles = Some(phase_1_mles);
1332    }
1333
1334    /// Initialize phase 2, or the necessary mles in order to bind the variables in the `rhs` of the
1335    /// expression. Once this phase is initialized, the sumcheck rounds binding the "y" variables can
1336    /// be performed.
1337    fn init_phase_2(&mut self, u_claim: &[F], f_at_u: F, g1_claim_points: &[F]) {
1338        let beta_g2_fully_bound = self.beta_g2_vec.as_ref().unwrap()[0].fold_updated_values();
1339
1340        let (a_f1_lhs, a_f1_rhs) = fold_binary_gate_wiring_into_mles_phase_2(
1341            &self.nonzero_gates,
1342            f_at_u,
1343            u_claim,
1344            &[g1_claim_points],
1345            &[beta_g2_fully_bound],
1346            self.rhs.num_free_vars(),
1347            self.gate_operation,
1348        );
1349
1350        // We need to multiply h_g(x) by f_2(x)
1351        let mut phase_2_mles = match self.gate_operation {
1352            BinaryOperation::Add => {
1353                vec![
1354                    vec![
1355                        DenseMle::new_from_raw(a_f1_rhs, LayerId::Input(0)),
1356                        self.rhs.clone(),
1357                    ],
1358                    vec![DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0))],
1359                ]
1360            }
1361            BinaryOperation::Mul => {
1362                vec![vec![
1363                    DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0)),
1364                    self.rhs.clone(),
1365                ]]
1366            }
1367        };
1368
1369        phase_2_mles.iter_mut().for_each(|mle_vec| {
1370            index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1371        });
1372        self.phase_2_mles = Some(phase_2_mles);
1373    }
1374
1375    fn init_phase_2_rlc(
1376        &mut self,
1377        u_claim: &[F],
1378        f_at_u: F,
1379        g1_claim_points: &[&[F]],
1380        random_coefficients: &[F],
1381    ) {
1382        let random_coefficients_scaled_by_beta_g2 = self
1383            .beta_g2_vec
1384            .as_ref()
1385            .unwrap()
1386            .iter()
1387            .zip(random_coefficients)
1388            .map(|(beta_values, random_coeff)| {
1389                assert!(beta_values.is_fully_bounded());
1390                beta_values.fold_updated_values() * random_coeff
1391            })
1392            .collect_vec();
1393
1394        let (a_f1_lhs, a_f1_rhs) = fold_binary_gate_wiring_into_mles_phase_2(
1395            &self.nonzero_gates,
1396            f_at_u,
1397            u_claim,
1398            g1_claim_points,
1399            &random_coefficients_scaled_by_beta_g2,
1400            self.rhs.num_free_vars(),
1401            self.gate_operation,
1402        );
1403
1404        // We need to multiply h_g(x) by f_2(x)
1405        let mut phase_2_mles = match self.gate_operation {
1406            BinaryOperation::Add => {
1407                vec![
1408                    vec![
1409                        DenseMle::new_from_raw(a_f1_rhs, LayerId::Input(0)),
1410                        self.rhs.clone(),
1411                    ],
1412                    vec![DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0))],
1413                ]
1414            }
1415            BinaryOperation::Mul => {
1416                vec![vec![
1417                    DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0)),
1418                    self.rhs.clone(),
1419                ]]
1420            }
1421        };
1422
1423        phase_2_mles.iter_mut().for_each(|mle_vec| {
1424            index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1425        });
1426        self.phase_2_mles = Some(phase_2_mles.clone());
1427    }
1428}
1429
1430/// Computes the correct result of a gate layer,
1431/// Used for data generation and testing.
1432/// Arguments:
1433/// - wiring: A vector of tuples representing the "nonzero" gates, especially useful
1434///   in the sparse case the format is (z, x, y) where the gate at label z is
1435///   the output of performing an operation on gates with labels x and y.
1436///
1437/// - num_dataparallel_bits: The number of bits representing the number of "dataparallel"
1438///   copies of the circuit.
1439///
1440/// - lhs_data: The left side of the expression, i.e. the mle that makes up the "x"
1441///   variables.
1442///
1443/// - rhs_data: The mles that are constructed when initializing phase 2 (binding the y
1444///   variables).
1445///
1446/// - gate_operation: The gate operation representing the fan-in-two relationship.
1447pub fn compute_gate_data_outputs<F: Field>(
1448    wiring: Vec<(u32, u32, u32)>,
1449    num_dataparallel_bits: usize,
1450    lhs_data: &MultilinearExtension<F>,
1451    rhs_data: &MultilinearExtension<F>,
1452    gate_operation: BinaryOperation,
1453) -> MultilinearExtension<F> {
1454    let max_gate_val = wiring
1455        .iter()
1456        .fold(&0, |acc, (z, _, _)| std::cmp::max(acc, z));
1457
1458    // number of entries in the resulting table is the max gate z value * 2 to the power of the number of dataparallel bits, as we are
1459    // evaluating over all values in the boolean hypercube which includes dataparallel bits
1460    let num_dataparallel_vals = 1 << num_dataparallel_bits;
1461    let res_table_num_entries = ((max_gate_val + 1) * num_dataparallel_vals).next_power_of_two();
1462    let num_gate_outputs_per_dataparallel_instance = (max_gate_val + 1).next_power_of_two();
1463
1464    let mut res_table = vec![F::ZERO; res_table_num_entries as usize];
1465    // TDH(ende): investigate if this can be parallelized (and if it's a bottleneck)
1466    (0..num_dataparallel_vals).for_each(|idx| {
1467        wiring.iter().for_each(|(z_ind, x_ind, y_ind)| {
1468            let zero = F::ZERO;
1469            let f2_val = lhs_data
1470                .f
1471                .get((idx * (1 << (lhs_data.num_vars() - num_dataparallel_bits)) + x_ind) as usize)
1472                .unwrap_or(zero);
1473            let f3_val = rhs_data
1474                .f
1475                .get((idx * (1 << (rhs_data.num_vars() - num_dataparallel_bits)) + y_ind) as usize)
1476                .unwrap_or(zero);
1477            res_table[(num_gate_outputs_per_dataparallel_instance * idx + z_ind) as usize] +=
1478                gate_operation.perform_operation(f2_val, f3_val);
1479        });
1480    });
1481
1482    MultilinearExtension::new(res_table)
1483}