1#[cfg(test)]
2mod tests;
3
4use std::fmt::Debug;
5
6use ark_std::log2;
7use itertools::{repeat_n, Itertools};
8
9use serde::{Deserialize, Serialize};
10
11use super::{evals::EvaluationsIterator, mle_enum::MleEnum, Mle, MleIndex};
12use crate::{
13 claims::RawClaim,
14 mle::evals::{Evaluations, MultilinearExtension},
15};
16use crate::{
17 expression::{generic_expr::Expression, prover_expr::ProverExpr},
18 layer::LayerId,
19};
20use shared_types::Field;
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
24#[serde(bound = "F: Field")]
25pub struct DenseMle<F: Field> {
26 pub layer_id: LayerId,
28 pub mle: MultilinearExtension<F>,
30 pub mle_indices: Vec<MleIndex<F>>,
32}
33
34impl<F: Field> Mle<F> for DenseMle<F> {
35 fn num_free_vars(&self) -> usize {
36 self.mle.num_vars()
37 }
38
39 fn get_padded_evaluations(&self) -> Vec<F> {
40 let size: usize = 1 << self.mle.num_vars();
41 let padding = size - self.mle.len();
42
43 self.mle.iter().chain(repeat_n(F::ZERO, padding)).collect()
44 }
45
46 fn add_prefix_bits(&mut self, mut new_bits: Vec<MleIndex<F>>) {
47 new_bits.extend(self.mle_indices.clone());
48 self.mle_indices.clone_from(&new_bits);
49 }
50
51 fn layer_id(&self) -> LayerId {
52 self.layer_id
53 }
54
55 fn len(&self) -> usize {
56 self.mle.len()
57 }
58
59 fn iter(&self) -> EvaluationsIterator<'_, F> {
60 self.mle.iter()
61 }
62
63 fn mle_indices(&self) -> &[MleIndex<F>] {
64 &self.mle_indices
65 }
66
67 fn fix_variable_at_index(&mut self, indexed_bit_index: usize, point: F) -> Option<RawClaim<F>> {
68 let (index_found, bit_count) =
80 self.mle_indices
81 .iter_mut()
82 .fold((false, 0), |state, mle_index| {
83 if state.0 {
84 state
86 } else if let MleIndex::Indexed(current_bit_index) = *mle_index {
87 if current_bit_index == indexed_bit_index {
88 mle_index.bind_index(point);
91 (true, state.1 + 1)
92 } else {
93 (false, state.1 + 1)
96 }
97 } else {
98 state
101 }
102 });
103
104 assert!(index_found);
105 debug_assert!(1 <= bit_count && bit_count <= self.num_free_vars());
106
107 self.mle.fix_variable_at_index(bit_count - 1, point);
108
109 if self.is_fully_bounded() {
110 let fixed_claim_return = RawClaim::new(
111 self.mle_indices
112 .iter()
113 .map(|index| index.val().unwrap())
114 .collect_vec(),
115 self.mle.value(),
116 );
117 Some(fixed_claim_return)
118 } else {
119 None
120 }
121 }
122
123 fn fix_variable(&mut self, index: usize, binding: F) -> Option<RawClaim<F>> {
127 for mle_index in self.mle_indices.iter_mut() {
128 if *mle_index == MleIndex::Indexed(index) {
129 mle_index.bind_index(binding);
130 }
131 }
132 self.mle.fix_variable(binding);
134
135 if self.is_fully_bounded() {
136 let fixed_claim_return = RawClaim::new(
137 self.mle_indices
138 .iter()
139 .map(|index| index.val().unwrap())
140 .collect_vec(),
141 self.mle.value(),
142 );
143 Some(fixed_claim_return)
144 } else {
145 None
146 }
147 }
148
149 fn index_mle_indices(&mut self, curr_index: usize) -> usize {
150 let mut new_indices = 0;
151 for mle_index in self.mle_indices.iter_mut() {
152 if *mle_index == MleIndex::Free {
153 *mle_index = MleIndex::Indexed(curr_index + new_indices);
154 new_indices += 1;
155 }
156 }
157
158 curr_index + new_indices
159 }
160
161 fn get_enum(self) -> MleEnum<F> {
162 MleEnum::Dense(self)
163 }
164
165 fn get(&self, index: usize) -> Option<F> {
166 self.mle.get(index)
167 }
168
169 fn first(&self) -> F {
170 self.mle.first()
171 }
172
173 fn value(&self) -> F {
174 self.mle.value()
175 }
176}
177
178impl<F: Field> DenseMle<F> {
179 pub fn new_with_prefix_bits(
182 data: MultilinearExtension<F>,
183 layer_id: LayerId,
184 prefix_bits: Vec<bool>,
185 ) -> Self {
186 let free_bits = data.num_vars();
187
188 let mle_indices: Vec<MleIndex<F>> = prefix_bits
189 .into_iter()
190 .map(|bit| MleIndex::Fixed(bit))
191 .chain((0..free_bits).map(|_| MleIndex::Free))
192 .collect();
193 Self {
194 layer_id,
195 mle: data,
196 mle_indices,
197 }
198 }
199
200 pub fn new_with_indices(data: &[F], layer_id: LayerId, mle_indices: &[MleIndex<F>]) -> Self {
207 let mut mle = DenseMle::new_from_raw(data.to_vec(), layer_id);
208
209 let all_indices_free_or_fixed = mle_indices.iter().all(|index| {
210 index == &MleIndex::Free
211 || index == &MleIndex::Fixed(true)
212 || index == &MleIndex::Fixed(false)
213 });
214 assert!(all_indices_free_or_fixed);
215
216 mle.mle_indices = mle_indices.to_vec();
217 mle
218 }
219
220 pub fn new_from_iter(iter: impl Iterator<Item = F>, layer_id: LayerId) -> Self {
232 let items = iter.collect_vec();
233 let num_free_vars = log2(items.len()) as usize;
234
235 let mle_indices: Vec<MleIndex<F>> = ((0..num_free_vars).map(|_| MleIndex::Free)).collect();
236
237 let current_mle =
238 MultilinearExtension::new_from_evals(Evaluations::<F>::new(num_free_vars, items));
239 Self {
240 layer_id,
241 mle: current_mle,
242 mle_indices,
243 }
244 }
245
246 pub fn new_from_raw(items: Vec<F>, layer_id: LayerId) -> Self {
258 let num_free_vars = log2(items.len()) as usize;
259
260 let mle_indices: Vec<MleIndex<F>> = ((0..num_free_vars).map(|_| MleIndex::Free)).collect();
261
262 let current_mle =
263 MultilinearExtension::new_from_evals(Evaluations::<F>::new(num_free_vars, items));
264
265 Self {
266 layer_id,
267 mle: current_mle,
268 mle_indices,
269 }
270 }
271
272 pub fn new_from_multilinear_extension(
277 mle: MultilinearExtension<F>,
278 layer_id: LayerId,
279 prefix_vars: Option<Vec<bool>>,
280 maybe_starting_var_index: Option<usize>,
281 ) -> Self {
282 let mle_indices: Vec<MleIndex<F>> = prefix_vars
283 .unwrap_or_default()
284 .into_iter()
285 .map(|prefix_var| MleIndex::Fixed(prefix_var))
286 .chain((0..mle.num_vars()).map(|_| MleIndex::Free))
287 .collect();
288 let mut ret = Self {
289 layer_id,
290 mle,
291 mle_indices,
292 };
293 if let Some(starting_var_index) = maybe_starting_var_index {
294 ret.index_mle_indices(starting_var_index);
295 }
296 ret
297 }
298
299 pub fn combine_mles(mles: Vec<DenseMle<F>>) -> DenseMle<F> {
301 let first_mle_num_vars = mles[0].num_free_vars();
302 let all_same_num_vars = mles
303 .iter()
304 .all(|mle| mle.num_free_vars() == first_mle_num_vars);
305 assert!(all_same_num_vars);
306 let layer_id = mles[0].layer_id;
307 let mle_flattened = mles.into_iter().flat_map(|mle| mle.into_iter());
308
309 Self::new_from_iter(mle_flattened, layer_id)
310 }
311
312 pub fn expression(self) -> Expression<F, ProverExpr> {
314 Expression::<F, ProverExpr>::mle(self)
315 }
316
317 pub fn get_bound_point(&self) -> Vec<F> {
322 self.mle_indices()
323 .iter()
324 .map(|index| match index {
325 MleIndex::Bound(chal, _) => *chal,
326 MleIndex::Fixed(chal) => F::from(*chal as u64),
327 _ => panic!("MLE index not bound"),
328 })
329 .collect()
330 }
331}
332
333impl<F: Field> IntoIterator for DenseMle<F> {
334 type Item = F;
335
336 type IntoIter = std::vec::IntoIter<Self::Item>;
337
338 fn into_iter(self) -> Self::IntoIter {
339 self.mle.iter().collect::<Vec<F>>().into_iter()
341 }
342}