remainder/mle/evals.rs
1#[cfg(test)]
2mod tests;
3
4use ark_std::{cfg_into_iter, log2};
5use itertools::{EitherOrBoth::*, Itertools};
6use ndarray::{Dimension, IxDyn};
7#[cfg(feature = "parallel")]
8use rayon::iter::{IntoParallelIterator, ParallelIterator};
9use serde::{Deserialize, Serialize};
10use shared_types::Field;
11use thiserror::Error;
12
13pub mod bit_packed_vector;
14
15use bit_packed_vector::BitPackedVector;
16use zeroize::Zeroize;
17
18use crate::utils::arithmetic::i64_to_field;
19
20use anyhow::{anyhow, Result};
21
22#[derive(Error, Debug, Clone)]
23/// the errors associated with the dimension of the MLE.
24pub enum DimensionError {
25 #[error("Dimensions: {0} do not match with number of axes: {1} as indicated by their names.")]
26 /// The dimensions of the MLE do not match the number of axes.
27 DimensionMismatchError(usize, usize),
28 #[error("Dimensions: {0} do not match with the numvar: {1} of the mle.")]
29 /// The dimensions of the MLE do not match the number of variables.
30 DimensionNumVarError(usize, usize),
31 #[error("Trying to get the underlying mle as an ndarray, but there is no dimension info.")]
32 /// No dimension info of the MLE.
33 NoDimensionInfoError(),
34}
35
36#[derive(Clone, PartialEq, Serialize, Deserialize)]
37/// the dimension information of the MLE. contains the dim: [type@IxDyn], see ndarray
38/// for more detailed documentation and the names of the axes.
39pub struct DimInfo {
40 dims: IxDyn,
41 axes_names: Vec<String>,
42}
43
44impl DimInfo {
45 /// Creates a new DimInfo from the dimensions and the axes names.
46 pub fn new(dims: IxDyn, axes_names: Vec<String>) -> Result<Self> {
47 if dims.ndim() != axes_names.len() {
48 return Err(anyhow!(DimensionError::DimensionMismatchError(
49 dims.ndim(),
50 axes_names.len(),
51 )));
52 }
53 Ok(Self { dims, axes_names })
54 }
55}
56
57impl std::fmt::Debug for DimInfo {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("DimensionInfo")
60 .field("dim sizes", &self.dims.slice())
61 .field("axes names", &self.axes_names)
62 .finish()
63 }
64}
65
66// -------------- Various Helper Functions -----------------
67
68/// Mirrors the `num_bits` LSBs of `value`.
69/// # Example
70/// ```ignore
71/// assert_eq!(mirror_bits(4, 0b1110), 0b0111);
72/// assert_eq!(mirror_bits(3, 0b1110), 0b1011);
73/// assert_eq!(mirror_bits(2, 0b1110), 0b1101);
74/// assert_eq!(mirror_bits(1, 0b1110), 0b1110);
75/// assert_eq!(mirror_bits(0, 0b1110), 0b1110);
76/// ```
77fn mirror_bits(num_bits: usize, mut value: usize) -> usize {
78 let mut result: usize = 0;
79
80 for _ in 0..num_bits {
81 result = (result << 1) | (value & 1);
82 value >>= 1;
83 }
84
85 // Add back the remaining bits.
86 result | (value << num_bits)
87}
88
89/// Stores a boolean function `f: {0, 1}^n -> F` represented as a list of up to
90/// `2^n` evaluations of `f` on the boolean hypercube. The `n` variables are
91/// indexed from `0` to `n-1` throughout the lifetime of the object.
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93#[serde(bound = "F: Field")]
94pub struct Evaluations<F: Field> {
95 /// To understand how evaluations are stored, let's index `f`'s input bits
96 /// as follows: `f(b_0, b_1, ..., b_{n-1})`. Evaluations are ordered using
97 /// the bit-string `b_0b_1...b_{n-2}b_{n-1}` as key. This ordering is
98 /// sometimes referred to as "big-endian" due to its resemblance to
99 /// big-endian byte ordering. A suffix of contiguous evaluations all equal
100 /// to `F::ZERO` may be omitted in this internal representation but this
101 /// struct is not responsible for maintaining this property at all times.
102 /// # Example
103 /// * The evaluations of a 2-dimensional function are stored in the
104 /// following order: `[ f(0, 0), f(0, 1), f(1, 0), f(1, 1) ]`.
105 /// * The evaluation table `[ 1, 0, 5, 0 ]` may be stored as `[1, 0, 5]` by
106 /// omitting the trailing zero. Note that both representations are valid.
107 evals: BitPackedVector<F>,
108
109 /// Number of input variables to `f`. Invariant: `0 <= evals.len() <=
110 /// 2^num_vars`. The length can be less than `2^num_vars` due to suffix
111 /// omission.
112 num_vars: usize,
113
114 /// TODO(Makis): Is there a better way to handle this?? When accessing an
115 /// element of the bookkeping table, we return a reference to a field
116 /// element. In case the element is stored implicitly as a missing entry, we
117 /// need someone to own the "zero" of the field. If I make this a const, I'm
118 /// not sure how to initialize it.
119 zero: F,
120}
121
122impl<F: Field> Zeroize for Evaluations<F> {
123 fn zeroize(&mut self) {
124 self.evals.zeroize();
125 self.num_vars.zeroize();
126 self.zero.zeroize();
127 }
128}
129
130impl<F: Field> Evaluations<F> {
131 /// Returns a representation of the constant function on zero variables
132 /// equal to `F::ZERO`.
133 pub fn new_zero() -> Self {
134 Self::new(0, vec![])
135 }
136
137 /// Builds an evaluation representation for a function `f: {0, 1}^num_vars
138 /// -> F` from a vector of evaluations in big-endian order (see
139 /// documentation comment for `Self::evals` for explanation).
140 ///
141 /// # Example
142 /// For a function `f: {0, 1}^2 -> F`, an evaluations table may be built as:
143 /// `Evaluations::new(2, vec![ f(0, 0), f(0, 1), f(1, 0), f(1, 1) ])`.
144 ///
145 /// An example of suffix omission is when `f(1, 0) == f(1, 1) == F::ZERO`.
146 /// In that case those zero values may be omitted and the following
147 /// statement generates an equivalent representation: `Evaluations::new(2,
148 /// vec![ f(0, 0), f(0, 1) ])`.
149 pub fn new(num_vars: usize, evals: Vec<F>) -> Self {
150 debug_assert!(evals.len() <= (1 << num_vars));
151
152 // debug_evals(&evals);
153
154 Evaluations::<F> {
155 evals: BitPackedVector::new(&evals),
156 num_vars,
157 zero: F::ZERO,
158 }
159 }
160
161 /// Builds an evaluation representation for a function `f: {0, 1}^num_vars
162 /// -> F` from a vector of evaluations in _little-endian_ order (see
163 /// documentation comment for `Self::evals` for explanation).
164 ///
165 /// # Example
166 /// For a function `f: {0, 1}^2 -> F`, an evaluations table may be built as:
167 /// `Evaluations::new_from_little_endian(2, vec![ f(0, 0), f(1, 0), f(0, 1),
168 /// f(1, 1) ])`.
169 pub fn new_from_little_endian(num_vars: usize, evals: &[F]) -> Self {
170 debug_assert!(evals.len() <= (1 << num_vars));
171
172 println!("New MLE (big-endian) on {} entries.", evals.len());
173
174 Self {
175 evals: BitPackedVector::new(&Self::flip_endianess(num_vars, evals)),
176 num_vars,
177 zero: F::ZERO,
178 }
179 }
180
181 /// Returns the number of variables of the current `Evalations`.
182 pub fn num_vars(&self) -> usize {
183 self.num_vars
184 }
185
186 /// Returns true if the boolean function has not free variables. Equivalent
187 /// to checking whether that [Self::num_vars] is equal to zero.
188 pub fn is_fully_bound(&self) -> bool {
189 self.num_vars == 0
190 }
191
192 /// Returns the first element of the bookkeeping table. This operation
193 /// should always be successful because even in the case that
194 /// [Self::num_vars] is zero, there is a non-zero number of vertices on the
195 /// boolean hypercube and hence there's at least one evaluation stored in
196 /// the bookkeeping table, either explicitly as a value inside
197 /// `Self::evals`, or implicitly if it's a `F::ZERO` that has been pruned as
198 /// part of a zero suffix.
199 pub fn first(&self) -> F {
200 self.evals.get(0).unwrap_or(F::ZERO)
201 }
202
203 /// If `self` represents a fully-bound boolean function (i.e.
204 /// [Self::num_vars] is zero), it returns its value. Otherwise panics.
205 pub fn value(&self) -> F {
206 assert!(self.is_fully_bound());
207 self.first()
208 }
209
210 /// Returns an iterator that traverses the evaluations in "big-endian"
211 /// order.
212 pub fn iter(&self) -> EvaluationsIterator<'_, F> {
213 EvaluationsIterator::<F> {
214 evals: self,
215 current_index: 0,
216 }
217 }
218
219 /// Temporary function returning the length of the internal representation.
220 #[allow(clippy::len_without_is_empty)]
221 pub fn len(&self) -> usize {
222 self.evals.len()
223 }
224
225 /// Temporary function for accessing a the `idx`-th element in the internal
226 /// representation.
227 pub fn get(&self, idx: usize) -> Option<F> {
228 self.evals.get(idx)
229 }
230
231 // -------- Helper Functions --------
232
233 /// Checks whether its arguments correspond to equivalent representations of
234 /// the same list of evaluations. Two representations are equivalent if
235 /// omitting the longest contiguous suffix of `F::ZERO`s from each results
236 /// in the same vectors.
237 #[allow(dead_code)]
238 fn equiv_repr(evals1: &BitPackedVector<F>, evals2: &BitPackedVector<F>) -> bool {
239 evals1
240 .iter()
241 .zip_longest(evals2.iter())
242 .all(|pair| match pair {
243 Both(l, r) => l == r,
244 Left(l) => l == F::ZERO,
245 Right(r) => r == F::ZERO,
246 })
247 }
248
249 /// Sorts the elements of `values` by their 0-based index transformed by
250 /// mirroring the `num_bits` LSBs. This operation effectively flips the
251 /// "endianess" of the index ordering. If `values.len() < 2^num_bits`, the
252 /// missing values are assumed to be zeros. The resulting vector is always
253 /// of size `2^num_bits`.
254 /// # Example
255 /// ```
256 /// use remainder::mle::evals::Evaluations;
257 /// use shared_types::Fr;
258 /// assert_eq!(Evaluations::flip_endianess(
259 /// 2,
260 /// &[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
261 /// vec![ Fr::from(1), Fr::from(3), Fr::from(2), Fr::from(4) ]
262 /// );
263 /// assert_eq!(Evaluations::flip_endianess(
264 /// 2,
265 /// &[ Fr::from(1), Fr::from(2) ]),
266 /// vec![ Fr::from(1), Fr::from(0), Fr::from(2), Fr::from(0) ]
267 /// );
268 /// ```
269 pub fn flip_endianess(num_bits: usize, values: &[F]) -> Vec<F> {
270 let num_evals = values.len();
271
272 let result: Vec<F> = cfg_into_iter!(0..(1 << num_bits))
273 .map(|idx| {
274 let mirrored_idx = mirror_bits(num_bits, idx);
275 if mirrored_idx >= num_evals {
276 F::ZERO
277 } else {
278 values[mirrored_idx]
279 }
280 })
281 .collect();
282
283 result
284 }
285}
286
287/// An iterator over evaluations in a "big-endian" order.
288pub struct EvaluationsIterator<'a, F: Field> {
289 /// Reference to the original `Evaluations` struct.
290 evals: &'a Evaluations<F>,
291
292 /// Index of the next evaluation to be retrieved.
293 current_index: usize,
294}
295
296impl<F: Field> Iterator for EvaluationsIterator<'_, F> {
297 type Item = F;
298
299 fn next(&mut self) -> Option<Self::Item> {
300 if self.current_index < self.evals.len() {
301 let val = self.evals.get(self.current_index).unwrap();
302 self.current_index += 1;
303
304 Some(val)
305 } else {
306 None
307 }
308 }
309}
310
311impl<F: Field> Clone for EvaluationsIterator<'_, F> {
312 fn clone(&self) -> Self {
313 Self {
314 evals: self.evals,
315 current_index: self.current_index,
316 }
317 }
318}
319
320/// An iterator over evaluations indexed by vertices of a projection of the
321/// boolean hypercube on `num_vars - 1` dimensions. See documentation for
322/// `Evaluations::project` for more information.
323#[allow(dead_code)]
324pub struct EvaluationsPairIterator<'a, F: Field> {
325 /// Reference to original bookkeeping table.
326 evals: &'a Evaluations<F>,
327
328 /// A mask for isolating the `k` LSBs of the `current_eval_index` where `k`
329 /// is the dimension on which the original hypercube is projected on.
330 lsb_mask: usize,
331
332 /// 0-base index of the next element to be returned. Invariant:
333 /// `current_eval_index \in [0, 2^(evals.num_vars() - 1)]`. If equal to
334 /// `2^(evals.num_vars() - 1)`, the iterator has reached the end.
335 current_pair_index: usize,
336}
337
338impl<F: Field> Iterator for EvaluationsPairIterator<'_, F> {
339 type Item = (F, F);
340
341 fn next(&mut self) -> Option<Self::Item> {
342 let num_vars = self.evals.num_vars();
343 let num_pairs = 1_usize << (num_vars - 1);
344
345 if self.current_pair_index < num_pairs {
346 // Compute the two indices by inserting a `0` and a `1` respectively
347 // in the appropriate position of `current_pair_index`. For example,
348 // if this is an Iterator projecting on `fix_variable_index == 2`
349 // for an Evaluations table of `num_vars == 5`, then `lsb_mask ==
350 // 0b00011` (the `fix_variable_index` LSBs are on). When, for
351 // example `current_pair_index == 0b1010`, it is split into a "right
352 // part": `lsb_idx == 0b00 0 10`, and a "shifted left part":
353 // `msb_idx == 0b10 0 00`. The two parts are then combined with the
354 // middle bit on and off respectively: `idx1 == 0b10 0 10`, `idx2 ==
355 // 0b10 1 10`.
356 let lsb_idx = self.current_pair_index & self.lsb_mask;
357 let msb_idx = (self.current_pair_index & (!self.lsb_mask)) << 1;
358 let mid_idx = self.lsb_mask + 1;
359
360 let idx1 = lsb_idx | msb_idx;
361 let idx2 = lsb_idx | mid_idx | msb_idx;
362
363 self.current_pair_index += 1;
364
365 let val1 = self.evals.get(idx1).unwrap();
366 let val2 = self.evals.get(idx2).unwrap();
367
368 Some((val1, val2))
369 } else {
370 None
371 }
372 }
373}
374
375/// Stores a function `\tilde{f}: F^n -> F`, the unique Multilinear Extension
376/// (MLE) of a given function `f: {0, 1}^n -> F`:
377/// ```text
378/// \tilde{f}(x_0, ..., x_{n-1})
379/// = \sum_{b_0, ..., b_{n-1} \in {0, 1}^n}
380/// \tilde{beta}(x_0, ..., x_{n-1}, b_0, ..., b_{n-1})
381/// * f(b_0, ..., b_{n-1}).
382/// ```
383/// where `\tilde{beta}` is the MLE of the equality function:
384/// ```text
385/// \tilde{beta}(x_0, ..., x_{n-1}, b_0, ..., b_{n-1})
386/// = \prod_{i = 0}^{n-1} ( x_i * b_i + (1 - x_i) * (1 - b_i) )
387/// ```
388/// Internally, `f` is represented as a list of evaluations of `f` on the
389/// boolean hypercube. The `n` variables are indexed from `0` to `n-1`
390/// throughout the lifetime of the object even if `n` is modified by fixing a
391/// variable to a constant value.
392#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
393#[serde(bound = "F: Field")]
394pub struct MultilinearExtension<F: Field> {
395 /// The bookkeeping table with the evaluations of `f` on the hypercube.
396 pub f: Evaluations<F>,
397}
398
399impl<F: Field> Zeroize for MultilinearExtension<F> {
400 fn zeroize(&mut self) {
401 self.f.zeroize();
402 }
403}
404
405impl<F: Field> From<Vec<bool>> for MultilinearExtension<F> {
406 fn from(bools: Vec<bool>) -> Self {
407 let evals = bools
408 .into_iter()
409 .map(|b| if b { F::ONE } else { F::ZERO })
410 .collect();
411 MultilinearExtension::new(evals)
412 }
413}
414
415impl<F: Field> From<Vec<u32>> for MultilinearExtension<F> {
416 fn from(uints: Vec<u32>) -> Self {
417 let evals = uints.into_iter().map(|v| F::from(v as u64)).collect();
418 MultilinearExtension::new(evals)
419 }
420}
421
422impl<F: Field> From<Vec<u64>> for MultilinearExtension<F> {
423 fn from(uints: Vec<u64>) -> Self {
424 let evals = uints.into_iter().map(F::from).collect();
425 MultilinearExtension::new(evals)
426 }
427}
428
429impl<F: Field> From<Vec<i32>> for MultilinearExtension<F> {
430 fn from(ints: Vec<i32>) -> Self {
431 let evals = ints.into_iter().map(|v| i64_to_field(v as i64)).collect();
432 MultilinearExtension::new(evals)
433 }
434}
435
436impl<F: Field> From<Vec<i64>> for MultilinearExtension<F> {
437 fn from(ints: Vec<i64>) -> Self {
438 let evals = ints.into_iter().map(i64_to_field).collect();
439 MultilinearExtension::new(evals)
440 }
441}
442
443impl<F: Field> MultilinearExtension<F> {
444 /// Create a new MultilinearExtension from a [`Vec<F>`] of evaluations.
445 pub fn new(evals_vec: Vec<F>) -> Self {
446 let num_vars = log2(evals_vec.len()) as usize;
447 let evals = Evaluations::new(num_vars, evals_vec);
448 MultilinearExtension::new_from_evals(evals)
449 }
450
451 /// Generate a new MultilinearExtension from a representation `evals` of a
452 /// function `f`.
453 pub fn new_from_evals(evals: Evaluations<F>) -> Self {
454 Self { f: evals }
455 }
456
457 /// Creates a new mle which is all zeroes of a specific num_vars. In this
458 /// case the size of the evals and the num_vars will not match up
459 pub fn new_sized_zero(num_vars: usize) -> Self {
460 Self {
461 f: Evaluations {
462 evals: BitPackedVector::new(&[]),
463 num_vars,
464 zero: F::ZERO,
465 },
466 }
467 }
468
469 /// Returns an iterator accessing the evaluations defining this MLE in
470 /// "big-endian" order.
471 pub fn iter(&self) -> EvaluationsIterator<'_, F> {
472 self.f.iter()
473 }
474
475 /// Generate a Vector of the evaluations of `f` over the hypercube.
476 pub fn to_vec(&self) -> Vec<F> {
477 self.f.iter().collect()
478 }
479
480 /// Returns true if the MLE has not free variables. Equivalent to checking
481 /// whether that [Self::num_vars] is equal to zero.
482 pub fn is_fully_bound(&self) -> bool {
483 self.f.is_fully_bound()
484 }
485
486 /// Returns the first element of the bookkeeping table of this MLE,
487 /// corresponding to the value of the MLE when all varables are set to zero.
488 /// This operation never fails (see [Evaluations::first]).
489 pub fn first(&self) -> F {
490 self.f.first()
491 }
492
493 /// If `self` represents a fully-bound MLE (i.e. on zero variables), it
494 /// returns its value. Otherwise panics.
495 pub fn value(&self) -> F {
496 self.f.value()
497 }
498
499 /// Generates a representation for the MLE of the zero function on zero
500 /// variables.
501 pub fn new_zero() -> Self {
502 let zero_evals = Evaluations::new_zero();
503 Self::new_from_evals(zero_evals)
504 }
505
506 /// Returns `n`, the number of arguments `\tilde{f}` takes.
507 pub fn num_vars(&self) -> usize {
508 self.f.num_vars()
509 }
510
511 /// Returns the `idx`-th element, if `idx` is in the range `[0,
512 /// 2^self.num_vars)`.
513 pub fn get(&self, idx: usize) -> Option<F> {
514 if idx >= (1 << self.num_vars()) {
515 // `idx` is out of range.
516 None
517 } else if idx >= self.f.len() {
518 // `idx` is within range, but value is implicitly assumed to be
519 // zero.
520 Some(F::ZERO)
521 } else {
522 // `idx`-th position is stored explicitly in `self.f`
523 self.f.get(idx)
524 }
525 }
526
527 /// Evaluate `\tilde{f}` at `point \in F^n`.
528 /// # Panics
529 /// If `point` does not contain exactly `self.num_vars()` elements.
530 pub fn evaluate_at_point(&self, point: &[F]) -> F {
531 let n = self.num_vars();
532 assert_eq!(n, point.len());
533
534 // TODO: Provide better access mechanism.
535 self.f
536 .evals
537 .clone()
538 .iter() // was into_iter()
539 .enumerate()
540 .fold(F::ZERO, |acc, (idx, v)| {
541 let beta = (0..n).fold(F::ONE, |acc, i| {
542 let bit_i = idx & (1 << (n - 1 - i));
543 if bit_i > 0 {
544 acc * point[i]
545 } else {
546 acc * (F::ONE - point[i])
547 }
548 });
549 acc + v * beta
550 })
551 }
552
553 /// Returns the length of the evaluations vector.
554 #[allow(clippy::len_without_is_empty)]
555 pub fn len(&self) -> usize {
556 self.f.len()
557 }
558
559 /// Fix the 0-based `var_index`-th bit of `\tilde{f}` to an arbitrary field
560 /// element `point \in F` by destructively modifying `self`.
561 /// # Params
562 /// * `var_index`: A 0-based index of the input variable to be fixed.
563 /// * `point`: The field element to set `x_{var_index}` equal to.
564 /// # Example
565 /// If `self` represents a function `\tilde{f}: F^3 -> F`,
566 /// `self.fix_variable_at_index(1, r)` fixes the middle variable to `r \in
567 /// F`. After the invocation, `self` represents a function `\tilde{g}: F^2
568 /// -> F` defined as the multilinear extension of the following function:
569 /// `g(b_0, b_1) = \tilde{f}(b_0, r, b_1)`.
570 /// # Panics
571 /// if `var_index` is outside the interval `[0, self.num_vars())`.
572 pub fn fix_variable_at_index(&mut self, var_index: usize, point: F) {
573 let num_vars = self.num_vars();
574 let lsb_mask = (1_usize << (num_vars - 1 - var_index)) - 1;
575
576 let num_pairs = 1_usize << (num_vars - 1);
577
578 let new_evals: Vec<F> = cfg_into_iter!(0..num_pairs)
579 .map(|idx| {
580 // This iteration computes the value of
581 // `f'(idx[0], ..., idx[var_index-1], idx[var_index+1], ..., idx[num_vars - 1])`
582 // where `f'` is the resulting function after fixing the
583 // the `var_index`-th variable.
584 // To do this, we must combine the values of:
585 // `f(idx1) = f(idx[0], ..., idx[var_index-1], 0, idx[var_index+1], ..., idx[num_vars-1])`
586 // and
587 // `f(idx2) = f(idx[0], ..., idx[var_index-1], 1, idx[var_index+1], ..., idx[num_vars-1])`
588 // Below we compute `idx1` and `idx2` corresponding to the two
589 // indices above.
590
591 // Compute the two indices by inserting a `0` and a `1`
592 // respectively in the appropriate position of `idx`. For
593 // example, if `var_index == 2` and `self.num_vars == 5`, then
594 // `lsb_mask == 0b0011` (the `num_var - 1 - var_index` LSBs are
595 // on). When, for example `idx == 0b1010`, it is split into a
596 // "right part": `lsb_idx == 0b00 0 10`, and a "shifted left
597 // part": `msb_idx == 0b10 0 00`. The two parts are then
598 // combined with the middle bit on and off respectively: `idx1
599 // == 0b10 0 10`, `idx2 == 0b10 1 10`.
600 let lsb_idx = idx & lsb_mask;
601 let msb_idx = (idx & (!lsb_mask)) << 1;
602 let mid_idx = lsb_mask + 1;
603
604 let idx1 = lsb_idx | msb_idx;
605 let idx2 = lsb_idx | mid_idx | msb_idx;
606
607 let val1 = self.get(idx1).unwrap_or(F::ZERO);
608 let val2 = self.get(idx2).unwrap_or(F::ZERO);
609
610 val1 + (val2 - val1) * point
611 })
612 .collect();
613
614 debug_assert_eq!(new_evals.len(), 1 << (num_vars - 1));
615 self.f = Evaluations::new(num_vars - 1, new_evals);
616 }
617
618 /// Optimized version of `fix_variable_at_index` for `var_index == 0`.
619 /// # Panics
620 /// If `self.num_vars() == 0`.
621 pub fn fix_variable(&mut self, point: F) {
622 self.fix_variable_at_index(0, point);
623 }
624
625 /// Stacks the MLEs into a single MLE, assuming they are stored in a "big
626 /// endian" fashion.
627 pub fn stack_mles(mles: Vec<MultilinearExtension<F>>) -> MultilinearExtension<F> {
628 let first_len = mles[0].len();
629
630 if !mles.iter().all(|v| v.len() == first_len) {
631 panic!("All mles's underlying bookkeeping table must have the same length");
632 }
633
634 let out = mles.iter().flat_map(|mle| mle.to_vec()).collect();
635 Self::new(out)
636 }
637
638 /// Convert a [MultilinearExtension] into a vector of u8s.
639 /// Every element is padded to contain 8 bits.
640 pub fn convert_into_u8_vec(&self) -> Vec<u8> {
641 self.f
642 .iter()
643 .map(|field_element| {
644 let field_element_le_bytes = field_element.to_bytes_le();
645 let mut padded_u8 = [0u8; 1];
646 padded_u8.copy_from_slice(&field_element_le_bytes[..1]);
647 u8::from_le_bytes(padded_u8)
648 })
649 .collect_vec()
650 }
651
652 /// Convert a [MultilinearExtension] into a vector of u16s.
653 /// Every element is padded to contain 16 bits.
654 pub fn convert_into_u16_vec(&self) -> Vec<u16> {
655 self.f
656 .iter()
657 .map(|field_element| {
658 let field_element_le_bytes = field_element.to_bytes_le();
659 let mut padded_u16 = [0u8; 2];
660 padded_u16.copy_from_slice(&field_element_le_bytes[..2]);
661 u16::from_le_bytes(padded_u16)
662 })
663 .collect_vec()
664 }
665
666 /// Convert a [MultilinearExtension] into a vector of u32s.
667 /// Every element is padded to contain 32 bits.
668 pub fn convert_into_u32_vec(&self) -> Vec<u32> {
669 self.f
670 .iter()
671 .map(|field_element| {
672 let field_element_le_bytes = field_element.to_bytes_le();
673 let mut padded_u32 = [0u8; 4];
674 padded_u32.copy_from_slice(&field_element_le_bytes[..4]);
675 u32::from_le_bytes(padded_u32)
676 })
677 .collect_vec()
678 }
679
680 /// Convert a [MultilinearExtension] into a vector of u64s.
681 /// Every element is padded to contain 64 bits.
682 pub fn convert_into_u64_vec(&self) -> Vec<u64> {
683 self.f
684 .iter()
685 .map(|field_element| {
686 let field_element_le_bytes = field_element.to_bytes_le();
687 let mut padded_u64 = [0u8; 8];
688 padded_u64.copy_from_slice(&field_element_le_bytes[..8]);
689 u64::from_le_bytes(padded_u64)
690 })
691 .collect_vec()
692 }
693
694 /// Convert a [MultilinearExtension] into a vector of u128s.
695 /// Every element is padded to contain 128 bits.
696 pub fn convert_into_u128_vec(&self) -> Vec<u128> {
697 self.f
698 .iter()
699 .map(|field_element| {
700 let field_element_le_bytes = field_element.to_bytes_le();
701 let mut padded_u128 = [0u8; 16];
702 padded_u128.copy_from_slice(&field_element_le_bytes[..16]);
703 u128::from_le_bytes(padded_u128)
704 })
705 .collect_vec()
706 }
707}