remainder/claims/claim_group.rs
1use ark_std::cfg_into_iter;
2use itertools::Itertools;
3use shared_types::{
4 config::global_config::global_verifier_claim_agg_constant_column_optimization,
5 transcript::{ProverTranscript, VerifierTranscript},
6 Field,
7};
8use tracing::{debug, info};
9
10use crate::{
11 claims::{
12 claim_aggregation::{get_num_wlx_evaluations, get_wlx_evaluations},
13 ClaimError,
14 },
15 mle::dense::DenseMle,
16 sumcheck::evaluate_at_a_point,
17};
18
19use super::{Claim, RawClaim};
20
21use anyhow::{anyhow, Ok, Result};
22
23#[cfg(feature = "parallel")]
24use rayon::iter::{IntoParallelIterator, ParallelIterator};
25
26/// Stores a collection of claims and provides an API running claim aggregation
27/// algorithms on them.
28/// The current implementation introduces up to 3x memory redundancy in order to
29/// achieve faster access times.
30/// Invariant: All claims are on the same number of variables.
31#[derive(Clone, Debug)]
32pub struct ClaimGroup<F: Field> {
33 /// A vector of raw claims in F^n.
34 pub claims: Vec<RawClaim<F>>,
35
36 /// A 2D matrix with the claim's points as its rows.
37 claim_points_matrix: Vec<Vec<F>>,
38
39 /// The points in `claims` is effectively a matrix of elements in F. We also
40 /// store the transpose of this matrix for convenient access.
41 claim_points_transpose: Vec<Vec<F>>,
42
43 /// A vector of `self.get_num_claims()` elements. For each claim i,
44 /// `result_vector[i]` stores the expected result of the i-th claim.
45 result_vector: Vec<F>,
46}
47
48impl<F: Field> ClaimGroup<F> {
49 /// Generates a [ClaimGroup] from a collection of [Claim]s.
50 /// All claims agree on the `Claim::to_layer_id` field and returns
51 /// [ClaimError::LayerIdMismatch] otherwise. Returns
52 /// [ClaimError::NumVarsMismatch] if the collection of claims do not all
53 /// agree on the number of variables.
54 pub fn new(claims: Vec<Claim<F>>) -> Result<Self> {
55 let num_claims = claims.len();
56 if num_claims == 0 {
57 return Ok(Self {
58 claims: vec![],
59 claim_points_matrix: vec![],
60 claim_points_transpose: vec![],
61 result_vector: vec![],
62 });
63 }
64 // Check all claims match on the `to_layer_id` field.
65 let layer_id = claims[0].get_to_layer_id();
66 if !claims
67 .iter()
68 .all(|claim| claim.get_to_layer_id() == layer_id)
69 {
70 return Err(anyhow!(ClaimError::LayerIdMismatch));
71 }
72
73 Self::new_from_raw_claims(claims.into_iter().map(Into::into).collect())
74 }
75
76 /// Generates a new [ClaimGroup] from a collection of [RawClaim]s.
77 /// Returns [ClaimError::NumVarsMismatch] if the collection of claims
78 /// do not all agree on the number of variables.
79 pub fn new_from_raw_claims(claims: Vec<RawClaim<F>>) -> Result<Self> {
80 let num_claims = claims.len();
81
82 if num_claims == 0 {
83 return Ok(Self {
84 claims: vec![],
85 claim_points_matrix: vec![],
86 claim_points_transpose: vec![],
87 result_vector: vec![],
88 });
89 }
90
91 let num_vars = claims[0].get_num_vars();
92
93 // Check all claims match on the number of variables.
94 if !claims.iter().all(|claim| claim.get_num_vars() == num_vars) {
95 return Err(anyhow!(ClaimError::NumVarsMismatch));
96 }
97
98 // Populate the points_matrix
99 let points_matrix: Vec<_> = claims
100 .iter()
101 .map(|claim| -> Vec<F> { claim.get_point().to_vec() })
102 .collect();
103
104 // Compute the claim points transpose.
105 let claim_points_transpose: Vec<Vec<F>> = (0..num_vars)
106 .map(|j| (0..num_claims).map(|i| claims[i].get_point()[j]).collect())
107 .collect();
108
109 // Compute the result vector.
110 let result_vector: Vec<F> = (0..num_claims).map(|i| claims[i].get_eval()).collect();
111
112 Ok(Self {
113 claims,
114 claim_points_matrix: points_matrix,
115 claim_points_transpose,
116 result_vector,
117 })
118 }
119
120 /// Returns the number of claims stored in this group.
121 pub fn get_num_claims(&self) -> usize {
122 self.claims.len()
123 }
124
125 /// Returns true if the group contains no claims.
126 pub fn is_empty(&self) -> bool {
127 self.claims.is_empty()
128 }
129
130 /// Returns the number of indices of the claims stored.
131 /// Panics if no claims present.
132 pub fn get_num_vars(&self) -> usize {
133 self.claims[0].get_num_vars()
134 }
135
136 /// Returns a reference to a vector of `self.get_num_claims()` elements, the
137 /// j-th entry of which is the i-th coordinate of the j-th claim's point. In
138 /// other words, it returns the i-th column of the matrix containing the
139 /// claim points as its rows.
140 /// # Panics
141 /// When i is not in the range: 0 <= i < `self.get_num_vars()`.
142 pub fn get_points_column(&self, i: usize) -> &Vec<F> {
143 &self.claim_points_transpose[i]
144 }
145
146 /// Returns a reference to an "m x n" matrix where n = `self.get_num_vars()`
147 /// and m = `self.get_num_claims()` with the claim points as its rows.
148 pub fn get_claim_points_matrix(&self) -> &Vec<Vec<F>> {
149 &self.claim_points_matrix
150 }
151
152 /// Returns a reference to a vector with m = `self.get_num_claims()`
153 /// elements containing the results of all claims.
154 pub fn get_results(&self) -> &Vec<F> {
155 &self.result_vector
156 }
157
158 /// Returns a reference to the i-th claim.
159 pub fn get_raw_claim(&self, i: usize) -> &RawClaim<F> {
160 &self.claims[i]
161 }
162
163 /// Returns a reference to a vector of claims contained in this group.
164 pub fn get_raw_claims(&self) -> &[RawClaim<F>] {
165 &self.claims
166 }
167
168 /// Returns `claims` sorted by `from_layer_id` to prepare them for grouping.
169 /// Also performs claim de-duplication by eliminating copies of claims
170 /// on the same point.
171 fn preprocess_claims(mut claims: Vec<Claim<F>>) -> Vec<Claim<F>> {
172 // Sort claims on the `from_layer_id` field.
173 claims.sort_by(|claim1, claim2| {
174 claim1
175 .get_from_layer_id()
176 .partial_cmp(&claim2.get_from_layer_id())
177 .unwrap()
178 });
179
180 // Perform claim de-duplication
181 let claims = claims
182 .into_iter()
183 .unique_by(|c| c.get_point().to_vec())
184 .collect_vec();
185
186 claims
187 }
188
189 /// Partition `claims` into groups to be aggregated together.
190 pub fn form_claim_groups(claims: Vec<Claim<F>>) -> Vec<Self> {
191 // Sort claims by `from_layer_id` and remove duplicates.
192 let claims = Self::preprocess_claims(claims);
193
194 let num_claims = claims.len();
195 let mut claim_group_vec: Vec<Self> = vec![];
196
197 // Identify runs of claims with the same `from_layer_id` field.
198 let mut start_index = 0;
199 for idx in 1..num_claims {
200 if claims[idx].get_from_layer_id() != claims[idx - 1].get_from_layer_id() {
201 let end_index = idx;
202 claim_group_vec.push(Self::new(claims[start_index..end_index].to_vec()).unwrap());
203 start_index = idx;
204 }
205 }
206
207 // Process the last group.
208 let end_index = num_claims;
209 claim_group_vec.push(Self::new(claims[start_index..end_index].to_vec()).unwrap());
210
211 claim_group_vec
212 }
213
214 /// Computes the aggregated challenge point by interpolating a polynomial
215 /// passing through all the points in the claim group and then evaluating
216 /// it at `r_star`.
217 /// More precicely, if `self.claims` contains `m` points `[u_0, u_1, ...,
218 /// u_{m-1}]` where each `u_i \in F^n`, it computes a univariate polynomial
219 /// vector `l : F -> F^n` such that `l(0) = u_0, l(1) = u_1, ..., l(m-1) =
220 /// u_{m-1}` using Lagrange interpolation, then evaluates `l` on `r_star`
221 /// and returns it.
222 ///
223 /// # Requires
224 /// `self.claims_points` should be non-empty, otherwise a
225 /// [ClaimError::ClaimAggroError] is returned.
226 /// Using the ClaimGroup abstraction here is not ideal since we are only
227 /// operating on the points and not on the results. However, the ClaimGroup API
228 /// is convenient for accessing columns and makes the implementation more
229 /// readable. We should consider alternative designs.
230 fn compute_aggregated_challenges(&self, r_star: F) -> Result<Vec<F>> {
231 if self.is_empty() {
232 return Err(anyhow!(ClaimError::ClaimAggroError));
233 }
234
235 let num_vars = self.get_num_vars();
236
237 // Compute r = l(r*) by performing Lagrange interpolation on each coordinate
238 // using `evaluate_at_a_point`.
239 let r: Vec<F> = cfg_into_iter!(0..num_vars)
240 .map(|idx| {
241 let evals = self.get_points_column(idx);
242 // Interpolate the value l(r*) from the values
243 // l(0), l(1), ..., l(m-1) where m = # of claims.
244 evaluate_at_a_point(evals, r_star).unwrap()
245 })
246 .collect();
247
248 Ok(r)
249 }
250
251 /// Performs claim aggregation on the prover side for this claim group in a
252 /// single stage -- this is the standard "Thaler13" claim aggregation
253 /// without any heuristic optimizations.
254 ///
255 /// # Parameters
256 /// * `layer_mles`: the compiled bookkeeping tables from this layer, which
257 /// when aggregated appropriately with their prefix bits, make up the
258 /// layerwise bookkeeping table.
259 /// * `layer`: the layer whose output MLE is being made a claim on. Each of the
260 /// claims are aggregated into one claim, whose validity is reduced to the
261 /// validity of a claim in a future layer throught he sumcheck protocol.
262 /// * `transcript_writer`: is used to post wlx evaluations and generate
263 /// challenges.
264 ///
265 /// # Returns
266 ///
267 /// If successful, returns a single aggregated claim.
268 pub fn prover_aggregate(
269 &self,
270 layer_mles: &[DenseMle<F>],
271 transcript_writer: &mut impl ProverTranscript<F>,
272 ) -> Result<RawClaim<F>> {
273 let num_claims = self.get_num_claims();
274 debug_assert!(num_claims > 0);
275 info!("ClaimGroup aggregation on {num_claims} claims.");
276
277 // Do nothing if there is only one claim.
278 if num_claims == 1 {
279 debug!("Received 1 claim. Doing nothing.");
280 return Ok(self.claims[0].clone());
281 }
282 assert!(self.get_claim_points_matrix().len() > 1);
283
284 // Aggregate claims by performing the claim aggregation protocol.
285 // First compute V_i(l(x)).
286 let wlx_evaluations = get_wlx_evaluations(
287 self.get_claim_points_matrix(),
288 self.get_results(),
289 layer_mles.to_vec(),
290 num_claims,
291 self.get_num_vars(),
292 )
293 .unwrap();
294 let relevant_wlx_evaluations = wlx_evaluations[num_claims..].to_vec();
295
296 // Append evaluations to the transcript before sampling a challenge.
297 transcript_writer.append_elements(
298 "Claim aggregation interpolation polynomial evaluations",
299 &relevant_wlx_evaluations,
300 );
301
302 // Next, sample `r^\star` from the transcript.
303 let agg_chal = transcript_writer
304 .get_challenge("Challenge for claim aggregation interpolation polynomial");
305 debug!("Aggregate challenge: {:#?}", agg_chal);
306
307 let aggregated_challenges = self.compute_aggregated_challenges(agg_chal).unwrap();
308 let claimed_val = evaluate_at_a_point(&wlx_evaluations, agg_chal).unwrap();
309
310 debug!("Aggregating claims: {:#?}", self.get_raw_claims());
311
312 let claim = RawClaim::new(aggregated_challenges, claimed_val);
313 debug!("Low level aggregated claim:\n{:#?}", &claim);
314
315 Ok(claim)
316 }
317
318 /// Performs claim aggregation on the verifier side for this claim group in
319 /// a single stage -- this is the standard "Thaler13" claim aggregation
320 /// without any heuristic optimizations.
321 ///
322 /// # Parameters
323 /// * `transcript_reader`: is used to retrieve wlx evaluations and generate
324 /// challenges.
325 ///
326 /// # Returns
327 /// If successful, returns a single aggregated claim.
328 pub fn verifier_aggregate(
329 &self,
330 transcript_reader: &mut impl VerifierTranscript<F>,
331 ) -> Result<RawClaim<F>> {
332 let num_claims = self.get_num_claims();
333 debug_assert!(num_claims > 0);
334 info!("Low-level claim aggregation on {num_claims} claims.");
335
336 // Do nothing if there is only one claim.
337 if num_claims == 1 {
338 debug!("Received 1 claim. Doing nothing.");
339 return Ok(self.get_raw_claim(0).clone());
340 }
341
342 // Aggregate claims by performing the claim aggregation protocol.
343 // First retrieve V_i(l(x)).
344
345 let num_wlx_evaluations = if global_verifier_claim_agg_constant_column_optimization() {
346 let (num_wlx_evaluations, _, _) =
347 get_num_wlx_evaluations(self.get_claim_points_matrix());
348 num_wlx_evaluations
349 } else {
350 ((num_claims - 1) * self.get_num_vars()) + 1
351 };
352
353 let num_relevant_wlx_evaluations = num_wlx_evaluations - num_claims;
354 let relevant_wlx_evaluations = transcript_reader.consume_elements(
355 "Claim aggregation interpolation polynomial evaluations",
356 num_relevant_wlx_evaluations,
357 )?;
358 let wlx_evaluations = self
359 .get_results()
360 .clone()
361 .into_iter()
362 .chain(relevant_wlx_evaluations.clone())
363 .collect_vec();
364
365 // Next, sample `r^\star` from the transcript.
366 let agg_chal = transcript_reader
367 .get_challenge("Challenge for claim aggregation interpolation polynomial")?;
368 debug!("Aggregate challenge: {:#?}", agg_chal);
369
370 let aggregated_challenges = self.compute_aggregated_challenges(agg_chal).unwrap();
371 let claimed_val = evaluate_at_a_point(&wlx_evaluations, agg_chal).unwrap();
372
373 debug!("Aggregating claims: {:#?}", self.get_raw_claims());
374
375 let claim = RawClaim::new(aggregated_challenges, claimed_val);
376 debug!("Low level aggregated claim:\n{:#?}", &claim);
377
378 Ok(claim)
379 }
380}