1#![allow(clippy::type_complexity)]
3
4pub mod helpers;
6
7pub mod proof_system;
9
10pub mod layers;
12
13use std::collections::{HashMap, HashSet};
14
15use self::layers::Layers;
16use crate::circuit_layout::{CircuitEvalMap, CircuitLocation};
17use crate::claims::claim_aggregation::{prover_aggregate_claims, verifier_aggregate_claims};
18use crate::claims::{Claim, ClaimTracker};
19use crate::expression::circuit_expr::filter_bookkeeping_table;
20use crate::input_layer::fiat_shamir_challenge::{
21 FiatShamirChallenge, FiatShamirChallengeDescription,
22};
23use crate::input_layer::{InputLayer, InputLayerDescription};
24use crate::layer::layer_enum::{LayerDescriptionEnum, VerifierLayerEnum};
25use crate::layer::{layer_enum::LayerEnum, LayerId};
26use crate::layer::{Layer, LayerDescription, VerifierLayer};
27use crate::mle::dense::DenseMle;
28use crate::mle::evals::MultilinearExtension;
29use crate::mle::mle_description::MleDescription;
30use crate::mle::mle_enum::MleEnum;
31use crate::output_layer::{OutputLayer, OutputLayerDescription};
32use crate::provable_circuit::ProvableCircuit;
33use crate::utils::mle::verify_claim;
34use ark_std::{end_timer, start_timer};
35use itertools::Itertools;
36use serde::{Deserialize, Serialize};
37use shared_types::config::global_config::global_claim_agg_strategy;
38use shared_types::config::ClaimAggregationStrategy;
39use shared_types::transcript::{ProverTranscript, TranscriptWriter};
40use shared_types::transcript::{TranscriptSponge, VerifierTranscript};
41use shared_types::{Field, Halo2FFTFriendlyField};
42use thiserror::Error;
43use tracing::{debug, info};
44use tracing::{instrument, span, Level};
45
46use anyhow::{anyhow, Result};
47
48#[derive(Error, Debug, Clone)]
50pub enum GKRError {
51 #[error("No claims were found for layer {0:?}")]
52 NoClaimsForLayer(LayerId),
54 #[error("Error when proving layer {0:?}")]
55 ErrorWhenProvingLayer(LayerId),
57 #[error("Error when verifying layer {0:?}")]
58 ErrorWhenVerifyingLayer(LayerId),
60 #[error("Evaluation of input layer {0:?} doesn't match value of a claim originating from layer {0:?}.")]
62 EvaluationMismatch(LayerId, LayerId),
63 #[error("Values for public input layer {0:?} were not as expected by the verifier.")]
65 PublicInputLayerValuesMismatch(LayerId),
66 #[error("Verifier's claim tracker was not empty at the end of the verification process.")]
68 ClaimTrackerNotEmpty,
69
70 #[error("Error when verifying output layer")]
71 ErrorWhenVerifyingOutputLayer,
73 #[error("InputShred with NodeId {0} should have {1} variables, but has {2}")]
75 InputShredLengthMismatch(usize, usize, usize),
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct SumcheckProof<F>(pub Vec<Vec<F>>);
82
83impl<F: Field> From<Vec<Vec<F>>> for SumcheckProof<F> {
84 fn from(value: Vec<Vec<F>>) -> Self {
85 Self(value)
86 }
87}
88
89#[derive(Debug)]
91pub struct InstantiatedCircuit<F: Field> {
92 pub layers: Layers<F, LayerEnum<F>>,
94 pub output_layers: Vec<OutputLayer<F>>,
96 pub input_layers: Vec<InputLayer<F>>,
98 pub fiat_shamir_challenges: Vec<FiatShamirChallenge<F>>,
100 pub layer_map: HashMap<LayerId, Vec<DenseMle<F>>>,
102}
103
104pub fn prove_circuit<F: Halo2FFTFriendlyField, Tr: TranscriptSponge<F>>(
107 provable_circuit: &ProvableCircuit<F>,
108 transcript_writer: &mut TranscriptWriter<F, Tr>,
109) -> Result<Vec<Claim<F>>> {
110 let mut challenge_sampler =
115 |size| transcript_writer.get_challenges("Verifier challenges", size);
116 let instantiated_circuit = provable_circuit
117 .get_gkr_circuit_description_ref()
118 .instantiate(provable_circuit.get_inputs_ref(), &mut challenge_sampler);
119
120 let InstantiatedCircuit {
121 input_layers,
122 mut output_layers,
123 layers,
124 fiat_shamir_challenges: _fiat_shamir_challenges,
125 mut layer_map,
126 } = instantiated_circuit;
127
128 let mut claim_tracker = ClaimTracker::new();
130
131 let claims_timer = start_timer!(|| "Output claims generation");
133 let output_claims_span = span!(Level::DEBUG, "output_claims_span").entered();
134
135 for output in output_layers.iter_mut() {
137 let layer_id = output.layer_id();
138 info!("Output Layer: {layer_id:?}");
139
140 match output.get_mle() {
141 MleEnum::Dense(_) => {
142 panic!("We don't support DenseMLE as output layers for now")
143 }
144 MleEnum::Zero(_) => {
146 transcript_writer.append_elements("Output layer MLE evals", &[F::ZERO])
147 }
148 };
149
150 let challenges = transcript_writer
151 .get_challenges("Challenge on the output layer", output.num_free_vars());
152 output.fix_layer(&challenges)?;
153
154 let claim = output.get_claim()?;
155 claim_tracker.insert(claim.get_to_layer_id(), claim);
156 }
157
158 end_timer!(claims_timer);
159 output_claims_span.exit();
160
161 let intermediate_layers_timer = start_timer!(|| "ALL intermediate layers proof generation");
163 let all_layers_sumcheck_proving_span =
164 span!(Level::DEBUG, "all_layers_sumcheck_proving_span").entered();
165
166 for mut layer in layers.layers.into_iter().rev() {
170 let layer_id = layer.layer_id();
171 let layer_timer = start_timer!(|| format!("Generating proof for layer {layer_id:?}"));
172 info!("Proving Intermediate Layer: {layer_id:?}");
173
174 info!("Starting claim aggregation...");
175
176 let output_mles_from_layer = layer_map.remove(&layer_id).unwrap();
177 let layer_claims = claim_tracker.get(layer_id).unwrap();
178
179 if let LayerEnum::MatMult(_) = layer {
181 let claim_aggr_timer =
182 start_timer!(|| format!("Claim aggregation for layer {layer_id:?}"));
183 let layer_claim =
184 prover_aggregate_claims(layer_claims, output_mles_from_layer, transcript_writer)?;
185 end_timer!(claim_aggr_timer);
186
187 info!("Prove sumcheck message");
188 let sumcheck_msg_timer = start_timer!(|| format!(
189 "Compute sumcheck message for layer {:?}",
190 layer.layer_id()
191 ));
192
193 layer.prove(&[&layer_claim], transcript_writer)?;
195
196 end_timer!(sumcheck_msg_timer);
197 }
198 else {
200 match global_claim_agg_strategy() {
201 ClaimAggregationStrategy::Interpolative => {
202 let claim_aggr_timer =
203 start_timer!(|| format!("Claim aggregation for layer {layer_id:?}"));
204 let layer_claim = prover_aggregate_claims(
205 layer_claims,
206 output_mles_from_layer,
207 transcript_writer,
208 )?;
209 end_timer!(claim_aggr_timer);
210
211 info!("Prove sumcheck message");
212 let sumcheck_msg_timer = start_timer!(|| format!(
213 "Compute sumcheck message for layer {:?}",
214 layer.layer_id()
215 ));
216
217 layer.prove(&[&layer_claim], transcript_writer)?;
219
220 end_timer!(sumcheck_msg_timer);
221 }
222 ClaimAggregationStrategy::RLC => {
223 let sumcheck_msg_timer = start_timer!(|| format!(
224 "Compute sumcheck message for layer {:?}",
225 layer.layer_id()
226 ));
227
228 layer.prove(
229 &layer_claims
230 .iter()
231 .map(|claim| claim.get_raw_claim())
232 .collect_vec(),
233 transcript_writer,
234 )?;
235 end_timer!(sumcheck_msg_timer);
236 }
237 }
238 }
239
240 for claim in layer.get_claims()? {
241 claim_tracker.insert(claim.get_to_layer_id(), claim);
242 }
243
244 end_timer!(layer_timer);
245 }
246
247 end_timer!(intermediate_layers_timer);
248 all_layers_sumcheck_proving_span.exit();
249
250 let input_layer_claims = input_layers
251 .iter()
252 .filter_map(|input_layer| claim_tracker.get(input_layer.layer_id))
253 .flatten()
254 .cloned()
255 .collect_vec();
256
257 Ok(input_layer_claims)
258}
259
260#[derive(Debug, Serialize, Deserialize, Hash, Clone)]
263#[serde(bound = "F: Field")]
264pub struct GKRCircuitDescription<F: Field> {
265 pub input_layers: Vec<InputLayerDescription>,
267 pub fiat_shamir_challenges: Vec<FiatShamirChallengeDescription<F>>,
269 pub intermediate_layers: Vec<LayerDescriptionEnum<F>>,
271 pub output_layers: Vec<OutputLayerDescription<F>>,
273}
274
275impl<F: Field> GKRCircuitDescription<F> {
276 pub fn index_mle_indices(&mut self, start_index: usize) {
279 let GKRCircuitDescription {
280 input_layers: _,
281 fiat_shamir_challenges: _,
282 intermediate_layers,
283 output_layers,
284 } = self;
285 intermediate_layers
286 .iter_mut()
287 .for_each(|intermediate_layer| {
288 intermediate_layer.index_mle_indices(start_index);
289 });
290 output_layers.iter_mut().for_each(|output_layer| {
291 output_layer.index_mle_indices(start_index);
292 })
293 }
294
295 pub fn instantiate(
303 &self,
304 input_data: &HashMap<LayerId, MultilinearExtension<F>>,
305 challenge_sampler: &mut impl FnMut(usize) -> Vec<F>,
306 ) -> InstantiatedCircuit<F> {
307 let GKRCircuitDescription {
308 input_layers: input_layer_descriptions,
309 fiat_shamir_challenges: fiat_shamir_challenge_descriptions,
310 intermediate_layers: intermediate_layer_descriptions,
311 output_layers: output_layer_descriptions,
312 } = self;
313
314 let mut mle_claim_map = HashMap::<LayerId, HashSet<&MleDescription<F>>>::new();
321 intermediate_layer_descriptions
325 .iter()
326 .for_each(|intermediate_layer| {
327 let layer_source_circuit_mles = intermediate_layer.get_circuit_mles();
328 layer_source_circuit_mles
329 .into_iter()
330 .for_each(|circuit_mle| {
331 let layer_id = circuit_mle.layer_id();
332 mle_claim_map
333 .entry(layer_id)
334 .or_default()
335 .insert(circuit_mle);
336 })
337 });
338
339 output_layer_descriptions.iter().for_each(|output_layer| {
343 let layer_source_mle = &output_layer.mle;
344 let layer_id = layer_source_mle.layer_id();
345 mle_claim_map
346 .entry(layer_id)
347 .or_default()
348 .insert(&output_layer.mle);
349 });
350
351 let mut circuit_map = CircuitEvalMap::new();
354 let mut prover_input_layers: Vec<InputLayer<F>> = Vec::new();
355 let mut fiat_shamir_challenges = Vec::new();
356 input_layer_descriptions
360 .iter()
361 .for_each(|input_layer_description| {
362 let input_layer_id = input_layer_description.layer_id;
363 let combined_mle = input_data.get(&input_layer_id).unwrap();
364 let mle_outputs_necessary = mle_claim_map.get(&input_layer_id).unwrap();
365 mle_outputs_necessary.iter().for_each(|mle_output| {
368 let prefix_bits = mle_output.prefix_bits();
369 let output = filter_bookkeeping_table(combined_mle, &prefix_bits);
370 circuit_map.add_node(CircuitLocation::new(input_layer_id, prefix_bits), output);
371 });
372 let prover_input_layer = InputLayer {
373 mle: combined_mle.clone(),
374 layer_id: input_layer_id,
375 };
376 prover_input_layers.push(prover_input_layer);
377 });
378 fiat_shamir_challenge_descriptions
381 .iter()
382 .for_each(|fiat_shamir_challenge_description| {
383 let fiat_shamir_challenge_mle = MultilinearExtension::new(challenge_sampler(
384 1 << fiat_shamir_challenge_description.num_bits,
385 ));
386 circuit_map.add_node(
387 CircuitLocation::new(fiat_shamir_challenge_description.layer_id(), vec![]),
388 fiat_shamir_challenge_mle.clone(),
389 );
390 fiat_shamir_challenges.push(FiatShamirChallenge {
391 mle: fiat_shamir_challenge_mle,
392 layer_id: fiat_shamir_challenge_description.layer_id(),
393 });
394 });
395
396 intermediate_layer_descriptions
400 .iter()
401 .for_each(|intermediate_layer_description| {
402 let mle_outputs_necessary = mle_claim_map
403 .get(&intermediate_layer_description.layer_id())
404 .unwrap();
405 intermediate_layer_description
406 .compute_data_outputs(mle_outputs_necessary, &mut circuit_map);
407 });
408
409 let mut prover_intermediate_layers: Vec<LayerEnum<F>> =
413 Vec::with_capacity(intermediate_layer_descriptions.len());
414 intermediate_layer_descriptions
415 .iter()
416 .for_each(|intermediate_layer_description| {
417 let prover_intermediate_layer =
418 intermediate_layer_description.convert_into_prover_layer(&circuit_map);
419 prover_intermediate_layers.push(prover_intermediate_layer)
420 });
421
422 let mut prover_output_layers: Vec<OutputLayer<F>> = Vec::new();
424 output_layer_descriptions
425 .iter()
426 .for_each(|output_layer_description| {
427 let prover_output_layer =
428 output_layer_description.into_prover_output_layer(&circuit_map);
429 prover_output_layers.push(prover_output_layer)
430 });
431
432 InstantiatedCircuit {
433 input_layers: prover_input_layers,
434 fiat_shamir_challenges,
435 layers: Layers::new_with_layers(prover_intermediate_layers),
436 output_layers: prover_output_layers,
437 layer_map: circuit_map.convert_to_layer_map(),
438 }
439 }
440
441 #[instrument(skip_all, err)]
447 pub fn verify(
448 &self,
449 transcript_reader: &mut impl VerifierTranscript<F>,
450 ) -> Result<Vec<Claim<F>>> {
451 let fiat_shamir_challenges: Vec<FiatShamirChallenge<F>> = self
453 .fiat_shamir_challenges
454 .iter()
455 .map(|fs_desc| {
456 let values = transcript_reader
457 .get_challenges("Verifier challenges", 1 << fs_desc.num_bits)
458 .unwrap();
459 fs_desc.instantiate(values)
460 })
461 .collect();
462
463 let mut claim_tracker = ClaimTracker::new();
465
466 let claims_timer = start_timer!(|| "Output claims generation");
468 let verifier_output_claims_span =
469 span!(Level::DEBUG, "verifier_output_claims_span").entered();
470
471 for circuit_output_layer in self.output_layers.iter() {
472 let layer_id = circuit_output_layer.layer_id();
473 info!("Verifying Output Layer: {layer_id:?}");
474
475 let verifier_output_layer = circuit_output_layer
476 .retrieve_mle_from_transcript_and_fix_layer(transcript_reader)?;
477
478 let claim = verifier_output_layer.get_claim()?;
479 claim_tracker.insert(claim.get_to_layer_id(), claim);
480 }
481
482 end_timer!(claims_timer);
483 verifier_output_claims_span.exit();
484
485 let intermediate_layers_timer =
487 start_timer!(|| "ALL intermediate layers proof verification");
488
489 for layer in self.intermediate_layers.iter().rev() {
490 let layer_id = layer.layer_id();
491
492 info!("Intermediate Layer: {layer_id:?}");
493 let layer_timer = start_timer!(|| format!("Proof verification for layer {layer_id:?}"));
494
495 let layer_claims = claim_tracker.remove(layer_id).unwrap();
496
497 let verifier_layer = match global_claim_agg_strategy() {
498 ClaimAggregationStrategy::Interpolative => {
499 let claim_aggr_timer =
500 start_timer!(|| format!("Claim aggregation for layer {layer_id:?}"));
501 let prev_claim = verifier_aggregate_claims(&layer_claims, transcript_reader)?;
502 debug!("Aggregated claim: {:#?}", prev_claim);
503 end_timer!(claim_aggr_timer);
504
505 info!("Prove sumcheck message");
506 let sumcheck_msg_timer = start_timer!(|| format!(
507 "Compute sumcheck message for layer {:?}",
508 layer.layer_id()
509 ));
510
511 let verifier_layer: VerifierLayerEnum<F> =
513 layer.verify_rounds(&[&prev_claim], transcript_reader)?;
514
515 end_timer!(sumcheck_msg_timer);
516
517 verifier_layer
518 }
519 ClaimAggregationStrategy::RLC => {
520 let sumcheck_msg_timer = start_timer!(|| format!(
521 "Compute sumcheck message for layer {:?}",
522 layer.layer_id()
523 ));
524
525 let verifier_layer = layer.verify_rounds(
526 &layer_claims
527 .iter()
528 .map(|claim| claim.get_raw_claim())
529 .collect_vec(),
530 transcript_reader,
531 )?;
532 end_timer!(sumcheck_msg_timer);
533
534 verifier_layer
535 }
536 };
537
538 for claim in verifier_layer.get_claims()? {
539 claim_tracker.insert(claim.get_to_layer_id(), claim);
540 }
541
542 end_timer!(layer_timer);
543 }
544
545 end_timer!(intermediate_layers_timer);
546
547 let fiat_shamir_challenges_timer = start_timer!(|| "Verifier challenges proof generation");
549 for fiat_shamir_challenge in fiat_shamir_challenges {
550 if let Some(claims) = claim_tracker.remove(fiat_shamir_challenge.layer_id()) {
551 claims.iter().for_each(|claim| {
552 verify_claim(&fiat_shamir_challenge.mle.to_vec(), claim.get_raw_claim());
553 });
554 } else {
555 return Err(anyhow!(GKRError::NoClaimsForLayer(
556 fiat_shamir_challenge.layer_id()
557 )));
558 }
559 }
560 end_timer!(fiat_shamir_challenges_timer);
561
562 let input_layer_claims = self
563 .input_layers
564 .iter()
565 .flat_map(|input_layer| claim_tracker.remove(input_layer.layer_id).unwrap())
566 .collect_vec();
567
568 if !claim_tracker.is_empty() {
570 return Err(anyhow!(GKRError::ClaimTrackerNotEmpty));
571 }
572
573 Ok(input_layer_claims)
574 }
575}