remainder/mle/
evals.rs

1#[cfg(test)]
2mod tests;
3
4use ark_std::{cfg_into_iter, log2};
5use itertools::{EitherOrBoth::*, Itertools};
6use ndarray::{Dimension, IxDyn};
7#[cfg(feature = "parallel")]
8use rayon::iter::{IntoParallelIterator, ParallelIterator};
9use serde::{Deserialize, Serialize};
10use shared_types::Field;
11use thiserror::Error;
12
13pub mod bit_packed_vector;
14
15use bit_packed_vector::BitPackedVector;
16use zeroize::Zeroize;
17
18use crate::utils::arithmetic::i64_to_field;
19
20use anyhow::{anyhow, Result};
21
22#[derive(Error, Debug, Clone)]
23/// the errors associated with the dimension of the MLE.
24pub enum DimensionError {
25    #[error("Dimensions: {0} do not match with number of axes: {1} as indicated by their names.")]
26    /// The dimensions of the MLE do not match the number of axes.
27    DimensionMismatchError(usize, usize),
28    #[error("Dimensions: {0} do not match with the numvar: {1} of the mle.")]
29    /// The dimensions of the MLE do not match the number of variables.
30    DimensionNumVarError(usize, usize),
31    #[error("Trying to get the underlying mle as an ndarray, but there is no dimension info.")]
32    /// No dimension info of the MLE.
33    NoDimensionInfoError(),
34}
35
36#[derive(Clone, PartialEq, Serialize, Deserialize)]
37/// the dimension information of the MLE. contains the dim: [type@IxDyn], see ndarray
38/// for more detailed documentation and the names of the axes.
39pub struct DimInfo {
40    dims: IxDyn,
41    axes_names: Vec<String>,
42}
43
44impl DimInfo {
45    /// Creates a new DimInfo from the dimensions and the axes names.
46    pub fn new(dims: IxDyn, axes_names: Vec<String>) -> Result<Self> {
47        if dims.ndim() != axes_names.len() {
48            return Err(anyhow!(DimensionError::DimensionMismatchError(
49                dims.ndim(),
50                axes_names.len(),
51            )));
52        }
53        Ok(Self { dims, axes_names })
54    }
55}
56
57impl std::fmt::Debug for DimInfo {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("DimensionInfo")
60            .field("dim sizes", &self.dims.slice())
61            .field("axes names", &self.axes_names)
62            .finish()
63    }
64}
65
66// -------------- Various Helper Functions -----------------
67
68/// Mirrors the `num_bits` LSBs of `value`.
69/// # Example
70/// ```ignore
71///     assert_eq!(mirror_bits(4, 0b1110), 0b0111);
72///     assert_eq!(mirror_bits(3, 0b1110), 0b1011);
73///     assert_eq!(mirror_bits(2, 0b1110), 0b1101);
74///     assert_eq!(mirror_bits(1, 0b1110), 0b1110);
75///     assert_eq!(mirror_bits(0, 0b1110), 0b1110);
76/// ```
77fn mirror_bits(num_bits: usize, mut value: usize) -> usize {
78    let mut result: usize = 0;
79
80    for _ in 0..num_bits {
81        result = (result << 1) | (value & 1);
82        value >>= 1;
83    }
84
85    // Add back the remaining bits.
86    result | (value << num_bits)
87}
88
89/// Stores a boolean function `f: {0, 1}^n -> F` represented as a list of up to
90/// `2^n` evaluations of `f` on the boolean hypercube. The `n` variables are
91/// indexed from `0` to `n-1` throughout the lifetime of the object.
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93#[serde(bound = "F: Field")]
94pub struct Evaluations<F: Field> {
95    /// To understand how evaluations are stored, let's index `f`'s input bits
96    /// as follows: `f(b_0, b_1, ..., b_{n-1})`. Evaluations are ordered using
97    /// the bit-string `b_0b_1...b_{n-2}b_{n-1}` as key. This ordering is
98    /// sometimes referred to as "big-endian" due to its resemblance to
99    /// big-endian byte ordering. A suffix of contiguous evaluations all equal
100    /// to `F::ZERO` may be omitted in this internal representation but this
101    /// struct is not responsible for maintaining this property at all times.
102    /// # Example
103    /// * The evaluations of a 2-dimensional function are stored in the
104    ///   following order: `[ f(0, 0), f(0, 1), f(1, 0), f(1, 1) ]`.
105    /// * The evaluation table `[ 1, 0, 5, 0 ]` may be stored as `[1, 0, 5]` by
106    ///   omitting the trailing zero. Note that both representations are valid.
107    evals: BitPackedVector<F>,
108
109    /// Number of input variables to `f`. Invariant: `0 <= evals.len() <=
110    /// 2^num_vars`. The length can be less than `2^num_vars` due to suffix
111    /// omission.
112    num_vars: usize,
113
114    /// TODO(Makis): Is there a better way to handle this?? When accessing an
115    /// element of the bookkeping table, we return a reference to a field
116    /// element. In case the element is stored implicitly as a missing entry, we
117    /// need someone to own the "zero" of the field. If I make this a const, I'm
118    /// not sure how to initialize it.
119    zero: F,
120}
121
122impl<F: Field> Zeroize for Evaluations<F> {
123    fn zeroize(&mut self) {
124        self.evals.zeroize();
125        self.num_vars.zeroize();
126        self.zero.zeroize();
127    }
128}
129
130impl<F: Field> Evaluations<F> {
131    /// Returns a representation of the constant function on zero variables
132    /// equal to `F::ZERO`.
133    pub fn new_zero() -> Self {
134        Self::new(0, vec![])
135    }
136
137    /// Builds an evaluation representation for a function `f: {0, 1}^num_vars
138    /// -> F` from a vector of evaluations in big-endian order (see
139    /// documentation comment for `Self::evals` for explanation).
140    ///
141    /// # Example
142    /// For a function `f: {0, 1}^2 -> F`, an evaluations table may be built as:
143    /// `Evaluations::new(2, vec![ f(0, 0), f(0, 1), f(1, 0), f(1, 1) ])`.
144    ///
145    /// An example of suffix omission is when `f(1, 0) == f(1, 1) == F::ZERO`.
146    /// In that case those zero values may be omitted and the following
147    /// statement generates an equivalent representation: `Evaluations::new(2,
148    /// vec![ f(0, 0), f(0, 1) ])`.
149    pub fn new(num_vars: usize, evals: Vec<F>) -> Self {
150        debug_assert!(evals.len() <= (1 << num_vars));
151
152        // debug_evals(&evals);
153
154        Evaluations::<F> {
155            evals: BitPackedVector::new(&evals),
156            num_vars,
157            zero: F::ZERO,
158        }
159    }
160
161    /// Builds an evaluation representation for a function `f: {0, 1}^num_vars
162    /// -> F` from a vector of evaluations in _little-endian_ order (see
163    /// documentation comment for `Self::evals` for explanation).
164    ///
165    /// # Example
166    /// For a function `f: {0, 1}^2 -> F`, an evaluations table may be built as:
167    /// `Evaluations::new_from_little_endian(2, vec![ f(0, 0), f(1, 0), f(0, 1),
168    /// f(1, 1) ])`.
169    pub fn new_from_little_endian(num_vars: usize, evals: &[F]) -> Self {
170        debug_assert!(evals.len() <= (1 << num_vars));
171
172        println!("New MLE (big-endian) on {} entries.", evals.len());
173
174        Self {
175            evals: BitPackedVector::new(&Self::flip_endianess(num_vars, evals)),
176            num_vars,
177            zero: F::ZERO,
178        }
179    }
180
181    /// Returns the number of variables of the current `Evalations`.
182    pub fn num_vars(&self) -> usize {
183        self.num_vars
184    }
185
186    /// Returns true if the boolean function has not free variables. Equivalent
187    /// to checking whether that [Self::num_vars] is equal to zero.
188    pub fn is_fully_bound(&self) -> bool {
189        self.num_vars == 0
190    }
191
192    /// Returns the first element of the bookkeeping table. This operation
193    /// should always be successful because even in the case that
194    /// [Self::num_vars] is zero, there is a non-zero number of vertices on the
195    /// boolean hypercube and hence there's at least one evaluation stored in
196    /// the bookkeeping table, either explicitly as a value inside
197    /// `Self::evals`, or implicitly if it's a `F::ZERO` that has been pruned as
198    /// part of a zero suffix.
199    pub fn first(&self) -> F {
200        self.evals.get(0).unwrap_or(F::ZERO)
201    }
202
203    /// If `self` represents a fully-bound boolean function (i.e.
204    /// [Self::num_vars] is zero), it returns its value. Otherwise panics.
205    pub fn value(&self) -> F {
206        assert!(self.is_fully_bound());
207        self.first()
208    }
209
210    /// Returns an iterator that traverses the evaluations in "big-endian"
211    /// order.
212    pub fn iter(&self) -> EvaluationsIterator<'_, F> {
213        EvaluationsIterator::<F> {
214            evals: self,
215            current_index: 0,
216        }
217    }
218
219    /// Temporary function returning the length of the internal representation.
220    #[allow(clippy::len_without_is_empty)]
221    pub fn len(&self) -> usize {
222        self.evals.len()
223    }
224
225    /// Temporary function for accessing a the `idx`-th element in the internal
226    /// representation.
227    pub fn get(&self, idx: usize) -> Option<F> {
228        self.evals.get(idx)
229    }
230
231    // --------  Helper Functions --------
232
233    /// Checks whether its arguments correspond to equivalent representations of
234    /// the same list of evaluations. Two representations are equivalent if
235    /// omitting the longest contiguous suffix of `F::ZERO`s from each results
236    /// in the same vectors.
237    #[allow(dead_code)]
238    fn equiv_repr(evals1: &BitPackedVector<F>, evals2: &BitPackedVector<F>) -> bool {
239        evals1
240            .iter()
241            .zip_longest(evals2.iter())
242            .all(|pair| match pair {
243                Both(l, r) => l == r,
244                Left(l) => l == F::ZERO,
245                Right(r) => r == F::ZERO,
246            })
247    }
248
249    /// Sorts the elements of `values` by their 0-based index transformed by
250    /// mirroring the `num_bits` LSBs. This operation effectively flips the
251    /// "endianess" of the index ordering. If `values.len() < 2^num_bits`, the
252    /// missing values are assumed to be zeros. The resulting vector is always
253    /// of size `2^num_bits`.
254    /// # Example
255    /// ```
256    /// use remainder::mle::evals::Evaluations;
257    /// use shared_types::Fr;
258    /// assert_eq!(Evaluations::flip_endianess(
259    ///     2,
260    ///     &[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
261    ///     vec![ Fr::from(1), Fr::from(3), Fr::from(2), Fr::from(4) ]
262    /// );
263    /// assert_eq!(Evaluations::flip_endianess(
264    ///     2,
265    ///     &[ Fr::from(1), Fr::from(2) ]),
266    ///     vec![ Fr::from(1), Fr::from(0), Fr::from(2), Fr::from(0) ]
267    /// );
268    /// ```
269    pub fn flip_endianess(num_bits: usize, values: &[F]) -> Vec<F> {
270        let num_evals = values.len();
271
272        let result: Vec<F> = cfg_into_iter!(0..(1 << num_bits))
273            .map(|idx| {
274                let mirrored_idx = mirror_bits(num_bits, idx);
275                if mirrored_idx >= num_evals {
276                    F::ZERO
277                } else {
278                    values[mirrored_idx]
279                }
280            })
281            .collect();
282
283        result
284    }
285}
286
287/// An iterator over evaluations in a "big-endian" order.
288pub struct EvaluationsIterator<'a, F: Field> {
289    /// Reference to the original `Evaluations` struct.
290    evals: &'a Evaluations<F>,
291
292    /// Index of the next evaluation to be retrieved.
293    current_index: usize,
294}
295
296impl<F: Field> Iterator for EvaluationsIterator<'_, F> {
297    type Item = F;
298
299    fn next(&mut self) -> Option<Self::Item> {
300        if self.current_index < self.evals.len() {
301            let val = self.evals.get(self.current_index).unwrap();
302            self.current_index += 1;
303
304            Some(val)
305        } else {
306            None
307        }
308    }
309}
310
311impl<F: Field> Clone for EvaluationsIterator<'_, F> {
312    fn clone(&self) -> Self {
313        Self {
314            evals: self.evals,
315            current_index: self.current_index,
316        }
317    }
318}
319
320/// An iterator over evaluations indexed by vertices of a projection of the
321/// boolean hypercube on `num_vars - 1` dimensions. See documentation for
322/// `Evaluations::project` for more information.
323#[allow(dead_code)]
324pub struct EvaluationsPairIterator<'a, F: Field> {
325    /// Reference to original bookkeeping table.
326    evals: &'a Evaluations<F>,
327
328    /// A mask for isolating the `k` LSBs of the `current_eval_index` where `k`
329    /// is the dimension on which the original hypercube is projected on.
330    lsb_mask: usize,
331
332    /// 0-base index of the next element to be returned. Invariant:
333    /// `current_eval_index \in [0, 2^(evals.num_vars() - 1)]`. If equal to
334    /// `2^(evals.num_vars() - 1)`, the iterator has reached the end.
335    current_pair_index: usize,
336}
337
338impl<F: Field> Iterator for EvaluationsPairIterator<'_, F> {
339    type Item = (F, F);
340
341    fn next(&mut self) -> Option<Self::Item> {
342        let num_vars = self.evals.num_vars();
343        let num_pairs = 1_usize << (num_vars - 1);
344
345        if self.current_pair_index < num_pairs {
346            // Compute the two indices by inserting a `0` and a `1` respectively
347            // in the appropriate position of `current_pair_index`. For example,
348            // if this is an Iterator projecting on `fix_variable_index == 2`
349            // for an Evaluations table of `num_vars == 5`, then `lsb_mask ==
350            // 0b00011` (the `fix_variable_index` LSBs are on). When, for
351            // example `current_pair_index == 0b1010`, it is split into a "right
352            // part": `lsb_idx == 0b00 0 10`, and a "shifted left part":
353            // `msb_idx == 0b10 0 00`.  The two parts are then combined with the
354            // middle bit on and off respectively: `idx1 == 0b10 0 10`, `idx2 ==
355            // 0b10 1 10`.
356            let lsb_idx = self.current_pair_index & self.lsb_mask;
357            let msb_idx = (self.current_pair_index & (!self.lsb_mask)) << 1;
358            let mid_idx = self.lsb_mask + 1;
359
360            let idx1 = lsb_idx | msb_idx;
361            let idx2 = lsb_idx | mid_idx | msb_idx;
362
363            self.current_pair_index += 1;
364
365            let val1 = self.evals.get(idx1).unwrap();
366            let val2 = self.evals.get(idx2).unwrap();
367
368            Some((val1, val2))
369        } else {
370            None
371        }
372    }
373}
374
375/// Stores a function `\tilde{f}: F^n -> F`, the unique Multilinear Extension
376/// (MLE) of a given function `f: {0, 1}^n -> F`:
377/// ```text
378///     \tilde{f}(x_0, ..., x_{n-1})
379///         = \sum_{b_0, ..., b_{n-1} \in {0, 1}^n}
380///             \tilde{beta}(x_0, ..., x_{n-1}, b_0, ..., b_{n-1})
381///             * f(b_0, ..., b_{n-1}).
382/// ```
383/// where `\tilde{beta}` is the MLE of the equality function:
384/// ```text
385///     \tilde{beta}(x_0, ..., x_{n-1}, b_0, ..., b_{n-1})
386///         = \prod_{i  = 0}^{n-1} ( x_i * b_i + (1 - x_i) * (1 - b_i) )
387/// ```
388/// Internally, `f` is represented as a list of evaluations of `f` on the
389/// boolean hypercube. The `n` variables are indexed from `0` to `n-1`
390/// throughout the lifetime of the object even if `n` is modified by fixing a
391/// variable to a constant value.
392#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
393#[serde(bound = "F: Field")]
394pub struct MultilinearExtension<F: Field> {
395    /// The bookkeeping table with the evaluations of `f` on the hypercube.
396    pub f: Evaluations<F>,
397}
398
399impl<F: Field> Zeroize for MultilinearExtension<F> {
400    fn zeroize(&mut self) {
401        self.f.zeroize();
402    }
403}
404
405impl<F: Field> From<Vec<bool>> for MultilinearExtension<F> {
406    fn from(bools: Vec<bool>) -> Self {
407        let evals = bools
408            .into_iter()
409            .map(|b| if b { F::ONE } else { F::ZERO })
410            .collect();
411        MultilinearExtension::new(evals)
412    }
413}
414
415impl<F: Field> From<Vec<u32>> for MultilinearExtension<F> {
416    fn from(uints: Vec<u32>) -> Self {
417        let evals = uints.into_iter().map(|v| F::from(v as u64)).collect();
418        MultilinearExtension::new(evals)
419    }
420}
421
422impl<F: Field> From<Vec<u64>> for MultilinearExtension<F> {
423    fn from(uints: Vec<u64>) -> Self {
424        let evals = uints.into_iter().map(F::from).collect();
425        MultilinearExtension::new(evals)
426    }
427}
428
429impl<F: Field> From<Vec<i32>> for MultilinearExtension<F> {
430    fn from(ints: Vec<i32>) -> Self {
431        let evals = ints.into_iter().map(|v| i64_to_field(v as i64)).collect();
432        MultilinearExtension::new(evals)
433    }
434}
435
436impl<F: Field> From<Vec<i64>> for MultilinearExtension<F> {
437    fn from(ints: Vec<i64>) -> Self {
438        let evals = ints.into_iter().map(i64_to_field).collect();
439        MultilinearExtension::new(evals)
440    }
441}
442
443impl<F: Field> MultilinearExtension<F> {
444    /// Create a new MultilinearExtension from a [`Vec<F>`] of evaluations.
445    pub fn new(evals_vec: Vec<F>) -> Self {
446        let num_vars = log2(evals_vec.len()) as usize;
447        let evals = Evaluations::new(num_vars, evals_vec);
448        MultilinearExtension::new_from_evals(evals)
449    }
450
451    /// Generate a new MultilinearExtension from a representation `evals` of a
452    /// function `f`.
453    pub fn new_from_evals(evals: Evaluations<F>) -> Self {
454        Self { f: evals }
455    }
456
457    /// Creates a new mle which is all zeroes of a specific num_vars. In this
458    /// case the size of the evals and the num_vars will not match up
459    pub fn new_sized_zero(num_vars: usize) -> Self {
460        Self {
461            f: Evaluations {
462                evals: BitPackedVector::new(&[]),
463                num_vars,
464                zero: F::ZERO,
465            },
466        }
467    }
468
469    /// Returns an iterator accessing the evaluations defining this MLE in
470    /// "big-endian" order.
471    pub fn iter(&self) -> EvaluationsIterator<'_, F> {
472        self.f.iter()
473    }
474
475    /// Generate a Vector of the evaluations of `f` over the hypercube.
476    pub fn to_vec(&self) -> Vec<F> {
477        self.f.iter().collect()
478    }
479
480    /// Returns true if the MLE has not free variables. Equivalent to checking
481    /// whether that [Self::num_vars] is equal to zero.
482    pub fn is_fully_bound(&self) -> bool {
483        self.f.is_fully_bound()
484    }
485
486    /// Returns the first element of the bookkeeping table of this MLE,
487    /// corresponding to the value of the MLE when all varables are set to zero.
488    /// This operation never fails (see [Evaluations::first]).
489    pub fn first(&self) -> F {
490        self.f.first()
491    }
492
493    /// If `self` represents a fully-bound MLE (i.e. on zero variables), it
494    /// returns its value. Otherwise panics.
495    pub fn value(&self) -> F {
496        self.f.value()
497    }
498
499    /// Generates a representation for the MLE of the zero function on zero
500    /// variables.
501    pub fn new_zero() -> Self {
502        let zero_evals = Evaluations::new_zero();
503        Self::new_from_evals(zero_evals)
504    }
505
506    /// Returns `n`, the number of arguments `\tilde{f}` takes.
507    pub fn num_vars(&self) -> usize {
508        self.f.num_vars()
509    }
510
511    /// Returns the `idx`-th element, if `idx` is in the range `[0,
512    /// 2^self.num_vars)`.
513    pub fn get(&self, idx: usize) -> Option<F> {
514        if idx >= (1 << self.num_vars()) {
515            // `idx` is out of range.
516            None
517        } else if idx >= self.f.len() {
518            // `idx` is within range, but value is implicitly assumed to be
519            // zero.
520            Some(F::ZERO)
521        } else {
522            // `idx`-th position is stored explicitly in `self.f`
523            self.f.get(idx)
524        }
525    }
526
527    /// Evaluate `\tilde{f}` at `point \in F^n`.
528    /// # Panics
529    /// If `point` does not contain exactly `self.num_vars()` elements.
530    pub fn evaluate_at_point(&self, point: &[F]) -> F {
531        let n = self.num_vars();
532        assert_eq!(n, point.len());
533
534        // TODO: Provide better access mechanism.
535        self.f
536            .evals
537            .clone()
538            .iter() // was into_iter()
539            .enumerate()
540            .fold(F::ZERO, |acc, (idx, v)| {
541                let beta = (0..n).fold(F::ONE, |acc, i| {
542                    let bit_i = idx & (1 << (n - 1 - i));
543                    if bit_i > 0 {
544                        acc * point[i]
545                    } else {
546                        acc * (F::ONE - point[i])
547                    }
548                });
549                acc + v * beta
550            })
551    }
552
553    /// Returns the length of the evaluations vector.
554    #[allow(clippy::len_without_is_empty)]
555    pub fn len(&self) -> usize {
556        self.f.len()
557    }
558
559    /// Fix the 0-based `var_index`-th bit of `\tilde{f}` to an arbitrary field
560    /// element `point \in F` by destructively modifying `self`.
561    /// # Params
562    /// * `var_index`: A 0-based index of the input variable to be fixed.
563    /// * `point`: The field element to set `x_{var_index}` equal to.
564    /// # Example
565    /// If `self` represents a function `\tilde{f}: F^3 -> F`,
566    /// `self.fix_variable_at_index(1, r)` fixes the middle variable to `r \in
567    /// F`. After the invocation, `self` represents a function `\tilde{g}: F^2
568    /// -> F` defined as the multilinear extension of the following function:
569    /// `g(b_0, b_1) = \tilde{f}(b_0, r, b_1)`.
570    /// # Panics
571    /// if `var_index` is outside the interval `[0, self.num_vars())`.
572    pub fn fix_variable_at_index(&mut self, var_index: usize, point: F) {
573        let num_vars = self.num_vars();
574        let lsb_mask = (1_usize << (num_vars - 1 - var_index)) - 1;
575
576        let num_pairs = 1_usize << (num_vars - 1);
577
578        let new_evals: Vec<F> = cfg_into_iter!(0..num_pairs)
579            .map(|idx| {
580                // This iteration computes the value of
581                // `f'(idx[0], ..., idx[var_index-1], idx[var_index+1], ..., idx[num_vars - 1])`
582                // where `f'` is the resulting function after fixing the
583                // the `var_index`-th variable.
584                // To do this, we must combine the values of:
585                // `f(idx1) = f(idx[0], ..., idx[var_index-1], 0, idx[var_index+1], ..., idx[num_vars-1])`
586                // and
587                // `f(idx2) = f(idx[0], ..., idx[var_index-1], 1, idx[var_index+1], ..., idx[num_vars-1])`
588                // Below we compute `idx1` and `idx2` corresponding to the two
589                // indices above.
590
591                // Compute the two indices by inserting a `0` and a `1`
592                // respectively in the appropriate position of `idx`. For
593                // example, if `var_index == 2` and `self.num_vars == 5`, then
594                // `lsb_mask == 0b0011` (the `num_var - 1 - var_index` LSBs are
595                // on). When, for example `idx == 0b1010`, it is split into a
596                // "right part": `lsb_idx == 0b00 0 10`, and a "shifted left
597                // part": `msb_idx == 0b10 0 00`.  The two parts are then
598                // combined with the middle bit on and off respectively: `idx1
599                // == 0b10 0 10`, `idx2 == 0b10 1 10`.
600                let lsb_idx = idx & lsb_mask;
601                let msb_idx = (idx & (!lsb_mask)) << 1;
602                let mid_idx = lsb_mask + 1;
603
604                let idx1 = lsb_idx | msb_idx;
605                let idx2 = lsb_idx | mid_idx | msb_idx;
606
607                let val1 = self.get(idx1).unwrap_or(F::ZERO);
608                let val2 = self.get(idx2).unwrap_or(F::ZERO);
609
610                val1 + (val2 - val1) * point
611            })
612            .collect();
613
614        debug_assert_eq!(new_evals.len(), 1 << (num_vars - 1));
615        self.f = Evaluations::new(num_vars - 1, new_evals);
616    }
617
618    /// Optimized version of `fix_variable_at_index` for `var_index == 0`.
619    /// # Panics
620    /// If `self.num_vars() == 0`.
621    pub fn fix_variable(&mut self, point: F) {
622        self.fix_variable_at_index(0, point);
623    }
624
625    /// Stacks the MLEs into a single MLE, assuming they are stored in a "big
626    /// endian" fashion.
627    pub fn stack_mles(mles: Vec<MultilinearExtension<F>>) -> MultilinearExtension<F> {
628        let first_len = mles[0].len();
629
630        if !mles.iter().all(|v| v.len() == first_len) {
631            panic!("All mles's underlying bookkeeping table must have the same length");
632        }
633
634        let out = mles.iter().flat_map(|mle| mle.to_vec()).collect();
635        Self::new(out)
636    }
637
638    /// Convert a [MultilinearExtension] into a vector of u8s.
639    /// Every element is padded to contain 8 bits.
640    pub fn convert_into_u8_vec(&self) -> Vec<u8> {
641        self.f
642            .iter()
643            .map(|field_element| {
644                let field_element_le_bytes = field_element.to_bytes_le();
645                let mut padded_u8 = [0u8; 1];
646                padded_u8.copy_from_slice(&field_element_le_bytes[..1]);
647                u8::from_le_bytes(padded_u8)
648            })
649            .collect_vec()
650    }
651
652    /// Convert a [MultilinearExtension] into a vector of u16s.
653    /// Every element is padded to contain 16 bits.
654    pub fn convert_into_u16_vec(&self) -> Vec<u16> {
655        self.f
656            .iter()
657            .map(|field_element| {
658                let field_element_le_bytes = field_element.to_bytes_le();
659                let mut padded_u16 = [0u8; 2];
660                padded_u16.copy_from_slice(&field_element_le_bytes[..2]);
661                u16::from_le_bytes(padded_u16)
662            })
663            .collect_vec()
664    }
665
666    /// Convert a [MultilinearExtension] into a vector of u32s.
667    /// Every element is padded to contain 32 bits.
668    pub fn convert_into_u32_vec(&self) -> Vec<u32> {
669        self.f
670            .iter()
671            .map(|field_element| {
672                let field_element_le_bytes = field_element.to_bytes_le();
673                let mut padded_u32 = [0u8; 4];
674                padded_u32.copy_from_slice(&field_element_le_bytes[..4]);
675                u32::from_le_bytes(padded_u32)
676            })
677            .collect_vec()
678    }
679
680    /// Convert a [MultilinearExtension] into a vector of u64s.
681    /// Every element is padded to contain 64 bits.
682    pub fn convert_into_u64_vec(&self) -> Vec<u64> {
683        self.f
684            .iter()
685            .map(|field_element| {
686                let field_element_le_bytes = field_element.to_bytes_le();
687                let mut padded_u64 = [0u8; 8];
688                padded_u64.copy_from_slice(&field_element_le_bytes[..8]);
689                u64::from_le_bytes(padded_u64)
690            })
691            .collect_vec()
692    }
693
694    /// Convert a [MultilinearExtension] into a vector of u128s.
695    /// Every element is padded to contain 128 bits.
696    pub fn convert_into_u128_vec(&self) -> Vec<u128> {
697        self.f
698            .iter()
699            .map(|field_element| {
700                let field_element_le_bytes = field_element.to_bytes_le();
701                let mut padded_u128 = [0u8; 16];
702                padded_u128.copy_from_slice(&field_element_le_bytes[..16]);
703                u128::from_le_bytes(padded_u128)
704            })
705            .collect_vec()
706    }
707}