remainder/layer.rs
1//! A layer is a combination of multiple MLEs with an expression.
2
3pub mod combine_mles;
4pub mod gate;
5pub mod identity_gate;
6pub mod layer_enum;
7pub mod matmult;
8pub mod product;
9pub mod regular_layer;
10
11use std::{collections::HashSet, fmt::Debug};
12
13use layer_enum::{LayerEnum, VerifierLayerEnum};
14use product::PostSumcheckLayer;
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17
18use crate::{
19 circuit_building_context::CircuitBuildingContext,
20 circuit_layout::CircuitEvalMap,
21 claims::{Claim, ClaimError, RawClaim},
22 expression::expr_errors::ExpressionError,
23 mle::mle_description::MleDescription,
24 sumcheck::InterpError,
25};
26use shared_types::{
27 transcript::{ProverTranscript, TranscriptReaderError, VerifierTranscript},
28 Field,
29};
30
31use anyhow::Result;
32
33/// Errors to do with working with a type implementing [Layer].
34#[derive(Error, Debug, Clone)]
35pub enum LayerError {
36 #[error("Layer isn't ready to prove")]
37 /// Layer isn't ready to prove
38 LayerNotReady,
39 #[error("Error with underlying expression: {0}")]
40 /// Error with underlying expression: {0}
41 ExpressionError(#[from] ExpressionError),
42 #[error("Error with aggregating curr layer")]
43 /// Error with aggregating curr layer
44 AggregationError,
45 #[error("Error with getting Claim: {0}")]
46 /// Error with getting Claim
47 ClaimError(#[from] ClaimError),
48 #[error("Error with verifying layer: {0}")]
49 /// Error with verifying layer
50 VerificationError(#[from] VerificationError),
51 #[error("InterpError: {0}")]
52 /// InterpError
53 InterpError(#[from] InterpError),
54 #[error("Transcript Error: {0}")]
55 /// Transcript Error
56 TranscriptError(#[from] TranscriptReaderError),
57 /// Incorrect number of variable bindings
58 #[error("Layer {0} requires {1} variable bindings, but {2} were provided")]
59 NumVarsMismatch(LayerId, usize, usize),
60}
61
62/// Errors to do with verifying a layer while working with a type implementing
63/// [VerifierLayer].
64#[derive(Error, Debug, Clone)]
65pub enum VerificationError {
66 #[error("The sum of the first evaluations do not equal the claim")]
67 /// The sum of the first evaluations do not equal the claim
68 SumcheckStartFailed,
69
70 #[error("The sum of the current rounds evaluations do not equal the previous round at a random point")]
71 /// The sum of the current rounds evaluations do not equal the previous round at a random point
72 SumcheckFailed,
73
74 #[error("The final rounds evaluations at r do not equal the oracle query")]
75 /// The final rounds evaluations at r do not equal the oracle query
76 FinalSumcheckFailed,
77
78 #[error("The Oracle query does not match the final claim")]
79 /// The Oracle query does not match the final claim
80 GKRClaimCheckFailed,
81
82 #[error(
83 "The Challenges generated during sumcheck don't match the claims in the given expression"
84 )]
85 ///The Challenges generated during sumcheck don't match the claims in the given expression
86 ChallengeCheckFailed,
87
88 /// Error with underlying expression
89 #[error("Error with underlying expression")]
90 ExpressionError,
91
92 /// Error while reading the transcript proof.
93 #[error("Error while reading the transcript proof")]
94 TranscriptError,
95
96 /// Interpolation Error.
97 #[error("Interpolation Error: {0}")]
98 InterpError(#[from] InterpError),
99}
100
101/// A layer is the smallest component of the GKR protocol.
102///
103/// Each `Layer` is a sub-protocol that takes in some `Claim` and creates a proof
104/// that the `Claim` is correct
105pub trait Layer<F: Field> {
106 /// Gets this layer's ID.
107 fn layer_id(&self) -> LayerId;
108
109 /// Initialize this layer and perform any necessary pre-computation: beta
110 /// table, number of rounds, etc.
111 fn initialize(&mut self, claim_point: &[F]) -> Result<()>;
112
113 /// Tries to prove `claims` for this layer. There is only a single
114 /// aggregated claim if our
115 /// [shared_types::config::ClaimAggregationStrategy] is
116 /// [shared_types::config::ClaimAggregationStrategy::Interpolative],
117 /// otherwise we have several claims we take the random linear
118 /// combination over.
119 ///
120 /// In the process of proving, it mutates itself binding the variables
121 /// of the expression that define the layer.
122 ///
123 /// If successful, the proof is implicitly included in the modified
124 /// transcript.
125 fn prove(
126 &mut self,
127 claims: &[&RawClaim<F>],
128 transcript: &mut impl ProverTranscript<F>,
129 ) -> Result<()>;
130
131 /// Return the evaluations of the univariate polynomial generated during this round of sumcheck.
132 ///
133 /// This must be called with a steadily incrementing round_index & with a securely generated challenge.
134 fn compute_round_sumcheck_message(
135 &mut self,
136 round_index: usize,
137 random_coefficients: &[F],
138 ) -> Result<Vec<F>>;
139
140 /// Mutate the underlying bookkeeping tables to "bind" the given `challenge` to the bit.
141 /// labeled with that `round_index`.
142 fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()>;
143
144 /// The list of sumcheck rounds this layer will prove, by index.
145 fn sumcheck_round_indices(&self) -> Vec<usize>;
146
147 /// The maximum degree for any univariate in the sumcheck protocol.
148 fn max_degree(&self) -> usize;
149
150 /// Get the [PostSumcheckLayer] for a certain layer, which is a struct which represents
151 /// the fully bound layer.
152 /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
153 fn get_post_sumcheck_layer(
154 &self,
155 round_challenges: &[F],
156 claim_challenges: &[&[F]],
157 random_coefficients: &[F],
158 ) -> PostSumcheckLayer<F, F>;
159
160 /// Generates and returns all claims that this layer makes onto previous
161 /// layers.
162 fn get_claims(&self) -> Result<Vec<Claim<F>>>;
163
164 /// Transforms the underlying expression in the layer to the expression that
165 /// must be sumchecked over to verify claims combined using the RLC claim
166 /// aggregation method presented in Libra (2019).
167 fn initialize_rlc(&mut self, random_coefficients: &[F], claims: &[&RawClaim<F>]);
168}
169
170/// A circuit-description counterpart of the GKR [Layer] trait.
171pub trait LayerDescription<F: Field> {
172 /// The associated type that the verifier uses to work with a layer of this
173 /// kind.
174 type VerifierLayer: VerifierLayer<F> + Debug + Serialize + for<'a> Deserialize<'a>;
175
176 /// Returns this layer's ID.
177 fn layer_id(&self) -> LayerId;
178
179 /// Tries to verify `claims` for this layer and returns a [VerifierLayer]
180 /// with a fully bound and evaluated expression.
181 ///
182 /// There is only a single aggregated claim if our
183 /// [shared_types::config::ClaimAggregationStrategy]
184 /// is Interpolative, otherwise we have several claims we take the random linear
185 /// combination over.
186 ///
187 /// The proof is implicitly included in the `transcript`.
188 fn verify_rounds(
189 &self,
190 claims: &[&RawClaim<F>],
191 transcript: &mut impl VerifierTranscript<F>,
192 ) -> Result<VerifierLayerEnum<F>>;
193
194 /// The list of sumcheck rounds this layer will prove, by index.
195 fn sumcheck_round_indices(&self) -> Vec<usize>;
196
197 /// Turns this [LayerDescription] into a [VerifierLayer] by taking the
198 /// `sumcheck_bindings` and `claim_points` and inserting them into the
199 /// expression to become a verifier expression.
200 fn convert_into_verifier_layer(
201 &self,
202 sumcheck_bindings: &[F],
203 claim_points: &[&[F]],
204 transcript: &mut impl VerifierTranscript<F>,
205 ) -> Result<Self::VerifierLayer>;
206
207 /// Gets the [PostSumcheckLayer] for this layer.
208 /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
209 fn get_post_sumcheck_layer(
210 &self,
211 round_challenges: &[F],
212 claim_challenges: &[&[F]],
213 random_coefficients: &[F],
214 ) -> PostSumcheckLayer<F, Option<F>>;
215
216 /// The maximum degree for any univariate in the sumcheck protocol.
217 fn max_degree(&self) -> usize;
218
219 /// Label the MLE indices, starting from the `start_index` by
220 /// converting [crate::mle::MleIndex::Free] to [crate::mle::MleIndex::Indexed].
221 fn index_mle_indices(&mut self, start_index: usize);
222
223 /// Given the [MleDescription]s of which outputs are expected of this layer, compute the data
224 /// that populates these bookkeeping tables and mutate the circuit map to reflect this.
225 fn compute_data_outputs(
226 &self,
227 mle_outputs_necessary: &HashSet<&MleDescription<F>>,
228 circuit_map: &mut CircuitEvalMap<F>,
229 );
230
231 /// The [MleDescription]s that make up the leaves of the expression in this layer.
232 fn get_circuit_mles(&self) -> Vec<&MleDescription<F>>;
233
234 /// Given a [CircuitEvalMap], turn this [LayerDescription] into a ProverLayer.
235 fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F>;
236}
237
238/// A verifier counterpart of a GKR [Layer] trait.
239pub trait VerifierLayer<F: Field> {
240 /// Returns this layer's ID.
241 fn layer_id(&self) -> LayerId;
242
243 /// Get the claims that this layer makes on other layers.
244 fn get_claims(&self) -> Result<Vec<Claim<F>>>;
245}
246
247#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, Serialize, Deserialize, Copy, PartialOrd)]
248/// The location of a layer within the GKR circuit
249pub enum LayerId {
250 /// An Mle located in the input layer
251 Input(usize),
252 /// A layer representing values sampled from the verifier via Fiat-Shamir
253 FiatShamirChallengeLayer(usize),
254 /// A layer between the output layer and input layers
255 Layer(usize),
256}
257
258/// Implement Display for LayerId, so that we can use it in error messages
259impl std::fmt::Display for LayerId {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 match self {
262 LayerId::Input(id) => write!(f, "Input Layer {id}"),
263 LayerId::Layer(id) => write!(f, "Layer {id}"),
264 LayerId::FiatShamirChallengeLayer(id) => {
265 write!(f, "Fiat-Shamir Challenge Layer {id}")
266 }
267 }
268 }
269}
270
271impl LayerId {
272 /// Creates a new LayerId representing an input layer.
273 pub fn next_input_layer_id() -> Self {
274 LayerId::Input(CircuitBuildingContext::next_input_layer_id())
275 }
276
277 /// Creates a new LayerId representing a layer.
278 pub fn next_layer_id() -> Self {
279 LayerId::Layer(CircuitBuildingContext::next_layer_id())
280 }
281
282 /// Creates a new LayerId representing a Fiat-Shamir challenge layer.
283 pub fn next_fiat_shamir_challenge_layer_id() -> Self {
284 LayerId::FiatShamirChallengeLayer(
285 CircuitBuildingContext::next_fiat_shamir_challenge_layer_id(),
286 )
287 }
288
289 /// Returns the underlying usize if self is a variant of type Input, otherwise panics.
290 pub fn get_raw_input_layer_id(&self) -> usize {
291 match self {
292 LayerId::Input(id) => *id,
293 _ => panic!("Expected LayerId::Input, found {self:?}"),
294 }
295 }
296
297 /// Returns the underlying usize if self is a variant of type Input, otherwise panics.
298 pub fn get_raw_layer_id(&self) -> usize {
299 match self {
300 LayerId::Layer(id) => *id,
301 _ => panic!("Expected LayerId::Layer, found {self:?}"),
302 }
303 }
304}