1use serde::{Deserialize, Serialize};
7use shared_types::{
8 transcript::{TranscriptReaderError, VerifierTranscript},
9 Field,
10};
11use thiserror::Error;
12
13use crate::{
14 claims::Claim,
15 layer::{LayerError, LayerId},
16};
17
18use crate::{
19 circuit_layout::CircuitEvalMap,
20 claims::ClaimError,
21 mle::{
22 dense::DenseMle, mle_description::MleDescription, mle_enum::MleEnum,
23 verifier_mle::VerifierMle, zero::ZeroMle, Mle, MleIndex,
24 },
25};
26
27use anyhow::{anyhow, Result};
28
29#[cfg(test)]
31pub mod tests;
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
38#[serde(bound = "F: Field")]
39pub struct OutputLayer<F: Field> {
40 mle: MleEnum<F>,
41}
42
43impl<F: Field> From<DenseMle<F>> for OutputLayer<F> {
45 fn from(value: DenseMle<F>) -> Self {
46 Self {
47 mle: MleEnum::Dense(value),
48 }
49 }
50}
51
52impl<F: Field> From<ZeroMle<F>> for OutputLayer<F> {
53 fn from(value: ZeroMle<F>) -> Self {
54 Self {
55 mle: MleEnum::Zero(value),
56 }
57 }
58}
59
60impl<F: Field> OutputLayer<F> {
61 pub fn get_mle(&self) -> &MleEnum<F> {
63 &self.mle
64 }
65
66 pub fn new_zero(zero_mle: ZeroMle<F>) -> Self {
68 Self {
69 mle: MleEnum::Zero(zero_mle),
70 }
71 }
72
73 pub fn value(&self) -> Result<F> {
76 match &self.mle {
77 MleEnum::Dense(_) => unimplemented!(),
78 MleEnum::Zero(zero_mle) => {
79 if !zero_mle.is_fully_bounded() {
80 return Err(anyhow!(OutputLayerError::MleNotFullyBound));
81 }
82
83 Ok(F::ZERO)
84 }
85 }
86 }
87
88 pub fn layer_id(&self) -> LayerId {
91 self.mle.layer_id()
92 }
93
94 pub fn num_free_vars(&self) -> usize {
96 self.mle.num_free_vars()
97 }
98
99 pub fn is_fully_bounded(&self) -> bool {
101 self.mle.is_fully_bounded()
102 }
103
104 pub fn fix_layer(&mut self, challenges: &[F]) -> Result<()> {
108 let bits = self.mle.index_mle_indices(0);
109 if bits != challenges.len() {
110 return Err(anyhow!(LayerError::NumVarsMismatch(
111 self.mle.layer_id(),
112 bits,
113 challenges.len(),
114 )));
115 }
116 (0..bits)
117 .zip(challenges.iter())
118 .for_each(|(bit, challenge)| {
119 self.mle.fix_variable(bit, *challenge);
120 });
121 debug_assert!(self.is_fully_bounded());
122 Ok(())
123 }
124
125 pub fn get_claim(&mut self) -> Result<Claim<F>> {
127 if !self.mle.is_fully_bounded() {
128 return Err(anyhow!(LayerError::ClaimError(ClaimError::MleRefMleError)));
129 }
130
131 let mle_indices: Result<Vec<F>> = self
132 .mle
133 .mle_indices()
134 .iter()
135 .map(|index| {
136 index
137 .val()
138 .ok_or(anyhow!(LayerError::ClaimError(ClaimError::MleRefMleError)))
139 })
140 .collect();
141
142 let claim_value = self.mle.first();
143
144 Ok(Claim::new(
145 mle_indices?,
146 claim_value,
147 self.mle.layer_id(),
148 self.mle.layer_id(),
149 ))
150 }
151}
152
153#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash)]
156#[serde(bound = "F: Field")]
157pub struct OutputLayerDescription<F: Field> {
158 pub mle: MleDescription<F>,
160
161 is_zero: bool,
163}
164
165impl<F: Field> OutputLayerDescription<F> {
166 pub fn new_dense(_layer_id: LayerId, _mle_indices: &[MleIndex<F>]) -> Self {
169 unimplemented!()
171 }
172
173 pub fn new_zero(layer_id: LayerId, mle_indices: &[MleIndex<F>]) -> Self {
176 Self {
177 mle: MleDescription::new(layer_id, mle_indices),
178 is_zero: true,
179 }
180 }
181
182 pub fn is_zero(&self) -> bool {
185 self.is_zero
186 }
187
188 pub fn index_mle_indices(&mut self, start_index: usize) {
190 self.mle.index_mle_indices(start_index);
191 }
192
193 pub fn into_prover_output_layer(&self, circuit_map: &CircuitEvalMap<F>) -> OutputLayer<F> {
195 let output_mle = circuit_map.get_data_from_circuit_mle(&self.mle).unwrap();
196 let prefix_bits = self.mle.prefix_bits();
197 let prefix_bits_mle_index = prefix_bits
198 .iter()
199 .map(|bit| MleIndex::Fixed(*bit))
200 .collect();
201
202 if self.is_zero {
203 if !(output_mle.iter().all(|val| val == F::ZERO)) {
205 println!(
206 "WARNING: MLE for output layer {} is not zero",
207 self.mle.layer_id()
208 );
209 }
210 ZeroMle::new(
211 output_mle.num_vars(),
212 Some(prefix_bits_mle_index),
213 self.layer_id(),
214 )
215 .into()
216 } else {
217 DenseMle::new_with_prefix_bits(output_mle.clone(), self.layer_id(), prefix_bits).into()
218 }
219 }
220}
221
222impl<F: Field> OutputLayerDescription<F> {
223 pub fn layer_id(&self) -> LayerId {
226 self.mle.layer_id()
227 }
228
229 pub fn retrieve_mle_from_transcript_and_fix_layer(
233 &self,
234 transcript_reader: &mut impl VerifierTranscript<F>,
235 ) -> Result<VerifierOutputLayer<F>> {
236 assert!(self.is_zero());
238
239 let num_evals = 1;
240
241 let evals = transcript_reader.consume_elements("Output layer MLE evals", num_evals)?;
242
243 if evals != vec![F::ZERO] {
244 return Err(anyhow!(VerifierOutputLayerError::NonZeroEvalForZeroMle));
245 }
246
247 let bits = self.mle.num_free_vars();
248
249 let mut mle = self.mle.clone();
250
251 for bit in 0..bits {
253 let challenge = transcript_reader.get_challenge("Challenge on the output layer")?;
254 mle.fix_variable(bit, challenge);
255 }
256
257 debug_assert_eq!(mle.num_free_vars(), 0);
258
259 let verifier_output_layer =
260 VerifierOutputLayer::new_zero(self.mle.layer_id(), mle.var_indices(), F::ZERO);
261
262 Ok(verifier_output_layer)
263 }
264}
265
266#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
269#[serde(bound = "F: Field")]
270pub struct VerifierOutputLayer<F: Field> {
271 mle: VerifierMle<F>,
273
274 is_zero: bool,
276}
277
278impl<F: Field> VerifierOutputLayer<F> {
279 pub fn new_dense(_layer_id: LayerId, _mle_indices: &[MleIndex<F>]) -> Self {
282 unimplemented!()
284 }
285
286 pub fn new_zero(layer_id: LayerId, mle_indices: &[MleIndex<F>], value: F) -> Self {
289 Self {
290 mle: VerifierMle::new(layer_id, mle_indices.to_vec(), value),
291 is_zero: true,
292 }
293 }
294
295 pub fn is_zero(&self) -> bool {
298 self.is_zero
299 }
300
301 pub fn num_vars(&self) -> usize {
303 self.mle.num_vars()
304 }
305}
306
307impl<F: Field> VerifierOutputLayer<F> {
308 pub fn layer_id(&self) -> LayerId {
311 self.mle.layer_id()
312 }
313
314 pub fn get_claim(&self) -> Result<Claim<F>> {
316 assert!(self.is_zero());
318
319 let layer_id = self.layer_id();
320
321 let prefix_bits: Vec<MleIndex<F>> = self
322 .mle
323 .var_indices()
324 .iter()
325 .filter(|index| matches!(index, MleIndex::Fixed(_bit)))
326 .cloned()
327 .collect();
328
329 let claim_point: Vec<F> = self
330 .mle
331 .var_indices()
332 .iter()
333 .map(|index| {
334 index
335 .val()
336 .ok_or(anyhow!(LayerError::ClaimError(ClaimError::MleRefMleError)))
337 })
338 .collect::<Result<Vec<_>>>()?;
339
340 let num_vars = self.num_vars();
341 let num_prefix_bits = prefix_bits.len();
342 let num_free_vars = num_vars - num_prefix_bits;
343
344 let claim_value = self.mle.value();
345
346 let mut claim_mle = MleEnum::Zero(ZeroMle::new(num_free_vars, Some(prefix_bits), layer_id));
349 claim_mle.index_mle_indices(0);
350
351 for mle_index in self.mle.var_indices().iter() {
352 if let MleIndex::Bound(val, idx) = mle_index {
353 claim_mle.fix_variable(*idx, *val);
354 }
355 }
356
357 Ok(Claim::new(
358 claim_point,
359 claim_value,
360 self.mle.layer_id(),
361 self.mle.layer_id(),
362 ))
363 }
364}
365
366#[derive(Error, Debug, Clone)]
368pub enum OutputLayerError {
369 #[error("Expected fully-bound MLE")]
371 MleNotFullyBound,
372}
373
374#[derive(Error, Debug, Clone)]
376pub enum VerifierOutputLayerError {
377 #[error("Prover sent a non-zero value for a ZeroMle")]
379 NonZeroEvalForZeroMle,
380
381 #[error("Transcript Reader Error: {:0}", _0)]
383 TranscriptError(#[from] TranscriptReaderError),
384}