remainder/layer/
identity_gate.rs

1//! Identity gate id(z, x) determines whether the xth gate from the i + 1th
2//! layer contributes to the zth gate in the ith layer.
3
4use std::{
5    cmp::Ordering,
6    collections::HashSet,
7    fmt::{Debug, Formatter},
8};
9
10use crate::{
11    circuit_layout::{CircuitEvalMap, CircuitLocation},
12    claims::{Claim, ClaimError, RawClaim},
13    layer::{
14        gate::gate_helpers::compute_fully_bound_identity_gate_function, LayerError,
15        VerificationError,
16    },
17    mle::{
18        betavalues::BetaValues, dense::DenseMle, evals::MultilinearExtension,
19        mle_description::MleDescription, verifier_mle::VerifierMle, Mle, MleIndex,
20    },
21    sumcheck::*,
22};
23use itertools::Itertools;
24use serde::{Deserialize, Serialize};
25use shared_types::{
26    config::{global_config::global_claim_agg_strategy, ClaimAggregationStrategy},
27    transcript::{ProverTranscript, VerifierTranscript},
28    Field,
29};
30
31use thiserror::Error;
32
33use super::{
34    gate::gate_helpers::{
35        compute_sumcheck_message_data_parallel_identity_gate, evaluate_mle_product_no_beta_table,
36        fold_wiring_into_beta_mle_identity_gate,
37    },
38    layer_enum::{LayerEnum, VerifierLayerEnum},
39    product::{PostSumcheckLayer, Product},
40    Layer, LayerDescription, LayerId, VerifierLayer,
41};
42
43use anyhow::{anyhow, Ok, Result};
44
45/// The circuit Description for an [IdentityGate].
46#[derive(Serialize, Deserialize, Clone, Hash)]
47#[serde(bound = "F: Field")]
48pub struct IdentityGateLayerDescription<F: Field> {
49    /// The layer id associated with this gate layer.
50    id: LayerId,
51
52    /// A vector of tuples representing the "nonzero" gates, especially useful
53    /// in the sparse case the format is (z, x) where the gate at label z is the
54    /// output of adding all values from labels x.
55    wiring: Vec<(u32, u32)>,
56
57    /// The source MLE of the expression, i.e. the mle that makes up the "x"
58    /// variables.
59    source_mle: MleDescription<F>,
60
61    /// The total number of variables in the layer.
62    total_num_vars: usize,
63
64    /// The number of vars representing the number of "dataparallel" copies of
65    /// the circuit.
66    num_dataparallel_vars: usize,
67}
68
69impl<F: Field> std::fmt::Debug for IdentityGateLayerDescription<F> {
70    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("IdentityGateLayerDescription")
72            .field("id", &self.id)
73            .field("wiring.len()", &self.wiring.len())
74            .field("source_mle", &self.source_mle)
75            .field("num_dataparallel_vars", &self.num_dataparallel_vars)
76            .finish()
77    }
78}
79
80impl<F: Field> IdentityGateLayerDescription<F> {
81    /// Constructor for [IdentityGateLayerDescription]. Arguments:
82    /// * `id`: The layer id associated with this layer.
83    /// * `source_mle`: The Mle that is being routed to this layer.
84    /// * `nonzero_gates`: A list of tuples representing the gates that are
85    ///   nonzero, in the form `(dest_idx, src_idx)`.
86    /// * `total_num_vars`: The total number of variables in the layer.
87    /// * `num_dataparallel_vars`: The number of dataparallel variables to use
88    ///   in this layer.
89    pub fn new(
90        id: LayerId,
91        wiring: Vec<(u32, u32)>,
92        source_mle: MleDescription<F>,
93        total_num_vars: usize,
94        num_dataparallel_vars: Option<usize>,
95    ) -> Self {
96        Self {
97            id,
98            wiring,
99            source_mle,
100            total_num_vars,
101            num_dataparallel_vars: num_dataparallel_vars.unwrap_or(0),
102        }
103    }
104}
105
106impl<F: Field> LayerDescription<F> for IdentityGateLayerDescription<F> {
107    type VerifierLayer = VerifierIdentityGateLayer<F>;
108
109    fn layer_id(&self) -> LayerId {
110        self.id
111    }
112
113    fn verify_rounds(
114        &self,
115        claims: &[&RawClaim<F>],
116        transcript_reader: &mut impl VerifierTranscript<F>,
117    ) -> Result<VerifierLayerEnum<F>> {
118        // Keeps track of challenges `r_1, ..., r_n` sent by the verifier.
119        let mut challenges = vec![];
120
121        // Random coefficients depending on claim aggregation strategy.
122        let random_coefficients = match global_claim_agg_strategy() {
123            ClaimAggregationStrategy::Interpolative => {
124                assert_eq!(claims.len(), 1);
125                vec![F::ONE]
126            }
127            ClaimAggregationStrategy::RLC => {
128                transcript_reader.get_challenges("RLC Claim Agg Coefficients", claims.len())?
129            }
130        };
131
132        // Represents `g_{i-1}(x)` of the previous round. This is initialized to
133        // the constant polynomial `g_0(x)` which evaluates to the claim result
134        // for any `x`.
135        let mut g_prev_round = match global_claim_agg_strategy() {
136            ClaimAggregationStrategy::Interpolative => {
137                vec![claims[0].get_eval()]
138            }
139            ClaimAggregationStrategy::RLC => vec![random_coefficients
140                .iter()
141                .zip(claims)
142                .fold(F::ZERO, |acc, (rlc_val, claim)| {
143                    acc + *rlc_val * claim.get_eval()
144                })],
145        };
146
147        // Previous round's challege: r_{i-1}.
148        let mut prev_challenge = F::ZERO;
149
150        let num_rounds = self.sumcheck_round_indices().len();
151
152        // For round 1 <= i <= n, perform the check:
153        for _round in 0..num_rounds {
154            // Degree of independent variable is always quadratic! (regardless
155            // of if there's dataparallel or not) V_i(g_2, g_1) = \sum_{p_2}
156            // \sum_{x} \beta(g_2, p_2) f_1(g_1, x) (V_{i + 1}(p_2, x))
157            let degree = 2;
158
159            // Read g_i(1), ..., g_i(d+1) from the prover, reserve space to compute g_i(0)
160            let mut g_cur_round: Vec<_> = [Ok(F::from(0))]
161                .into_iter()
162                .chain((0..degree).map(|_| {
163                    transcript_reader.consume_element("Sumcheck round univariate evaluations")
164                }))
165                .collect::<Result<_, _>>()?;
166
167            // Sample random challenge `r_i`.
168            let challenge = transcript_reader.get_challenge("Sumcheck round challenge")?;
169
170            // Compute:
171            //       `g_i(0) = g_{i - 1}(r_{i-1}) - g_i(1)`
172            let g_prev_r_prev = evaluate_at_a_point(&g_prev_round, prev_challenge).unwrap();
173            let g_i_one = evaluate_at_a_point(&g_cur_round, F::ONE).unwrap();
174            g_cur_round[0] = g_prev_r_prev - g_i_one;
175
176            g_prev_round = g_cur_round;
177            prev_challenge = challenge;
178            challenges.push(challenge);
179        }
180
181        // Evalute `g_n(r_n)`. Note: If there were no nonlinear rounds, this
182        // value reduces to `claim.get_result()` due to how we initialized
183        // `g_prev_round`.
184        let g_final_r_final = evaluate_at_a_point(&g_prev_round, prev_challenge)?;
185
186        let verifier_id_gate_layer = self
187            .convert_into_verifier_layer(
188                &challenges,
189                &claims.iter().map(|claim| claim.get_point()).collect_vec(),
190                transcript_reader,
191            )
192            .unwrap();
193        let final_result = verifier_id_gate_layer.evaluate(
194            &claims.iter().map(|claim| claim.get_point()).collect_vec(),
195            &random_coefficients,
196        );
197
198        if g_final_r_final != final_result {
199            return Err(anyhow!(VerificationError::FinalSumcheckFailed));
200        }
201
202        Ok(VerifierLayerEnum::IdentityGate(verifier_id_gate_layer))
203    }
204
205    fn sumcheck_round_indices(&self) -> Vec<usize> {
206        let num_vars = self
207            .source_mle
208            .var_indices()
209            .iter()
210            .fold(0_usize, |acc, idx| {
211                acc + match idx {
212                    MleIndex::Fixed(_) => 0,
213                    _ => 1,
214                }
215            });
216
217        (0..num_vars).collect_vec()
218    }
219
220    fn convert_into_verifier_layer(
221        &self,
222        sumcheck_challenges: &[F],
223        _claim_points: &[&[F]],
224        transcript_reader: &mut impl VerifierTranscript<F>,
225    ) -> Result<Self::VerifierLayer> {
226        // WARNING: WE ARE ASSUMING HERE THAT MLE INDICES INCLUDE DATAPARALLEL
227        // INDICES AND MAKE NO DISTINCTION BETWEEN THOSE AND REGULAR
228        // FREE/INDEXED vars
229        let num_u = self
230            .source_mle
231            .var_indices()
232            .iter()
233            .fold(0_usize, |acc, idx| {
234                acc + match idx {
235                    MleIndex::Fixed(_) => 0,
236                    _ => 1,
237                }
238            })
239            - self.num_dataparallel_vars;
240
241        // We want to separate the challenges into which ones are from the
242        // dataparallel vars, which ones and are for binding x (phase 1)
243        let mut sumcheck_bindings_vec = sumcheck_challenges.to_vec();
244        let first_u_challenges = sumcheck_bindings_vec.split_off(self.num_dataparallel_vars);
245        let dataparallel_sumcheck_challenges = sumcheck_bindings_vec;
246
247        assert_eq!(first_u_challenges.len(), num_u);
248
249        // Since the original mles are dataparallel, the challenges are the
250        // concat of the copy vars and the variable bound vars.
251        let src_verifier_mle = self
252            .source_mle
253            .into_verifier_mle(sumcheck_challenges, transcript_reader)
254            .unwrap();
255
256        // Create the resulting verifier layer for claim tracking TODO(ryancao):
257        // This is not necessary; we only need to pass back the actual claims
258        let verifier_id_gate_layer = VerifierIdentityGateLayer {
259            layer_id: self.layer_id(),
260            wiring: self.wiring.clone(),
261            source_mle: src_verifier_mle,
262            first_u_challenges,
263            total_num_vars: self.total_num_vars,
264            num_dataparallel_rounds: self.num_dataparallel_vars,
265            dataparallel_sumcheck_challenges,
266        };
267
268        Ok(verifier_id_gate_layer)
269    }
270
271    fn get_post_sumcheck_layer(
272        &self,
273        round_challenges: &[F],
274        claim_challenges: &[&[F]],
275        random_coefficients: &[F],
276    ) -> PostSumcheckLayer<F, Option<F>> {
277        assert_eq!(claim_challenges.len(), random_coefficients.len());
278        let random_coefficients_scaled_by_beta_bound = claim_challenges
279            .iter()
280            .zip(random_coefficients)
281            .map(|(claim_chals, random_coeff)| {
282                let beta_bound = if self.num_dataparallel_vars > 0 {
283                    let g2_challenges = claim_chals[..self.num_dataparallel_vars].to_vec();
284                    BetaValues::compute_beta_over_two_challenges(
285                        &g2_challenges,
286                        &round_challenges[..self.num_dataparallel_vars],
287                    )
288                } else {
289                    F::ONE
290                };
291                beta_bound * random_coeff
292            })
293            .collect_vec();
294
295        let nondataparallel_claim_chals = claim_challenges
296            .iter()
297            .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
298            .collect_vec();
299
300        let f_1_gu = compute_fully_bound_identity_gate_function(
301            &round_challenges[self.num_dataparallel_vars..],
302            &nondataparallel_claim_chals,
303            &self.wiring,
304            &random_coefficients_scaled_by_beta_bound,
305        );
306
307        PostSumcheckLayer(vec![Product::<F, Option<F>>::new(
308            std::slice::from_ref(&self.source_mle),
309            f_1_gu,
310            round_challenges,
311        )])
312    }
313
314    fn max_degree(&self) -> usize {
315        2
316    }
317
318    fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
319        vec![&self.source_mle]
320    }
321
322    fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
323        let source_mle = self.source_mle.into_dense_mle(circuit_map);
324        let id_gate_layer = IdentityGate::new(
325            self.layer_id(),
326            self.wiring.clone(),
327            source_mle,
328            self.total_num_vars,
329            self.num_dataparallel_vars,
330        );
331        id_gate_layer.into()
332    }
333
334    fn index_mle_indices(&mut self, start_index: usize) {
335        self.source_mle.index_mle_indices(start_index);
336    }
337
338    fn compute_data_outputs(
339        &self,
340        mle_outputs_necessary: &HashSet<&MleDescription<F>>,
341        circuit_map: &mut CircuitEvalMap<F>,
342    ) {
343        // This may not be true, specifically because of e.g. `SplitNode`.
344        // assert_eq!(mle_outputs_necessary.len(), 1);
345        // let mle_output_necessary = mle_outputs_necessary.iter().next().unwrap();
346
347        let source_mle_data = circuit_map
348            .get_data_from_circuit_mle(&self.source_mle)
349            .unwrap();
350
351        let res_table_num_entries = 1 << self.total_num_vars;
352        let num_entries_per_dataparallel_instance =
353            1 << (self.total_num_vars - self.num_dataparallel_vars);
354        let mut remap_table = vec![F::ZERO; res_table_num_entries];
355
356        (0..(1 << self.num_dataparallel_vars)).for_each(|data_parallel_idx| {
357            self.wiring.iter().for_each(|(dest_idx, src_idx)| {
358                let id_val = source_mle_data
359                    .f
360                    .get(
361                        data_parallel_idx
362                            * (1 << (self.source_mle.num_free_vars() - self.num_dataparallel_vars))
363                            + (*src_idx as usize),
364                    )
365                    .unwrap_or(F::ZERO);
366                remap_table[num_entries_per_dataparallel_instance * data_parallel_idx
367                    + (*dest_idx as usize)] += id_val;
368            });
369        });
370
371        // Now that we have populated the entire bookkeeping table for the
372        // current layer, we can split it into what we are looking for.
373        // Because all "selector" variables occur before all "dataparallel"
374        // variables, we can always assume a split against the `prefix_vars`
375        // of the `mle_outputs_necessary`.
376        //
377        // Note that this is wasteful because we are cloning, but the
378        // interaction between dataparallel copies and splitting along
379        // dataparallel + non-dataparallel variables is too complicated
380        // to handle here. We should handle it at a higher level.
381        mle_outputs_necessary
382            .iter()
383            .for_each(|mle_output_necessary| {
384                let prefix_vars = mle_output_necessary.prefix_bits();
385                let bookkeeping_table_len = 1 << (self.total_num_vars - prefix_vars.len());
386                let start_idx =
387                    prefix_vars
388                        .iter()
389                        .enumerate()
390                        .fold(0, |acc, (var_idx, prefix_var)| {
391                            // Big-endian indexing means the following:
392                            // 0th variable adds 1/2 the bookkeeping table length
393                            // 1st variable adds 1/4 the bookkeeping table length
394                            // ...
395                            acc + if *prefix_var {
396                                1 << (self.total_num_vars - var_idx - 1)
397                            } else {
398                                0
399                            }
400                        });
401                let data_slice = &remap_table[start_idx..start_idx + bookkeeping_table_len];
402                let mle_output = MultilinearExtension::new(data_slice.to_vec());
403                circuit_map.add_node(
404                    CircuitLocation::new(self.layer_id(), prefix_vars),
405                    mle_output,
406                );
407            });
408    }
409}
410
411impl<F: Field> VerifierIdentityGateLayer<F> {
412    /// Computes the oracle query's value for a given
413    /// [VerifierIdentityGateLayer].
414    pub fn evaluate(&self, claim_points: &[&[F]], random_coefficients: &[F]) -> F {
415        assert_eq!(random_coefficients.len(), claim_points.len());
416        let scaled_random_coeffs = claim_points
417            .iter()
418            .zip(random_coefficients)
419            .map(|(claim, random_coeff)| {
420                let beta_bound = BetaValues::compute_beta_over_two_challenges(
421                    &claim[..self.num_dataparallel_rounds],
422                    &self.dataparallel_sumcheck_challenges,
423                );
424                beta_bound * random_coeff
425            })
426            .collect_vec();
427
428        let f_1_gu = compute_fully_bound_identity_gate_function(
429            &self.first_u_challenges,
430            &claim_points
431                .iter()
432                .map(|claim| &claim[self.num_dataparallel_rounds..])
433                .collect_vec(),
434            &self.wiring,
435            &scaled_random_coeffs,
436        );
437        // get the fully evaluated "expression"
438        f_1_gu * self.source_mle.value()
439    }
440}
441
442#[derive(Serialize, Deserialize, Clone, Debug)]
443#[serde(bound = "F: Field")]
444/// The layer representing a fully bound [IdentityGate].
445pub struct VerifierIdentityGateLayer<F: Field> {
446    /// The layer id associated with this gate layer.
447    layer_id: LayerId,
448
449    /// A vector of tuples representing the "nonzero" gates, especially useful
450    /// in the sparse case the format is (z, x) where the gate at label z is the
451    /// output of adding all values from labels x.
452    wiring: Vec<(u32, u32)>,
453
454    /// The source MLE of the expression, i.e. the mle that makes up the "x"
455    /// variables.
456    source_mle: VerifierMle<F>,
457
458    /// The challenges for `x`, as derived from sumcheck.
459    first_u_challenges: Vec<F>,
460
461    /// The total number of variables in the layer.
462    total_num_vars: usize,
463
464    /// The number of dataparallel rounds.
465    num_dataparallel_rounds: usize,
466
467    /// The challenges for `p_2`, as derived from sumcheck.
468    dataparallel_sumcheck_challenges: Vec<F>,
469}
470
471impl<F: Field> VerifierLayer<F> for VerifierIdentityGateLayer<F> {
472    fn layer_id(&self) -> LayerId {
473        self.layer_id
474    }
475
476    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
477        // Grab the claim on the left side.
478        let source_vars = self.source_mle.var_indices();
479        let source_point = source_vars
480            .iter()
481            .map(|idx| match idx {
482                MleIndex::Bound(chal, _bit_idx) => *chal,
483                MleIndex::Fixed(val) => {
484                    if *val {
485                        F::ONE
486                    } else {
487                        F::ZERO
488                    }
489                }
490                _ => panic!("Error: Not fully bound"),
491            })
492            .collect_vec();
493        let source_val = self.source_mle.value();
494
495        let source_claim: Claim<F> = Claim::new(
496            source_point,
497            source_val,
498            self.layer_id(),
499            self.source_mle.layer_id(),
500        );
501
502        Ok(vec![source_claim])
503    }
504}
505
506/// The layer trait implementation for [IdentityGate], which has the proving
507/// functionality as well as the modular functions for each round of sumcheck.
508impl<F: Field> Layer<F> for IdentityGate<F> {
509    fn prove(
510        &mut self,
511        claims: &[&RawClaim<F>],
512        transcript_writer: &mut impl ProverTranscript<F>,
513    ) -> Result<()> {
514        let random_coefficients = match global_claim_agg_strategy() {
515            ClaimAggregationStrategy::Interpolative => {
516                assert_eq!(claims.len(), 1);
517                self.initialize(claims[0].get_point())?;
518                vec![F::ONE]
519            }
520            ClaimAggregationStrategy::RLC => {
521                let random_coefficients =
522                    transcript_writer.get_challenges("RLC Claim Agg Coefficients", claims.len());
523                self.initialize_rlc(&random_coefficients, claims);
524                random_coefficients
525            }
526        };
527        let sumcheck_indices = self.sumcheck_round_indices();
528        (sumcheck_indices.iter()).for_each(|round_idx| {
529            let sumcheck_message = self
530                .compute_round_sumcheck_message(*round_idx, &random_coefficients)
531                .unwrap();
532            // Since the verifier can deduce g_i(0) by computing claim - g_i(1), the prover does not send g_i(0)
533            transcript_writer.append_elements(
534                "Sumcheck round univariate evaluations",
535                &sumcheck_message[1..],
536            );
537            let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
538            self.bind_round_variable(*round_idx, challenge).unwrap();
539        });
540        self.append_leaf_mles_to_transcript(transcript_writer);
541        Ok(())
542    }
543
544    fn layer_id(&self) -> LayerId {
545        self.layer_id
546    }
547
548    fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
549        self.challenges_vec = Some(vec![claim_point.to_vec()]);
550        let g2_challenges = &claim_point[..self.num_dataparallel_vars];
551        let g1_challenges = &claim_point[self.num_dataparallel_vars..];
552        self.g1_challenges_vec = Some(vec![g1_challenges.to_vec()]);
553
554        if self.num_dataparallel_vars > 0 {
555            let beta_g2 = BetaValues::new(g2_challenges.iter().copied().enumerate().collect());
556            self.beta_g2_vec = Some(vec![beta_g2]);
557        }
558
559        self.source_mle.index_mle_indices(0);
560        Ok(())
561    }
562
563    fn initialize_rlc(&mut self, random_coefficients: &[F], claims: &[&RawClaim<F>]) {
564        assert_eq!(random_coefficients.len(), claims.len());
565
566        // Split all of the claimed challenges into whether they are claimed
567        // challenges on the dataparallel variables or not.
568        self.challenges_vec = Some(
569            claims
570                .iter()
571                .map(|claim| claim.get_point().to_vec())
572                .collect_vec(),
573        );
574        let (g2_challenges_vec, g1_challenges_vec): (Vec<&[F]>, Vec<&[F]>) = claims
575            .iter()
576            .map(|claim| claim.get_point().split_at(self.num_dataparallel_vars))
577            .unzip();
578        self.g1_challenges_vec = Some(
579            g1_challenges_vec
580                .into_iter()
581                .map(|challenges| challenges.to_vec())
582                .collect_vec(),
583        );
584
585        if self.num_dataparallel_vars > 0 {
586            let beta_g2_vec = g2_challenges_vec
587                .iter()
588                .map(|g2_challenges| {
589                    BetaValues::new(g2_challenges.iter().copied().enumerate().collect())
590                })
591                .collect();
592            self.beta_g2_vec = Some(beta_g2_vec);
593        }
594        self.source_mle.index_mle_indices(0);
595    }
596
597    fn compute_round_sumcheck_message(
598        &mut self,
599        round_index: usize,
600        random_coefficients: &[F],
601    ) -> Result<Vec<F>> {
602        match round_index.cmp(&self.num_dataparallel_vars) {
603            // Dataparallel phase.
604            Ordering::Less => {
605                let sumcheck_message = compute_sumcheck_message_data_parallel_identity_gate(
606                    &self.source_mle,
607                    &self.wiring,
608                    self.num_dataparallel_vars - round_index,
609                    &self
610                        .challenges_vec
611                        .as_ref()
612                        .unwrap()
613                        .iter()
614                        .map(|claim| &claim[round_index..])
615                        .collect_vec(),
616                    &self
617                        .beta_g2_vec
618                        .as_ref()
619                        .unwrap()
620                        .iter()
621                        .zip(random_coefficients)
622                        .map(|(beta_values, random_coeff)| {
623                            *random_coeff * beta_values.fold_updated_values()
624                        })
625                        .collect_vec(),
626                )
627                .unwrap();
628                Ok(sumcheck_message)
629            }
630            _ => {
631                if round_index == self.num_dataparallel_vars {
632                    match global_claim_agg_strategy() {
633                        ClaimAggregationStrategy::Interpolative => {
634                            // We compute the singular fully bound value for the beta MLE over
635                            // the dataparallel challenges.
636                            let beta_g2_fully_bound = if self.num_dataparallel_vars > 0 {
637                                self.beta_g2_vec.as_ref().unwrap()[0].fold_updated_values()
638                            } else {
639                                F::ONE
640                            };
641
642                            self.init_phase_1(
643                                &self.g1_challenges_vec.as_ref().unwrap()[0].clone(),
644                                beta_g2_fully_bound,
645                            );
646                        }
647                        ClaimAggregationStrategy::RLC => {
648                            // We compute the beta MLE fully bound over all the claims rather
649                            // than the aggregated claim.
650                            let random_coefficients = if self.num_dataparallel_vars > 0 {
651                                random_coefficients
652                                    .iter()
653                                    .zip(self.beta_g2_vec.as_ref().unwrap())
654                                    .map(|(random_coeff, beta_values)| {
655                                        if self.num_dataparallel_vars > 0 {
656                                            beta_values.fold_updated_values() * random_coeff
657                                        } else {
658                                            F::ONE * random_coeff
659                                        }
660                                    })
661                                    .collect_vec()
662                            } else {
663                                random_coefficients.to_vec()
664                            };
665
666                            self.init_phase_1_rlc(
667                                &self
668                                    .g1_challenges_vec
669                                    .as_ref()
670                                    .unwrap()
671                                    .clone()
672                                    .iter()
673                                    .map(|challenge| challenge.as_slice())
674                                    .collect_vec(),
675                                &random_coefficients,
676                            );
677                        }
678                    }
679                }
680
681                let mles: Vec<&DenseMle<F>> =
682                    vec![&self.a_hg_mle_phase_1.as_ref().unwrap(), &self.source_mle];
683                let independent_variable = mles
684                    .iter()
685                    .map(|mle| mle.mle_indices().contains(&MleIndex::Indexed(round_index)))
686                    .reduce(|acc, item| acc | item)
687                    .unwrap();
688                let sumcheck_evals =
689                    evaluate_mle_product_no_beta_table(&mles, independent_variable, mles.len())
690                        .unwrap();
691                Ok(sumcheck_evals.0)
692            }
693        }
694    }
695
696    fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
697        if round_index < self.num_dataparallel_vars {
698            self.beta_g2_vec
699                .as_mut()
700                .unwrap()
701                .iter_mut()
702                .for_each(|beta| {
703                    beta.beta_update(round_index, challenge);
704                });
705            self.source_mle.fix_variable(round_index, challenge);
706
707            Ok(())
708        } else {
709            if self.num_dataparallel_vars > 0 {
710                self.beta_g2_vec.as_ref().unwrap().iter().for_each(|beta| {
711                    assert!(beta.is_fully_bounded());
712                })
713            }
714            let a_hg_mle = self.a_hg_mle_phase_1.as_mut().unwrap();
715
716            [a_hg_mle, &mut self.source_mle].iter_mut().for_each(|mle| {
717                mle.fix_variable(round_index, challenge);
718            });
719            Ok(())
720        }
721    }
722
723    fn sumcheck_round_indices(&self) -> Vec<usize> {
724        (0..self.source_mle.num_free_vars()).collect_vec()
725    }
726
727    fn max_degree(&self) -> usize {
728        2
729    }
730
731    fn get_post_sumcheck_layer(
732        &self,
733        round_challenges: &[F],
734        claim_challenges: &[&[F]],
735        random_coefficients: &[F],
736    ) -> PostSumcheckLayer<F, F> {
737        assert_eq!(claim_challenges.len(), random_coefficients.len());
738        let random_coefficients_scaled_by_beta_bound = claim_challenges
739            .iter()
740            .zip(random_coefficients)
741            .map(|(claim_chals, random_coeff)| {
742                let beta_bound = if self.num_dataparallel_vars > 0 {
743                    let g2_challenges = claim_chals[..self.num_dataparallel_vars].to_vec();
744                    BetaValues::compute_beta_over_two_challenges(
745                        &g2_challenges,
746                        &round_challenges[..self.num_dataparallel_vars],
747                    )
748                } else {
749                    F::ONE
750                };
751                beta_bound * random_coeff
752            })
753            .collect_vec();
754
755        let nondataparallel_claim_chals = claim_challenges
756            .iter()
757            .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
758            .collect_vec();
759
760        let f_1_gu = compute_fully_bound_identity_gate_function(
761            &round_challenges[self.num_dataparallel_vars..],
762            &nondataparallel_claim_chals,
763            &self.wiring,
764            &random_coefficients_scaled_by_beta_bound,
765        );
766
767        PostSumcheckLayer(vec![Product::<F, F>::new(
768            std::slice::from_ref(&self.source_mle),
769            f_1_gu,
770        )])
771    }
772
773    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
774        let mut claims = vec![];
775        let mut fixed_mle_indices_u: Vec<F> = vec![];
776
777        for index in self.source_mle.mle_indices() {
778            fixed_mle_indices_u.push(
779                index
780                    .val()
781                    .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))?,
782            );
783        }
784        let val = self.source_mle.first();
785        let claim: Claim<F> = Claim::new(
786            fixed_mle_indices_u,
787            val,
788            self.layer_id(),
789            self.source_mle.layer_id(),
790        );
791        claims.push(claim);
792
793        Ok(claims)
794    }
795}
796
797/// The Identity Gate struct allows being able to select specific indices
798/// from the `source_mle` that are not necessarily regular.
799#[derive(Error, Debug, Serialize, Deserialize, Clone)]
800#[serde(bound = "F: Field")]
801pub struct IdentityGate<F: Field> {
802    /// Layer ID for this gate.
803    layer_id: LayerId,
804    /// Wiring tuples are of the form `(dest_idx, src_idx)`, which specifies
805    /// that we would like the value in the `src_idx` of the `source_mle` to
806    /// be copied into `dest_idx` of the output MLE.
807    wiring: Vec<(u32, u32)>,
808    /// The MLE from which we are selecting the indices.
809    source_mle: DenseMle<F>,
810    /// The [BetaValues] struct which enumerates the incoming claim's challenge
811    /// points on the dataparallel vars of the MLE.
812    beta_g2_vec: Option<Vec<BetaValues<F>>>,
813    /// The nondataparallel claim points in the layer.
814    g1_challenges_vec: Option<Vec<Vec<F>>>,
815    /// The claim points in the layer (both nondataparallel and dataparallel)
816    /// challenges.
817    challenges_vec: Option<Vec<Vec<F>>>,
818    /// The MLE initialized in phase 1, which contains the beta values over
819    /// `g1_challenges` folded into the wiring function.
820    a_hg_mle_phase_1: Option<DenseMle<F>>,
821    /// The total number of variables in the layer.
822    total_num_vars: usize,
823    /// The number of vars representing the number of "dataparallel" copies of
824    /// the circuit.
825    num_dataparallel_vars: usize,
826}
827
828impl<F: Field> IdentityGate<F> {
829    /// Create a new [IdentityGate] struct.
830    pub fn new(
831        layer_id: LayerId,
832        wiring: Vec<(u32, u32)>,
833        mle: DenseMle<F>,
834        total_num_vars: usize,
835        num_dataparallel_vars: usize,
836    ) -> IdentityGate<F> {
837        IdentityGate {
838            layer_id,
839            wiring,
840            source_mle: mle,
841            beta_g2_vec: None,
842            a_hg_mle_phase_1: None,
843            total_num_vars,
844            num_dataparallel_vars,
845            g1_challenges_vec: None,
846            challenges_vec: None,
847        }
848    }
849
850    fn append_leaf_mles_to_transcript(&self, transcript_writer: &mut impl ProverTranscript<F>) {
851        assert!(self.source_mle.is_fully_bounded());
852        transcript_writer.append("Fully bound MLE evaluation", self.source_mle.first());
853    }
854
855    /// Initialize the bookkeeping table necessary for phase 1, which is the
856    /// binding of the non-dataparallel variables in the source MLE. This is the
857    /// initialization function used when we are doing interpolative claim
858    /// aggregation.
859    ///
860    /// For the random coefficients, we simply use the fully bound value of
861    /// beta_g2 since this is the value that scales all of the sumcheck
862    /// evaluations.
863    fn init_phase_1(&mut self, challenge: &[F], fully_bound_beta_g2: F) {
864        let a_hg_mle_vec = fold_wiring_into_beta_mle_identity_gate(
865            &self.wiring,
866            &[challenge],
867            self.source_mle.num_free_vars(),
868            &[fully_bound_beta_g2],
869        );
870        let mut a_hg_mle = DenseMle::new_from_raw(a_hg_mle_vec, self.layer_id());
871        a_hg_mle.index_mle_indices(self.num_dataparallel_vars);
872
873        self.a_hg_mle_phase_1 = Some(a_hg_mle);
874    }
875
876    /// Initialize the bookkeeping table necessary for phase 1, which is the
877    /// binding of the non-dataparallel variables in the source MLE. This is the
878    /// initialization function used when we are doing RLC claim
879    /// aggregation.
880    fn init_phase_1_rlc(&mut self, challenges: &[&[F]], random_coefficients: &[F]) {
881        let a_hg_mle_vec = fold_wiring_into_beta_mle_identity_gate(
882            &self.wiring,
883            challenges,
884            self.source_mle.num_free_vars(),
885            random_coefficients,
886        );
887        let mut a_hg_mle = DenseMle::new_from_raw(a_hg_mle_vec, self.layer_id());
888        a_hg_mle.index_mle_indices(self.num_dataparallel_vars);
889        self.a_hg_mle_phase_1 = Some(a_hg_mle);
890    }
891}