remainder/prover/
proof_system.rs

1/// This macro generates a layer enum that represents all the possible layers
2/// Every layer variant of the enum needs to implement Layer, and the enum will also implement Layer and pass methods to it's variants
3///
4/// Usage:
5///
6/// layer_enum!(EnumName, (FirstVariant: LayerType), (SecondVariant: SecondLayerType), ..)
7#[macro_export]
8macro_rules! layer_enum {
9    ($type_name:ident, $(($var_name:ident: $variant:ty)),+) => {
10
11        paste::paste! {
12            #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
13            #[serde(bound = "F: Field")]
14            #[doc = r"Remainder generated trait enum"]
15            pub enum [<$type_name Enum>]<F: Field> {
16                $(
17                    #[doc = "Remainder generated layer variant"]
18                    $var_name(Box<$variant>),
19                )*
20            }
21
22            impl<F: Field> $crate::layer::LayerDescription<F> for [<$type_name DescriptionEnum>]<F> {
23                type VerifierLayer = [<Verifier $type_name Enum>]<F>;
24
25                fn layer_id(&self) -> super::LayerId {
26                    match self {
27                        $(
28                            Self::$var_name(layer) => layer.layer_id(),
29                        )*
30                    }
31                }
32
33                fn compute_data_outputs(
34                    &self,
35                    mle_outputs_necessary: &std::collections::HashSet<&$crate::mle::mle_description::MleDescription<F>>,
36                    circuit_map: &mut $crate::circuit_layout::CircuitEvalMap<F>,
37                ) {
38                    match self {
39                        $(
40                            Self::$var_name(layer) => layer.compute_data_outputs(mle_outputs_necessary, circuit_map),
41                        )*
42                    }
43                }
44
45                fn verify_rounds(
46                    &self,
47                    claims: &[&$crate::claims::RawClaim<F>],
48                    transcript: &mut impl $crate::shared_types::transcript::VerifierTranscript<F>,
49                ) -> anyhow::Result<VerifierLayerEnum<F>> {
50                    match self {
51                        $(
52                            Self::$var_name(layer) => Ok(layer.verify_rounds(claims, transcript)?),
53                        )*
54                    }
55                }
56
57                fn sumcheck_round_indices(
58                    &self
59                ) -> Vec<usize> {
60                    match self {
61                        $(
62                            Self::$var_name(layer) => layer.sumcheck_round_indices(),
63                        )*
64                    }
65                }
66
67                fn convert_into_verifier_layer(
68                    &self,
69                    sumcheck_bindings: &[F],
70                    claim_points: &[&[F]],
71                    transcript_reader: &mut impl $crate::shared_types::transcript::VerifierTranscript<F>,
72                ) -> anyhow::Result<Self::VerifierLayer> {
73                    match self {
74                        $(
75                            Self::$var_name(layer) => Ok(Self::VerifierLayer::$var_name(layer.convert_into_verifier_layer(sumcheck_bindings, claim_points, transcript_reader)?)),
76                        )*
77                    }
78                }
79
80                fn get_circuit_mles(
81                    &self,
82                ) -> Vec<& $crate::mle::mle_description::MleDescription<F>> {
83                    match self {
84                        $(
85                            Self::$var_name(layer) => layer.get_circuit_mles(),
86                        )*
87                    }
88                }
89
90                fn index_mle_indices(
91                    &mut self, start_index: usize,
92                ) {
93                    match self {
94                        $(
95                            Self::$var_name(layer) => layer.index_mle_indices(start_index),
96                        )*
97                    }
98                }
99
100                fn convert_into_prover_layer(
101                    &self,
102                    circuit_map: &$crate::circuit_layout::CircuitEvalMap<F>
103                ) -> LayerEnum<F> {
104                    match self {
105                        $(
106                            Self::$var_name(layer) => layer.convert_into_prover_layer(circuit_map),
107                        )*
108                    }
109                }
110
111                fn get_post_sumcheck_layer(
112                    &self,
113                    round_challenges: &[F],
114                    claim_challenges: &[&[F]],
115                    random_coefficients: &[F],
116                ) -> $crate::layer::PostSumcheckLayer<F, Option<F>> {
117                    match self {
118                        $(
119                            Self::$var_name(layer) => layer.get_post_sumcheck_layer(round_challenges, claim_challenges, random_coefficients),
120                        )*
121                    }
122                }
123
124                fn max_degree(&self) -> usize {
125                    match self {
126                        $(
127                            Self::$var_name(layer) => layer.max_degree(),
128                        )*
129                    }
130                }
131            }
132
133            impl<F: Field> $crate::layer::VerifierLayer<F> for [<Verifier$type_name Enum>]<F> {
134                fn layer_id(&self) -> super::LayerId {
135                    match self {
136                        $(
137                            Self::$var_name(layer) => layer.layer_id(),
138                        )*
139                    }
140                }
141
142                fn get_claims(&self) -> anyhow::Result<Vec<$crate::claims::Claim<F>>> {
143                    match self {
144                        $(
145                            Self::$var_name(layer) => layer.get_claims(),
146                        )*
147                    }
148                }
149            }
150
151            impl<F: Field> $crate::layer::Layer<F> for [<$type_name Enum>]<F> {
152                fn layer_id(&self) -> super::LayerId {
153                    match self {
154                        $(
155                            Self::$var_name(layer) => layer.layer_id(),
156                        )*
157                    }
158                }
159
160                fn prove(
161                    &mut self,
162                    claims: &[&$crate::claims::RawClaim<F>],
163                    transcript: &mut impl $crate::shared_types::transcript::ProverTranscript<F>,
164                ) -> anyhow::Result<()> {
165                    match self {
166                        $(
167                            Self::$var_name(layer) => layer.prove(claims, transcript),
168                        )*
169                    }
170                }
171
172                fn initialize(&mut self, claim_point: &[F]) -> anyhow::Result<()> {
173                    match self {
174                        $(
175                            Self::$var_name(layer) => layer.initialize(claim_point),
176                        )*
177                    }
178                }
179
180                fn compute_round_sumcheck_message(&mut self, round_index: usize, random_coefficients: &[F]) -> anyhow::Result<Vec<F>> {
181                    match self {
182                        $(
183                            Self::$var_name(layer) => layer.compute_round_sumcheck_message(round_index, random_coefficients),
184                        )*
185                    }
186                }
187
188                fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> anyhow::Result<()> {
189                    match self {
190                        $(
191                            Self::$var_name(layer) => layer.bind_round_variable(round_index, challenge),
192                        )*
193                    }
194                }
195
196                fn sumcheck_round_indices(&self) -> Vec<usize> {
197                    match self {
198                        $(
199                            Self::$var_name(layer) => layer.sumcheck_round_indices(),
200                        )*
201                    }
202                }
203
204                fn max_degree(&self) -> usize {
205                    match self {
206                        $(
207                            Self::$var_name(layer) => layer.max_degree(),
208                        )*
209                    }
210                }
211
212                fn get_post_sumcheck_layer(
213                    &self,
214                    round_challenges: &[F],
215                    claim_challenges: &[&[F]],
216                    random_coefficients: &[F],
217                ) -> $crate::layer::PostSumcheckLayer<F, F> {
218                    match self {
219                        $(
220                            Self::$var_name(layer) => layer.get_post_sumcheck_layer(round_challenges, claim_challenges, random_coefficients),
221                        )*
222                    }
223                }
224
225                fn get_claims(&self) -> anyhow::Result<Vec<$crate::claims::Claim<F>>> {
226                    match self {
227                        $(
228                            Self::$var_name(layer) => layer.get_claims(),
229                        )*
230                    }
231                }
232
233                fn initialize_rlc(&mut self, random_coefficients: &[F], claims: &[&$crate::claims::RawClaim<F>]) {
234                    match self {
235                        $(
236                            Self::$var_name(layer) => layer.initialize_rlc(random_coefficients, claims),
237                        )*
238                    }
239                }
240
241            }
242
243        $(
244            impl<F: Field> From<$variant> for [<$type_name Enum>]<F> {
245                fn from(var: $variant) -> [<$type_name Enum>]<F> {
246                    Self::$var_name(Box::new(var))
247                }
248            }
249        )*
250        }
251    }
252}