remainder/
prover.rs

1//! Modules that orchestrate creating a GKR Proof
2#![allow(clippy::type_complexity)]
3
4/// Includes boilerplate for creating a GKR circuit, i.e. creating a transcript, proving, verifying, etc.
5pub mod helpers;
6
7/// Includes various traits that define interfaces of a GKR Prover
8pub mod proof_system;
9
10/// Struct for representing a list of layers
11pub mod layers;
12
13use std::collections::{HashMap, HashSet};
14
15use self::layers::Layers;
16use crate::circuit_layout::{CircuitEvalMap, CircuitLocation};
17use crate::claims::claim_aggregation::{prover_aggregate_claims, verifier_aggregate_claims};
18use crate::claims::{Claim, ClaimTracker};
19use crate::expression::circuit_expr::filter_bookkeeping_table;
20use crate::input_layer::fiat_shamir_challenge::{
21    FiatShamirChallenge, FiatShamirChallengeDescription,
22};
23use crate::input_layer::{InputLayer, InputLayerDescription};
24use crate::layer::layer_enum::{LayerDescriptionEnum, VerifierLayerEnum};
25use crate::layer::{layer_enum::LayerEnum, LayerId};
26use crate::layer::{Layer, LayerDescription, VerifierLayer};
27use crate::mle::dense::DenseMle;
28use crate::mle::evals::MultilinearExtension;
29use crate::mle::mle_description::MleDescription;
30use crate::mle::mle_enum::MleEnum;
31use crate::output_layer::{OutputLayer, OutputLayerDescription};
32use crate::provable_circuit::ProvableCircuit;
33use crate::utils::mle::verify_claim;
34use ark_std::{end_timer, start_timer};
35use itertools::Itertools;
36use serde::{Deserialize, Serialize};
37use shared_types::config::global_config::global_claim_agg_strategy;
38use shared_types::config::ClaimAggregationStrategy;
39use shared_types::transcript::{ProverTranscript, TranscriptWriter};
40use shared_types::transcript::{TranscriptSponge, VerifierTranscript};
41use shared_types::{Field, Halo2FFTFriendlyField};
42use thiserror::Error;
43use tracing::{debug, info};
44use tracing::{instrument, span, Level};
45
46use anyhow::{anyhow, Result};
47
48/// Errors that can be generated during GKR proving.
49#[derive(Error, Debug, Clone)]
50pub enum GKRError {
51    #[error("No claims were found for layer {0:?}")]
52    /// No claims were found for layer
53    NoClaimsForLayer(LayerId),
54    #[error("Error when proving layer {0:?}")]
55    /// Error when proving layer
56    ErrorWhenProvingLayer(LayerId),
57    #[error("Error when verifying layer {0:?}")]
58    /// Error when verifying layer
59    ErrorWhenVerifyingLayer(LayerId),
60    /// The evaluation of the input layer doesn't match the value of the claim from the layer
61    #[error("Evaluation of input layer {0:?} doesn't match value of a claim originating from layer {0:?}.")]
62    EvaluationMismatch(LayerId, LayerId),
63    /// The public input layer values were not as expected by the verifier
64    #[error("Values for public input layer {0:?} were not as expected by the verifier.")]
65    PublicInputLayerValuesMismatch(LayerId),
66    /// The verifier's claim tracker was not empty at the end of the verification process
67    #[error("Verifier's claim tracker was not empty at the end of the verification process.")]
68    ClaimTrackerNotEmpty,
69
70    #[error("Error when verifying output layer")]
71    /// Error when verifying output layer
72    ErrorWhenVerifyingOutputLayer,
73    /// InputShred length mismatch
74    #[error("InputShred with NodeId {0} should have {1} variables, but has {2}")]
75    InputShredLengthMismatch(usize, usize, usize),
76}
77
78/// A proof of the sumcheck protocol; Outer vec is rounds, inner vec is evaluations
79/// this inner vec is none if there is no sumcheck proof
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct SumcheckProof<F>(pub Vec<Vec<F>>);
82
83impl<F: Field> From<Vec<Vec<F>>> for SumcheckProof<F> {
84    fn from(value: Vec<Vec<F>>) -> Self {
85        Self(value)
86    }
87}
88
89/// The witness of a GKR circuit, used to actually prove the circuit
90#[derive(Debug)]
91pub struct InstantiatedCircuit<F: Field> {
92    /// The intermediate layers of the circuit
93    pub layers: Layers<F, LayerEnum<F>>,
94    /// The output layers of the circuit
95    pub output_layers: Vec<OutputLayer<F>>,
96    /// The input layers of the circuit
97    pub input_layers: Vec<InputLayer<F>>,
98    /// The verifier challenges
99    pub fiat_shamir_challenges: Vec<FiatShamirChallenge<F>>,
100    /// Maps LayerId to the MLE of its values
101    pub layer_map: HashMap<LayerId, Vec<DenseMle<F>>>,
102}
103
104/// Assumes that the inputs have already been added to the transcript (if necessary).
105/// Returns the vector of claims on the input layers.
106pub fn prove_circuit<F: Halo2FFTFriendlyField, Tr: TranscriptSponge<F>>(
107    provable_circuit: &ProvableCircuit<F>,
108    transcript_writer: &mut TranscriptWriter<F, Tr>,
109) -> Result<Vec<Claim<F>>> {
110    // Note: no need to return the Transcript, since it is already in the TranscriptWriter!
111    // Note(Ben): this can't be an instance method, because it consumes the intermediate layers!
112    // Note(Ben): this is a GKR specific method.  So it makes sense for IT to define the challenge sampler, so that the circuit can be instantiated (rather than leaving this complexity to the calling context).
113
114    let mut challenge_sampler =
115        |size| transcript_writer.get_challenges("Verifier challenges", size);
116    let instantiated_circuit = provable_circuit
117        .get_gkr_circuit_description_ref()
118        .instantiate(provable_circuit.get_inputs_ref(), &mut challenge_sampler);
119
120    let InstantiatedCircuit {
121        input_layers,
122        mut output_layers,
123        layers,
124        fiat_shamir_challenges: _fiat_shamir_challenges,
125        mut layer_map,
126    } = instantiated_circuit;
127
128    // Maps a `LayerId` to a collection of claims made on that layer.
129    let mut claim_tracker = ClaimTracker::new();
130
131    // --------- STAGE 1: Output Claim Generation ---------
132    let claims_timer = start_timer!(|| "Output claims generation");
133    let output_claims_span = span!(Level::DEBUG, "output_claims_span").entered();
134
135    // Go through circuit output layers and grab claims on each.
136    for output in output_layers.iter_mut() {
137        let layer_id = output.layer_id();
138        info!("Output Layer: {layer_id:?}");
139
140        match output.get_mle() {
141            MleEnum::Dense(_) => {
142                panic!("We don't support DenseMLE as output layers for now")
143            }
144            // Just write a single zero into the transcript since the counts (layer size) are already included in the circuit description
145            MleEnum::Zero(_) => {
146                transcript_writer.append_elements("Output layer MLE evals", &[F::ZERO])
147            }
148        };
149
150        let challenges = transcript_writer
151            .get_challenges("Challenge on the output layer", output.num_free_vars());
152        output.fix_layer(&challenges)?;
153
154        let claim = output.get_claim()?;
155        claim_tracker.insert(claim.get_to_layer_id(), claim);
156    }
157
158    end_timer!(claims_timer);
159    output_claims_span.exit();
160
161    // --------- STAGE 2: Prove Intermediate Layers ---------
162    let intermediate_layers_timer = start_timer!(|| "ALL intermediate layers proof generation");
163    let all_layers_sumcheck_proving_span =
164        span!(Level::DEBUG, "all_layers_sumcheck_proving_span").entered();
165
166    // Collects all the prover messages for sumchecking over each layer, as
167    // well as all the prover messages for claim aggregation at the
168    // beginning of proving each layer.
169    for mut layer in layers.layers.into_iter().rev() {
170        let layer_id = layer.layer_id();
171        let layer_timer = start_timer!(|| format!("Generating proof for layer {layer_id:?}"));
172        info!("Proving Intermediate Layer: {layer_id:?}");
173
174        info!("Starting claim aggregation...");
175
176        let output_mles_from_layer = layer_map.remove(&layer_id).unwrap();
177        let layer_claims = claim_tracker.get(layer_id).unwrap();
178
179        // We always want to perform interpolative claim aggregation on MatMult layers.
180        if let LayerEnum::MatMult(_) = layer {
181            let claim_aggr_timer =
182                start_timer!(|| format!("Claim aggregation for layer {layer_id:?}"));
183            let layer_claim =
184                prover_aggregate_claims(layer_claims, output_mles_from_layer, transcript_writer)?;
185            end_timer!(claim_aggr_timer);
186
187            info!("Prove sumcheck message");
188            let sumcheck_msg_timer = start_timer!(|| format!(
189                "Compute sumcheck message for layer {:?}",
190                layer.layer_id()
191            ));
192
193            // Compute all sumcheck messages across this particular layer.
194            layer.prove(&[&layer_claim], transcript_writer)?;
195
196            end_timer!(sumcheck_msg_timer);
197        }
198        // Otherwise, we perform claim aggregation specified by the claim agg strategy.
199        else {
200            match global_claim_agg_strategy() {
201                ClaimAggregationStrategy::Interpolative => {
202                    let claim_aggr_timer =
203                        start_timer!(|| format!("Claim aggregation for layer {layer_id:?}"));
204                    let layer_claim = prover_aggregate_claims(
205                        layer_claims,
206                        output_mles_from_layer,
207                        transcript_writer,
208                    )?;
209                    end_timer!(claim_aggr_timer);
210
211                    info!("Prove sumcheck message");
212                    let sumcheck_msg_timer = start_timer!(|| format!(
213                        "Compute sumcheck message for layer {:?}",
214                        layer.layer_id()
215                    ));
216
217                    // Compute all sumcheck messages across this particular layer.
218                    layer.prove(&[&layer_claim], transcript_writer)?;
219
220                    end_timer!(sumcheck_msg_timer);
221                }
222                ClaimAggregationStrategy::RLC => {
223                    let sumcheck_msg_timer = start_timer!(|| format!(
224                        "Compute sumcheck message for layer {:?}",
225                        layer.layer_id()
226                    ));
227
228                    layer.prove(
229                        &layer_claims
230                            .iter()
231                            .map(|claim| claim.get_raw_claim())
232                            .collect_vec(),
233                        transcript_writer,
234                    )?;
235                    end_timer!(sumcheck_msg_timer);
236                }
237            }
238        }
239
240        for claim in layer.get_claims()? {
241            claim_tracker.insert(claim.get_to_layer_id(), claim);
242        }
243
244        end_timer!(layer_timer);
245    }
246
247    end_timer!(intermediate_layers_timer);
248    all_layers_sumcheck_proving_span.exit();
249
250    let input_layer_claims = input_layers
251        .iter()
252        .filter_map(|input_layer| claim_tracker.get(input_layer.layer_id))
253        .flatten()
254        .cloned()
255        .collect_vec();
256
257    Ok(input_layer_claims)
258}
259
260/// The complete description of a layered circuit whose output validity can be
261/// proven against a set of committed inputs.
262#[derive(Debug, Serialize, Deserialize, Hash, Clone)]
263#[serde(bound = "F: Field")]
264pub struct GKRCircuitDescription<F: Field> {
265    /// The circuit descriptions of the input layers.
266    pub input_layers: Vec<InputLayerDescription>,
267    /// The circuit descriptions of the verifier challengs
268    pub fiat_shamir_challenges: Vec<FiatShamirChallengeDescription<F>>,
269    /// The circuit descriptions of the intermediate layers.
270    pub intermediate_layers: Vec<LayerDescriptionEnum<F>>,
271    /// The circuit desriptions of the output layers.
272    pub output_layers: Vec<OutputLayerDescription<F>>,
273}
274
275impl<F: Field> GKRCircuitDescription<F> {
276    /// Label the MLE indices contained within a circuit description, starting
277    /// each layer with the start_index.
278    pub fn index_mle_indices(&mut self, start_index: usize) {
279        let GKRCircuitDescription {
280            input_layers: _,
281            fiat_shamir_challenges: _,
282            intermediate_layers,
283            output_layers,
284        } = self;
285        intermediate_layers
286            .iter_mut()
287            .for_each(|intermediate_layer| {
288                intermediate_layer.index_mle_indices(start_index);
289            });
290        output_layers.iter_mut().for_each(|output_layer| {
291            output_layer.index_mle_indices(start_index);
292        })
293    }
294
295    /// Returns an [InstantiatedCircuit] by populating the [GKRCircuitDescription] with data.
296    /// Assumes that the input data has already been added to the transcript.
297    ///
298    /// # Arguments:
299    /// * `input_data`: a [HashMap] mapping layer ids to the MLEs.
300    /// * `challenge_sampler`: a closure that takes a string and a usize and returns that many field
301    ///   elements; should be a wrapper of an instance method of the appropriate transcript.
302    pub fn instantiate(
303        &self,
304        input_data: &HashMap<LayerId, MultilinearExtension<F>>,
305        challenge_sampler: &mut impl FnMut(usize) -> Vec<F>,
306    ) -> InstantiatedCircuit<F> {
307        let GKRCircuitDescription {
308            input_layers: input_layer_descriptions,
309            fiat_shamir_challenges: fiat_shamir_challenge_descriptions,
310            intermediate_layers: intermediate_layer_descriptions,
311            output_layers: output_layer_descriptions,
312        } = self;
313
314        // Create a map that maps layer ID to a set of MLE descriptions that are
315        // expected to be compiled from its output. For example, if we have a
316        // layer whose first "half" (when MSB is 0) is used in a future layer,
317        // and its second half is also used in a future layer, we would expect
318        // both of these to be represented as MLE descriptions in the HashSet
319        // associated with this layer with the appropriate prefix bits.
320        let mut mle_claim_map = HashMap::<LayerId, HashSet<&MleDescription<F>>>::new();
321        // Do a forward pass through all of the intermediate layer descriptions
322        // and look into the "future" to see which parts of each layer are
323        // required for future layers.
324        intermediate_layer_descriptions
325            .iter()
326            .for_each(|intermediate_layer| {
327                let layer_source_circuit_mles = intermediate_layer.get_circuit_mles();
328                layer_source_circuit_mles
329                    .into_iter()
330                    .for_each(|circuit_mle| {
331                        let layer_id = circuit_mle.layer_id();
332                        mle_claim_map
333                            .entry(layer_id)
334                            .or_default()
335                            .insert(circuit_mle);
336                    })
337            });
338
339        // Do a forward pass through all of the intermediate layer descriptions
340        // and look into the "future" to see which parts of each layer are
341        // required for output layers.
342        output_layer_descriptions.iter().for_each(|output_layer| {
343            let layer_source_mle = &output_layer.mle;
344            let layer_id = layer_source_mle.layer_id();
345            mle_claim_map
346                .entry(layer_id)
347                .or_default()
348                .insert(&output_layer.mle);
349        });
350
351        // Step 1: populate the circuit map with all of the data necessary in
352        // order to instantiate the circuit.
353        let mut circuit_map = CircuitEvalMap::new();
354        let mut prover_input_layers: Vec<InputLayer<F>> = Vec::new();
355        let mut fiat_shamir_challenges = Vec::new();
356        // Step 1a: populate the circuit map by compiling the necessary data
357        // outputs for each of the input layers, while writing the commitments
358        // to them into the transcript.
359        input_layer_descriptions
360            .iter()
361            .for_each(|input_layer_description| {
362                let input_layer_id = input_layer_description.layer_id;
363                let combined_mle = input_data.get(&input_layer_id).unwrap();
364                let mle_outputs_necessary = mle_claim_map.get(&input_layer_id).unwrap();
365                // Compute all data outputs necessary for future layers for each
366                // input layer.
367                mle_outputs_necessary.iter().for_each(|mle_output| {
368                    let prefix_bits = mle_output.prefix_bits();
369                    let output = filter_bookkeeping_table(combined_mle, &prefix_bits);
370                    circuit_map.add_node(CircuitLocation::new(input_layer_id, prefix_bits), output);
371                });
372                let prover_input_layer = InputLayer {
373                    mle: combined_mle.clone(),
374                    layer_id: input_layer_id,
375                };
376                prover_input_layers.push(prover_input_layer);
377            });
378        // Step 1b: for each of the fiat shamir challenges, use the transcript
379        // in order to get the challenges and fill the layer.
380        fiat_shamir_challenge_descriptions
381            .iter()
382            .for_each(|fiat_shamir_challenge_description| {
383                let fiat_shamir_challenge_mle = MultilinearExtension::new(challenge_sampler(
384                    1 << fiat_shamir_challenge_description.num_bits,
385                ));
386                circuit_map.add_node(
387                    CircuitLocation::new(fiat_shamir_challenge_description.layer_id(), vec![]),
388                    fiat_shamir_challenge_mle.clone(),
389                );
390                fiat_shamir_challenges.push(FiatShamirChallenge {
391                    mle: fiat_shamir_challenge_mle,
392                    layer_id: fiat_shamir_challenge_description.layer_id(),
393                });
394            });
395
396        // Step 1c: Compute the data outputs, using the map from Layer ID to
397        // which Circuit MLEs are necessary to compile for this layer, for each
398        // of the intermediate layers.
399        intermediate_layer_descriptions
400            .iter()
401            .for_each(|intermediate_layer_description| {
402                let mle_outputs_necessary = mle_claim_map
403                    .get(&intermediate_layer_description.layer_id())
404                    .unwrap();
405                intermediate_layer_description
406                    .compute_data_outputs(mle_outputs_necessary, &mut circuit_map);
407            });
408
409        // Step 2: Using the fully populated circuit map, convert each of the
410        // layer descriptions into concretized layers. Step 2a: Concretize the
411        // intermediate layer descriptions.
412        let mut prover_intermediate_layers: Vec<LayerEnum<F>> =
413            Vec::with_capacity(intermediate_layer_descriptions.len());
414        intermediate_layer_descriptions
415            .iter()
416            .for_each(|intermediate_layer_description| {
417                let prover_intermediate_layer =
418                    intermediate_layer_description.convert_into_prover_layer(&circuit_map);
419                prover_intermediate_layers.push(prover_intermediate_layer)
420            });
421
422        // Step 2b: Concretize the output layer descriptions.
423        let mut prover_output_layers: Vec<OutputLayer<F>> = Vec::new();
424        output_layer_descriptions
425            .iter()
426            .for_each(|output_layer_description| {
427                let prover_output_layer =
428                    output_layer_description.into_prover_output_layer(&circuit_map);
429                prover_output_layers.push(prover_output_layer)
430            });
431
432        InstantiatedCircuit {
433            input_layers: prover_input_layers,
434            fiat_shamir_challenges,
435            layers: Layers::new_with_layers(prover_intermediate_layers),
436            output_layers: prover_output_layers,
437            layer_map: circuit_map.convert_to_layer_map(),
438        }
439    }
440
441    /// Verifies a GKR circuit proof produced by the `prove` method.
442    /// Assumes that the circuit description, all inputs and input commitments have already been added to transcript.
443    /// # Arguments
444    /// * `transcript_reader`: servers as the proof.
445    /// Returns claims on the input layers.
446    #[instrument(skip_all, err)]
447    pub fn verify(
448        &self,
449        transcript_reader: &mut impl VerifierTranscript<F>,
450    ) -> Result<Vec<Claim<F>>> {
451        // Get the verifier challenges from the transcript.
452        let fiat_shamir_challenges: Vec<FiatShamirChallenge<F>> = self
453            .fiat_shamir_challenges
454            .iter()
455            .map(|fs_desc| {
456                let values = transcript_reader
457                    .get_challenges("Verifier challenges", 1 << fs_desc.num_bits)
458                    .unwrap();
459                fs_desc.instantiate(values)
460            })
461            .collect();
462
463        // Claim tracker to keep track of GKR-style claims across all layers.
464        let mut claim_tracker = ClaimTracker::new();
465
466        // --------- Output Claim Generation ---------
467        let claims_timer = start_timer!(|| "Output claims generation");
468        let verifier_output_claims_span =
469            span!(Level::DEBUG, "verifier_output_claims_span").entered();
470
471        for circuit_output_layer in self.output_layers.iter() {
472            let layer_id = circuit_output_layer.layer_id();
473            info!("Verifying Output Layer: {layer_id:?}");
474
475            let verifier_output_layer = circuit_output_layer
476                .retrieve_mle_from_transcript_and_fix_layer(transcript_reader)?;
477
478            let claim = verifier_output_layer.get_claim()?;
479            claim_tracker.insert(claim.get_to_layer_id(), claim);
480        }
481
482        end_timer!(claims_timer);
483        verifier_output_claims_span.exit();
484
485        // --------- Verify Intermediate Layers ---------
486        let intermediate_layers_timer =
487            start_timer!(|| "ALL intermediate layers proof verification");
488
489        for layer in self.intermediate_layers.iter().rev() {
490            let layer_id = layer.layer_id();
491
492            info!("Intermediate Layer: {layer_id:?}");
493            let layer_timer = start_timer!(|| format!("Proof verification for layer {layer_id:?}"));
494
495            let layer_claims = claim_tracker.remove(layer_id).unwrap();
496
497            let verifier_layer = match global_claim_agg_strategy() {
498                ClaimAggregationStrategy::Interpolative => {
499                    let claim_aggr_timer =
500                        start_timer!(|| format!("Claim aggregation for layer {layer_id:?}"));
501                    let prev_claim = verifier_aggregate_claims(&layer_claims, transcript_reader)?;
502                    debug!("Aggregated claim: {:#?}", prev_claim);
503                    end_timer!(claim_aggr_timer);
504
505                    info!("Prove sumcheck message");
506                    let sumcheck_msg_timer = start_timer!(|| format!(
507                        "Compute sumcheck message for layer {:?}",
508                        layer.layer_id()
509                    ));
510
511                    // Performs the actual sumcheck verification step.
512                    let verifier_layer: VerifierLayerEnum<F> =
513                        layer.verify_rounds(&[&prev_claim], transcript_reader)?;
514
515                    end_timer!(sumcheck_msg_timer);
516
517                    verifier_layer
518                }
519                ClaimAggregationStrategy::RLC => {
520                    let sumcheck_msg_timer = start_timer!(|| format!(
521                        "Compute sumcheck message for layer {:?}",
522                        layer.layer_id()
523                    ));
524
525                    let verifier_layer = layer.verify_rounds(
526                        &layer_claims
527                            .iter()
528                            .map(|claim| claim.get_raw_claim())
529                            .collect_vec(),
530                        transcript_reader,
531                    )?;
532                    end_timer!(sumcheck_msg_timer);
533
534                    verifier_layer
535                }
536            };
537
538            for claim in verifier_layer.get_claims()? {
539                claim_tracker.insert(claim.get_to_layer_id(), claim);
540            }
541
542            end_timer!(layer_timer);
543        }
544
545        end_timer!(intermediate_layers_timer);
546
547        // --------- Verify claims on the verifier challenges ---------
548        let fiat_shamir_challenges_timer = start_timer!(|| "Verifier challenges proof generation");
549        for fiat_shamir_challenge in fiat_shamir_challenges {
550            if let Some(claims) = claim_tracker.remove(fiat_shamir_challenge.layer_id()) {
551                claims.iter().for_each(|claim| {
552                    verify_claim(&fiat_shamir_challenge.mle.to_vec(), claim.get_raw_claim());
553                });
554            } else {
555                return Err(anyhow!(GKRError::NoClaimsForLayer(
556                    fiat_shamir_challenge.layer_id()
557                )));
558            }
559        }
560        end_timer!(fiat_shamir_challenges_timer);
561
562        let input_layer_claims = self
563            .input_layers
564            .iter()
565            .flat_map(|input_layer| claim_tracker.remove(input_layer.layer_id).unwrap())
566            .collect_vec();
567
568        // Verify that there are no claims remaining in the claim tracker.
569        if !claim_tracker.is_empty() {
570            return Err(anyhow!(GKRError::ClaimTrackerNotEmpty));
571        }
572
573        Ok(input_layer_claims)
574    }
575}