remainder/mle/
zero.rs

1//! A space-efficient implementation of an MLE which contains only zeros.
2
3use itertools::{repeat_n, Itertools};
4use serde::{Deserialize, Serialize};
5
6use crate::claims::RawClaim;
7use crate::layer::LayerId;
8use shared_types::Field;
9
10use super::evals::{Evaluations, EvaluationsIterator};
11use super::Mle;
12use super::{mle_enum::MleEnum, MleIndex};
13
14/// An MLE that contains only zeros; typically used for the output layer.
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[serde(bound = "F: Field")]
17pub struct ZeroMle<F: Field> {
18    pub(crate) mle_indices: Vec<MleIndex<F>>,
19    /// Number of non-fixed variables within this MLE
20    /// (warning: this gets modified destructively DURING sumcheck).
21    pub(crate) num_vars: usize,
22    pub(crate) layer_id: LayerId,
23    pub(crate) zero: [F; 1],
24    pub(crate) zero_eval: Evaluations<F>,
25    pub(crate) indexed: bool,
26}
27
28impl<F: Field> ZeroMle<F> {
29    /// Constructs a new `ZeroMle` on `num_vars` variables with the
30    /// appropriate `prefix_bits` for a layer with ID `layer_id`.
31    pub fn new(num_vars: usize, prefix_bits: Option<Vec<MleIndex<F>>>, layer_id: LayerId) -> Self {
32        let mle_indices = prefix_bits
33            .into_iter()
34            .flatten()
35            .chain(repeat_n(MleIndex::Free, num_vars))
36            .collect_vec();
37
38        Self {
39            mle_indices,
40            num_vars,
41            layer_id,
42            zero: [F::ZERO],
43            zero_eval: Evaluations::new(num_vars, vec![F::ZERO]),
44            indexed: false,
45        }
46    }
47}
48
49impl<F: Field> Mle<F> for ZeroMle<F> {
50    fn mle_indices(&self) -> &[MleIndex<F>] {
51        &self.mle_indices
52    }
53
54    fn num_free_vars(&self) -> usize {
55        self.num_vars
56    }
57
58    fn fix_variable(&mut self, round_index: usize, challenge: F) -> Option<RawClaim<F>> {
59        for mle_index in self.mle_indices.iter_mut() {
60            if *mle_index == MleIndex::Indexed(round_index) {
61                mle_index.bind_index(challenge);
62            }
63        }
64
65        // One fewer free variable to sumcheck through
66        self.num_vars -= 1;
67
68        if self.num_vars == 0 {
69            let send_claim = RawClaim::new(
70                self.mle_indices
71                    .iter()
72                    .map(|index| index.val().unwrap())
73                    .collect_vec(),
74                F::ZERO,
75            );
76            Some(send_claim)
77        } else {
78            None
79        }
80    }
81
82    fn fix_variable_at_index(&mut self, indexed_bit_index: usize, point: F) -> Option<RawClaim<F>> {
83        self.fix_variable(indexed_bit_index, point)
84    }
85
86    fn index_mle_indices(&mut self, curr_index: usize) -> usize {
87        let mut new_indices = 0;
88        for mle_index in self.mle_indices.iter_mut() {
89            if *mle_index == MleIndex::Free {
90                *mle_index = MleIndex::Indexed(curr_index + new_indices);
91                new_indices += 1;
92            }
93        }
94
95        curr_index + new_indices
96    }
97
98    fn layer_id(&self) -> LayerId {
99        self.layer_id
100    }
101
102    fn get_enum(self) -> MleEnum<F> {
103        MleEnum::Zero(self)
104    }
105
106    #[doc = " Get the padded set of evaluations over the boolean hypercube; useful for"]
107    #[doc = " constructing the input layer."]
108    fn get_padded_evaluations(&self) -> Vec<F> {
109        todo!()
110    }
111
112    #[doc = " Mutates the MLE in order to set the prefix bits. This is needed when we"]
113    #[doc = " are working with dataparallel circuits and new bits need to be added."]
114    fn add_prefix_bits(&mut self, _new_bits: Vec<MleIndex<F>>) {
115        todo!()
116    }
117
118    fn len(&self) -> usize {
119        1
120    }
121
122    fn iter(&self) -> EvaluationsIterator<'_, F> {
123        self.zero_eval.iter()
124    }
125
126    fn first(&self) -> F {
127        F::ZERO
128    }
129
130    fn value(&self) -> F {
131        assert!(self.is_fully_bounded());
132        F::ZERO
133    }
134
135    fn get(&self, index: usize) -> Option<F> {
136        if index < self.len() {
137            Some(F::ZERO)
138        } else {
139            None
140        }
141    }
142}