remainder/claims/
claim_aggregation.rs1use std::cmp::max;
4
5use ark_std::{cfg_into_iter, end_timer, start_timer};
6use shared_types::{
7 config::global_config::global_prover_claim_agg_constant_column_optimization,
8 transcript::{ProverTranscript, VerifierTranscript},
9 Field,
10};
11use tracing::{debug, info};
12
13use crate::{
14 claims::{Claim, RawClaim},
15 layer::combine_mles::{
16 combine_mles_with_aggregate, get_indexed_layer_mles_to_combine, pre_fix_mles,
17 },
18 mle::dense::DenseMle,
19 sumcheck::evaluate_at_a_point,
20};
21
22use super::claim_group::ClaimGroup;
23
24use anyhow::{Ok, Result};
25
26#[cfg(feature = "parallel")]
27use rayon::iter::{IntoParallelIterator, ParallelIterator};
28
29pub fn prover_aggregate_claims<F: Field>(
44 claims: &[Claim<F>],
45 output_mles_from_layer: Vec<DenseMle<F>>,
46 transcript_writer: &mut impl ProverTranscript<F>,
47) -> Result<RawClaim<F>> {
48 let num_claims = claims.len();
49 debug_assert!(num_claims > 0);
50 info!("High-level claim aggregation on {num_claims} claims.");
51
52 let claim_preproc_timer = start_timer!(|| "Claim preprocessing".to_string());
53
54 let fixed_output_mles = get_indexed_layer_mles_to_combine(output_mles_from_layer);
55
56 let claim_groups = ClaimGroup::form_claim_groups(claims.to_vec());
57
58 debug!("Grouped claims for aggregation: ");
59 for group in &claim_groups {
60 debug!("GROUP: {:#?}", group.get_raw_claims());
61 }
62
63 end_timer!(claim_preproc_timer);
64 let intermediate_timer = start_timer!(|| "Intermediate claim aggregation.".to_string());
65
66 let intermediate_claims = claim_groups
67 .into_iter()
68 .map(|claim_group| claim_group.prover_aggregate(&fixed_output_mles, transcript_writer))
69 .collect::<Result<Vec<_>>>()?;
70
71 end_timer!(intermediate_timer);
72 let final_timer = start_timer!(|| "Final stage aggregation.".to_string());
73
74 let intermediate_claims_group = ClaimGroup::new_from_raw_claims(intermediate_claims).unwrap();
75
76 let claim =
78 intermediate_claims_group.prover_aggregate(&fixed_output_mles, transcript_writer)?;
79
80 end_timer!(final_timer);
81 Ok(claim)
82}
83
84pub fn get_num_wlx_evaluations<F: Field>(
98 claim_vecs: &[Vec<F>],
99) -> (usize, Option<Vec<usize>>, Vec<usize>) {
100 let num_claims = claim_vecs.len();
101 let num_vars = claim_vecs[0].len();
102
103 debug!("Smart num_evals");
104 let mut num_constant_columns = num_vars as i64;
105 let mut common_idx = vec![];
106 let mut non_common_idx = vec![];
107 #[allow(clippy::needless_range_loop)]
108 for j in 0..num_vars {
109 let mut degree_reduced = true;
110 for i in 1..num_claims {
111 if claim_vecs[i][j] != claim_vecs[i - 1][j] {
112 num_constant_columns -= 1;
113 degree_reduced = false;
114 non_common_idx.push(j);
115 break;
116 }
117 }
118 if degree_reduced {
119 common_idx.push(j);
120 }
121 }
122 assert!(num_constant_columns >= 0);
123 debug!("degree_reduction = {}", num_constant_columns);
124
125 let num_evals =
137 (num_vars) * (num_claims - 1) + 1 - (num_constant_columns as usize) * (num_claims - 1);
138 debug!("num_evals originally = {}", num_evals);
139 (max(num_evals, num_claims), Some(common_idx), non_common_idx)
140}
141
142pub fn get_wlx_evaluations<F: Field>(
146 claim_vecs: &[Vec<F>],
147 claimed_vals: &[F],
148 claim_mles: Vec<DenseMle<F>>,
149 num_claims: usize,
150 num_idx: usize,
151) -> Result<Vec<F>> {
152 let (num_evals, common_idx) = if global_prover_claim_agg_constant_column_optimization() {
155 let (num_evals, common_idx, _) = get_num_wlx_evaluations(claim_vecs);
156 (num_evals, common_idx)
157 } else {
158 (((num_claims - 1) * num_idx) + 1, None)
159 };
160
161 let mut claim_mles = claim_mles;
162
163 if let Some(common_idx) = common_idx {
164 pre_fix_mles(&mut claim_mles, &claim_vecs[0], common_idx);
165 }
166
167 let next_evals: Vec<F> = cfg_into_iter!(num_claims..num_evals)
169 .map(|idx| {
170 let new_chal: Vec<F> = cfg_into_iter!(0..num_idx)
172 .map(|claim_idx| {
173 let evals: Vec<F> = cfg_into_iter!(claim_vecs)
174 .map(|claim| claim[claim_idx])
175 .collect();
176 evaluate_at_a_point(&evals, F::from(idx as u64)).unwrap()
177 })
178 .collect();
179
180 let wlx_eval_on_mle = combine_mles_with_aggregate(&claim_mles, &new_chal);
181 wlx_eval_on_mle.unwrap()
182 })
183 .collect();
184
185 let mut wlx_evals = claimed_vals.to_vec();
188 wlx_evals.extend(&next_evals);
189 Ok(wlx_evals)
190}
191
192pub fn verifier_aggregate_claims<F: Field>(
202 claims: &[Claim<F>],
203 transcript_reader: &mut impl VerifierTranscript<F>,
204) -> Result<RawClaim<F>> {
205 let num_claims = claims.len();
206 debug_assert!(num_claims > 0);
207 info!("High-level claim aggregation on {num_claims} claims.");
208
209 let claim_preproc_timer = start_timer!(|| "Claim preprocessing".to_string());
210
211 let claim_groups = ClaimGroup::form_claim_groups(claims.to_vec());
212
213 debug!("Grouped claims for aggregation: ");
214 for group in &claim_groups {
215 debug!("GROUP:");
216 for claim in group.get_raw_claims() {
217 debug!("{:#?}", claim);
218 }
219 }
220
221 end_timer!(claim_preproc_timer);
222 let intermediate_timer = start_timer!(|| "Intermediate claim aggregation.".to_string());
223
224 let intermediate_claims = claim_groups
225 .into_iter()
226 .map(|claim_group| claim_group.verifier_aggregate(transcript_reader))
227 .collect::<Result<Vec<_>>>()?;
228
229 end_timer!(intermediate_timer);
230 let final_timer = start_timer!(|| "Final stage aggregation.".to_string());
231
232 let intermediate_claim_group = ClaimGroup::new_from_raw_claims(intermediate_claims).unwrap();
234 let claim = intermediate_claim_group.verifier_aggregate(transcript_reader)?;
235
236 end_timer!(final_timer);
237 Ok(claim)
238}