remainder/prover/
proof_system.rs1#[macro_export]
8macro_rules! layer_enum {
9 ($type_name:ident, $(($var_name:ident: $variant:ty)),+) => {
10
11 paste::paste! {
12 #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
13 #[serde(bound = "F: Field")]
14 #[doc = r"Remainder generated trait enum"]
15 pub enum [<$type_name Enum>]<F: Field> {
16 $(
17 #[doc = "Remainder generated layer variant"]
18 $var_name(Box<$variant>),
19 )*
20 }
21
22 impl<F: Field> $crate::layer::LayerDescription<F> for [<$type_name DescriptionEnum>]<F> {
23 type VerifierLayer = [<Verifier $type_name Enum>]<F>;
24
25 fn layer_id(&self) -> super::LayerId {
26 match self {
27 $(
28 Self::$var_name(layer) => layer.layer_id(),
29 )*
30 }
31 }
32
33 fn compute_data_outputs(
34 &self,
35 mle_outputs_necessary: &std::collections::HashSet<&$crate::mle::mle_description::MleDescription<F>>,
36 circuit_map: &mut $crate::circuit_layout::CircuitEvalMap<F>,
37 ) {
38 match self {
39 $(
40 Self::$var_name(layer) => layer.compute_data_outputs(mle_outputs_necessary, circuit_map),
41 )*
42 }
43 }
44
45 fn verify_rounds(
46 &self,
47 claims: &[&$crate::claims::RawClaim<F>],
48 transcript: &mut impl $crate::shared_types::transcript::VerifierTranscript<F>,
49 ) -> anyhow::Result<VerifierLayerEnum<F>> {
50 match self {
51 $(
52 Self::$var_name(layer) => Ok(layer.verify_rounds(claims, transcript)?),
53 )*
54 }
55 }
56
57 fn sumcheck_round_indices(
58 &self
59 ) -> Vec<usize> {
60 match self {
61 $(
62 Self::$var_name(layer) => layer.sumcheck_round_indices(),
63 )*
64 }
65 }
66
67 fn convert_into_verifier_layer(
68 &self,
69 sumcheck_bindings: &[F],
70 claim_points: &[&[F]],
71 transcript_reader: &mut impl $crate::shared_types::transcript::VerifierTranscript<F>,
72 ) -> anyhow::Result<Self::VerifierLayer> {
73 match self {
74 $(
75 Self::$var_name(layer) => Ok(Self::VerifierLayer::$var_name(layer.convert_into_verifier_layer(sumcheck_bindings, claim_points, transcript_reader)?)),
76 )*
77 }
78 }
79
80 fn get_circuit_mles(
81 &self,
82 ) -> Vec<& $crate::mle::mle_description::MleDescription<F>> {
83 match self {
84 $(
85 Self::$var_name(layer) => layer.get_circuit_mles(),
86 )*
87 }
88 }
89
90 fn index_mle_indices(
91 &mut self, start_index: usize,
92 ) {
93 match self {
94 $(
95 Self::$var_name(layer) => layer.index_mle_indices(start_index),
96 )*
97 }
98 }
99
100 fn convert_into_prover_layer(
101 &self,
102 circuit_map: &$crate::circuit_layout::CircuitEvalMap<F>
103 ) -> LayerEnum<F> {
104 match self {
105 $(
106 Self::$var_name(layer) => layer.convert_into_prover_layer(circuit_map),
107 )*
108 }
109 }
110
111 fn get_post_sumcheck_layer(
112 &self,
113 round_challenges: &[F],
114 claim_challenges: &[&[F]],
115 random_coefficients: &[F],
116 ) -> $crate::layer::PostSumcheckLayer<F, Option<F>> {
117 match self {
118 $(
119 Self::$var_name(layer) => layer.get_post_sumcheck_layer(round_challenges, claim_challenges, random_coefficients),
120 )*
121 }
122 }
123
124 fn max_degree(&self) -> usize {
125 match self {
126 $(
127 Self::$var_name(layer) => layer.max_degree(),
128 )*
129 }
130 }
131 }
132
133 impl<F: Field> $crate::layer::VerifierLayer<F> for [<Verifier$type_name Enum>]<F> {
134 fn layer_id(&self) -> super::LayerId {
135 match self {
136 $(
137 Self::$var_name(layer) => layer.layer_id(),
138 )*
139 }
140 }
141
142 fn get_claims(&self) -> anyhow::Result<Vec<$crate::claims::Claim<F>>> {
143 match self {
144 $(
145 Self::$var_name(layer) => layer.get_claims(),
146 )*
147 }
148 }
149 }
150
151 impl<F: Field> $crate::layer::Layer<F> for [<$type_name Enum>]<F> {
152 fn layer_id(&self) -> super::LayerId {
153 match self {
154 $(
155 Self::$var_name(layer) => layer.layer_id(),
156 )*
157 }
158 }
159
160 fn prove(
161 &mut self,
162 claims: &[&$crate::claims::RawClaim<F>],
163 transcript: &mut impl $crate::shared_types::transcript::ProverTranscript<F>,
164 ) -> anyhow::Result<()> {
165 match self {
166 $(
167 Self::$var_name(layer) => layer.prove(claims, transcript),
168 )*
169 }
170 }
171
172 fn initialize(&mut self, claim_point: &[F]) -> anyhow::Result<()> {
173 match self {
174 $(
175 Self::$var_name(layer) => layer.initialize(claim_point),
176 )*
177 }
178 }
179
180 fn compute_round_sumcheck_message(&mut self, round_index: usize, random_coefficients: &[F]) -> anyhow::Result<Vec<F>> {
181 match self {
182 $(
183 Self::$var_name(layer) => layer.compute_round_sumcheck_message(round_index, random_coefficients),
184 )*
185 }
186 }
187
188 fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> anyhow::Result<()> {
189 match self {
190 $(
191 Self::$var_name(layer) => layer.bind_round_variable(round_index, challenge),
192 )*
193 }
194 }
195
196 fn sumcheck_round_indices(&self) -> Vec<usize> {
197 match self {
198 $(
199 Self::$var_name(layer) => layer.sumcheck_round_indices(),
200 )*
201 }
202 }
203
204 fn max_degree(&self) -> usize {
205 match self {
206 $(
207 Self::$var_name(layer) => layer.max_degree(),
208 )*
209 }
210 }
211
212 fn get_post_sumcheck_layer(
213 &self,
214 round_challenges: &[F],
215 claim_challenges: &[&[F]],
216 random_coefficients: &[F],
217 ) -> $crate::layer::PostSumcheckLayer<F, F> {
218 match self {
219 $(
220 Self::$var_name(layer) => layer.get_post_sumcheck_layer(round_challenges, claim_challenges, random_coefficients),
221 )*
222 }
223 }
224
225 fn get_claims(&self) -> anyhow::Result<Vec<$crate::claims::Claim<F>>> {
226 match self {
227 $(
228 Self::$var_name(layer) => layer.get_claims(),
229 )*
230 }
231 }
232
233 fn initialize_rlc(&mut self, random_coefficients: &[F], claims: &[&$crate::claims::RawClaim<F>]) {
234 match self {
235 $(
236 Self::$var_name(layer) => layer.initialize_rlc(random_coefficients, claims),
237 )*
238 }
239 }
240
241 }
242
243 $(
244 impl<F: Field> From<$variant> for [<$type_name Enum>]<F> {
245 fn from(var: $variant) -> [<$type_name Enum>]<F> {
246 Self::$var_name(Box::new(var))
247 }
248 }
249 )*
250 }
251 }
252}