1use std::{
5 cmp::Ordering,
6 collections::HashSet,
7 fmt::{Debug, Formatter},
8};
9
10use crate::{
11 circuit_layout::{CircuitEvalMap, CircuitLocation},
12 claims::{Claim, ClaimError, RawClaim},
13 layer::{
14 gate::gate_helpers::compute_fully_bound_identity_gate_function, LayerError,
15 VerificationError,
16 },
17 mle::{
18 betavalues::BetaValues, dense::DenseMle, evals::MultilinearExtension,
19 mle_description::MleDescription, verifier_mle::VerifierMle, Mle, MleIndex,
20 },
21 sumcheck::*,
22};
23use itertools::Itertools;
24use serde::{Deserialize, Serialize};
25use shared_types::{
26 config::{global_config::global_claim_agg_strategy, ClaimAggregationStrategy},
27 transcript::{ProverTranscript, VerifierTranscript},
28 Field,
29};
30
31use thiserror::Error;
32
33use super::{
34 gate::gate_helpers::{
35 compute_sumcheck_message_data_parallel_identity_gate, evaluate_mle_product_no_beta_table,
36 fold_wiring_into_beta_mle_identity_gate,
37 },
38 layer_enum::{LayerEnum, VerifierLayerEnum},
39 product::{PostSumcheckLayer, Product},
40 Layer, LayerDescription, LayerId, VerifierLayer,
41};
42
43use anyhow::{anyhow, Ok, Result};
44
45#[derive(Serialize, Deserialize, Clone, Hash)]
47#[serde(bound = "F: Field")]
48pub struct IdentityGateLayerDescription<F: Field> {
49 id: LayerId,
51
52 wiring: Vec<(u32, u32)>,
56
57 source_mle: MleDescription<F>,
60
61 total_num_vars: usize,
63
64 num_dataparallel_vars: usize,
67}
68
69impl<F: Field> std::fmt::Debug for IdentityGateLayerDescription<F> {
70 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("IdentityGateLayerDescription")
72 .field("id", &self.id)
73 .field("wiring.len()", &self.wiring.len())
74 .field("source_mle", &self.source_mle)
75 .field("num_dataparallel_vars", &self.num_dataparallel_vars)
76 .finish()
77 }
78}
79
80impl<F: Field> IdentityGateLayerDescription<F> {
81 pub fn new(
90 id: LayerId,
91 wiring: Vec<(u32, u32)>,
92 source_mle: MleDescription<F>,
93 total_num_vars: usize,
94 num_dataparallel_vars: Option<usize>,
95 ) -> Self {
96 Self {
97 id,
98 wiring,
99 source_mle,
100 total_num_vars,
101 num_dataparallel_vars: num_dataparallel_vars.unwrap_or(0),
102 }
103 }
104}
105
106impl<F: Field> LayerDescription<F> for IdentityGateLayerDescription<F> {
107 type VerifierLayer = VerifierIdentityGateLayer<F>;
108
109 fn layer_id(&self) -> LayerId {
110 self.id
111 }
112
113 fn verify_rounds(
114 &self,
115 claims: &[&RawClaim<F>],
116 transcript_reader: &mut impl VerifierTranscript<F>,
117 ) -> Result<VerifierLayerEnum<F>> {
118 let mut challenges = vec![];
120
121 let random_coefficients = match global_claim_agg_strategy() {
123 ClaimAggregationStrategy::Interpolative => {
124 assert_eq!(claims.len(), 1);
125 vec![F::ONE]
126 }
127 ClaimAggregationStrategy::RLC => {
128 transcript_reader.get_challenges("RLC Claim Agg Coefficients", claims.len())?
129 }
130 };
131
132 let mut g_prev_round = match global_claim_agg_strategy() {
136 ClaimAggregationStrategy::Interpolative => {
137 vec![claims[0].get_eval()]
138 }
139 ClaimAggregationStrategy::RLC => vec![random_coefficients
140 .iter()
141 .zip(claims)
142 .fold(F::ZERO, |acc, (rlc_val, claim)| {
143 acc + *rlc_val * claim.get_eval()
144 })],
145 };
146
147 let mut prev_challenge = F::ZERO;
149
150 let num_rounds = self.sumcheck_round_indices().len();
151
152 for _round in 0..num_rounds {
154 let degree = 2;
158
159 let mut g_cur_round: Vec<_> = [Ok(F::from(0))]
161 .into_iter()
162 .chain((0..degree).map(|_| {
163 transcript_reader.consume_element("Sumcheck round univariate evaluations")
164 }))
165 .collect::<Result<_, _>>()?;
166
167 let challenge = transcript_reader.get_challenge("Sumcheck round challenge")?;
169
170 let g_prev_r_prev = evaluate_at_a_point(&g_prev_round, prev_challenge).unwrap();
173 let g_i_one = evaluate_at_a_point(&g_cur_round, F::ONE).unwrap();
174 g_cur_round[0] = g_prev_r_prev - g_i_one;
175
176 g_prev_round = g_cur_round;
177 prev_challenge = challenge;
178 challenges.push(challenge);
179 }
180
181 let g_final_r_final = evaluate_at_a_point(&g_prev_round, prev_challenge)?;
185
186 let verifier_id_gate_layer = self
187 .convert_into_verifier_layer(
188 &challenges,
189 &claims.iter().map(|claim| claim.get_point()).collect_vec(),
190 transcript_reader,
191 )
192 .unwrap();
193 let final_result = verifier_id_gate_layer.evaluate(
194 &claims.iter().map(|claim| claim.get_point()).collect_vec(),
195 &random_coefficients,
196 );
197
198 if g_final_r_final != final_result {
199 return Err(anyhow!(VerificationError::FinalSumcheckFailed));
200 }
201
202 Ok(VerifierLayerEnum::IdentityGate(verifier_id_gate_layer))
203 }
204
205 fn sumcheck_round_indices(&self) -> Vec<usize> {
206 let num_vars = self
207 .source_mle
208 .var_indices()
209 .iter()
210 .fold(0_usize, |acc, idx| {
211 acc + match idx {
212 MleIndex::Fixed(_) => 0,
213 _ => 1,
214 }
215 });
216
217 (0..num_vars).collect_vec()
218 }
219
220 fn convert_into_verifier_layer(
221 &self,
222 sumcheck_challenges: &[F],
223 _claim_points: &[&[F]],
224 transcript_reader: &mut impl VerifierTranscript<F>,
225 ) -> Result<Self::VerifierLayer> {
226 let num_u = self
230 .source_mle
231 .var_indices()
232 .iter()
233 .fold(0_usize, |acc, idx| {
234 acc + match idx {
235 MleIndex::Fixed(_) => 0,
236 _ => 1,
237 }
238 })
239 - self.num_dataparallel_vars;
240
241 let mut sumcheck_bindings_vec = sumcheck_challenges.to_vec();
244 let first_u_challenges = sumcheck_bindings_vec.split_off(self.num_dataparallel_vars);
245 let dataparallel_sumcheck_challenges = sumcheck_bindings_vec;
246
247 assert_eq!(first_u_challenges.len(), num_u);
248
249 let src_verifier_mle = self
252 .source_mle
253 .into_verifier_mle(sumcheck_challenges, transcript_reader)
254 .unwrap();
255
256 let verifier_id_gate_layer = VerifierIdentityGateLayer {
259 layer_id: self.layer_id(),
260 wiring: self.wiring.clone(),
261 source_mle: src_verifier_mle,
262 first_u_challenges,
263 total_num_vars: self.total_num_vars,
264 num_dataparallel_rounds: self.num_dataparallel_vars,
265 dataparallel_sumcheck_challenges,
266 };
267
268 Ok(verifier_id_gate_layer)
269 }
270
271 fn get_post_sumcheck_layer(
272 &self,
273 round_challenges: &[F],
274 claim_challenges: &[&[F]],
275 random_coefficients: &[F],
276 ) -> PostSumcheckLayer<F, Option<F>> {
277 assert_eq!(claim_challenges.len(), random_coefficients.len());
278 let random_coefficients_scaled_by_beta_bound = claim_challenges
279 .iter()
280 .zip(random_coefficients)
281 .map(|(claim_chals, random_coeff)| {
282 let beta_bound = if self.num_dataparallel_vars > 0 {
283 let g2_challenges = claim_chals[..self.num_dataparallel_vars].to_vec();
284 BetaValues::compute_beta_over_two_challenges(
285 &g2_challenges,
286 &round_challenges[..self.num_dataparallel_vars],
287 )
288 } else {
289 F::ONE
290 };
291 beta_bound * random_coeff
292 })
293 .collect_vec();
294
295 let nondataparallel_claim_chals = claim_challenges
296 .iter()
297 .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
298 .collect_vec();
299
300 let f_1_gu = compute_fully_bound_identity_gate_function(
301 &round_challenges[self.num_dataparallel_vars..],
302 &nondataparallel_claim_chals,
303 &self.wiring,
304 &random_coefficients_scaled_by_beta_bound,
305 );
306
307 PostSumcheckLayer(vec![Product::<F, Option<F>>::new(
308 std::slice::from_ref(&self.source_mle),
309 f_1_gu,
310 round_challenges,
311 )])
312 }
313
314 fn max_degree(&self) -> usize {
315 2
316 }
317
318 fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
319 vec![&self.source_mle]
320 }
321
322 fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
323 let source_mle = self.source_mle.into_dense_mle(circuit_map);
324 let id_gate_layer = IdentityGate::new(
325 self.layer_id(),
326 self.wiring.clone(),
327 source_mle,
328 self.total_num_vars,
329 self.num_dataparallel_vars,
330 );
331 id_gate_layer.into()
332 }
333
334 fn index_mle_indices(&mut self, start_index: usize) {
335 self.source_mle.index_mle_indices(start_index);
336 }
337
338 fn compute_data_outputs(
339 &self,
340 mle_outputs_necessary: &HashSet<&MleDescription<F>>,
341 circuit_map: &mut CircuitEvalMap<F>,
342 ) {
343 let source_mle_data = circuit_map
348 .get_data_from_circuit_mle(&self.source_mle)
349 .unwrap();
350
351 let res_table_num_entries = 1 << self.total_num_vars;
352 let num_entries_per_dataparallel_instance =
353 1 << (self.total_num_vars - self.num_dataparallel_vars);
354 let mut remap_table = vec![F::ZERO; res_table_num_entries];
355
356 (0..(1 << self.num_dataparallel_vars)).for_each(|data_parallel_idx| {
357 self.wiring.iter().for_each(|(dest_idx, src_idx)| {
358 let id_val = source_mle_data
359 .f
360 .get(
361 data_parallel_idx
362 * (1 << (self.source_mle.num_free_vars() - self.num_dataparallel_vars))
363 + (*src_idx as usize),
364 )
365 .unwrap_or(F::ZERO);
366 remap_table[num_entries_per_dataparallel_instance * data_parallel_idx
367 + (*dest_idx as usize)] += id_val;
368 });
369 });
370
371 mle_outputs_necessary
382 .iter()
383 .for_each(|mle_output_necessary| {
384 let prefix_vars = mle_output_necessary.prefix_bits();
385 let bookkeeping_table_len = 1 << (self.total_num_vars - prefix_vars.len());
386 let start_idx =
387 prefix_vars
388 .iter()
389 .enumerate()
390 .fold(0, |acc, (var_idx, prefix_var)| {
391 acc + if *prefix_var {
396 1 << (self.total_num_vars - var_idx - 1)
397 } else {
398 0
399 }
400 });
401 let data_slice = &remap_table[start_idx..start_idx + bookkeeping_table_len];
402 let mle_output = MultilinearExtension::new(data_slice.to_vec());
403 circuit_map.add_node(
404 CircuitLocation::new(self.layer_id(), prefix_vars),
405 mle_output,
406 );
407 });
408 }
409}
410
411impl<F: Field> VerifierIdentityGateLayer<F> {
412 pub fn evaluate(&self, claim_points: &[&[F]], random_coefficients: &[F]) -> F {
415 assert_eq!(random_coefficients.len(), claim_points.len());
416 let scaled_random_coeffs = claim_points
417 .iter()
418 .zip(random_coefficients)
419 .map(|(claim, random_coeff)| {
420 let beta_bound = BetaValues::compute_beta_over_two_challenges(
421 &claim[..self.num_dataparallel_rounds],
422 &self.dataparallel_sumcheck_challenges,
423 );
424 beta_bound * random_coeff
425 })
426 .collect_vec();
427
428 let f_1_gu = compute_fully_bound_identity_gate_function(
429 &self.first_u_challenges,
430 &claim_points
431 .iter()
432 .map(|claim| &claim[self.num_dataparallel_rounds..])
433 .collect_vec(),
434 &self.wiring,
435 &scaled_random_coeffs,
436 );
437 f_1_gu * self.source_mle.value()
439 }
440}
441
442#[derive(Serialize, Deserialize, Clone, Debug)]
443#[serde(bound = "F: Field")]
444pub struct VerifierIdentityGateLayer<F: Field> {
446 layer_id: LayerId,
448
449 wiring: Vec<(u32, u32)>,
453
454 source_mle: VerifierMle<F>,
457
458 first_u_challenges: Vec<F>,
460
461 total_num_vars: usize,
463
464 num_dataparallel_rounds: usize,
466
467 dataparallel_sumcheck_challenges: Vec<F>,
469}
470
471impl<F: Field> VerifierLayer<F> for VerifierIdentityGateLayer<F> {
472 fn layer_id(&self) -> LayerId {
473 self.layer_id
474 }
475
476 fn get_claims(&self) -> Result<Vec<Claim<F>>> {
477 let source_vars = self.source_mle.var_indices();
479 let source_point = source_vars
480 .iter()
481 .map(|idx| match idx {
482 MleIndex::Bound(chal, _bit_idx) => *chal,
483 MleIndex::Fixed(val) => {
484 if *val {
485 F::ONE
486 } else {
487 F::ZERO
488 }
489 }
490 _ => panic!("Error: Not fully bound"),
491 })
492 .collect_vec();
493 let source_val = self.source_mle.value();
494
495 let source_claim: Claim<F> = Claim::new(
496 source_point,
497 source_val,
498 self.layer_id(),
499 self.source_mle.layer_id(),
500 );
501
502 Ok(vec![source_claim])
503 }
504}
505
506impl<F: Field> Layer<F> for IdentityGate<F> {
509 fn prove(
510 &mut self,
511 claims: &[&RawClaim<F>],
512 transcript_writer: &mut impl ProverTranscript<F>,
513 ) -> Result<()> {
514 let random_coefficients = match global_claim_agg_strategy() {
515 ClaimAggregationStrategy::Interpolative => {
516 assert_eq!(claims.len(), 1);
517 self.initialize(claims[0].get_point())?;
518 vec![F::ONE]
519 }
520 ClaimAggregationStrategy::RLC => {
521 let random_coefficients =
522 transcript_writer.get_challenges("RLC Claim Agg Coefficients", claims.len());
523 self.initialize_rlc(&random_coefficients, claims);
524 random_coefficients
525 }
526 };
527 let sumcheck_indices = self.sumcheck_round_indices();
528 (sumcheck_indices.iter()).for_each(|round_idx| {
529 let sumcheck_message = self
530 .compute_round_sumcheck_message(*round_idx, &random_coefficients)
531 .unwrap();
532 transcript_writer.append_elements(
534 "Sumcheck round univariate evaluations",
535 &sumcheck_message[1..],
536 );
537 let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
538 self.bind_round_variable(*round_idx, challenge).unwrap();
539 });
540 self.append_leaf_mles_to_transcript(transcript_writer);
541 Ok(())
542 }
543
544 fn layer_id(&self) -> LayerId {
545 self.layer_id
546 }
547
548 fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
549 self.challenges_vec = Some(vec![claim_point.to_vec()]);
550 let g2_challenges = &claim_point[..self.num_dataparallel_vars];
551 let g1_challenges = &claim_point[self.num_dataparallel_vars..];
552 self.g1_challenges_vec = Some(vec![g1_challenges.to_vec()]);
553
554 if self.num_dataparallel_vars > 0 {
555 let beta_g2 = BetaValues::new(g2_challenges.iter().copied().enumerate().collect());
556 self.beta_g2_vec = Some(vec![beta_g2]);
557 }
558
559 self.source_mle.index_mle_indices(0);
560 Ok(())
561 }
562
563 fn initialize_rlc(&mut self, random_coefficients: &[F], claims: &[&RawClaim<F>]) {
564 assert_eq!(random_coefficients.len(), claims.len());
565
566 self.challenges_vec = Some(
569 claims
570 .iter()
571 .map(|claim| claim.get_point().to_vec())
572 .collect_vec(),
573 );
574 let (g2_challenges_vec, g1_challenges_vec): (Vec<&[F]>, Vec<&[F]>) = claims
575 .iter()
576 .map(|claim| claim.get_point().split_at(self.num_dataparallel_vars))
577 .unzip();
578 self.g1_challenges_vec = Some(
579 g1_challenges_vec
580 .into_iter()
581 .map(|challenges| challenges.to_vec())
582 .collect_vec(),
583 );
584
585 if self.num_dataparallel_vars > 0 {
586 let beta_g2_vec = g2_challenges_vec
587 .iter()
588 .map(|g2_challenges| {
589 BetaValues::new(g2_challenges.iter().copied().enumerate().collect())
590 })
591 .collect();
592 self.beta_g2_vec = Some(beta_g2_vec);
593 }
594 self.source_mle.index_mle_indices(0);
595 }
596
597 fn compute_round_sumcheck_message(
598 &mut self,
599 round_index: usize,
600 random_coefficients: &[F],
601 ) -> Result<Vec<F>> {
602 match round_index.cmp(&self.num_dataparallel_vars) {
603 Ordering::Less => {
605 let sumcheck_message = compute_sumcheck_message_data_parallel_identity_gate(
606 &self.source_mle,
607 &self.wiring,
608 self.num_dataparallel_vars - round_index,
609 &self
610 .challenges_vec
611 .as_ref()
612 .unwrap()
613 .iter()
614 .map(|claim| &claim[round_index..])
615 .collect_vec(),
616 &self
617 .beta_g2_vec
618 .as_ref()
619 .unwrap()
620 .iter()
621 .zip(random_coefficients)
622 .map(|(beta_values, random_coeff)| {
623 *random_coeff * beta_values.fold_updated_values()
624 })
625 .collect_vec(),
626 )
627 .unwrap();
628 Ok(sumcheck_message)
629 }
630 _ => {
631 if round_index == self.num_dataparallel_vars {
632 match global_claim_agg_strategy() {
633 ClaimAggregationStrategy::Interpolative => {
634 let beta_g2_fully_bound = if self.num_dataparallel_vars > 0 {
637 self.beta_g2_vec.as_ref().unwrap()[0].fold_updated_values()
638 } else {
639 F::ONE
640 };
641
642 self.init_phase_1(
643 &self.g1_challenges_vec.as_ref().unwrap()[0].clone(),
644 beta_g2_fully_bound,
645 );
646 }
647 ClaimAggregationStrategy::RLC => {
648 let random_coefficients = if self.num_dataparallel_vars > 0 {
651 random_coefficients
652 .iter()
653 .zip(self.beta_g2_vec.as_ref().unwrap())
654 .map(|(random_coeff, beta_values)| {
655 if self.num_dataparallel_vars > 0 {
656 beta_values.fold_updated_values() * random_coeff
657 } else {
658 F::ONE * random_coeff
659 }
660 })
661 .collect_vec()
662 } else {
663 random_coefficients.to_vec()
664 };
665
666 self.init_phase_1_rlc(
667 &self
668 .g1_challenges_vec
669 .as_ref()
670 .unwrap()
671 .clone()
672 .iter()
673 .map(|challenge| challenge.as_slice())
674 .collect_vec(),
675 &random_coefficients,
676 );
677 }
678 }
679 }
680
681 let mles: Vec<&DenseMle<F>> =
682 vec![&self.a_hg_mle_phase_1.as_ref().unwrap(), &self.source_mle];
683 let independent_variable = mles
684 .iter()
685 .map(|mle| mle.mle_indices().contains(&MleIndex::Indexed(round_index)))
686 .reduce(|acc, item| acc | item)
687 .unwrap();
688 let sumcheck_evals =
689 evaluate_mle_product_no_beta_table(&mles, independent_variable, mles.len())
690 .unwrap();
691 Ok(sumcheck_evals.0)
692 }
693 }
694 }
695
696 fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
697 if round_index < self.num_dataparallel_vars {
698 self.beta_g2_vec
699 .as_mut()
700 .unwrap()
701 .iter_mut()
702 .for_each(|beta| {
703 beta.beta_update(round_index, challenge);
704 });
705 self.source_mle.fix_variable(round_index, challenge);
706
707 Ok(())
708 } else {
709 if self.num_dataparallel_vars > 0 {
710 self.beta_g2_vec.as_ref().unwrap().iter().for_each(|beta| {
711 assert!(beta.is_fully_bounded());
712 })
713 }
714 let a_hg_mle = self.a_hg_mle_phase_1.as_mut().unwrap();
715
716 [a_hg_mle, &mut self.source_mle].iter_mut().for_each(|mle| {
717 mle.fix_variable(round_index, challenge);
718 });
719 Ok(())
720 }
721 }
722
723 fn sumcheck_round_indices(&self) -> Vec<usize> {
724 (0..self.source_mle.num_free_vars()).collect_vec()
725 }
726
727 fn max_degree(&self) -> usize {
728 2
729 }
730
731 fn get_post_sumcheck_layer(
732 &self,
733 round_challenges: &[F],
734 claim_challenges: &[&[F]],
735 random_coefficients: &[F],
736 ) -> PostSumcheckLayer<F, F> {
737 assert_eq!(claim_challenges.len(), random_coefficients.len());
738 let random_coefficients_scaled_by_beta_bound = claim_challenges
739 .iter()
740 .zip(random_coefficients)
741 .map(|(claim_chals, random_coeff)| {
742 let beta_bound = if self.num_dataparallel_vars > 0 {
743 let g2_challenges = claim_chals[..self.num_dataparallel_vars].to_vec();
744 BetaValues::compute_beta_over_two_challenges(
745 &g2_challenges,
746 &round_challenges[..self.num_dataparallel_vars],
747 )
748 } else {
749 F::ONE
750 };
751 beta_bound * random_coeff
752 })
753 .collect_vec();
754
755 let nondataparallel_claim_chals = claim_challenges
756 .iter()
757 .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
758 .collect_vec();
759
760 let f_1_gu = compute_fully_bound_identity_gate_function(
761 &round_challenges[self.num_dataparallel_vars..],
762 &nondataparallel_claim_chals,
763 &self.wiring,
764 &random_coefficients_scaled_by_beta_bound,
765 );
766
767 PostSumcheckLayer(vec![Product::<F, F>::new(
768 std::slice::from_ref(&self.source_mle),
769 f_1_gu,
770 )])
771 }
772
773 fn get_claims(&self) -> Result<Vec<Claim<F>>> {
774 let mut claims = vec![];
775 let mut fixed_mle_indices_u: Vec<F> = vec![];
776
777 for index in self.source_mle.mle_indices() {
778 fixed_mle_indices_u.push(
779 index
780 .val()
781 .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))?,
782 );
783 }
784 let val = self.source_mle.first();
785 let claim: Claim<F> = Claim::new(
786 fixed_mle_indices_u,
787 val,
788 self.layer_id(),
789 self.source_mle.layer_id(),
790 );
791 claims.push(claim);
792
793 Ok(claims)
794 }
795}
796
797#[derive(Error, Debug, Serialize, Deserialize, Clone)]
800#[serde(bound = "F: Field")]
801pub struct IdentityGate<F: Field> {
802 layer_id: LayerId,
804 wiring: Vec<(u32, u32)>,
808 source_mle: DenseMle<F>,
810 beta_g2_vec: Option<Vec<BetaValues<F>>>,
813 g1_challenges_vec: Option<Vec<Vec<F>>>,
815 challenges_vec: Option<Vec<Vec<F>>>,
818 a_hg_mle_phase_1: Option<DenseMle<F>>,
821 total_num_vars: usize,
823 num_dataparallel_vars: usize,
826}
827
828impl<F: Field> IdentityGate<F> {
829 pub fn new(
831 layer_id: LayerId,
832 wiring: Vec<(u32, u32)>,
833 mle: DenseMle<F>,
834 total_num_vars: usize,
835 num_dataparallel_vars: usize,
836 ) -> IdentityGate<F> {
837 IdentityGate {
838 layer_id,
839 wiring,
840 source_mle: mle,
841 beta_g2_vec: None,
842 a_hg_mle_phase_1: None,
843 total_num_vars,
844 num_dataparallel_vars,
845 g1_challenges_vec: None,
846 challenges_vec: None,
847 }
848 }
849
850 fn append_leaf_mles_to_transcript(&self, transcript_writer: &mut impl ProverTranscript<F>) {
851 assert!(self.source_mle.is_fully_bounded());
852 transcript_writer.append("Fully bound MLE evaluation", self.source_mle.first());
853 }
854
855 fn init_phase_1(&mut self, challenge: &[F], fully_bound_beta_g2: F) {
864 let a_hg_mle_vec = fold_wiring_into_beta_mle_identity_gate(
865 &self.wiring,
866 &[challenge],
867 self.source_mle.num_free_vars(),
868 &[fully_bound_beta_g2],
869 );
870 let mut a_hg_mle = DenseMle::new_from_raw(a_hg_mle_vec, self.layer_id());
871 a_hg_mle.index_mle_indices(self.num_dataparallel_vars);
872
873 self.a_hg_mle_phase_1 = Some(a_hg_mle);
874 }
875
876 fn init_phase_1_rlc(&mut self, challenges: &[&[F]], random_coefficients: &[F]) {
881 let a_hg_mle_vec = fold_wiring_into_beta_mle_identity_gate(
882 &self.wiring,
883 challenges,
884 self.source_mle.num_free_vars(),
885 random_coefficients,
886 );
887 let mut a_hg_mle = DenseMle::new_from_raw(a_hg_mle_vec, self.layer_id());
888 a_hg_mle.index_mle_indices(self.num_dataparallel_vars);
889 self.a_hg_mle_phase_1 = Some(a_hg_mle);
890 }
891}