remainder/mle/
dense.rs

1#[cfg(test)]
2mod tests;
3
4use std::fmt::Debug;
5
6use ark_std::log2;
7use itertools::{repeat_n, Itertools};
8
9use serde::{Deserialize, Serialize};
10
11use super::{evals::EvaluationsIterator, mle_enum::MleEnum, Mle, MleIndex};
12use crate::{
13    claims::RawClaim,
14    mle::evals::{Evaluations, MultilinearExtension},
15};
16use crate::{
17    expression::{generic_expr::Expression, prover_expr::ProverExpr},
18    layer::LayerId,
19};
20use shared_types::Field;
21
22/// An implementation of an [Mle] using a dense representation.
23#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
24#[serde(bound = "F: Field")]
25pub struct DenseMle<F: Field> {
26    /// The ID of the layer this data belongs to.
27    pub layer_id: LayerId,
28    /// A representation of the MLE on its current state.
29    pub mle: MultilinearExtension<F>,
30    /// The MleIndices `current_mle`.
31    pub mle_indices: Vec<MleIndex<F>>,
32}
33
34impl<F: Field> Mle<F> for DenseMle<F> {
35    fn num_free_vars(&self) -> usize {
36        self.mle.num_vars()
37    }
38
39    fn get_padded_evaluations(&self) -> Vec<F> {
40        let size: usize = 1 << self.mle.num_vars();
41        let padding = size - self.mle.len();
42
43        self.mle.iter().chain(repeat_n(F::ZERO, padding)).collect()
44    }
45
46    fn add_prefix_bits(&mut self, mut new_bits: Vec<MleIndex<F>>) {
47        new_bits.extend(self.mle_indices.clone());
48        self.mle_indices.clone_from(&new_bits);
49    }
50
51    fn layer_id(&self) -> LayerId {
52        self.layer_id
53    }
54
55    fn len(&self) -> usize {
56        self.mle.len()
57    }
58
59    fn iter(&self) -> EvaluationsIterator<'_, F> {
60        self.mle.iter()
61    }
62
63    fn mle_indices(&self) -> &[MleIndex<F>] {
64        &self.mle_indices
65    }
66
67    fn fix_variable_at_index(&mut self, indexed_bit_index: usize, point: F) -> Option<RawClaim<F>> {
68        // Bind the `MleIndex::IndexedBit(index)` to the challenge `point`.
69
70        // First, find the bit corresponding to `index` and compute its absolute
71        // index. For example, if `mle_indices` is equal to
72        // `[MleIndex::Fixed(0), MleIndex::Bound(42, 0), MleIndex::IndexedBit(1),
73        // MleIndex::Bound(17, 2) MleIndex::IndexedBit(3))]`
74        // then `fix_variable_at_index(3, r)` will fix `IndexedBit(3)`, which is
75        // the 2nd indexed bit, to `r`
76
77        // Count of the bit we're fixing. In the above example
78        // `bit_count == 2`.
79        let (index_found, bit_count) =
80            self.mle_indices
81                .iter_mut()
82                .fold((false, 0), |state, mle_index| {
83                    if state.0 {
84                        // Index already found; do nothing.
85                        state
86                    } else if let MleIndex::Indexed(current_bit_index) = *mle_index {
87                        if current_bit_index == indexed_bit_index {
88                            // Found the indexed bit in the current index;
89                            // bind it and increment the bit count.
90                            mle_index.bind_index(point);
91                            (true, state.1 + 1)
92                        } else {
93                            // Index not yet found but this is an indexed
94                            // bit; increasing bit count.
95                            (false, state.1 + 1)
96                        }
97                    } else {
98                        // Index not yet found but the current bit is not an
99                        // indexed bit; do nothing.
100                        state
101                    }
102                });
103
104        assert!(index_found);
105        debug_assert!(1 <= bit_count && bit_count <= self.num_free_vars());
106
107        self.mle.fix_variable_at_index(bit_count - 1, point);
108
109        if self.is_fully_bounded() {
110            let fixed_claim_return = RawClaim::new(
111                self.mle_indices
112                    .iter()
113                    .map(|index| index.val().unwrap())
114                    .collect_vec(),
115                self.mle.value(),
116            );
117            Some(fixed_claim_return)
118        } else {
119            None
120        }
121    }
122
123    /// Bind the bit `index` to the value `binding`.
124    /// If this was the last unbound variable, then return a Claim object giving the fully specified
125    /// evaluation point and the (single) value of the bookkeeping table.  Otherwise, return None.
126    fn fix_variable(&mut self, index: usize, binding: F) -> Option<RawClaim<F>> {
127        for mle_index in self.mle_indices.iter_mut() {
128            if *mle_index == MleIndex::Indexed(index) {
129                mle_index.bind_index(binding);
130            }
131        }
132        // Update the bookkeeping table.
133        self.mle.fix_variable(binding);
134
135        if self.is_fully_bounded() {
136            let fixed_claim_return = RawClaim::new(
137                self.mle_indices
138                    .iter()
139                    .map(|index| index.val().unwrap())
140                    .collect_vec(),
141                self.mle.value(),
142            );
143            Some(fixed_claim_return)
144        } else {
145            None
146        }
147    }
148
149    fn index_mle_indices(&mut self, curr_index: usize) -> usize {
150        let mut new_indices = 0;
151        for mle_index in self.mle_indices.iter_mut() {
152            if *mle_index == MleIndex::Free {
153                *mle_index = MleIndex::Indexed(curr_index + new_indices);
154                new_indices += 1;
155            }
156        }
157
158        curr_index + new_indices
159    }
160
161    fn get_enum(self) -> MleEnum<F> {
162        MleEnum::Dense(self)
163    }
164
165    fn get(&self, index: usize) -> Option<F> {
166        self.mle.get(index)
167    }
168
169    fn first(&self) -> F {
170        self.mle.first()
171    }
172
173    fn value(&self) -> F {
174        self.mle.value()
175    }
176}
177
178impl<F: Field> DenseMle<F> {
179    /// Constructs a new `DenseMle` with specified prefix_bits
180    /// todo: change this to create a DenseMle with already specified IndexedBits
181    pub fn new_with_prefix_bits(
182        data: MultilinearExtension<F>,
183        layer_id: LayerId,
184        prefix_bits: Vec<bool>,
185    ) -> Self {
186        let free_bits = data.num_vars();
187
188        let mle_indices: Vec<MleIndex<F>> = prefix_bits
189            .into_iter()
190            .map(|bit| MleIndex::Fixed(bit))
191            .chain((0..free_bits).map(|_| MleIndex::Free))
192            .collect();
193        Self {
194            layer_id,
195            mle: data,
196            mle_indices,
197        }
198    }
199
200    /// Constructs a new `DenseMle` with specified MLE indices, normally when we are
201    /// trying to construct a new MLE based off of a previous MLE, such as in
202    /// [crate::layer::matmult::MatMult], but want to preserve the "prefix vars."
203    ///
204    /// The MLE should not have ever been mutated if this function is ever called, so none of the
205    /// indices should ever be Indexed here.
206    pub fn new_with_indices(data: &[F], layer_id: LayerId, mle_indices: &[MleIndex<F>]) -> Self {
207        let mut mle = DenseMle::new_from_raw(data.to_vec(), layer_id);
208
209        let all_indices_free_or_fixed = mle_indices.iter().all(|index| {
210            index == &MleIndex::Free
211                || index == &MleIndex::Fixed(true)
212                || index == &MleIndex::Fixed(false)
213        });
214        assert!(all_indices_free_or_fixed);
215
216        mle.mle_indices = mle_indices.to_vec();
217        mle
218    }
219
220    /// Constructs a new `DenseMle` from a bookkeeping table represented by
221    /// [`Iterator<Item = F>`] and [LayerId].
222    ///
223    /// # Example
224    /// ```
225    ///     use remainder::layer::LayerId;
226    ///     use shared_types::Fr;
227    ///     use remainder::mle::dense::DenseMle;
228    ///
229    ///     DenseMle::<Fr>::new_from_iter(vec![Fr::one()].into_iter(), LayerId::Input(0));
230    /// ```
231    pub fn new_from_iter(iter: impl Iterator<Item = F>, layer_id: LayerId) -> Self {
232        let items = iter.collect_vec();
233        let num_free_vars = log2(items.len()) as usize;
234
235        let mle_indices: Vec<MleIndex<F>> = ((0..num_free_vars).map(|_| MleIndex::Free)).collect();
236
237        let current_mle =
238            MultilinearExtension::new_from_evals(Evaluations::<F>::new(num_free_vars, items));
239        Self {
240            layer_id,
241            mle: current_mle,
242            mle_indices,
243        }
244    }
245
246    /// Constructs a new `DenseMle` from a bookkeeping table represented by
247    /// [`Vec<F>`] and [LayerId].
248    ///
249    /// # Example
250    /// ```
251    ///     use remainder::layer::LayerId;
252    ///     use shared_types::Fr;
253    ///     use remainder::mle::dense::DenseMle;
254    ///
255    ///     DenseMle::<Fr>::new_from_raw(vec![Fr::one()], LayerId::Input(0));
256    /// ```
257    pub fn new_from_raw(items: Vec<F>, layer_id: LayerId) -> Self {
258        let num_free_vars = log2(items.len()) as usize;
259
260        let mle_indices: Vec<MleIndex<F>> = ((0..num_free_vars).map(|_| MleIndex::Free)).collect();
261
262        let current_mle =
263            MultilinearExtension::new_from_evals(Evaluations::<F>::new(num_free_vars, items));
264
265        Self {
266            layer_id,
267            mle: current_mle,
268            mle_indices,
269        }
270    }
271
272    /// Constructs a new [DenseMle] from a [MultilinearExtension], additionally
273    /// being able to specify the prefix vars and layer ID.
274    ///
275    /// Optionally gives back an indexed [DenseMle].
276    pub fn new_from_multilinear_extension(
277        mle: MultilinearExtension<F>,
278        layer_id: LayerId,
279        prefix_vars: Option<Vec<bool>>,
280        maybe_starting_var_index: Option<usize>,
281    ) -> Self {
282        let mle_indices: Vec<MleIndex<F>> = prefix_vars
283            .unwrap_or_default()
284            .into_iter()
285            .map(|prefix_var| MleIndex::Fixed(prefix_var))
286            .chain((0..mle.num_vars()).map(|_| MleIndex::Free))
287            .collect();
288        let mut ret = Self {
289            layer_id,
290            mle,
291            mle_indices,
292        };
293        if let Some(starting_var_index) = maybe_starting_var_index {
294            ret.index_mle_indices(starting_var_index);
295        }
296        ret
297    }
298
299    /// Merges the MLEs into a single MLE by simply concatenating them.
300    pub fn combine_mles(mles: Vec<DenseMle<F>>) -> DenseMle<F> {
301        let first_mle_num_vars = mles[0].num_free_vars();
302        let all_same_num_vars = mles
303            .iter()
304            .all(|mle| mle.num_free_vars() == first_mle_num_vars);
305        assert!(all_same_num_vars);
306        let layer_id = mles[0].layer_id;
307        let mle_flattened = mles.into_iter().flat_map(|mle| mle.into_iter());
308
309        Self::new_from_iter(mle_flattened, layer_id)
310    }
311
312    /// Creates an expression from the current MLE.
313    pub fn expression(self) -> Expression<F, ProverExpr> {
314        Expression::<F, ProverExpr>::mle(self)
315    }
316
317    /// Returns the evaluation challenges for a fully-bound MLE.
318    ///
319    /// Note that this function panics if a particular challenge is neither
320    /// fixed nor bound!
321    pub fn get_bound_point(&self) -> Vec<F> {
322        self.mle_indices()
323            .iter()
324            .map(|index| match index {
325                MleIndex::Bound(chal, _) => *chal,
326                MleIndex::Fixed(chal) => F::from(*chal as u64),
327                _ => panic!("MLE index not bound"),
328            })
329            .collect()
330    }
331}
332
333impl<F: Field> IntoIterator for DenseMle<F> {
334    type Item = F;
335
336    type IntoIter = std::vec::IntoIter<Self::Item>;
337
338    fn into_iter(self) -> Self::IntoIter {
339        // TEMPORARY: get_evals_vector()
340        self.mle.iter().collect::<Vec<F>>().into_iter()
341    }
342}