1pub mod gate_helpers;
6use std::{cmp::max, collections::HashSet};
7
8use gate_helpers::{
9 compute_fully_bound_binary_gate_function, fold_binary_gate_wiring_into_mles_phase_1,
10 fold_binary_gate_wiring_into_mles_phase_2,
11};
12use itertools::Itertools;
13use serde::{Deserialize, Serialize};
14use shared_types::{
15 config::{global_config::global_claim_agg_strategy, ClaimAggregationStrategy},
16 transcript::{ProverTranscript, VerifierTranscript},
17 Field,
18};
19
20use crate::{
21 circuit_layout::{CircuitEvalMap, CircuitLocation},
22 claims::{Claim, ClaimError, RawClaim},
23 layer::{
24 product::{PostSumcheckLayer, Product},
25 Layer, LayerError, LayerId, VerificationError,
26 },
27 mle::{
28 betavalues::BetaValues, dense::DenseMle, evals::MultilinearExtension,
29 mle_description::MleDescription, verifier_mle::VerifierMle, Mle, MleIndex,
30 },
31 sumcheck::{evaluate_at_a_point, SumcheckEvals},
32};
33
34use anyhow::{anyhow, Ok, Result};
35
36pub use self::gate_helpers::{
37 compute_sumcheck_message_data_parallel_gate, compute_sumcheck_message_no_beta_table,
38 index_mle_indices_gate, GateError,
39};
40
41use super::{
42 layer_enum::{LayerEnum, VerifierLayerEnum},
43 LayerDescription, VerifierLayer,
44};
45
46#[derive(PartialEq, Serialize, Deserialize, Clone, Debug, Copy)]
47
48#[derive(Hash)]
51pub enum BinaryOperation {
52 Add,
54
55 Mul,
57}
58
59impl BinaryOperation {
60 pub fn perform_operation<F: Field>(&self, a: F, b: F) -> F {
62 match self {
63 BinaryOperation::Add => a + b,
64 BinaryOperation::Mul => a * b,
65 }
66 }
67}
68
69#[derive(Serialize, Deserialize, Clone, Debug)]
74#[serde(bound = "F: Field")]
75pub struct GateLayer<F: Field> {
76 pub layer_id: LayerId,
78 pub num_dataparallel_vars: usize,
80 pub nonzero_gates: Vec<(u32, u32, u32)>,
84 pub lhs: DenseMle<F>,
86 pub rhs: DenseMle<F>,
88 pub phase_1_mles: Option<Vec<Vec<DenseMle<F>>>>,
90 pub phase_2_mles: Option<Vec<Vec<DenseMle<F>>>>,
92 pub gate_operation: BinaryOperation,
94 beta_g2_vec: Option<Vec<BetaValues<F>>>,
97 g_vec: Option<Vec<Vec<F>>>,
99 num_rounds_phase1: usize,
101}
102
103impl<F: Field> Layer<F> for GateLayer<F> {
104 fn layer_id(&self) -> LayerId {
106 self.layer_id
107 }
108
109 fn prove(
110 &mut self,
111 claims: &[&RawClaim<F>],
112 transcript_writer: &mut impl ProverTranscript<F>,
113 ) -> Result<()> {
114 let original_lhs_num_free_vars = self.lhs.num_free_vars();
115 let original_rhs_num_free_vars = self.rhs.num_free_vars();
116 let random_coefficients = match global_claim_agg_strategy() {
117 ClaimAggregationStrategy::Interpolative => {
118 assert_eq!(claims.len(), 1);
119 self.initialize(claims[0].get_point())?;
120 vec![F::ONE]
121 }
122 ClaimAggregationStrategy::RLC => {
123 let random_coefficients =
124 transcript_writer.get_challenges("RLC Claim Agg Coefficients", claims.len());
125 self.initialize_rlc(&random_coefficients, claims);
126 random_coefficients
127 }
128 };
129 let sumcheck_indices = self.sumcheck_round_indices();
130 (sumcheck_indices.iter()).for_each(|round_idx| {
131 let sumcheck_message = self
132 .compute_round_sumcheck_message(*round_idx, &random_coefficients)
133 .unwrap();
134 transcript_writer
135 .append_elements("Sumcheck round univariate evaluations", &sumcheck_message);
136 let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
137 self.bind_round_variable(*round_idx, challenge).unwrap();
138 });
139
140 if original_lhs_num_free_vars - self.num_dataparallel_vars == 0 {
142 match global_claim_agg_strategy() {
143 ClaimAggregationStrategy::Interpolative => {
144 self.init_phase_1(
145 self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec(),
146 );
147 }
148 ClaimAggregationStrategy::RLC => {
149 self.init_phase_1_rlc(
150 &self
151 .g_vec
152 .as_ref()
153 .unwrap()
154 .clone()
155 .iter()
156 .map(|challenge| &challenge[self.num_dataparallel_vars..])
157 .collect_vec(),
158 &random_coefficients,
159 );
160 }
161 }
162 }
163 if original_rhs_num_free_vars - self.num_dataparallel_vars == 0 {
164 let f2 = &self.phase_1_mles.as_ref().unwrap()[0][1];
165 let f2_at_u = f2.value();
166 let u_challenges = &f2
167 .mle_indices()
168 .iter()
169 .filter_map(|mle_index| match mle_index {
170 MleIndex::Bound(value, _idx) => Some(*value),
171 MleIndex::Fixed(_) => None,
172 _ => panic!("Should not have any unbound values"),
173 })
174 .collect_vec()[self.num_dataparallel_vars..];
175
176 match global_claim_agg_strategy() {
177 ClaimAggregationStrategy::Interpolative => {
178 let g_challenges =
179 self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec();
180 self.init_phase_2(u_challenges, f2_at_u, &g_challenges);
181 }
182 ClaimAggregationStrategy::RLC => {
183 self.init_phase_2_rlc(
184 u_challenges,
185 f2_at_u,
186 &self
187 .g_vec
188 .as_ref()
189 .unwrap()
190 .clone()
191 .iter()
192 .map(|claim| &claim[self.num_dataparallel_vars..])
193 .collect_vec(),
194 &random_coefficients,
195 );
196 }
197 }
198 }
199
200 let lhs_reduced = &self.phase_1_mles.as_ref().unwrap()[0][1];
203 let rhs_reduced = &self.phase_2_mles.as_ref().unwrap()[0][1];
204 transcript_writer.append("Fully bound MLE evaluation", lhs_reduced.value());
205 transcript_writer.append("Fully bound MLE evaluation", rhs_reduced.value());
207
208 Ok(())
209 }
210
211 fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
212 self.beta_g2_vec = Some(vec![BetaValues::new(
213 claim_point[..self.num_dataparallel_vars]
214 .iter()
215 .copied()
216 .enumerate()
217 .collect(),
218 )]);
219 self.g_vec = Some(vec![claim_point.to_vec()]);
220 self.lhs.index_mle_indices(0);
221 self.rhs.index_mle_indices(0);
222
223 Ok(())
224 }
225
226 fn initialize_rlc(&mut self, _random_coefficients: &[F], claims: &[&RawClaim<F>]) {
227 self.lhs.index_mle_indices(0);
228 self.rhs.index_mle_indices(0);
229 let (g_vec, beta_g2_vec): (Vec<Vec<F>>, Vec<BetaValues<F>>) = claims
230 .iter()
231 .map(|claim| {
232 (
233 claim.get_point().to_vec(),
234 BetaValues::new(
235 claim.get_point()[..self.num_dataparallel_vars]
236 .iter()
237 .copied()
238 .enumerate()
239 .collect(),
240 ),
241 )
242 })
243 .unzip();
244 self.g_vec = Some(g_vec);
245 self.beta_g2_vec = Some(beta_g2_vec);
246 }
247
248 fn compute_round_sumcheck_message(
249 &mut self,
250 round_index: usize,
251 random_coefficients: &[F],
252 ) -> Result<Vec<F>> {
253 let rounds_before_phase_2 = self.num_dataparallel_vars + self.num_rounds_phase1;
254
255 if round_index < self.num_dataparallel_vars {
256 Ok(compute_sumcheck_message_data_parallel_gate(
258 &self.lhs,
259 &self.rhs,
260 self.gate_operation,
261 &self.nonzero_gates,
262 self.num_dataparallel_vars - round_index,
263 &self
264 .g_vec
265 .as_ref()
266 .unwrap()
267 .iter()
268 .map(|challenge| &challenge[round_index..])
269 .collect_vec(),
270 &self
271 .beta_g2_vec
272 .as_ref()
273 .unwrap()
274 .iter()
275 .zip(random_coefficients)
276 .map(|(beta_values, random_coeff)| {
277 *random_coeff * beta_values.fold_updated_values()
278 })
279 .collect_vec(),
280 )
281 .unwrap())
282 } else if round_index < rounds_before_phase_2 {
283 if round_index == self.num_dataparallel_vars {
284 match global_claim_agg_strategy() {
285 ClaimAggregationStrategy::Interpolative => {
286 self.init_phase_1(
287 self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec(),
288 );
289 }
290 ClaimAggregationStrategy::RLC => {
291 self.init_phase_1_rlc(
292 &self
293 .g_vec
294 .as_ref()
295 .unwrap()
296 .clone()
297 .iter()
298 .map(|challenge| &challenge[self.num_dataparallel_vars..])
299 .collect_vec(),
300 random_coefficients,
301 );
302 }
303 }
304 }
305 let max_deg = self
306 .phase_1_mles
307 .as_ref()
308 .unwrap()
309 .iter()
310 .fold(0, |acc, elem| max(acc, elem.len()));
311
312 let init_mles: Vec<Vec<&DenseMle<F>>> = self
313 .phase_1_mles
314 .as_ref()
315 .unwrap()
316 .iter()
317 .map(|mle_vec| {
318 let mle_reference: Vec<&DenseMle<F>> = mle_vec.iter().collect();
319 mle_reference
320 })
321 .collect();
322 let evals_vec = init_mles
323 .iter()
324 .map(|mle_vec| {
325 compute_sumcheck_message_no_beta_table(mle_vec, round_index, max_deg).unwrap()
326 })
327 .collect_vec();
328 let final_evals = evals_vec
329 .clone()
330 .into_iter()
331 .skip(1)
332 .fold(SumcheckEvals(evals_vec[0].clone()), |acc, elem| {
333 acc + SumcheckEvals(elem)
334 });
335
336 Ok(final_evals.0)
337 } else {
338 if round_index == rounds_before_phase_2 {
339 let f2 = &self.phase_1_mles.as_ref().unwrap()[0][1];
340 let f2_at_u = f2.value();
341 let u_challenges = &f2
342 .mle_indices()
343 .iter()
344 .filter_map(|mle_index| match mle_index {
345 MleIndex::Bound(value, _idx) => Some(*value),
346 MleIndex::Fixed(_) => None,
347 _ => panic!("Should not have any unbound values"),
348 })
349 .collect_vec()[self.num_dataparallel_vars..];
350
351 match global_claim_agg_strategy() {
352 ClaimAggregationStrategy::Interpolative => {
353 let g_challenges =
354 self.g_vec.as_ref().unwrap()[0][self.num_dataparallel_vars..].to_vec();
355 self.init_phase_2(u_challenges, f2_at_u, &g_challenges);
356 }
357 ClaimAggregationStrategy::RLC => {
358 self.init_phase_2_rlc(
359 u_challenges,
360 f2_at_u,
361 &self
362 .g_vec
363 .as_ref()
364 .unwrap()
365 .clone()
366 .iter()
367 .map(|claim| &claim[self.num_dataparallel_vars..])
368 .collect_vec(),
369 random_coefficients,
370 );
371 }
372 }
373 }
374 if self.phase_2_mles.as_ref().unwrap()[0][1].num_free_vars() > 0 {
375 let max_deg = self
377 .phase_2_mles
378 .as_ref()
379 .unwrap()
380 .iter()
381 .fold(0, |acc, elem| max(acc, elem.len()));
382
383 let init_mles: Vec<Vec<&DenseMle<F>>> = self
384 .phase_2_mles
385 .as_ref()
386 .unwrap()
387 .iter()
388 .map(|mle_vec| {
389 let mle_references: Vec<&DenseMle<F>> = mle_vec.iter().collect();
390 mle_references
391 })
392 .collect();
393 let evals_vec = init_mles
394 .iter()
395 .map(|mle_vec| {
396 compute_sumcheck_message_no_beta_table(
397 mle_vec,
398 round_index - self.num_rounds_phase1,
399 max_deg,
400 )
401 .unwrap()
402 })
403 .collect_vec();
404 let final_evals = evals_vec
405 .clone()
406 .into_iter()
407 .skip(1)
408 .fold(SumcheckEvals(evals_vec[0].clone()), |acc, elem| {
409 acc + SumcheckEvals(elem)
410 });
411 Ok(final_evals.0)
412 } else {
413 Ok(vec![])
414 }
415 }
416 }
417
418 fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
419 if round_index < self.num_dataparallel_vars {
420 self.beta_g2_vec
421 .as_mut()
422 .unwrap()
423 .iter_mut()
424 .for_each(|beta| beta.beta_update(round_index, challenge));
425 self.lhs.fix_variable(round_index, challenge);
426 self.rhs.fix_variable(round_index, challenge);
427
428 Ok(())
429 } else if round_index < self.num_rounds_phase1 + self.num_dataparallel_vars {
430 let mles = self.phase_1_mles.as_mut().unwrap();
431 mles.iter_mut().for_each(|mle_vec| {
432 mle_vec.iter_mut().for_each(|mle| {
433 mle.fix_variable(round_index, challenge);
434 })
435 });
436 Ok(())
437 } else {
438 let round_index = round_index - self.num_rounds_phase1;
439 let mles = self.phase_2_mles.as_mut().unwrap();
440 mles.iter_mut().for_each(|mle_vec| {
441 mle_vec.iter_mut().for_each(|mle| {
442 mle.fix_variable(round_index, challenge);
443 })
444 });
445 Ok(())
446 }
447 }
448
449 fn sumcheck_round_indices(&self) -> Vec<usize> {
450 let num_u = self.lhs.mle_indices().iter().fold(0_usize, |acc, idx| {
451 acc + match idx {
452 MleIndex::Fixed(_) => 0,
453 _ => 1,
454 }
455 }) - self.num_dataparallel_vars;
456 let num_v = self.rhs.mle_indices().iter().fold(0_usize, |acc, idx| {
457 acc + match idx {
458 MleIndex::Fixed(_) => 0,
459 _ => 1,
460 }
461 }) - self.num_dataparallel_vars;
462
463 (0..num_u + num_v + self.num_dataparallel_vars).collect_vec()
464 }
465
466 fn max_degree(&self) -> usize {
467 match self.gate_operation {
468 BinaryOperation::Add => 2,
469 BinaryOperation::Mul => {
470 if self.num_dataparallel_vars != 0 {
471 3
472 } else {
473 2
474 }
475 }
476 }
477 }
478
479 fn get_post_sumcheck_layer(
480 &self,
481 round_challenges: &[F],
482 claim_challenges: &[&[F]],
483 random_coefficients: &[F],
484 ) -> super::product::PostSumcheckLayer<F, F> {
485 assert_eq!(claim_challenges.len(), random_coefficients.len());
486 let lhs_mle = &self.phase_1_mles.as_ref().unwrap()[0][1];
487 let rhs_mle = &self.phase_2_mles.as_ref().unwrap()[0][1];
488
489 let g2_challenges_vec = claim_challenges
490 .iter()
491 .map(|claim_chal| &claim_chal[..self.num_dataparallel_vars])
492 .collect_vec();
493 let g1_challenges_vec = claim_challenges
494 .iter()
495 .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
496 .collect_vec();
497
498 let dataparallel_sumcheck_challenges =
499 round_challenges[..self.num_dataparallel_vars].to_vec();
500 let first_u_challenges = round_challenges
501 [self.num_dataparallel_vars..self.num_dataparallel_vars + self.num_rounds_phase1]
502 .to_vec();
503 let last_v_challenges =
504 round_challenges[self.num_dataparallel_vars + self.num_rounds_phase1..].to_vec();
505 let random_coefficients_scaled_by_beta_bound = g2_challenges_vec
506 .iter()
507 .zip(random_coefficients)
508 .map(|(g2_challenges, random_coeff)| {
509 let beta_bound = if self.num_dataparallel_vars != 0 {
510 BetaValues::compute_beta_over_two_challenges(
511 g2_challenges,
512 &dataparallel_sumcheck_challenges,
513 )
514 } else {
515 F::ONE
516 };
517 beta_bound * random_coeff
518 })
519 .collect_vec();
520
521 let f_1_uv = compute_fully_bound_binary_gate_function(
522 &first_u_challenges,
523 &last_v_challenges,
524 &g1_challenges_vec,
525 &self.nonzero_gates,
526 &random_coefficients_scaled_by_beta_bound,
527 );
528
529 match self.gate_operation {
530 BinaryOperation::Add => PostSumcheckLayer(vec![
531 Product::<F, F>::new(std::slice::from_ref(lhs_mle), f_1_uv),
532 Product::<F, F>::new(std::slice::from_ref(rhs_mle), f_1_uv),
533 ]),
534 BinaryOperation::Mul => PostSumcheckLayer(vec![Product::<F, F>::new(
535 &[lhs_mle.clone(), rhs_mle.clone()],
536 f_1_uv,
537 )]),
538 }
539 }
540
541 fn get_claims(&self) -> Result<Vec<Claim<F>>> {
542 let lhs_reduced = self.phase_1_mles.clone().unwrap()[0][1].clone();
543 let rhs_reduced = self.phase_2_mles.clone().unwrap()[0][1].clone();
544
545 let mut claims = vec![];
546
547 let mut fixed_mle_indices_u: Vec<F> = vec![];
549 for index in lhs_reduced.mle_indices() {
550 fixed_mle_indices_u.push(
551 index
552 .val()
553 .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))?,
554 );
555 }
556 let val = lhs_reduced.value();
557 let claim: Claim<F> = Claim::new(
558 fixed_mle_indices_u,
559 val,
560 self.layer_id(),
561 self.lhs.layer_id(),
562 );
563 claims.push(claim);
564
565 let mut fixed_mle_indices_v: Vec<F> = vec![];
567 for index in rhs_reduced.mle_indices() {
568 fixed_mle_indices_v.push(
569 index
570 .val()
571 .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))?,
572 );
573 }
574 let val = rhs_reduced.value();
575 let claim: Claim<F> = Claim::new(
576 fixed_mle_indices_v,
577 val,
578 self.layer_id(),
579 self.rhs.layer_id(),
580 );
581 claims.push(claim);
582
583 Ok(claims)
584 }
585}
586
587#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
589#[serde(bound = "F: Field")]
590pub struct GateLayerDescription<F: Field> {
591 id: LayerId,
593
594 gate_operation: BinaryOperation,
596
597 nonzero_gates: Vec<(u32, u32, u32)>,
601
602 lhs_mle: MleDescription<F>,
605
606 rhs_mle: MleDescription<F>,
609
610 num_dataparallel_vars: usize,
613}
614
615impl<F: Field> GateLayerDescription<F> {
616 pub fn new(
618 num_dataparallel_vars: Option<usize>,
619 wiring: Vec<(u32, u32, u32)>,
620 lhs_circuit_mle: MleDescription<F>,
621 rhs_circuit_mle: MleDescription<F>,
622 gate_layer_id: LayerId,
623 gate_operation: BinaryOperation,
624 ) -> Self {
625 GateLayerDescription {
626 id: gate_layer_id,
627 gate_operation,
628 nonzero_gates: wiring,
629 lhs_mle: lhs_circuit_mle,
630 rhs_mle: rhs_circuit_mle,
631 num_dataparallel_vars: num_dataparallel_vars.unwrap_or(0),
632 }
633 }
634}
635
636const DATAPARALLEL_ROUND_MUL_NUM_EVALS: usize = 4;
641const DATAPARALLEL_ROUND_ADD_NUM_EVALS: usize = 3;
642const NON_DATAPARALLEL_ROUND_MUL_NUM_EVALS: usize = 3;
643const NON_DATAPARALLEL_ROUND_ADD_NUM_EVALS: usize = 3;
644
645impl<F: Field> LayerDescription<F> for GateLayerDescription<F> {
646 type VerifierLayer = VerifierGateLayer<F>;
647
648 fn layer_id(&self) -> LayerId {
650 self.id
651 }
652
653 fn verify_rounds(
654 &self,
655 claims: &[&RawClaim<F>],
656 transcript_reader: &mut impl VerifierTranscript<F>,
657 ) -> Result<VerifierLayerEnum<F>> {
658 let mut challenges = vec![];
660
661 let random_coefficients = match global_claim_agg_strategy() {
663 ClaimAggregationStrategy::Interpolative => {
664 assert_eq!(claims.len(), 1);
665 vec![F::ONE]
666 }
667 ClaimAggregationStrategy::RLC => {
668 transcript_reader.get_challenges("RLC Claim Agg Coefficients", claims.len())?
669 }
670 };
671
672 let num_u = self.lhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
676 acc + match idx {
677 MleIndex::Fixed(_) => 0,
678 _ => 1,
679 }
680 }) - self.num_dataparallel_vars;
681 let num_v = self.rhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
682 acc + match idx {
683 MleIndex::Fixed(_) => 0,
684 _ => 1,
685 }
686 }) - self.num_dataparallel_vars;
687
688 let mut sumcheck_messages: Vec<Vec<F>> = vec![];
690
691 let first_round_num_evals = match (self.gate_operation, self.num_dataparallel_vars) {
693 (BinaryOperation::Add, 0) => NON_DATAPARALLEL_ROUND_ADD_NUM_EVALS,
694 (BinaryOperation::Mul, 0) => NON_DATAPARALLEL_ROUND_MUL_NUM_EVALS,
695 (BinaryOperation::Add, _) => DATAPARALLEL_ROUND_ADD_NUM_EVALS,
696 (BinaryOperation::Mul, _) => DATAPARALLEL_ROUND_MUL_NUM_EVALS,
697 };
698 let first_round_sumcheck_messages = transcript_reader.consume_elements(
699 "Sumcheck round univariate evaluations",
700 first_round_num_evals,
701 )?;
702 sumcheck_messages.push(first_round_sumcheck_messages.clone());
703
704 match global_claim_agg_strategy() {
705 ClaimAggregationStrategy::Interpolative => {
706 if first_round_sumcheck_messages[0] + first_round_sumcheck_messages[1]
709 != claims[0].get_eval()
710 {
711 return Err(anyhow!(VerificationError::SumcheckStartFailed));
712 }
713 }
714 ClaimAggregationStrategy::RLC => {
715 let rlc_claim_eval = random_coefficients
716 .iter()
717 .zip(claims)
718 .fold(F::ZERO, |acc, (rlc_val, claim)| {
719 acc + *rlc_val * claim.get_eval()
720 });
721 if first_round_sumcheck_messages[0] + first_round_sumcheck_messages[1]
722 != rlc_claim_eval
723 {
724 return Err(anyhow!(VerificationError::SumcheckStartFailed));
725 }
726 }
727 }
728
729 for sumcheck_round_idx in 1..self.num_dataparallel_vars + num_u + num_v {
733 let challenge = transcript_reader
735 .get_challenge("Sumcheck round challenge")
736 .unwrap();
737 let g_i_minus_1_evals = sumcheck_messages[sumcheck_messages.len() - 1].clone();
738
739 let prev_at_r = evaluate_at_a_point(&g_i_minus_1_evals, challenge).unwrap();
741
742 let univariate_num_evals = match (
744 sumcheck_round_idx < self.num_dataparallel_vars, self.gate_operation,
746 ) {
747 (true, BinaryOperation::Add) => DATAPARALLEL_ROUND_ADD_NUM_EVALS,
748 (true, BinaryOperation::Mul) => DATAPARALLEL_ROUND_MUL_NUM_EVALS,
749 (false, BinaryOperation::Add) => NON_DATAPARALLEL_ROUND_ADD_NUM_EVALS,
750 (false, BinaryOperation::Mul) => NON_DATAPARALLEL_ROUND_MUL_NUM_EVALS,
751 };
752
753 let curr_evals = transcript_reader
754 .consume_elements(
755 "Sumcheck round univariate evaluations",
756 univariate_num_evals,
757 )
758 .unwrap();
759
760 if prev_at_r != curr_evals[0] + curr_evals[1] {
762 dbg!(&sumcheck_round_idx);
763 return Err(anyhow!(VerificationError::SumcheckFailed));
764 };
765
766 sumcheck_messages.push(curr_evals);
768 challenges.push(challenge);
770 }
771
772 let final_chal = transcript_reader
774 .get_challenge("Sumcheck round challenge")
775 .unwrap();
776 challenges.push(final_chal);
777
778 let verifier_gate_layer = self
781 .convert_into_verifier_layer(
782 &challenges,
783 &claims.iter().map(|claim| claim.get_point()).collect_vec(),
784 transcript_reader,
785 )
786 .unwrap();
787 let final_result = verifier_gate_layer.evaluate(
788 &claims.iter().map(|claim| claim.get_point()).collect_vec(),
789 &random_coefficients,
790 );
791
792 let g_n_evals = sumcheck_messages[sumcheck_messages.len() - 1].clone();
794 let prev_at_r = evaluate_at_a_point(&g_n_evals, final_chal).unwrap();
795
796 if final_result != prev_at_r {
798 return Err(anyhow!(VerificationError::FinalSumcheckFailed));
799 }
800
801 Ok(VerifierLayerEnum::Gate(verifier_gate_layer))
802 }
803
804 fn sumcheck_round_indices(&self) -> Vec<usize> {
805 let num_u = self.lhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
806 acc + match idx {
807 MleIndex::Fixed(_) => 0,
808 _ => 1,
809 }
810 }) - self.num_dataparallel_vars;
811 let num_v = self.rhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
812 acc + match idx {
813 MleIndex::Fixed(_) => 0,
814 _ => 1,
815 }
816 }) - self.num_dataparallel_vars;
817 (0..num_u + num_v + self.num_dataparallel_vars).collect_vec()
818 }
819
820 fn convert_into_verifier_layer(
821 &self,
822 sumcheck_bindings: &[F],
823 claim_points: &[&[F]],
824 transcript_reader: &mut impl VerifierTranscript<F>,
825 ) -> Result<Self::VerifierLayer> {
826 let num_u = self.lhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
830 acc + match idx {
831 MleIndex::Fixed(_) => 0,
832 _ => 1,
833 }
834 }) - self.num_dataparallel_vars;
835 let num_v = self.rhs_mle.var_indices().iter().fold(0_usize, |acc, idx| {
836 acc + match idx {
837 MleIndex::Fixed(_) => 0,
838 _ => 1,
839 }
840 }) - self.num_dataparallel_vars;
841
842 let mut sumcheck_bindings_vec = sumcheck_bindings.to_vec();
845 let last_v_challenges = sumcheck_bindings_vec.split_off(self.num_dataparallel_vars + num_u);
846 let first_u_challenges = sumcheck_bindings_vec.split_off(self.num_dataparallel_vars);
847 let dataparallel_challenges = sumcheck_bindings_vec;
848
849 assert_eq!(last_v_challenges.len(), num_v);
850
851 let lhs_challenges = dataparallel_challenges
853 .iter()
854 .chain(first_u_challenges.iter())
855 .copied()
856 .collect_vec();
857 let rhs_challenges = dataparallel_challenges
858 .iter()
859 .chain(last_v_challenges.iter())
860 .copied()
861 .collect_vec();
862
863 let lhs_verifier_mle = self
864 .lhs_mle
865 .into_verifier_mle(&lhs_challenges, transcript_reader)
866 .unwrap();
867 let rhs_verifier_mle = self
868 .rhs_mle
869 .into_verifier_mle(&rhs_challenges, transcript_reader)
870 .unwrap();
871
872 let verifier_gate_layer = VerifierGateLayer {
875 layer_id: self.layer_id(),
876 gate_operation: self.gate_operation,
877 wiring: self.nonzero_gates.clone(),
878 lhs_mle: lhs_verifier_mle,
879 rhs_mle: rhs_verifier_mle,
880 num_dataparallel_rounds: self.num_dataparallel_vars,
881 claim_challenge_points: claim_points
882 .iter()
883 .cloned()
884 .map(|claim| claim.to_vec())
885 .collect_vec(),
886 dataparallel_sumcheck_challenges: dataparallel_challenges,
887 first_u_challenges,
888 last_v_challenges,
889 };
890
891 Ok(verifier_gate_layer)
892 }
893
894 fn get_post_sumcheck_layer(
895 &self,
896 round_challenges: &[F],
897 claim_challenges: &[&[F]],
898 random_coefficients: &[F],
899 ) -> super::product::PostSumcheckLayer<F, Option<F>> {
900 let num_rounds_phase1 = self.lhs_mle.num_free_vars() - self.num_dataparallel_vars;
901
902 let g2_challenges_vec = claim_challenges
903 .iter()
904 .map(|claim_chal| &claim_chal[..self.num_dataparallel_vars])
905 .collect_vec();
906 let g1_challenges_vec = claim_challenges
907 .iter()
908 .map(|claim_chal| &claim_chal[self.num_dataparallel_vars..])
909 .collect_vec();
910
911 let dataparallel_sumcheck_challenges =
912 round_challenges[..self.num_dataparallel_vars].to_vec();
913 let first_u_challenges = round_challenges
914 [self.num_dataparallel_vars..self.num_dataparallel_vars + num_rounds_phase1]
915 .to_vec();
916 let last_v_challenges =
917 round_challenges[self.num_dataparallel_vars + num_rounds_phase1..].to_vec();
918 let random_coefficients_scaled_by_beta_bound = g2_challenges_vec
919 .iter()
920 .zip(random_coefficients)
921 .map(|(g2_challenges, random_coeff)| {
922 let beta_bound = if self.num_dataparallel_vars != 0 {
923 BetaValues::compute_beta_over_two_challenges(
924 g2_challenges,
925 &dataparallel_sumcheck_challenges,
926 )
927 } else {
928 F::ONE
929 };
930 beta_bound * random_coeff
931 })
932 .collect_vec();
933
934 let f_1_uv = compute_fully_bound_binary_gate_function(
935 &first_u_challenges,
936 &last_v_challenges,
937 &g1_challenges_vec,
938 &self.nonzero_gates,
939 &random_coefficients_scaled_by_beta_bound,
940 );
941 let lhs_challenges = &round_challenges[..self.num_dataparallel_vars + num_rounds_phase1];
942 let rhs_challenges = &round_challenges[..self.num_dataparallel_vars]
943 .iter()
944 .copied()
945 .chain(round_challenges[self.num_dataparallel_vars + num_rounds_phase1..].to_vec())
946 .collect_vec();
947
948 match self.gate_operation {
949 BinaryOperation::Add => PostSumcheckLayer(vec![
950 Product::<F, Option<F>>::new(
951 std::slice::from_ref(&self.lhs_mle),
952 f_1_uv,
953 lhs_challenges,
954 ),
955 Product::<F, Option<F>>::new(
956 std::slice::from_ref(&self.rhs_mle),
957 f_1_uv,
958 rhs_challenges,
959 ),
960 ]),
961 BinaryOperation::Mul => {
962 PostSumcheckLayer(vec![Product::<F, Option<F>>::new_from_mul_gate(
963 &[self.lhs_mle.clone(), self.rhs_mle.clone()],
964 f_1_uv,
965 &[lhs_challenges, rhs_challenges],
966 )])
967 }
968 }
969 }
970
971 fn max_degree(&self) -> usize {
972 match self.gate_operation {
973 BinaryOperation::Add => 2,
974 BinaryOperation::Mul => {
975 if self.num_dataparallel_vars != 0 {
976 3
977 } else {
978 2
979 }
980 }
981 }
982 }
983
984 fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
985 vec![&self.lhs_mle, &self.rhs_mle]
986 }
987
988 fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
989 let lhs_mle = self.lhs_mle.into_dense_mle(circuit_map);
990 let rhs_mle = self.rhs_mle.into_dense_mle(circuit_map);
991 let num_dataparallel_vars = if self.num_dataparallel_vars == 0 {
992 None
993 } else {
994 Some(self.num_dataparallel_vars)
995 };
996 let gate_layer = GateLayer::new(
997 num_dataparallel_vars,
998 self.nonzero_gates.clone(),
999 lhs_mle,
1000 rhs_mle,
1001 self.gate_operation,
1002 self.layer_id(),
1003 );
1004 gate_layer.into()
1005 }
1006
1007 fn index_mle_indices(&mut self, start_index: usize) {
1008 self.lhs_mle.index_mle_indices(start_index);
1009 self.rhs_mle.index_mle_indices(start_index);
1010 }
1011
1012 fn compute_data_outputs(
1013 &self,
1014 mle_outputs_necessary: &HashSet<&MleDescription<F>>,
1015 circuit_map: &mut CircuitEvalMap<F>,
1016 ) {
1017 assert_eq!(mle_outputs_necessary.len(), 1);
1018 let mle_output_necessary = mle_outputs_necessary.iter().next().unwrap();
1019
1020 let max_gate_val = self
1021 .nonzero_gates
1022 .iter()
1023 .fold(&0, |acc, (z, _, _)| std::cmp::max(acc, z));
1024
1025 let num_dataparallel_vals = 1 << (self.num_dataparallel_vars);
1028 let res_table_num_entries =
1029 ((max_gate_val + 1) * num_dataparallel_vals).next_power_of_two();
1030
1031 let lhs_data = circuit_map
1032 .get_data_from_circuit_mle(&self.lhs_mle)
1033 .unwrap();
1034 let rhs_data = circuit_map
1035 .get_data_from_circuit_mle(&self.rhs_mle)
1036 .unwrap();
1037
1038 let num_gate_outputs_per_dataparallel_instance = (max_gate_val + 1).next_power_of_two();
1039 let mut res_table = vec![F::ZERO; res_table_num_entries as usize];
1040 (0..num_dataparallel_vals).for_each(|idx| {
1041 self.nonzero_gates.iter().for_each(|(z_ind, x_ind, y_ind)| {
1042 let zero = F::ZERO;
1043 let f2_val = lhs_data
1044 .f
1045 .get(
1046 (idx * (1 << (lhs_data.num_vars() - self.num_dataparallel_vars)) + x_ind)
1047 as usize,
1048 )
1049 .unwrap_or(zero);
1050 let f3_val = rhs_data
1051 .f
1052 .get(
1053 (idx * (1 << (rhs_data.num_vars() - self.num_dataparallel_vars)) + y_ind)
1054 as usize,
1055 )
1056 .unwrap_or(zero);
1057 res_table[(num_gate_outputs_per_dataparallel_instance * idx + z_ind) as usize] +=
1058 self.gate_operation.perform_operation(f2_val, f3_val);
1059 });
1060 });
1061
1062 let output_data = MultilinearExtension::new(res_table);
1063 assert_eq!(
1064 output_data.num_vars(),
1065 mle_output_necessary.var_indices().len()
1066 );
1067
1068 circuit_map.add_node(CircuitLocation::new(self.layer_id(), vec![]), output_data);
1069 }
1070}
1071
1072impl<F: Field> VerifierGateLayer<F> {
1073 pub fn evaluate(&self, claims: &[&[F]], random_coefficients: &[F]) -> F {
1075 assert_eq!(random_coefficients.len(), claims.len());
1076 let scaled_random_coeffs = claims
1077 .iter()
1078 .zip(random_coefficients)
1079 .map(|(claim, random_coeff)| {
1080 let beta_bound = BetaValues::compute_beta_over_two_challenges(
1081 &claim[..self.num_dataparallel_rounds],
1082 &self.dataparallel_sumcheck_challenges,
1083 );
1084 beta_bound * random_coeff
1085 })
1086 .collect_vec();
1087
1088 let f_1_uv = compute_fully_bound_binary_gate_function(
1089 &self.first_u_challenges,
1090 &self.last_v_challenges,
1091 &claims
1092 .iter()
1093 .map(|claim| &claim[self.num_dataparallel_rounds..])
1094 .collect_vec(),
1095 &self.wiring,
1096 &scaled_random_coeffs,
1097 );
1098
1099 f_1_uv
1101 * self
1102 .gate_operation
1103 .perform_operation(self.lhs_mle.value(), self.rhs_mle.value())
1104 }
1105}
1106
1107#[derive(Serialize, Deserialize, Clone, Debug)]
1109#[serde(bound = "F: Field")]
1110pub struct VerifierGateLayer<F: Field> {
1111 layer_id: LayerId,
1113
1114 gate_operation: BinaryOperation,
1116
1117 wiring: Vec<(u32, u32, u32)>,
1121
1122 lhs_mle: VerifierMle<F>,
1125
1126 rhs_mle: VerifierMle<F>,
1129
1130 claim_challenge_points: Vec<Vec<F>>,
1132
1133 num_dataparallel_rounds: usize,
1135
1136 dataparallel_sumcheck_challenges: Vec<F>,
1138
1139 first_u_challenges: Vec<F>,
1141
1142 last_v_challenges: Vec<F>,
1144}
1145
1146impl<F: Field> VerifierLayer<F> for VerifierGateLayer<F> {
1147 fn layer_id(&self) -> LayerId {
1148 self.layer_id
1149 }
1150
1151 fn get_claims(&self) -> Result<Vec<Claim<F>>> {
1152 let lhs_vars = self.lhs_mle.var_indices();
1155 let lhs_point = lhs_vars
1156 .iter()
1157 .map(|idx| match idx {
1158 MleIndex::Bound(chal, _bit_idx) => *chal,
1159 MleIndex::Fixed(val) => {
1160 if *val {
1161 F::ONE
1162 } else {
1163 F::ZERO
1164 }
1165 }
1166 _ => panic!("Error: Not fully bound"),
1167 })
1168 .collect_vec();
1169 let lhs_val = self.lhs_mle.value();
1170
1171 let lhs_claim: Claim<F> =
1172 Claim::new(lhs_point, lhs_val, self.layer_id(), self.lhs_mle.layer_id());
1173
1174 let rhs_vars: &[MleIndex<F>] = self.rhs_mle.var_indices();
1177 let rhs_point = rhs_vars
1178 .iter()
1179 .map(|idx| match idx {
1180 MleIndex::Bound(chal, _bit_idx) => *chal,
1181 MleIndex::Fixed(val) => {
1182 if *val {
1183 F::ONE
1184 } else {
1185 F::ZERO
1186 }
1187 }
1188 _ => panic!("Error: Not fully bound"),
1189 })
1190 .collect_vec();
1191 let rhs_val = self.rhs_mle.value();
1192
1193 let rhs_claim: Claim<F> =
1194 Claim::new(rhs_point, rhs_val, self.layer_id(), self.rhs_mle.layer_id());
1195
1196 Ok(vec![lhs_claim, rhs_claim])
1197 }
1198}
1199
1200impl<F: Field> GateLayer<F> {
1201 pub fn new(
1218 num_dataparallel_vars: Option<usize>,
1219 nonzero_gates: Vec<(u32, u32, u32)>,
1220 lhs: DenseMle<F>,
1221 rhs: DenseMle<F>,
1222 gate_operation: BinaryOperation,
1223 layer_id: LayerId,
1224 ) -> Self {
1225 let num_dataparallel_vars = num_dataparallel_vars.unwrap_or(0);
1226 let num_rounds_phase1 = lhs.num_free_vars() - num_dataparallel_vars;
1227
1228 GateLayer {
1229 num_dataparallel_vars,
1230 nonzero_gates,
1231 lhs,
1232 rhs,
1233 layer_id,
1234 phase_1_mles: None,
1235 phase_2_mles: None,
1236 gate_operation,
1237 beta_g2_vec: None,
1238 g_vec: None,
1239 num_rounds_phase1,
1240 }
1241 }
1242
1243 fn init_phase_1(&mut self, challenges: Vec<F>) {
1247 let beta_g2_fully_bound = self.beta_g2_vec.as_ref().unwrap()[0].fold_updated_values();
1248
1249 let (a_hg_lhs_vec, a_hg_rhs_vec) = fold_binary_gate_wiring_into_mles_phase_1(
1250 &self.nonzero_gates,
1251 &[&challenges],
1252 &self.lhs,
1253 &self.rhs,
1254 &[beta_g2_fully_bound],
1255 self.gate_operation,
1256 );
1257
1258 let mut phase_1_mles = match self.gate_operation {
1262 BinaryOperation::Add => {
1263 vec![
1264 vec![
1265 DenseMle::new_from_raw(a_hg_lhs_vec, LayerId::Input(0)),
1266 self.lhs.clone(),
1267 ],
1268 vec![DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0))],
1269 ]
1270 }
1271 BinaryOperation::Mul => {
1272 vec![vec![
1273 DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0)),
1274 self.lhs.clone(),
1275 ]]
1276 }
1277 };
1278
1279 phase_1_mles.iter_mut().for_each(|mle_vec| {
1280 index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1281 });
1282 self.phase_1_mles = Some(phase_1_mles);
1283 }
1284
1285 fn init_phase_1_rlc(&mut self, challenges: &[&[F]], random_coefficients: &[F]) {
1286 let random_coefficients_scaled_by_beta_g2 = self
1287 .beta_g2_vec
1288 .as_ref()
1289 .unwrap()
1290 .iter()
1291 .zip(random_coefficients)
1292 .map(|(beta_values, random_coeff)| {
1293 assert!(beta_values.is_fully_bounded());
1294 beta_values.fold_updated_values() * random_coeff
1295 })
1296 .collect_vec();
1297
1298 let (a_hg_lhs_vec, a_hg_rhs_vec) = fold_binary_gate_wiring_into_mles_phase_1(
1299 &self.nonzero_gates,
1300 challenges,
1301 &self.lhs,
1302 &self.rhs,
1303 &random_coefficients_scaled_by_beta_g2,
1304 self.gate_operation,
1305 );
1306
1307 let mut phase_1_mles = match self.gate_operation {
1311 BinaryOperation::Add => {
1312 vec![
1313 vec![
1314 DenseMle::new_from_raw(a_hg_lhs_vec, LayerId::Input(0)),
1315 self.lhs.clone(),
1316 ],
1317 vec![DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0))],
1318 ]
1319 }
1320 BinaryOperation::Mul => {
1321 vec![vec![
1322 DenseMle::new_from_raw(a_hg_rhs_vec, LayerId::Input(0)),
1323 self.lhs.clone(),
1324 ]]
1325 }
1326 };
1327
1328 phase_1_mles.iter_mut().for_each(|mle_vec| {
1329 index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1330 });
1331 self.phase_1_mles = Some(phase_1_mles);
1332 }
1333
1334 fn init_phase_2(&mut self, u_claim: &[F], f_at_u: F, g1_claim_points: &[F]) {
1338 let beta_g2_fully_bound = self.beta_g2_vec.as_ref().unwrap()[0].fold_updated_values();
1339
1340 let (a_f1_lhs, a_f1_rhs) = fold_binary_gate_wiring_into_mles_phase_2(
1341 &self.nonzero_gates,
1342 f_at_u,
1343 u_claim,
1344 &[g1_claim_points],
1345 &[beta_g2_fully_bound],
1346 self.rhs.num_free_vars(),
1347 self.gate_operation,
1348 );
1349
1350 let mut phase_2_mles = match self.gate_operation {
1352 BinaryOperation::Add => {
1353 vec![
1354 vec![
1355 DenseMle::new_from_raw(a_f1_rhs, LayerId::Input(0)),
1356 self.rhs.clone(),
1357 ],
1358 vec![DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0))],
1359 ]
1360 }
1361 BinaryOperation::Mul => {
1362 vec![vec![
1363 DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0)),
1364 self.rhs.clone(),
1365 ]]
1366 }
1367 };
1368
1369 phase_2_mles.iter_mut().for_each(|mle_vec| {
1370 index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1371 });
1372 self.phase_2_mles = Some(phase_2_mles);
1373 }
1374
1375 fn init_phase_2_rlc(
1376 &mut self,
1377 u_claim: &[F],
1378 f_at_u: F,
1379 g1_claim_points: &[&[F]],
1380 random_coefficients: &[F],
1381 ) {
1382 let random_coefficients_scaled_by_beta_g2 = self
1383 .beta_g2_vec
1384 .as_ref()
1385 .unwrap()
1386 .iter()
1387 .zip(random_coefficients)
1388 .map(|(beta_values, random_coeff)| {
1389 assert!(beta_values.is_fully_bounded());
1390 beta_values.fold_updated_values() * random_coeff
1391 })
1392 .collect_vec();
1393
1394 let (a_f1_lhs, a_f1_rhs) = fold_binary_gate_wiring_into_mles_phase_2(
1395 &self.nonzero_gates,
1396 f_at_u,
1397 u_claim,
1398 g1_claim_points,
1399 &random_coefficients_scaled_by_beta_g2,
1400 self.rhs.num_free_vars(),
1401 self.gate_operation,
1402 );
1403
1404 let mut phase_2_mles = match self.gate_operation {
1406 BinaryOperation::Add => {
1407 vec![
1408 vec![
1409 DenseMle::new_from_raw(a_f1_rhs, LayerId::Input(0)),
1410 self.rhs.clone(),
1411 ],
1412 vec![DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0))],
1413 ]
1414 }
1415 BinaryOperation::Mul => {
1416 vec![vec![
1417 DenseMle::new_from_raw(a_f1_lhs, LayerId::Input(0)),
1418 self.rhs.clone(),
1419 ]]
1420 }
1421 };
1422
1423 phase_2_mles.iter_mut().for_each(|mle_vec| {
1424 index_mle_indices_gate(mle_vec, self.num_dataparallel_vars);
1425 });
1426 self.phase_2_mles = Some(phase_2_mles.clone());
1427 }
1428}
1429
1430pub fn compute_gate_data_outputs<F: Field>(
1448 wiring: Vec<(u32, u32, u32)>,
1449 num_dataparallel_bits: usize,
1450 lhs_data: &MultilinearExtension<F>,
1451 rhs_data: &MultilinearExtension<F>,
1452 gate_operation: BinaryOperation,
1453) -> MultilinearExtension<F> {
1454 let max_gate_val = wiring
1455 .iter()
1456 .fold(&0, |acc, (z, _, _)| std::cmp::max(acc, z));
1457
1458 let num_dataparallel_vals = 1 << num_dataparallel_bits;
1461 let res_table_num_entries = ((max_gate_val + 1) * num_dataparallel_vals).next_power_of_two();
1462 let num_gate_outputs_per_dataparallel_instance = (max_gate_val + 1).next_power_of_two();
1463
1464 let mut res_table = vec![F::ZERO; res_table_num_entries as usize];
1465 (0..num_dataparallel_vals).for_each(|idx| {
1467 wiring.iter().for_each(|(z_ind, x_ind, y_ind)| {
1468 let zero = F::ZERO;
1469 let f2_val = lhs_data
1470 .f
1471 .get((idx * (1 << (lhs_data.num_vars() - num_dataparallel_bits)) + x_ind) as usize)
1472 .unwrap_or(zero);
1473 let f3_val = rhs_data
1474 .f
1475 .get((idx * (1 << (rhs_data.num_vars() - num_dataparallel_bits)) + y_ind) as usize)
1476 .unwrap_or(zero);
1477 res_table[(num_gate_outputs_per_dataparallel_instance * idx + z_ind) as usize] +=
1478 gate_operation.perform_operation(f2_val, f3_val);
1479 });
1480 });
1481
1482 MultilinearExtension::new(res_table)
1483}