remainder/
output_layer.rs

1//! An GKR Output Layer is a "virtual layer". This means it is not assigned a
2//! fresh [LayerId]. Instead, it is associated with some other
3//! intermediate/input layer whose [LayerId] it inherits. The MLE it stores is a
4//! restriction of an MLE defining its associated layer.
5
6use serde::{Deserialize, Serialize};
7use shared_types::{
8    transcript::{TranscriptReaderError, VerifierTranscript},
9    Field,
10};
11use thiserror::Error;
12
13use crate::{
14    claims::Claim,
15    layer::{LayerError, LayerId},
16};
17
18use crate::{
19    circuit_layout::CircuitEvalMap,
20    claims::ClaimError,
21    mle::{
22        dense::DenseMle, mle_description::MleDescription, mle_enum::MleEnum,
23        verifier_mle::VerifierMle, zero::ZeroMle, Mle, MleIndex,
24    },
25};
26
27use anyhow::{anyhow, Result};
28
29// Unit tests for Output Layers.
30#[cfg(test)]
31pub mod tests;
32
33/// Output layers are "virtual layers" in the sense that they are not assigned a
34/// separate [LayerId]. Instead they are associated with the ID of an existing
35/// intermediate/input layer on which they generate claims for.
36/// Contains an [MleEnum] which can be either a [DenseMle] or a [ZeroMle].
37#[derive(Serialize, Deserialize, Debug, Clone)]
38#[serde(bound = "F: Field")]
39pub struct OutputLayer<F: Field> {
40    mle: MleEnum<F>,
41}
42
43/// Required for output layer shenanigans within `layout`
44impl<F: Field> From<DenseMle<F>> for OutputLayer<F> {
45    fn from(value: DenseMle<F>) -> Self {
46        Self {
47            mle: MleEnum::Dense(value),
48        }
49    }
50}
51
52impl<F: Field> From<ZeroMle<F>> for OutputLayer<F> {
53    fn from(value: ZeroMle<F>) -> Self {
54        Self {
55            mle: MleEnum::Zero(value),
56        }
57    }
58}
59
60impl<F: Field> OutputLayer<F> {
61    /// Returns the MLE contained within.
62    pub fn get_mle(&self) -> &MleEnum<F> {
63        &self.mle
64    }
65
66    /// Generate a new [OutputLayer] from a [ZeroMle].
67    pub fn new_zero(zero_mle: ZeroMle<F>) -> Self {
68        Self {
69            mle: MleEnum::Zero(zero_mle),
70        }
71    }
72
73    /// If the MLE is fully-bound, returns its evaluation.
74    /// Otherwise, it returns an [OutputLayerError].
75    pub fn value(&self) -> Result<F> {
76        match &self.mle {
77            MleEnum::Dense(_) => unimplemented!(),
78            MleEnum::Zero(zero_mle) => {
79                if !zero_mle.is_fully_bounded() {
80                    return Err(anyhow!(OutputLayerError::MleNotFullyBound));
81                }
82
83                Ok(F::ZERO)
84            }
85        }
86    }
87
88    /// Returns the [LayerId] of the intermediate/input layer that this output
89    /// layer is associated with.
90    pub fn layer_id(&self) -> LayerId {
91        self.mle.layer_id()
92    }
93
94    /// Number of free variables.
95    pub fn num_free_vars(&self) -> usize {
96        self.mle.num_free_vars()
97    }
98
99    /// Whether the output layer is fully bounded
100    pub fn is_fully_bounded(&self) -> bool {
101        self.mle.is_fully_bounded()
102    }
103
104    /// Fix the variables of this output layer to random challenges sampled
105    /// from the transcript.
106    /// Expects `self.num_free_vars()` challenges.
107    pub fn fix_layer(&mut self, challenges: &[F]) -> Result<()> {
108        let bits = self.mle.index_mle_indices(0);
109        if bits != challenges.len() {
110            return Err(anyhow!(LayerError::NumVarsMismatch(
111                self.mle.layer_id(),
112                bits,
113                challenges.len(),
114            )));
115        }
116        (0..bits)
117            .zip(challenges.iter())
118            .for_each(|(bit, challenge)| {
119                self.mle.fix_variable(bit, *challenge);
120            });
121        debug_assert!(self.is_fully_bounded());
122        Ok(())
123    }
124
125    /// Extract a claim on this output layer by extracting the bindings from the fixed variables.
126    pub fn get_claim(&mut self) -> Result<Claim<F>> {
127        if !self.mle.is_fully_bounded() {
128            return Err(anyhow!(LayerError::ClaimError(ClaimError::MleRefMleError)));
129        }
130
131        let mle_indices: Result<Vec<F>> = self
132            .mle
133            .mle_indices()
134            .iter()
135            .map(|index| {
136                index
137                    .val()
138                    .ok_or(anyhow!(LayerError::ClaimError(ClaimError::MleRefMleError)))
139            })
140            .collect();
141
142        let claim_value = self.mle.first();
143
144        Ok(Claim::new(
145            mle_indices?,
146            claim_value,
147            self.mle.layer_id(),
148            self.mle.layer_id(),
149        ))
150    }
151}
152
153/// The circuit description type for the defaul Output Layer consisting of an
154/// MLE.
155#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash)]
156#[serde(bound = "F: Field")]
157pub struct OutputLayerDescription<F: Field> {
158    /// The metadata of this MLE: indices and associated layer.
159    pub mle: MleDescription<F>,
160
161    /// Whether this is an MLE that is supposed to evaluate to zero.
162    is_zero: bool,
163}
164
165impl<F: Field> OutputLayerDescription<F> {
166    /// Generate an output layer containing a verifier equivalent of a
167    /// [DenseMle], with a given `layer_id` and `mle_indices`.
168    pub fn new_dense(_layer_id: LayerId, _mle_indices: &[MleIndex<F>]) -> Self {
169        // We do not allow `DenseMle`s at this point.
170        unimplemented!()
171    }
172
173    /// Generate an output layer containing a verifier equivalent of a
174    /// [ZeroMle], with a given `layer_id` and `mle_indices`.
175    pub fn new_zero(layer_id: LayerId, mle_indices: &[MleIndex<F>]) -> Self {
176        Self {
177            mle: MleDescription::new(layer_id, mle_indices),
178            is_zero: true,
179        }
180    }
181
182    /// Determine whether the MLE Output layer contains an MLE whose
183    /// coefficients are all 0.
184    pub fn is_zero(&self) -> bool {
185        self.is_zero
186    }
187
188    /// Label the MLE indices in this layer, starting from the start_index.
189    pub fn index_mle_indices(&mut self, start_index: usize) {
190        self.mle.index_mle_indices(start_index);
191    }
192
193    /// Convert this into the prover view of an output layer, using the [CircuitEvalMap].
194    pub fn into_prover_output_layer(&self, circuit_map: &CircuitEvalMap<F>) -> OutputLayer<F> {
195        let output_mle = circuit_map.get_data_from_circuit_mle(&self.mle).unwrap();
196        let prefix_bits = self.mle.prefix_bits();
197        let prefix_bits_mle_index = prefix_bits
198            .iter()
199            .map(|bit| MleIndex::Fixed(*bit))
200            .collect();
201
202        if self.is_zero {
203            // Ensure that the calculated output MLE is all zeroes.
204            if !(output_mle.iter().all(|val| val == F::ZERO)) {
205                println!(
206                    "WARNING: MLE for output layer {} is not zero",
207                    self.mle.layer_id()
208                );
209            }
210            ZeroMle::new(
211                output_mle.num_vars(),
212                Some(prefix_bits_mle_index),
213                self.layer_id(),
214            )
215            .into()
216        } else {
217            DenseMle::new_with_prefix_bits(output_mle.clone(), self.layer_id(), prefix_bits).into()
218        }
219    }
220}
221
222impl<F: Field> OutputLayerDescription<F> {
223    /// Returns the [LayerId] of the intermediate/input layer that his output
224    /// layer is associated with.
225    pub fn layer_id(&self) -> LayerId {
226        self.mle.layer_id()
227    }
228
229    /// Retrieve the MLE evaluations from the transcript and fix the variables
230    /// of this output layer to random challenges sampled from the transcript.
231    /// Returns a description of the layer ready to be used by the verifier.
232    pub fn retrieve_mle_from_transcript_and_fix_layer(
233        &self,
234        transcript_reader: &mut impl VerifierTranscript<F>,
235    ) -> Result<VerifierOutputLayer<F>> {
236        // We do not yet handle DenseMle.
237        assert!(self.is_zero());
238
239        let num_evals = 1;
240
241        let evals = transcript_reader.consume_elements("Output layer MLE evals", num_evals)?;
242
243        if evals != vec![F::ZERO] {
244            return Err(anyhow!(VerifierOutputLayerError::NonZeroEvalForZeroMle));
245        }
246
247        let bits = self.mle.num_free_vars();
248
249        let mut mle = self.mle.clone();
250
251        // Evaluate each output MLE at a random challenge point.
252        for bit in 0..bits {
253            let challenge = transcript_reader.get_challenge("Challenge on the output layer")?;
254            mle.fix_variable(bit, challenge);
255        }
256
257        debug_assert_eq!(mle.num_free_vars(), 0);
258
259        let verifier_output_layer =
260            VerifierOutputLayer::new_zero(self.mle.layer_id(), mle.var_indices(), F::ZERO);
261
262        Ok(verifier_output_layer)
263    }
264}
265
266/// The verifier counterpart type for the defaul Output Layer consisting of an
267/// MLE.
268#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
269#[serde(bound = "F: Field")]
270pub struct VerifierOutputLayer<F: Field> {
271    /// A description of this layer's fully-bound MLE.
272    mle: VerifierMle<F>,
273
274    /// Whether this layer's MLE is supposed to evaluate to zero.
275    is_zero: bool,
276}
277
278impl<F: Field> VerifierOutputLayer<F> {
279    /// Generate an output layer containing a verifier equivalent of a
280    /// [DenseMle], with a given `layer_id` and `mle_indices`.
281    pub fn new_dense(_layer_id: LayerId, _mle_indices: &[MleIndex<F>]) -> Self {
282        // We do not allow `DenseMle`s at this point.
283        unimplemented!()
284    }
285
286    /// Generate an output layer containing a verifier equivalent of a
287    /// [ZeroMle], with a given `layer_id`, `mle_indices` and `value`.
288    pub fn new_zero(layer_id: LayerId, mle_indices: &[MleIndex<F>], value: F) -> Self {
289        Self {
290            mle: VerifierMle::new(layer_id, mle_indices.to_vec(), value),
291            is_zero: true,
292        }
293    }
294
295    /// Determine whether this output layer represents an MLE
296    /// whose coefficients are all 0.
297    pub fn is_zero(&self) -> bool {
298        self.is_zero
299    }
300
301    /// The number of variables used to represent the underlying MLE.
302    pub fn num_vars(&self) -> usize {
303        self.mle.num_vars()
304    }
305}
306
307impl<F: Field> VerifierOutputLayer<F> {
308    /// Returns the [LayerId] of the intermediate/input layer that this output
309    /// layer is associated with.
310    pub fn layer_id(&self) -> LayerId {
311        self.mle.layer_id()
312    }
313
314    /// Extract a claim on this output layer by extracting the bindings from the fixed variables.
315    pub fn get_claim(&self) -> Result<Claim<F>> {
316        // We do not support non-zero MLEs on Output Layers at this point!
317        assert!(self.is_zero());
318
319        let layer_id = self.layer_id();
320
321        let prefix_bits: Vec<MleIndex<F>> = self
322            .mle
323            .var_indices()
324            .iter()
325            .filter(|index| matches!(index, MleIndex::Fixed(_bit)))
326            .cloned()
327            .collect();
328
329        let claim_point: Vec<F> = self
330            .mle
331            .var_indices()
332            .iter()
333            .map(|index| {
334                index
335                    .val()
336                    .ok_or(anyhow!(LayerError::ClaimError(ClaimError::MleRefMleError)))
337            })
338            .collect::<Result<Vec<_>>>()?;
339
340        let num_vars = self.num_vars();
341        let num_prefix_bits = prefix_bits.len();
342        let num_free_vars = num_vars - num_prefix_bits;
343
344        let claim_value = self.mle.value();
345
346        // The verifier is expecting to receive a fully-bound [MleRef]. Start
347        // with an unindexed MLE, index it, and then bound its variables.
348        let mut claim_mle = MleEnum::Zero(ZeroMle::new(num_free_vars, Some(prefix_bits), layer_id));
349        claim_mle.index_mle_indices(0);
350
351        for mle_index in self.mle.var_indices().iter() {
352            if let MleIndex::Bound(val, idx) = mle_index {
353                claim_mle.fix_variable(*idx, *val);
354            }
355        }
356
357        Ok(Claim::new(
358            claim_point,
359            claim_value,
360            self.mle.layer_id(),
361            self.mle.layer_id(),
362        ))
363    }
364}
365
366/// Errors to do with working with a type implementing [OutputLayer].
367#[derive(Error, Debug, Clone)]
368pub enum OutputLayerError {
369    /// Expected fully-bound MLE.
370    #[error("Expected fully-bound MLE")]
371    MleNotFullyBound,
372}
373
374/// Errors to do with working with a type implementing [VerifierOutputLayer].
375#[derive(Error, Debug, Clone)]
376pub enum VerifierOutputLayerError {
377    /// Prover sent a non-zero value for a ZeroMle.
378    #[error("Prover sent a non-zero value for a ZeroMle")]
379    NonZeroEvalForZeroMle,
380
381    /// Transcript Reader Error during verification.
382    #[error("Transcript Reader Error: {:0}", _0)]
383    TranscriptError(#[from] TranscriptReaderError),
384}