1pub mod claim_group;
5
6#[cfg(test)]
8pub mod tests;
9
10pub mod claim_aggregation;
11
12use std::{collections::HashMap, fmt};
13
14use shared_types::Field;
15use thiserror::Error;
16
17use serde::{Deserialize, Serialize};
18
19use crate::layer::LayerId;
20
21#[derive(Error, Debug, Clone)]
23pub enum ClaimError {
24 #[error("MLE indices must all be fixed")]
26 ClaimMleIndexError,
27
28 #[error("MLE within MleRef has multiple values within it")]
30 MleRefMleError,
31
32 #[error("Error aggregating claims")]
34 ClaimAggroError,
35
36 #[error("All claims in a group should agree on the number of variables")]
38 NumVarsMismatch,
39
40 #[error("All claims in a group should agree the destination layer field")]
42 LayerIdMismatch,
43
44 #[error("Zero MLE refs cannot be used as intermediate values within a circuit")]
46 IntermediateZeroMLERefError,
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
56#[serde(bound = "F: Field")]
57pub struct RawClaim<F: Field> {
58 point: Vec<F>,
60
61 evaluation: F,
63}
64
65impl<F: Field> RawClaim<F> {
66 pub fn new(point: Vec<F>, evaluation: F) -> Self {
68 Self { point, evaluation }
69 }
70
71 pub fn get_num_vars(&self) -> usize {
73 self.point.len()
74 }
75
76 pub fn get_point(&self) -> &[F] {
78 &self.point
79 }
80
81 pub fn get_eval(&self) -> F {
83 self.evaluation
84 }
85}
86
87#[derive(Clone, Serialize, Deserialize, PartialEq)]
93#[serde(bound = "F: Field")]
94pub struct Claim<F: Field> {
95 claim: RawClaim<F>,
97
98 from_layer_id: LayerId,
100
101 to_layer_id: LayerId,
104}
105
106impl<F: Field> Claim<F> {
107 pub fn new(point: Vec<F>, evaluation: F, from_layer_id: LayerId, to_layer_id: LayerId) -> Self {
110 Self {
111 claim: RawClaim::new(point, evaluation),
112 from_layer_id,
113 to_layer_id,
114 }
115 }
116
117 pub fn get_num_vars(&self) -> usize {
119 self.claim.get_num_vars()
120 }
121
122 pub fn get_point(&self) -> &[F] {
124 self.claim.get_point()
125 }
126
127 pub fn get_eval(&self) -> F {
129 self.claim.get_eval()
130 }
131
132 pub fn get_from_layer_id(&self) -> LayerId {
134 self.from_layer_id
135 }
136
137 pub fn get_to_layer_id(&self) -> LayerId {
139 self.to_layer_id
140 }
141
142 pub fn get_raw_claim(&self) -> &RawClaim<F> {
144 &self.claim
145 }
146}
147
148impl<F: Field> From<Claim<F>> for RawClaim<F> {
149 fn from(value: Claim<F>) -> Self {
150 Self {
151 point: value.claim.point,
152 evaluation: value.claim.evaluation,
153 }
154 }
155}
156
157impl<F: fmt::Debug + Field> fmt::Debug for Claim<F> {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 f.debug_struct("Claim")
160 .field("point", &self.get_point().to_vec())
161 .field("result", &self.claim.get_eval())
162 .field("from_layer_id", &self.from_layer_id)
163 .field("to_layer_id", &self.to_layer_id)
164 .finish()
165 }
166}
167
168pub struct ClaimTracker<F: Field> {
173 claim_map: HashMap<LayerId, Vec<Claim<F>>>,
175}
176
177impl<F: Field> ClaimTracker<F> {
178 pub fn new() -> Self {
180 Self {
181 claim_map: HashMap::<LayerId, Vec<Claim<F>>>::new(),
182 }
183 }
184
185 pub fn new_with_capacity(capacity: usize) -> Self {
188 Self {
189 claim_map: HashMap::<LayerId, Vec<Claim<F>>>::with_capacity(capacity),
190 }
191 }
192
193 pub fn insert(&mut self, layer_id: LayerId, claim: Claim<F>) {
196 debug_assert_eq!(claim.get_to_layer_id(), layer_id);
197
198 if let Some(claims) = self.claim_map.get_mut(&layer_id) {
199 claims.push(claim);
200 } else {
201 self.claim_map.insert(layer_id, vec![claim]);
202 }
203 }
204
205 pub fn get(&self, layer_id: LayerId) -> Option<&Vec<Claim<F>>> {
209 self.claim_map.get(&layer_id)
210 }
211
212 pub fn remove(&mut self, layer_id: LayerId) -> Option<Vec<Claim<F>>> {
214 self.claim_map.remove(&layer_id)
215 }
216
217 pub fn is_empty(&self) -> bool {
219 self.claim_map.is_empty()
220 }
221}
222
223impl<F: Field> Default for ClaimTracker<F> {
226 fn default() -> Self {
227 ClaimTracker::new()
228 }
229}