remainder/layer/
combine_mles.rs

1//! This module contains the code used to combine parts of a layer and combining
2//! them to determine the evaluation of a layer's layerwise bookkeeping table at
3//! a point.
4
5use crate::{
6    mle::{
7        dense::DenseMle,
8        evals::{Evaluations, MultilinearExtension},
9        Mle, MleIndex,
10    },
11    utils::mle::evaluate_mle_at_a_point_gray_codes,
12};
13use ark_std::cfg_iter_mut;
14
15use itertools::Itertools;
16
17use shared_types::Field;
18use thiserror::Error;
19
20use anyhow::{anyhow, Ok, Result};
21
22#[cfg(feature = "parallel")]
23use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
24
25/// Type alias for an MLE evaluation and its prefix bits, as this type is used
26/// throughout the combining code.
27type MleEvaluationAndPrefixBits<F> = (F, Vec<bool>);
28
29/// Error handling for gate mle construction
30#[derive(Error, Debug, Clone)]
31pub enum CombineMleRefError {
32    #[error("we have not fully combined all the mles because the list size is > 1")]
33    /// We have not fully combined all the mles because the list size is > 1.
34    NotFullyCombined,
35    #[error("we have an mle that is not fully fixed even after fixing on the challenge point")]
36    /// We have an mle that is not fully fixed even after fixing on the
37    /// challenge point.
38    MleRefNotFullyFixed,
39}
40
41/// This fixes mles with shared points in the claims so that we don't repeatedly
42/// do so.
43pub fn pre_fix_mles<F: Field>(mles: &mut [DenseMle<F>], chal_point: &[F], common_idx: Vec<usize>) {
44    cfg_iter_mut!(mles).for_each(|mle| {
45        common_idx.iter().for_each(|chal_idx| {
46            if let MleIndex::Indexed(idx_bit_num) = mle.mle_indices()[*chal_idx] {
47                mle.fix_variable_at_index(idx_bit_num, chal_point[*chal_idx]);
48            }
49        });
50    });
51}
52
53/// Function that prepares all the mles to be combined. We simply index all the
54/// MLEs that are to be fixed and combined, and ensure that all fixed bits are
55/// always contiguous.
56pub fn get_indexed_layer_mles_to_combine<F: Field>(mles: Vec<DenseMle<F>>) -> Vec<DenseMle<F>> {
57    // We split all the mles with a free bit within the fixed bits. This is in
58    // order to ensure that all the fixed bits are truly "prefix" bits.
59    let mut mles_split = collapse_mles_with_free_in_prefix(mles);
60
61    // Index all the MLEs as they will be fixed throughout the combining
62    // process.
63    cfg_iter_mut!(mles_split).for_each(|mle| {
64        mle.index_mle_indices(0);
65    });
66    mles_split
67}
68
69/// This function takes in a list of mles, a challenge point we want to combine
70/// them under, and returns the final value in the bookkeeping table of the
71/// combined mle. This is equivalent to combining all of these mles according to
72/// their prefix bits, and then fixing variable on this combined mle (which is
73/// the layerwise bookkeeping table). Instead, we fix variable as we combine as
74/// this keeps the bookkeeping table sizes at one.
75pub fn combine_mles_with_aggregate<F: Field>(mles: &[DenseMle<F>], chal_point: &[F]) -> Result<F> {
76    // We go through all of the mles and fix variable in all of them given at
77    // the correct indices so that they are fully bound.
78    let fix_var_mles = mles
79        .iter()
80        .map(|mle| {
81            let point_to_bind = mle
82                .mle_indices
83                .iter()
84                .enumerate()
85                .filter_map(|(idx, mle_idx)| {
86                    if let MleIndex::Indexed(_idx_num) = mle_idx {
87                        Some(chal_point[idx])
88                    } else {
89                        None
90                    }
91                })
92                .collect_vec();
93
94            // Fully evaluate the MLE at a point using the gray codes algorithm.
95            let mle_evaluation = evaluate_mle_at_a_point_gray_codes(&mle.mle, &point_to_bind);
96            let prefix_bits = mle
97                .mle_indices
98                .iter()
99                .filter_map(|mle_index| {
100                    if let MleIndex::Fixed(prefix_bool) = mle_index {
101                        Some(*prefix_bool)
102                    } else {
103                        None
104                    }
105                })
106                .collect_vec();
107
108            (mle_evaluation, prefix_bits)
109        })
110        .collect_vec();
111
112    // Mutable variable that is overwritten every time we combine mles.
113    let mut updated_list = fix_var_mles;
114
115    // A loop that breaks when all the mles no longer have any fixed bits and
116    // only have free bits. This means we have fully combined the MLEs to form
117    // the layerwise bookkeeping table (but fully bound to a point).
118    loop {
119        // We first get the lsb fixed bit and the evaluation of the MLE that
120        // contributes to it.
121        let (mle_evaluation, mle_prefix_bits) = get_lsb_fixed_var(&updated_list);
122
123        // There are only 0 prefix bits for the MLE contributing to the lsb
124        // fixed bit if the MLEs have been fully combined.
125        if mle_prefix_bits.is_empty() {
126            break;
127        }
128
129        // Otherwise, overwrite updated_list to contain the combined MLE instead
130        // of the two MLEs contributing to the lsb fixed bit.
131        updated_list =
132            find_pair_and_combine(&updated_list, mle_prefix_bits, *mle_evaluation, chal_point);
133    }
134
135    // The list now should only have one combined mle, and its bookkeeping table
136    // should only have one value in it since we were binding variables as we
137    // were combining.
138    if updated_list.len() > 1 {
139        return Err(anyhow!(CombineMleRefError::NotFullyCombined));
140    }
141    let (full_eval, prefix_bits) = &updated_list[0];
142    assert_eq!(
143        prefix_bits.len(),
144        0,
145        "there should be no more prefix bits left after fully combining!"
146    );
147
148    Ok(*full_eval)
149}
150
151/// This function takes an MLE that has a free variable in between fixed
152/// variables, and it splits it into two MLEs, one where the free variable is
153/// replaced with `Fixed(false)`, and the other where it is replaced with
154/// `Fixed(true)`. This ensures that all the fixed bits are contiguous. NOTE we
155/// assume that this function is called on an mle that has a free variable
156/// within a bunch of fixed variables (note how it is used in the
157/// `collapse_mles_with_free_in_prefix` function)
158fn split_mle<F: Field>(mle: &DenseMle<F>) -> Vec<DenseMle<F>> {
159    // Get the index of the first free bit in the mle.
160    let first_free_idx: usize = mle.mle_indices().iter().enumerate().fold(
161        mle.mle_indices().len(),
162        |acc, (idx, mle_idx)| {
163            if let MleIndex::Free = mle_idx {
164                std::cmp::min(acc, idx)
165            } else {
166                acc
167            }
168        },
169    );
170
171    // Compute the correct indices, we have the first one be false, the second
172    // one as true instead of the free bit.
173    let first_indices = mle.mle_indices()[0..first_free_idx]
174        .iter()
175        .cloned()
176        .chain(std::iter::once(MleIndex::Fixed(false)))
177        .chain(mle.mle_indices()[first_free_idx + 1..].iter().cloned())
178        .collect_vec();
179    let second_indices = mle.mle_indices()[0..first_free_idx]
180        .iter()
181        .cloned()
182        .chain(std::iter::once(MleIndex::Fixed(true)))
183        .chain(mle.mle_indices()[first_free_idx + 1..].iter().cloned())
184        .collect_vec();
185
186    // Construct the first MLE in the pair.
187    let first_mle = DenseMle {
188        mle: MultilinearExtension::new_from_evals(Evaluations::<F>::new(
189            mle.num_free_vars() - 1,
190            mle.mle.iter().step_by(2).collect_vec(),
191        )),
192        mle_indices: first_indices,
193        layer_id: mle.layer_id,
194    };
195
196    // Second mle in the pair.
197    let second_mle = DenseMle {
198        mle: MultilinearExtension::new_from_evals(Evaluations::<F>::new(
199            mle.num_free_vars() - 1,
200            mle.mle.iter().skip(1).step_by(2).collect_vec(),
201        )),
202        mle_indices: second_indices,
203        layer_id: mle.layer_id,
204    };
205
206    vec![first_mle, second_mle]
207}
208
209/// This function will take a list of MLEs and updates the list to contain MLEs
210/// where all fixed bits are contiguous
211fn collapse_mles_with_free_in_prefix<F: Field>(mles: Vec<DenseMle<F>>) -> Vec<DenseMle<F>> {
212    mles.into_iter()
213        .flat_map(|mle| {
214            // This iterates through the mle indices to check whether there is a
215            // free bit within the fixed bits.
216            let (_, contains_free_in_fixed) = mle.mle_indices().iter().fold(
217                (false, false),
218                |(free_seen_so_far, fixed_after_free_so_far), mle_idx| match mle_idx {
219                    MleIndex::Free => (true, fixed_after_free_so_far),
220                    MleIndex::Fixed(_) => (free_seen_so_far, free_seen_so_far),
221                    _ => (free_seen_so_far, fixed_after_free_so_far),
222                },
223            );
224            // If true, we split, otherwise, we don't.
225            if contains_free_in_fixed {
226                split_mle(&mle)
227            } else {
228                vec![mle]
229            }
230        })
231        .collect()
232}
233
234/// Gets the index of the least significant bit (lsb) of the fixed bits out of a
235/// vector of MLEs.
236///
237/// In other words, this is the MLE evaluation pertaining to the MLE with the
238/// most fixed bits.
239fn get_lsb_fixed_var<F: Field>(
240    mles: &[MleEvaluationAndPrefixBits<F>],
241) -> &MleEvaluationAndPrefixBits<F> {
242    mles.iter()
243        .max_by_key(|(_mle_evaluation, prefix_bits)| prefix_bits.len())
244        .unwrap()
245}
246
247/// Given an MLE evaluation, and an option of a second MLE evaluation pair, this
248/// combines the two together this assumes that the first MLE evaluation and the
249/// second MLE evaluation are pairs, if the second MLE evaluation is a Some()
250///
251/// A pair consists of two MLE evaluations that match in every fixed bit except
252/// for the least significant one. This is because we combine in the reverse
253/// order that we split in terms of selectors, and we split in terms of
254/// selectors by doing huffman (most significant bit).
255///
256/// Example: if mle_evaluation_first has fixed bits true, true, false, its pair
257/// would have fixed bits true, true, true. When we combine them, the combined
258/// MLE has fixed bits true, true. We also simultaneously update the combined
259/// evaluation to use the challenge according to the index we combined at.
260///
261/// If there is no pair, then this is assumed to be an mle with all 0s.
262fn combine_pair<F: Field>(
263    mle_evaluation_first: F,
264    maybe_mle_evaluation_second: Option<F>,
265    prefix_vars_first: &[bool],
266    chal_point: &[F],
267) -> MleEvaluationAndPrefixBits<F> {
268    // If the second mle is None, we assume its bookkeeping table is all zeros.
269    // We are dealing with fully fixed mles, so we just use F::ZERO.
270    let mle_evaluation_second = maybe_mle_evaluation_second.unwrap_or(F::ZERO);
271
272    // Depending on whether the lsb fixed bit was true or false, we bind it to
273    // the correct challenge point at this index this is either the challenge
274    // point at the index, or one minus this value.
275    let bound_coord = if !prefix_vars_first.last().unwrap() {
276        F::ONE - chal_point[prefix_vars_first.len() - 1]
277    } else {
278        chal_point[prefix_vars_first.len() - 1]
279    };
280
281    // We compute the combined evaluation using the according index challenge
282    // point.
283    let new_eval =
284        bound_coord * mle_evaluation_first + (F::ONE - bound_coord) * mle_evaluation_second;
285
286    (
287        new_eval,
288        prefix_vars_first[..prefix_vars_first.len() - 1].to_vec(),
289    )
290}
291
292/// Given a list of mles, the lsb fixed var index, and the MLE evaluation that
293/// contributes to it, this will go through all of them and find its pair (if
294/// none exists, we assume it is 0) and combine the two it will then update the
295/// original list of MLEs to contain the combined MLE evaluation and remove the
296/// original ones that were paired.
297fn find_pair_and_combine<F: Field>(
298    all_refs: &[MleEvaluationAndPrefixBits<F>],
299    prefix_indices: &[bool],
300    mle_evaluation: F,
301    chal_point: &[F],
302) -> Vec<MleEvaluationAndPrefixBits<F>> {
303    // We want to compare all fixed bits except the one at the least significant
304    // bit index.
305    let indices_to_compare = &prefix_indices[0..prefix_indices.len() - 1];
306    let mut mle_eval_pair = None;
307    let mut all_refs_updated = Vec::new();
308
309    for (mle_eval, mle_indices) in all_refs {
310        let max_slice_idx = mle_indices.len();
311        let compare_indices =
312            &mle_indices[0..std::cmp::min(prefix_indices.len() - 1, max_slice_idx)];
313        // We want to make sure we aren't combining an mle with itself!
314        if (compare_indices == indices_to_compare) && (mle_indices != prefix_indices) {
315            mle_eval_pair = Some(*mle_eval);
316        } else if mle_indices != prefix_indices {
317            all_refs_updated.push((*mle_eval, mle_indices.to_vec()));
318        }
319    }
320
321    // Add the paired mle to the list and return this new updated list.
322    let new_mle_to_add = combine_pair(mle_evaluation, mle_eval_pair, prefix_indices, chal_point);
323    all_refs_updated.push(new_mle_to_add);
324    all_refs_updated
325}