remainder/mle/
mle_enum.rs

1//! A wrapper `enum` type around various implementations of MLEs.
2
3use super::{dense::DenseMle, evals::EvaluationsIterator, zero::ZeroMle, MleIndex};
4use crate::{layer::LayerId, mle::Mle};
5use itertools::{repeat_n, Itertools};
6use serde::{Deserialize, Serialize};
7use shared_types::Field;
8
9/// A wrapper type for various kinds of MLEs.
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11#[serde(bound = "F: Field")]
12pub enum MleEnum<F: Field> {
13    /// A [DenseMle] variant.
14    Dense(DenseMle<F>),
15    /// A [ZeroMle] variant.
16    Zero(ZeroMle<F>),
17}
18
19impl<F: Field> Mle<F> for MleEnum<F> {
20    fn len(&self) -> usize {
21        match self {
22            MleEnum::Dense(item) => item.len(),
23            MleEnum::Zero(item) => item.len(),
24        }
25    }
26
27    fn iter(&self) -> EvaluationsIterator<'_, F> {
28        match self {
29            MleEnum::Dense(item) => item.iter(),
30            MleEnum::Zero(item) => item.iter(),
31        }
32    }
33
34    fn first(&self) -> F {
35        match self {
36            MleEnum::Dense(item) => item.first(),
37            MleEnum::Zero(item) => item.first(),
38        }
39    }
40
41    fn value(&self) -> F {
42        match self {
43            MleEnum::Dense(item) => item.value(),
44            MleEnum::Zero(item) => item.value(),
45        }
46    }
47
48    fn get(&self, index: usize) -> Option<F> {
49        match self {
50            MleEnum::Dense(item) => item.get(index),
51            MleEnum::Zero(item) => item.get(index),
52        }
53    }
54
55    fn mle_indices(&self) -> &[super::MleIndex<F>] {
56        match self {
57            MleEnum::Dense(item) => item.mle_indices(),
58            MleEnum::Zero(item) => item.mle_indices(),
59        }
60    }
61
62    fn num_free_vars(&self) -> usize {
63        match self {
64            MleEnum::Dense(item) => item.num_free_vars(),
65            MleEnum::Zero(item) => item.num_free_vars(),
66        }
67    }
68
69    fn fix_variable(
70        &mut self,
71        round_index: usize,
72        challenge: F,
73    ) -> Option<crate::claims::RawClaim<F>> {
74        match self {
75            MleEnum::Dense(item) => item.fix_variable(round_index, challenge),
76            MleEnum::Zero(item) => item.fix_variable(round_index, challenge),
77        }
78    }
79
80    fn fix_variable_at_index(
81        &mut self,
82        indexed_bit_index: usize,
83        point: F,
84    ) -> Option<crate::claims::RawClaim<F>> {
85        match self {
86            MleEnum::Dense(item) => item.fix_variable_at_index(indexed_bit_index, point),
87            MleEnum::Zero(item) => item.fix_variable_at_index(indexed_bit_index, point),
88        }
89    }
90
91    fn index_mle_indices(&mut self, curr_index: usize) -> usize {
92        match self {
93            MleEnum::Dense(item) => item.index_mle_indices(curr_index),
94            MleEnum::Zero(item) => item.index_mle_indices(curr_index),
95        }
96    }
97
98    fn layer_id(&self) -> LayerId {
99        match self {
100            MleEnum::Dense(item) => item.layer_id(),
101            MleEnum::Zero(item) => item.layer_id(),
102        }
103    }
104
105    fn get_enum(self) -> MleEnum<F> {
106        self
107    }
108
109    fn get_padded_evaluations(&self) -> Vec<F> {
110        match self {
111            MleEnum::Dense(dense_mle) => dense_mle.mle.f.iter().collect_vec(),
112            MleEnum::Zero(zero_mle) => repeat_n(F::ZERO, 1 << zero_mle.num_vars).collect_vec(),
113        }
114    }
115
116    fn add_prefix_bits(&mut self, _new_bits: Vec<MleIndex<F>>) {
117        todo!()
118    }
119}
120
121impl<F: Field> From<DenseMle<F>> for MleEnum<F> {
122    fn from(value: DenseMle<F>) -> Self {
123        Self::Dense(value)
124    }
125}
126
127impl<F: Field> From<ZeroMle<F>> for MleEnum<F> {
128    fn from(value: ZeroMle<F>) -> Self {
129        Self::Zero(value)
130    }
131}