remainder/
layer.rs

1//! A layer is a combination of multiple MLEs with an expression.
2
3pub mod combine_mles;
4pub mod gate;
5pub mod identity_gate;
6pub mod layer_enum;
7pub mod matmult;
8pub mod product;
9pub mod regular_layer;
10
11use std::{collections::HashSet, fmt::Debug};
12
13use layer_enum::{LayerEnum, VerifierLayerEnum};
14use product::PostSumcheckLayer;
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17
18use crate::{
19    circuit_building_context::CircuitBuildingContext,
20    circuit_layout::CircuitEvalMap,
21    claims::{Claim, ClaimError, RawClaim},
22    expression::expr_errors::ExpressionError,
23    mle::mle_description::MleDescription,
24    sumcheck::InterpError,
25};
26use shared_types::{
27    transcript::{ProverTranscript, TranscriptReaderError, VerifierTranscript},
28    Field,
29};
30
31use anyhow::Result;
32
33/// Errors to do with working with a type implementing [Layer].
34#[derive(Error, Debug, Clone)]
35pub enum LayerError {
36    #[error("Layer isn't ready to prove")]
37    /// Layer isn't ready to prove
38    LayerNotReady,
39    #[error("Error with underlying expression: {0}")]
40    /// Error with underlying expression: {0}
41    ExpressionError(#[from] ExpressionError),
42    #[error("Error with aggregating curr layer")]
43    /// Error with aggregating curr layer
44    AggregationError,
45    #[error("Error with getting Claim: {0}")]
46    /// Error with getting Claim
47    ClaimError(#[from] ClaimError),
48    #[error("Error with verifying layer: {0}")]
49    /// Error with verifying layer
50    VerificationError(#[from] VerificationError),
51    #[error("InterpError: {0}")]
52    /// InterpError
53    InterpError(#[from] InterpError),
54    #[error("Transcript Error: {0}")]
55    /// Transcript Error
56    TranscriptError(#[from] TranscriptReaderError),
57    /// Incorrect number of variable bindings
58    #[error("Layer {0} requires {1} variable bindings, but {2} were provided")]
59    NumVarsMismatch(LayerId, usize, usize),
60}
61
62/// Errors to do with verifying a layer while working with a type implementing
63/// [VerifierLayer].
64#[derive(Error, Debug, Clone)]
65pub enum VerificationError {
66    #[error("The sum of the first evaluations do not equal the claim")]
67    /// The sum of the first evaluations do not equal the claim
68    SumcheckStartFailed,
69
70    #[error("The sum of the current rounds evaluations do not equal the previous round at a random point")]
71    /// The sum of the current rounds evaluations do not equal the previous round at a random point
72    SumcheckFailed,
73
74    #[error("The final rounds evaluations at r do not equal the oracle query")]
75    /// The final rounds evaluations at r do not equal the oracle query
76    FinalSumcheckFailed,
77
78    #[error("The Oracle query does not match the final claim")]
79    /// The Oracle query does not match the final claim
80    GKRClaimCheckFailed,
81
82    #[error(
83        "The Challenges generated during sumcheck don't match the claims in the given expression"
84    )]
85    ///The Challenges generated during sumcheck don't match the claims in the given expression
86    ChallengeCheckFailed,
87
88    /// Error with underlying expression
89    #[error("Error with underlying expression")]
90    ExpressionError,
91
92    /// Error while reading the transcript proof.
93    #[error("Error while reading the transcript proof")]
94    TranscriptError,
95
96    /// Interpolation Error.
97    #[error("Interpolation Error: {0}")]
98    InterpError(#[from] InterpError),
99}
100
101/// A layer is the smallest component of the GKR protocol.
102///
103/// Each `Layer` is a sub-protocol that takes in some `Claim` and creates a proof
104/// that the `Claim` is correct
105pub trait Layer<F: Field> {
106    /// Gets this layer's ID.
107    fn layer_id(&self) -> LayerId;
108
109    /// Initialize this layer and perform any necessary pre-computation: beta
110    /// table, number of rounds, etc.
111    fn initialize(&mut self, claim_point: &[F]) -> Result<()>;
112
113    /// Tries to prove `claims` for this layer. There is only a single
114    /// aggregated claim if our
115    /// [shared_types::config::ClaimAggregationStrategy] is
116    /// [shared_types::config::ClaimAggregationStrategy::Interpolative],
117    /// otherwise we have several claims we take the random linear
118    /// combination over.
119    ///
120    /// In the process of proving, it mutates itself binding the variables
121    /// of the expression that define the layer.
122    ///
123    /// If successful, the proof is implicitly included in the modified
124    /// transcript.
125    fn prove(
126        &mut self,
127        claims: &[&RawClaim<F>],
128        transcript: &mut impl ProverTranscript<F>,
129    ) -> Result<()>;
130
131    /// Return the evaluations of the univariate polynomial generated during this round of sumcheck.
132    ///
133    /// This must be called with a steadily incrementing round_index & with a securely generated challenge.
134    fn compute_round_sumcheck_message(
135        &mut self,
136        round_index: usize,
137        random_coefficients: &[F],
138    ) -> Result<Vec<F>>;
139
140    /// Mutate the underlying bookkeeping tables to "bind" the given `challenge` to the bit.
141    /// labeled with that `round_index`.
142    fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()>;
143
144    /// The list of sumcheck rounds this layer will prove, by index.
145    fn sumcheck_round_indices(&self) -> Vec<usize>;
146
147    /// The maximum degree for any univariate in the sumcheck protocol.
148    fn max_degree(&self) -> usize;
149
150    /// Get the [PostSumcheckLayer] for a certain layer, which is a struct which represents
151    /// the fully bound layer.
152    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
153    fn get_post_sumcheck_layer(
154        &self,
155        round_challenges: &[F],
156        claim_challenges: &[&[F]],
157        random_coefficients: &[F],
158    ) -> PostSumcheckLayer<F, F>;
159
160    /// Generates and returns all claims that this layer makes onto previous
161    /// layers.
162    fn get_claims(&self) -> Result<Vec<Claim<F>>>;
163
164    /// Transforms the underlying expression in the layer to the expression that
165    /// must be sumchecked over to verify claims combined using the RLC claim
166    /// aggregation method presented in Libra (2019).
167    fn initialize_rlc(&mut self, random_coefficients: &[F], claims: &[&RawClaim<F>]);
168}
169
170/// A circuit-description counterpart of the GKR [Layer] trait.
171pub trait LayerDescription<F: Field> {
172    /// The associated type that the verifier uses to work with a layer of this
173    /// kind.
174    type VerifierLayer: VerifierLayer<F> + Debug + Serialize + for<'a> Deserialize<'a>;
175
176    /// Returns this layer's ID.
177    fn layer_id(&self) -> LayerId;
178
179    /// Tries to verify `claims` for this layer and returns a [VerifierLayer]
180    /// with a fully bound and evaluated expression.
181    ///
182    /// There is only a single aggregated claim if our
183    /// [shared_types::config::ClaimAggregationStrategy]
184    /// is Interpolative, otherwise we have several claims we take the random linear
185    /// combination over.
186    ///
187    /// The proof is implicitly included in the `transcript`.
188    fn verify_rounds(
189        &self,
190        claims: &[&RawClaim<F>],
191        transcript: &mut impl VerifierTranscript<F>,
192    ) -> Result<VerifierLayerEnum<F>>;
193
194    /// The list of sumcheck rounds this layer will prove, by index.
195    fn sumcheck_round_indices(&self) -> Vec<usize>;
196
197    /// Turns this [LayerDescription] into a [VerifierLayer] by taking the
198    /// `sumcheck_bindings` and `claim_points` and inserting them into the
199    /// expression to become a verifier expression.
200    fn convert_into_verifier_layer(
201        &self,
202        sumcheck_bindings: &[F],
203        claim_points: &[&[F]],
204        transcript: &mut impl VerifierTranscript<F>,
205    ) -> Result<Self::VerifierLayer>;
206
207    /// Gets the [PostSumcheckLayer] for this layer.
208    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
209    fn get_post_sumcheck_layer(
210        &self,
211        round_challenges: &[F],
212        claim_challenges: &[&[F]],
213        random_coefficients: &[F],
214    ) -> PostSumcheckLayer<F, Option<F>>;
215
216    /// The maximum degree for any univariate in the sumcheck protocol.
217    fn max_degree(&self) -> usize;
218
219    /// Label the MLE indices, starting from the `start_index` by
220    /// converting [crate::mle::MleIndex::Free] to [crate::mle::MleIndex::Indexed].
221    fn index_mle_indices(&mut self, start_index: usize);
222
223    /// Given the [MleDescription]s of which outputs are expected of this layer, compute the data
224    /// that populates these bookkeeping tables and mutate the circuit map to reflect this.
225    fn compute_data_outputs(
226        &self,
227        mle_outputs_necessary: &HashSet<&MleDescription<F>>,
228        circuit_map: &mut CircuitEvalMap<F>,
229    );
230
231    /// The [MleDescription]s that make up the leaves of the expression in this layer.
232    fn get_circuit_mles(&self) -> Vec<&MleDescription<F>>;
233
234    /// Given a [CircuitEvalMap], turn this [LayerDescription] into a ProverLayer.
235    fn convert_into_prover_layer(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F>;
236}
237
238/// A verifier counterpart of a GKR [Layer] trait.
239pub trait VerifierLayer<F: Field> {
240    /// Returns this layer's ID.
241    fn layer_id(&self) -> LayerId;
242
243    /// Get the claims that this layer makes on other layers.
244    fn get_claims(&self) -> Result<Vec<Claim<F>>>;
245}
246
247#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, Serialize, Deserialize, Copy, PartialOrd)]
248/// The location of a layer within the GKR circuit
249pub enum LayerId {
250    /// An Mle located in the input layer
251    Input(usize),
252    /// A layer representing values sampled from the verifier via Fiat-Shamir
253    FiatShamirChallengeLayer(usize),
254    /// A layer between the output layer and input layers
255    Layer(usize),
256}
257
258/// Implement Display for LayerId, so that we can use it in error messages
259impl std::fmt::Display for LayerId {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        match self {
262            LayerId::Input(id) => write!(f, "Input Layer {id}"),
263            LayerId::Layer(id) => write!(f, "Layer {id}"),
264            LayerId::FiatShamirChallengeLayer(id) => {
265                write!(f, "Fiat-Shamir Challenge Layer {id}")
266            }
267        }
268    }
269}
270
271impl LayerId {
272    /// Creates a new LayerId representing an input layer.
273    pub fn next_input_layer_id() -> Self {
274        LayerId::Input(CircuitBuildingContext::next_input_layer_id())
275    }
276
277    /// Creates a new LayerId representing a layer.
278    pub fn next_layer_id() -> Self {
279        LayerId::Layer(CircuitBuildingContext::next_layer_id())
280    }
281
282    /// Creates a new LayerId representing a Fiat-Shamir challenge layer.
283    pub fn next_fiat_shamir_challenge_layer_id() -> Self {
284        LayerId::FiatShamirChallengeLayer(
285            CircuitBuildingContext::next_fiat_shamir_challenge_layer_id(),
286        )
287    }
288
289    /// Returns the underlying usize if self is a variant of type Input, otherwise panics.
290    pub fn get_raw_input_layer_id(&self) -> usize {
291        match self {
292            LayerId::Input(id) => *id,
293            _ => panic!("Expected LayerId::Input, found {self:?}"),
294        }
295    }
296
297    /// Returns the underlying usize if self is a variant of type Input, otherwise panics.
298    pub fn get_raw_layer_id(&self) -> usize {
299        match self {
300            LayerId::Layer(id) => *id,
301            _ => panic!("Expected LayerId::Layer, found {self:?}"),
302        }
303    }
304}