remainder/claims/
claim_group.rs

1use ark_std::cfg_into_iter;
2use itertools::Itertools;
3use shared_types::{
4    config::global_config::global_verifier_claim_agg_constant_column_optimization,
5    transcript::{ProverTranscript, VerifierTranscript},
6    Field,
7};
8use tracing::{debug, info};
9
10use crate::{
11    claims::{
12        claim_aggregation::{get_num_wlx_evaluations, get_wlx_evaluations},
13        ClaimError,
14    },
15    mle::dense::DenseMle,
16    sumcheck::evaluate_at_a_point,
17};
18
19use super::{Claim, RawClaim};
20
21use anyhow::{anyhow, Ok, Result};
22
23#[cfg(feature = "parallel")]
24use rayon::iter::{IntoParallelIterator, ParallelIterator};
25
26/// Stores a collection of claims and provides an API running claim aggregation
27/// algorithms on them.
28/// The current implementation introduces up to 3x memory redundancy in order to
29/// achieve faster access times.
30/// Invariant: All claims are on the same number of variables.
31#[derive(Clone, Debug)]
32pub struct ClaimGroup<F: Field> {
33    /// A vector of raw claims in F^n.
34    pub claims: Vec<RawClaim<F>>,
35
36    /// A 2D matrix with the claim's points as its rows.
37    claim_points_matrix: Vec<Vec<F>>,
38
39    /// The points in `claims` is effectively a matrix of elements in F. We also
40    /// store the transpose of this matrix for convenient access.
41    claim_points_transpose: Vec<Vec<F>>,
42
43    /// A vector of `self.get_num_claims()` elements. For each claim i,
44    /// `result_vector[i]` stores the expected result of the i-th claim.
45    result_vector: Vec<F>,
46}
47
48impl<F: Field> ClaimGroup<F> {
49    /// Generates a [ClaimGroup] from a collection of [Claim]s.
50    /// All claims agree on the `Claim::to_layer_id` field and returns
51    /// [ClaimError::LayerIdMismatch] otherwise.  Returns
52    /// [ClaimError::NumVarsMismatch] if the collection of claims do not all
53    /// agree on the number of variables.
54    pub fn new(claims: Vec<Claim<F>>) -> Result<Self> {
55        let num_claims = claims.len();
56        if num_claims == 0 {
57            return Ok(Self {
58                claims: vec![],
59                claim_points_matrix: vec![],
60                claim_points_transpose: vec![],
61                result_vector: vec![],
62            });
63        }
64        // Check all claims match on the `to_layer_id` field.
65        let layer_id = claims[0].get_to_layer_id();
66        if !claims
67            .iter()
68            .all(|claim| claim.get_to_layer_id() == layer_id)
69        {
70            return Err(anyhow!(ClaimError::LayerIdMismatch));
71        }
72
73        Self::new_from_raw_claims(claims.into_iter().map(Into::into).collect())
74    }
75
76    /// Generates a new [ClaimGroup] from a collection of [RawClaim]s.
77    /// Returns [ClaimError::NumVarsMismatch] if the collection of claims
78    /// do not all agree on the number of variables.
79    pub fn new_from_raw_claims(claims: Vec<RawClaim<F>>) -> Result<Self> {
80        let num_claims = claims.len();
81
82        if num_claims == 0 {
83            return Ok(Self {
84                claims: vec![],
85                claim_points_matrix: vec![],
86                claim_points_transpose: vec![],
87                result_vector: vec![],
88            });
89        }
90
91        let num_vars = claims[0].get_num_vars();
92
93        // Check all claims match on the number of variables.
94        if !claims.iter().all(|claim| claim.get_num_vars() == num_vars) {
95            return Err(anyhow!(ClaimError::NumVarsMismatch));
96        }
97
98        // Populate the points_matrix
99        let points_matrix: Vec<_> = claims
100            .iter()
101            .map(|claim| -> Vec<F> { claim.get_point().to_vec() })
102            .collect();
103
104        // Compute the claim points transpose.
105        let claim_points_transpose: Vec<Vec<F>> = (0..num_vars)
106            .map(|j| (0..num_claims).map(|i| claims[i].get_point()[j]).collect())
107            .collect();
108
109        // Compute the result vector.
110        let result_vector: Vec<F> = (0..num_claims).map(|i| claims[i].get_eval()).collect();
111
112        Ok(Self {
113            claims,
114            claim_points_matrix: points_matrix,
115            claim_points_transpose,
116            result_vector,
117        })
118    }
119
120    /// Returns the number of claims stored in this group.
121    pub fn get_num_claims(&self) -> usize {
122        self.claims.len()
123    }
124
125    /// Returns true if the group contains no claims.
126    pub fn is_empty(&self) -> bool {
127        self.claims.is_empty()
128    }
129
130    /// Returns the number of indices of the claims stored.
131    /// Panics if no claims present.
132    pub fn get_num_vars(&self) -> usize {
133        self.claims[0].get_num_vars()
134    }
135
136    /// Returns a reference to a vector of `self.get_num_claims()` elements, the
137    /// j-th entry of which is the i-th coordinate of the j-th claim's point. In
138    /// other words, it returns the i-th column of the matrix containing the
139    /// claim points as its rows.
140    /// # Panics
141    /// When i is not in the range: 0 <= i < `self.get_num_vars()`.
142    pub fn get_points_column(&self, i: usize) -> &Vec<F> {
143        &self.claim_points_transpose[i]
144    }
145
146    /// Returns a reference to an "m x n" matrix where n = `self.get_num_vars()`
147    /// and m = `self.get_num_claims()` with the claim points as its rows.
148    pub fn get_claim_points_matrix(&self) -> &Vec<Vec<F>> {
149        &self.claim_points_matrix
150    }
151
152    /// Returns a reference to a vector with m = `self.get_num_claims()`
153    /// elements containing the results of all claims.
154    pub fn get_results(&self) -> &Vec<F> {
155        &self.result_vector
156    }
157
158    /// Returns a reference to the i-th claim.
159    pub fn get_raw_claim(&self, i: usize) -> &RawClaim<F> {
160        &self.claims[i]
161    }
162
163    /// Returns a reference to a vector of claims contained in this group.
164    pub fn get_raw_claims(&self) -> &[RawClaim<F>] {
165        &self.claims
166    }
167
168    /// Returns `claims` sorted by `from_layer_id` to prepare them for grouping.
169    /// Also performs claim de-duplication by eliminating copies of claims
170    /// on the same point.
171    fn preprocess_claims(mut claims: Vec<Claim<F>>) -> Vec<Claim<F>> {
172        // Sort claims on the `from_layer_id` field.
173        claims.sort_by(|claim1, claim2| {
174            claim1
175                .get_from_layer_id()
176                .partial_cmp(&claim2.get_from_layer_id())
177                .unwrap()
178        });
179
180        // Perform claim de-duplication
181        let claims = claims
182            .into_iter()
183            .unique_by(|c| c.get_point().to_vec())
184            .collect_vec();
185
186        claims
187    }
188
189    /// Partition `claims` into groups to be aggregated together.
190    pub fn form_claim_groups(claims: Vec<Claim<F>>) -> Vec<Self> {
191        // Sort claims by `from_layer_id` and remove duplicates.
192        let claims = Self::preprocess_claims(claims);
193
194        let num_claims = claims.len();
195        let mut claim_group_vec: Vec<Self> = vec![];
196
197        // Identify runs of claims with the same `from_layer_id` field.
198        let mut start_index = 0;
199        for idx in 1..num_claims {
200            if claims[idx].get_from_layer_id() != claims[idx - 1].get_from_layer_id() {
201                let end_index = idx;
202                claim_group_vec.push(Self::new(claims[start_index..end_index].to_vec()).unwrap());
203                start_index = idx;
204            }
205        }
206
207        // Process the last group.
208        let end_index = num_claims;
209        claim_group_vec.push(Self::new(claims[start_index..end_index].to_vec()).unwrap());
210
211        claim_group_vec
212    }
213
214    /// Computes the aggregated challenge point by interpolating a polynomial
215    /// passing through all the points in the claim group and then evaluating
216    /// it at `r_star`.
217    /// More precicely, if `self.claims` contains `m` points `[u_0, u_1, ...,
218    /// u_{m-1}]` where each `u_i \in F^n`, it computes a univariate polynomial
219    /// vector `l : F -> F^n` such that `l(0) = u_0, l(1) = u_1, ..., l(m-1) =
220    /// u_{m-1}` using Lagrange interpolation, then evaluates `l` on `r_star`
221    /// and returns it.
222    ///
223    /// # Requires
224    /// `self.claims_points` should be non-empty, otherwise a
225    /// [ClaimError::ClaimAggroError] is returned.
226    /// Using the ClaimGroup abstraction here is not ideal since we are only
227    /// operating on the points and not on the results. However, the ClaimGroup API
228    /// is convenient for accessing columns and makes the implementation more
229    /// readable. We should consider alternative designs.
230    fn compute_aggregated_challenges(&self, r_star: F) -> Result<Vec<F>> {
231        if self.is_empty() {
232            return Err(anyhow!(ClaimError::ClaimAggroError));
233        }
234
235        let num_vars = self.get_num_vars();
236
237        // Compute r = l(r*) by performing Lagrange interpolation on each coordinate
238        // using `evaluate_at_a_point`.
239        let r: Vec<F> = cfg_into_iter!(0..num_vars)
240            .map(|idx| {
241                let evals = self.get_points_column(idx);
242                // Interpolate the value l(r*) from the values
243                // l(0), l(1), ..., l(m-1) where m = # of claims.
244                evaluate_at_a_point(evals, r_star).unwrap()
245            })
246            .collect();
247
248        Ok(r)
249    }
250
251    /// Performs claim aggregation on the prover side for this claim group in a
252    /// single stage -- this is the standard "Thaler13" claim aggregation
253    /// without any heuristic optimizations.
254    ///
255    /// # Parameters
256    /// * `layer_mles`: the compiled bookkeeping tables from this layer, which
257    ///   when aggregated appropriately with their prefix bits, make up the
258    ///   layerwise bookkeeping table.
259    /// * `layer`: the layer whose output MLE is being made a claim on. Each of the
260    ///   claims are aggregated into one claim, whose validity is reduced to the
261    ///   validity of a claim in a future layer throught he sumcheck protocol.
262    /// * `transcript_writer`: is used to post wlx evaluations and generate
263    ///   challenges.
264    ///
265    /// # Returns
266    ///
267    /// If successful, returns a single aggregated claim.
268    pub fn prover_aggregate(
269        &self,
270        layer_mles: &[DenseMle<F>],
271        transcript_writer: &mut impl ProverTranscript<F>,
272    ) -> Result<RawClaim<F>> {
273        let num_claims = self.get_num_claims();
274        debug_assert!(num_claims > 0);
275        info!("ClaimGroup aggregation on {num_claims} claims.");
276
277        // Do nothing if there is only one claim.
278        if num_claims == 1 {
279            debug!("Received 1 claim. Doing nothing.");
280            return Ok(self.claims[0].clone());
281        }
282        assert!(self.get_claim_points_matrix().len() > 1);
283
284        // Aggregate claims by performing the claim aggregation protocol.
285        // First compute V_i(l(x)).
286        let wlx_evaluations = get_wlx_evaluations(
287            self.get_claim_points_matrix(),
288            self.get_results(),
289            layer_mles.to_vec(),
290            num_claims,
291            self.get_num_vars(),
292        )
293        .unwrap();
294        let relevant_wlx_evaluations = wlx_evaluations[num_claims..].to_vec();
295
296        // Append evaluations to the transcript before sampling a challenge.
297        transcript_writer.append_elements(
298            "Claim aggregation interpolation polynomial evaluations",
299            &relevant_wlx_evaluations,
300        );
301
302        // Next, sample `r^\star` from the transcript.
303        let agg_chal = transcript_writer
304            .get_challenge("Challenge for claim aggregation interpolation polynomial");
305        debug!("Aggregate challenge: {:#?}", agg_chal);
306
307        let aggregated_challenges = self.compute_aggregated_challenges(agg_chal).unwrap();
308        let claimed_val = evaluate_at_a_point(&wlx_evaluations, agg_chal).unwrap();
309
310        debug!("Aggregating claims: {:#?}", self.get_raw_claims());
311
312        let claim = RawClaim::new(aggregated_challenges, claimed_val);
313        debug!("Low level aggregated claim:\n{:#?}", &claim);
314
315        Ok(claim)
316    }
317
318    /// Performs claim aggregation on the verifier side for this claim group in
319    /// a single stage -- this is the standard "Thaler13" claim aggregation
320    /// without any heuristic optimizations.
321    ///
322    /// # Parameters
323    /// * `transcript_reader`: is used to retrieve wlx evaluations and generate
324    ///   challenges.
325    ///
326    /// # Returns
327    /// If successful, returns a single aggregated claim.
328    pub fn verifier_aggregate(
329        &self,
330        transcript_reader: &mut impl VerifierTranscript<F>,
331    ) -> Result<RawClaim<F>> {
332        let num_claims = self.get_num_claims();
333        debug_assert!(num_claims > 0);
334        info!("Low-level claim aggregation on {num_claims} claims.");
335
336        // Do nothing if there is only one claim.
337        if num_claims == 1 {
338            debug!("Received 1 claim. Doing nothing.");
339            return Ok(self.get_raw_claim(0).clone());
340        }
341
342        // Aggregate claims by performing the claim aggregation protocol.
343        // First retrieve V_i(l(x)).
344
345        let num_wlx_evaluations = if global_verifier_claim_agg_constant_column_optimization() {
346            let (num_wlx_evaluations, _, _) =
347                get_num_wlx_evaluations(self.get_claim_points_matrix());
348            num_wlx_evaluations
349        } else {
350            ((num_claims - 1) * self.get_num_vars()) + 1
351        };
352
353        let num_relevant_wlx_evaluations = num_wlx_evaluations - num_claims;
354        let relevant_wlx_evaluations = transcript_reader.consume_elements(
355            "Claim aggregation interpolation polynomial evaluations",
356            num_relevant_wlx_evaluations,
357        )?;
358        let wlx_evaluations = self
359            .get_results()
360            .clone()
361            .into_iter()
362            .chain(relevant_wlx_evaluations.clone())
363            .collect_vec();
364
365        // Next, sample `r^\star` from the transcript.
366        let agg_chal = transcript_reader
367            .get_challenge("Challenge for claim aggregation interpolation polynomial")?;
368        debug!("Aggregate challenge: {:#?}", agg_chal);
369
370        let aggregated_challenges = self.compute_aggregated_challenges(agg_chal).unwrap();
371        let claimed_val = evaluate_at_a_point(&wlx_evaluations, agg_chal).unwrap();
372
373        debug!("Aggregating claims: {:#?}", self.get_raw_claims());
374
375        let claim = RawClaim::new(aggregated_challenges, claimed_val);
376        debug!("Low level aggregated claim:\n{:#?}", &claim);
377
378        Ok(claim)
379    }
380}