remainder/layer/combine_mles.rs
1//! This module contains the code used to combine parts of a layer and combining
2//! them to determine the evaluation of a layer's layerwise bookkeeping table at
3//! a point.
4
5use crate::{
6 mle::{
7 dense::DenseMle,
8 evals::{Evaluations, MultilinearExtension},
9 Mle, MleIndex,
10 },
11 utils::mle::evaluate_mle_at_a_point_gray_codes,
12};
13use ark_std::cfg_iter_mut;
14
15use itertools::Itertools;
16
17use shared_types::Field;
18use thiserror::Error;
19
20use anyhow::{anyhow, Ok, Result};
21
22#[cfg(feature = "parallel")]
23use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
24
25/// Type alias for an MLE evaluation and its prefix bits, as this type is used
26/// throughout the combining code.
27type MleEvaluationAndPrefixBits<F> = (F, Vec<bool>);
28
29/// Error handling for gate mle construction
30#[derive(Error, Debug, Clone)]
31pub enum CombineMleRefError {
32 #[error("we have not fully combined all the mles because the list size is > 1")]
33 /// We have not fully combined all the mles because the list size is > 1.
34 NotFullyCombined,
35 #[error("we have an mle that is not fully fixed even after fixing on the challenge point")]
36 /// We have an mle that is not fully fixed even after fixing on the
37 /// challenge point.
38 MleRefNotFullyFixed,
39}
40
41/// This fixes mles with shared points in the claims so that we don't repeatedly
42/// do so.
43pub fn pre_fix_mles<F: Field>(mles: &mut [DenseMle<F>], chal_point: &[F], common_idx: Vec<usize>) {
44 cfg_iter_mut!(mles).for_each(|mle| {
45 common_idx.iter().for_each(|chal_idx| {
46 if let MleIndex::Indexed(idx_bit_num) = mle.mle_indices()[*chal_idx] {
47 mle.fix_variable_at_index(idx_bit_num, chal_point[*chal_idx]);
48 }
49 });
50 });
51}
52
53/// Function that prepares all the mles to be combined. We simply index all the
54/// MLEs that are to be fixed and combined, and ensure that all fixed bits are
55/// always contiguous.
56pub fn get_indexed_layer_mles_to_combine<F: Field>(mles: Vec<DenseMle<F>>) -> Vec<DenseMle<F>> {
57 // We split all the mles with a free bit within the fixed bits. This is in
58 // order to ensure that all the fixed bits are truly "prefix" bits.
59 let mut mles_split = collapse_mles_with_free_in_prefix(mles);
60
61 // Index all the MLEs as they will be fixed throughout the combining
62 // process.
63 cfg_iter_mut!(mles_split).for_each(|mle| {
64 mle.index_mle_indices(0);
65 });
66 mles_split
67}
68
69/// This function takes in a list of mles, a challenge point we want to combine
70/// them under, and returns the final value in the bookkeeping table of the
71/// combined mle. This is equivalent to combining all of these mles according to
72/// their prefix bits, and then fixing variable on this combined mle (which is
73/// the layerwise bookkeeping table). Instead, we fix variable as we combine as
74/// this keeps the bookkeeping table sizes at one.
75pub fn combine_mles_with_aggregate<F: Field>(mles: &[DenseMle<F>], chal_point: &[F]) -> Result<F> {
76 // We go through all of the mles and fix variable in all of them given at
77 // the correct indices so that they are fully bound.
78 let fix_var_mles = mles
79 .iter()
80 .map(|mle| {
81 let point_to_bind = mle
82 .mle_indices
83 .iter()
84 .enumerate()
85 .filter_map(|(idx, mle_idx)| {
86 if let MleIndex::Indexed(_idx_num) = mle_idx {
87 Some(chal_point[idx])
88 } else {
89 None
90 }
91 })
92 .collect_vec();
93
94 // Fully evaluate the MLE at a point using the gray codes algorithm.
95 let mle_evaluation = evaluate_mle_at_a_point_gray_codes(&mle.mle, &point_to_bind);
96 let prefix_bits = mle
97 .mle_indices
98 .iter()
99 .filter_map(|mle_index| {
100 if let MleIndex::Fixed(prefix_bool) = mle_index {
101 Some(*prefix_bool)
102 } else {
103 None
104 }
105 })
106 .collect_vec();
107
108 (mle_evaluation, prefix_bits)
109 })
110 .collect_vec();
111
112 // Mutable variable that is overwritten every time we combine mles.
113 let mut updated_list = fix_var_mles;
114
115 // A loop that breaks when all the mles no longer have any fixed bits and
116 // only have free bits. This means we have fully combined the MLEs to form
117 // the layerwise bookkeeping table (but fully bound to a point).
118 loop {
119 // We first get the lsb fixed bit and the evaluation of the MLE that
120 // contributes to it.
121 let (mle_evaluation, mle_prefix_bits) = get_lsb_fixed_var(&updated_list);
122
123 // There are only 0 prefix bits for the MLE contributing to the lsb
124 // fixed bit if the MLEs have been fully combined.
125 if mle_prefix_bits.is_empty() {
126 break;
127 }
128
129 // Otherwise, overwrite updated_list to contain the combined MLE instead
130 // of the two MLEs contributing to the lsb fixed bit.
131 updated_list =
132 find_pair_and_combine(&updated_list, mle_prefix_bits, *mle_evaluation, chal_point);
133 }
134
135 // The list now should only have one combined mle, and its bookkeeping table
136 // should only have one value in it since we were binding variables as we
137 // were combining.
138 if updated_list.len() > 1 {
139 return Err(anyhow!(CombineMleRefError::NotFullyCombined));
140 }
141 let (full_eval, prefix_bits) = &updated_list[0];
142 assert_eq!(
143 prefix_bits.len(),
144 0,
145 "there should be no more prefix bits left after fully combining!"
146 );
147
148 Ok(*full_eval)
149}
150
151/// This function takes an MLE that has a free variable in between fixed
152/// variables, and it splits it into two MLEs, one where the free variable is
153/// replaced with `Fixed(false)`, and the other where it is replaced with
154/// `Fixed(true)`. This ensures that all the fixed bits are contiguous. NOTE we
155/// assume that this function is called on an mle that has a free variable
156/// within a bunch of fixed variables (note how it is used in the
157/// `collapse_mles_with_free_in_prefix` function)
158fn split_mle<F: Field>(mle: &DenseMle<F>) -> Vec<DenseMle<F>> {
159 // Get the index of the first free bit in the mle.
160 let first_free_idx: usize = mle.mle_indices().iter().enumerate().fold(
161 mle.mle_indices().len(),
162 |acc, (idx, mle_idx)| {
163 if let MleIndex::Free = mle_idx {
164 std::cmp::min(acc, idx)
165 } else {
166 acc
167 }
168 },
169 );
170
171 // Compute the correct indices, we have the first one be false, the second
172 // one as true instead of the free bit.
173 let first_indices = mle.mle_indices()[0..first_free_idx]
174 .iter()
175 .cloned()
176 .chain(std::iter::once(MleIndex::Fixed(false)))
177 .chain(mle.mle_indices()[first_free_idx + 1..].iter().cloned())
178 .collect_vec();
179 let second_indices = mle.mle_indices()[0..first_free_idx]
180 .iter()
181 .cloned()
182 .chain(std::iter::once(MleIndex::Fixed(true)))
183 .chain(mle.mle_indices()[first_free_idx + 1..].iter().cloned())
184 .collect_vec();
185
186 // Construct the first MLE in the pair.
187 let first_mle = DenseMle {
188 mle: MultilinearExtension::new_from_evals(Evaluations::<F>::new(
189 mle.num_free_vars() - 1,
190 mle.mle.iter().step_by(2).collect_vec(),
191 )),
192 mle_indices: first_indices,
193 layer_id: mle.layer_id,
194 };
195
196 // Second mle in the pair.
197 let second_mle = DenseMle {
198 mle: MultilinearExtension::new_from_evals(Evaluations::<F>::new(
199 mle.num_free_vars() - 1,
200 mle.mle.iter().skip(1).step_by(2).collect_vec(),
201 )),
202 mle_indices: second_indices,
203 layer_id: mle.layer_id,
204 };
205
206 vec![first_mle, second_mle]
207}
208
209/// This function will take a list of MLEs and updates the list to contain MLEs
210/// where all fixed bits are contiguous
211fn collapse_mles_with_free_in_prefix<F: Field>(mles: Vec<DenseMle<F>>) -> Vec<DenseMle<F>> {
212 mles.into_iter()
213 .flat_map(|mle| {
214 // This iterates through the mle indices to check whether there is a
215 // free bit within the fixed bits.
216 let (_, contains_free_in_fixed) = mle.mle_indices().iter().fold(
217 (false, false),
218 |(free_seen_so_far, fixed_after_free_so_far), mle_idx| match mle_idx {
219 MleIndex::Free => (true, fixed_after_free_so_far),
220 MleIndex::Fixed(_) => (free_seen_so_far, free_seen_so_far),
221 _ => (free_seen_so_far, fixed_after_free_so_far),
222 },
223 );
224 // If true, we split, otherwise, we don't.
225 if contains_free_in_fixed {
226 split_mle(&mle)
227 } else {
228 vec![mle]
229 }
230 })
231 .collect()
232}
233
234/// Gets the index of the least significant bit (lsb) of the fixed bits out of a
235/// vector of MLEs.
236///
237/// In other words, this is the MLE evaluation pertaining to the MLE with the
238/// most fixed bits.
239fn get_lsb_fixed_var<F: Field>(
240 mles: &[MleEvaluationAndPrefixBits<F>],
241) -> &MleEvaluationAndPrefixBits<F> {
242 mles.iter()
243 .max_by_key(|(_mle_evaluation, prefix_bits)| prefix_bits.len())
244 .unwrap()
245}
246
247/// Given an MLE evaluation, and an option of a second MLE evaluation pair, this
248/// combines the two together this assumes that the first MLE evaluation and the
249/// second MLE evaluation are pairs, if the second MLE evaluation is a Some()
250///
251/// A pair consists of two MLE evaluations that match in every fixed bit except
252/// for the least significant one. This is because we combine in the reverse
253/// order that we split in terms of selectors, and we split in terms of
254/// selectors by doing huffman (most significant bit).
255///
256/// Example: if mle_evaluation_first has fixed bits true, true, false, its pair
257/// would have fixed bits true, true, true. When we combine them, the combined
258/// MLE has fixed bits true, true. We also simultaneously update the combined
259/// evaluation to use the challenge according to the index we combined at.
260///
261/// If there is no pair, then this is assumed to be an mle with all 0s.
262fn combine_pair<F: Field>(
263 mle_evaluation_first: F,
264 maybe_mle_evaluation_second: Option<F>,
265 prefix_vars_first: &[bool],
266 chal_point: &[F],
267) -> MleEvaluationAndPrefixBits<F> {
268 // If the second mle is None, we assume its bookkeeping table is all zeros.
269 // We are dealing with fully fixed mles, so we just use F::ZERO.
270 let mle_evaluation_second = maybe_mle_evaluation_second.unwrap_or(F::ZERO);
271
272 // Depending on whether the lsb fixed bit was true or false, we bind it to
273 // the correct challenge point at this index this is either the challenge
274 // point at the index, or one minus this value.
275 let bound_coord = if !prefix_vars_first.last().unwrap() {
276 F::ONE - chal_point[prefix_vars_first.len() - 1]
277 } else {
278 chal_point[prefix_vars_first.len() - 1]
279 };
280
281 // We compute the combined evaluation using the according index challenge
282 // point.
283 let new_eval =
284 bound_coord * mle_evaluation_first + (F::ONE - bound_coord) * mle_evaluation_second;
285
286 (
287 new_eval,
288 prefix_vars_first[..prefix_vars_first.len() - 1].to_vec(),
289 )
290}
291
292/// Given a list of mles, the lsb fixed var index, and the MLE evaluation that
293/// contributes to it, this will go through all of them and find its pair (if
294/// none exists, we assume it is 0) and combine the two it will then update the
295/// original list of MLEs to contain the combined MLE evaluation and remove the
296/// original ones that were paired.
297fn find_pair_and_combine<F: Field>(
298 all_refs: &[MleEvaluationAndPrefixBits<F>],
299 prefix_indices: &[bool],
300 mle_evaluation: F,
301 chal_point: &[F],
302) -> Vec<MleEvaluationAndPrefixBits<F>> {
303 // We want to compare all fixed bits except the one at the least significant
304 // bit index.
305 let indices_to_compare = &prefix_indices[0..prefix_indices.len() - 1];
306 let mut mle_eval_pair = None;
307 let mut all_refs_updated = Vec::new();
308
309 for (mle_eval, mle_indices) in all_refs {
310 let max_slice_idx = mle_indices.len();
311 let compare_indices =
312 &mle_indices[0..std::cmp::min(prefix_indices.len() - 1, max_slice_idx)];
313 // We want to make sure we aren't combining an mle with itself!
314 if (compare_indices == indices_to_compare) && (mle_indices != prefix_indices) {
315 mle_eval_pair = Some(*mle_eval);
316 } else if mle_indices != prefix_indices {
317 all_refs_updated.push((*mle_eval, mle_indices.to_vec()));
318 }
319 }
320
321 // Add the paired mle to the list and return this new updated list.
322 let new_mle_to_add = combine_pair(mle_evaluation, mle_eval_pair, prefix_indices, chal_point);
323 all_refs_updated.push(new_mle_to_add);
324 all_refs_updated
325}