remainder/claims/
claim_aggregation.rs

1//! A set of functions providing the interface for perfoming claim aggregation.
2
3use std::cmp::max;
4
5use ark_std::{cfg_into_iter, end_timer, start_timer};
6use shared_types::{
7    config::global_config::global_prover_claim_agg_constant_column_optimization,
8    transcript::{ProverTranscript, VerifierTranscript},
9    Field,
10};
11use tracing::{debug, info};
12
13use crate::{
14    claims::{Claim, RawClaim},
15    layer::combine_mles::{
16        combine_mles_with_aggregate, get_indexed_layer_mles_to_combine, pre_fix_mles,
17    },
18    mle::dense::DenseMle,
19    sumcheck::evaluate_at_a_point,
20};
21
22use super::claim_group::ClaimGroup;
23
24use anyhow::{Ok, Result};
25
26#[cfg(feature = "parallel")]
27use rayon::iter::{IntoParallelIterator, ParallelIterator};
28
29/// Performs claim aggregation on the prover side and, if successful, returns a
30/// single, raw aggregated claim.
31///
32/// # Parameters
33///
34/// * `claims`: a slice of claims, all on the same layer (same `to_layer_id`),
35///   to be aggregated into one.
36/// * `layer`: the GKR layer that this claim group is making claims on; the
37///   layer whose ID matches the `to_layer_id` field of all elements in
38///   `claims`.
39/// * `output_mles_from_layer`: the compiled bookkeeping tables that result from
40///   this layer, in order to aggregate into the layerwise bookkeeping table.
41/// * `transcript_writer`: is used to post the interpolation polynomial
42///   evaluations and generate challenges.
43pub fn prover_aggregate_claims<F: Field>(
44    claims: &[Claim<F>],
45    output_mles_from_layer: Vec<DenseMle<F>>,
46    transcript_writer: &mut impl ProverTranscript<F>,
47) -> Result<RawClaim<F>> {
48    let num_claims = claims.len();
49    debug_assert!(num_claims > 0);
50    info!("High-level claim aggregation on {num_claims} claims.");
51
52    let claim_preproc_timer = start_timer!(|| "Claim preprocessing".to_string());
53
54    let fixed_output_mles = get_indexed_layer_mles_to_combine(output_mles_from_layer);
55
56    let claim_groups = ClaimGroup::form_claim_groups(claims.to_vec());
57
58    debug!("Grouped claims for aggregation: ");
59    for group in &claim_groups {
60        debug!("GROUP: {:#?}", group.get_raw_claims());
61    }
62
63    end_timer!(claim_preproc_timer);
64    let intermediate_timer = start_timer!(|| "Intermediate claim aggregation.".to_string());
65
66    let intermediate_claims = claim_groups
67        .into_iter()
68        .map(|claim_group| claim_group.prover_aggregate(&fixed_output_mles, transcript_writer))
69        .collect::<Result<Vec<_>>>()?;
70
71    end_timer!(intermediate_timer);
72    let final_timer = start_timer!(|| "Final stage aggregation.".to_string());
73
74    let intermediate_claims_group = ClaimGroup::new_from_raw_claims(intermediate_claims).unwrap();
75
76    // Finally, aggregate all intermediate claims.
77    let claim =
78        intermediate_claims_group.prover_aggregate(&fixed_output_mles, transcript_writer)?;
79
80    end_timer!(final_timer);
81    Ok(claim)
82}
83
84/// Returns an upper bound on the number of evaluations needed to represent the
85/// polynomial `P(x) = W(l(x))` where `W : F^n -> F` is a multilinear polynomial
86/// on `n` variables and `l : F -> F^n` is such that:
87///  * `l(0) = claim_vecs[0]`,
88///  * `l(1) = claim_vecs[1]`,
89///  * ...,
90///  * `l(m-1) = claim_vecs[m-1]`.
91///
92/// It is guaranteed that the returned value is at least `num_claims =
93/// claim_vecs.len()`.
94///
95/// # Panics
96///  if `claim_vecs` is empty.
97pub fn get_num_wlx_evaluations<F: Field>(
98    claim_vecs: &[Vec<F>],
99) -> (usize, Option<Vec<usize>>, Vec<usize>) {
100    let num_claims = claim_vecs.len();
101    let num_vars = claim_vecs[0].len();
102
103    debug!("Smart num_evals");
104    let mut num_constant_columns = num_vars as i64;
105    let mut common_idx = vec![];
106    let mut non_common_idx = vec![];
107    #[allow(clippy::needless_range_loop)]
108    for j in 0..num_vars {
109        let mut degree_reduced = true;
110        for i in 1..num_claims {
111            if claim_vecs[i][j] != claim_vecs[i - 1][j] {
112                num_constant_columns -= 1;
113                degree_reduced = false;
114                non_common_idx.push(j);
115                break;
116            }
117        }
118        if degree_reduced {
119            common_idx.push(j);
120        }
121    }
122    assert!(num_constant_columns >= 0);
123    debug!("degree_reduction = {}", num_constant_columns);
124
125    // Evaluate the P(x) := W(l(x)) polynomial at deg(P) + 1
126    // points. W : F^n -> F is a multi-linear polynomial on
127    // `num_vars` variables and l : F -> F^n is a canonical
128    // polynomial passing through `num_claims` points so its degree is
129    // at most `num_claims - 1`. This imposes an upper
130    // bound of `num_vars * (num_claims - 1)` to the degree of P.
131    // However, the actual degree of P might be lower.
132    // For any coordinate `i` such that all claims agree
133    // on that coordinate, we can quickly deduce that `l_i(x)` is a
134    // constant polynomial of degree zero instead of `num_claims -
135    // 1` which brings down the total degree by the same amount.
136    let num_evals =
137        (num_vars) * (num_claims - 1) + 1 - (num_constant_columns as usize) * (num_claims - 1);
138    debug!("num_evals originally = {}", num_evals);
139    (max(num_evals, num_claims), Some(common_idx), non_common_idx)
140}
141
142/// Returns a vector of evaluations of this layer's MLE on a sequence of
143/// points computed by interpolating a polynomial that passes through the
144/// points of `claims_vecs`.
145pub fn get_wlx_evaluations<F: Field>(
146    claim_vecs: &[Vec<F>],
147    claimed_vals: &[F],
148    claim_mles: Vec<DenseMle<F>>,
149    num_claims: usize,
150    num_idx: usize,
151) -> Result<Vec<F>> {
152    // get the number of evaluations
153
154    let (num_evals, common_idx) = if global_prover_claim_agg_constant_column_optimization() {
155        let (num_evals, common_idx, _) = get_num_wlx_evaluations(claim_vecs);
156        (num_evals, common_idx)
157    } else {
158        (((num_claims - 1) * num_idx) + 1, None)
159    };
160
161    let mut claim_mles = claim_mles;
162
163    if let Some(common_idx) = common_idx {
164        pre_fix_mles(&mut claim_mles, &claim_vecs[0], common_idx);
165    }
166
167    // we already have the first #claims evaluations, get the next num_evals - #claims evaluations
168    let next_evals: Vec<F> = cfg_into_iter!(num_claims..num_evals)
169        .map(|idx| {
170            // get the challenge l(idx)
171            let new_chal: Vec<F> = cfg_into_iter!(0..num_idx)
172                .map(|claim_idx| {
173                    let evals: Vec<F> = cfg_into_iter!(claim_vecs)
174                        .map(|claim| claim[claim_idx])
175                        .collect();
176                    evaluate_at_a_point(&evals, F::from(idx as u64)).unwrap()
177                })
178                .collect();
179
180            let wlx_eval_on_mle = combine_mles_with_aggregate(&claim_mles, &new_chal);
181            wlx_eval_on_mle.unwrap()
182        })
183        .collect();
184
185    // concat this with the first k evaluations from the claims to
186    // get num_evals evaluations
187    let mut wlx_evals = claimed_vals.to_vec();
188    wlx_evals.extend(&next_evals);
189    Ok(wlx_evals)
190}
191
192/// Performs claim aggregation on the verifier side.
193/// * `claims`: a group of claims, all on the same layer (same `to_layer_id`),
194///   to be aggregated into one.
195/// * `transcript_reader`: is used to retrieve the wlx evaluations and generate
196///   challenges.
197///
198/// # Returns
199///
200/// If successful, returns a single aggregated claim.
201pub fn verifier_aggregate_claims<F: Field>(
202    claims: &[Claim<F>],
203    transcript_reader: &mut impl VerifierTranscript<F>,
204) -> Result<RawClaim<F>> {
205    let num_claims = claims.len();
206    debug_assert!(num_claims > 0);
207    info!("High-level claim aggregation on {num_claims} claims.");
208
209    let claim_preproc_timer = start_timer!(|| "Claim preprocessing".to_string());
210
211    let claim_groups = ClaimGroup::form_claim_groups(claims.to_vec());
212
213    debug!("Grouped claims for aggregation: ");
214    for group in &claim_groups {
215        debug!("GROUP:");
216        for claim in group.get_raw_claims() {
217            debug!("{:#?}", claim);
218        }
219    }
220
221    end_timer!(claim_preproc_timer);
222    let intermediate_timer = start_timer!(|| "Intermediate claim aggregation.".to_string());
223
224    let intermediate_claims = claim_groups
225        .into_iter()
226        .map(|claim_group| claim_group.verifier_aggregate(transcript_reader))
227        .collect::<Result<Vec<_>>>()?;
228
229    end_timer!(intermediate_timer);
230    let final_timer = start_timer!(|| "Final stage aggregation.".to_string());
231
232    // Finally, aggregate all intermediate claims.
233    let intermediate_claim_group = ClaimGroup::new_from_raw_claims(intermediate_claims).unwrap();
234    let claim = intermediate_claim_group.verifier_aggregate(transcript_reader)?;
235
236    end_timer!(final_timer);
237    Ok(claim)
238}