remainder/mle/
mle_description.rs

1use serde::{Deserialize, Serialize};
2use shared_types::{transcript::VerifierTranscript, Field};
3
4use crate::{
5    circuit_layout::CircuitEvalMap, expression::expr_errors::ExpressionError, layer::LayerId,
6};
7
8use super::{dense::DenseMle, verifier_mle::VerifierMle, MleIndex};
9
10use anyhow::{anyhow, Result};
11
12/// A metadata-only version of [crate::mle::dense::DenseMle] used in the Circuit
13/// Descrption.  A [MleDescription] is stored in the leaves of an `Expression<F,
14/// ExprDescription>` tree.
15#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
16#[serde(bound = "F: Field")]
17pub struct MleDescription<F: Field> {
18    /// Layer whose data this MLE's is a subset of.
19    layer_id: LayerId,
20
21    /// A list of indices where the free variables have been assigned an index.
22    var_indices: Vec<MleIndex<F>>,
23}
24
25impl<F: Field> MleDescription<F> {
26    /// Create a new [MleDescription] given its layer id and the [MleIndex]s that it holds.
27    /// This is effectively the "shape" of a [DenseMle].
28    pub fn new(layer_id: LayerId, var_indices: &[MleIndex<F>]) -> Self {
29        Self {
30            layer_id,
31            var_indices: var_indices.to_vec(),
32        }
33    }
34
35    /// Replace the current MLE indices stored with custom MLE indices. Most
36    /// useful in [crate::layer::matmult::MatMult], where we do index manipulation.
37    pub fn set_mle_indices(&mut self, new_mle_indices: Vec<MleIndex<F>>) {
38        self.var_indices = new_mle_indices;
39    }
40
41    /// Convert [MleIndex::Free] into [MleIndex::Indexed] with the correct
42    /// index labeling, given by start_index parameter.
43    pub fn index_mle_indices(&mut self, start_index: usize) {
44        let mut index_counter = start_index;
45        self.var_indices
46            .iter_mut()
47            .for_each(|mle_index| match mle_index {
48                MleIndex::Free => {
49                    let indexed_mle_index = MleIndex::Indexed(index_counter);
50                    index_counter += 1;
51                    *mle_index = indexed_mle_index;
52                }
53                MleIndex::Fixed(_bit) => {}
54                _ => panic!("We should not have indexed or bound bits at this point!"),
55            });
56    }
57
58    /// Returns the [LayerId] of this MleDescription.
59    pub fn layer_id(&self) -> LayerId {
60        self.layer_id
61    }
62
63    /// Returns the MLE indices of this MleDescription.
64    pub fn var_indices(&self) -> &[MleIndex<F>] {
65        &self.var_indices
66    }
67
68    /// The number of [MleIndex::Indexed] OR [MleIndex::Free] bits in this MLE.
69    pub fn num_free_vars(&self) -> usize {
70        self.var_indices.iter().fold(0, |acc, idx| {
71            acc + match idx {
72                MleIndex::Free => 1,
73                MleIndex::Indexed(_) => 1,
74                _ => 0,
75            }
76        })
77    }
78
79    /// Get the bits in the MLE that are fixed bits.
80    pub fn prefix_bits(&self) -> Vec<bool> {
81        self.var_indices
82            .iter()
83            .filter_map(|idx| match idx {
84                MleIndex::Fixed(bit) => Some(*bit),
85                _ => None,
86            })
87            .collect()
88    }
89
90    /// Convert this MLE into a [DenseMle] using the [CircuitEvalMap],
91    /// which holds information using the prefix bits and layer id
92    /// on the data that should be stored in this MLE.
93    pub fn into_dense_mle(&self, circuit_map: &CircuitEvalMap<F>) -> DenseMle<F> {
94        let data = circuit_map.get_data_from_circuit_mle(self).unwrap();
95        DenseMle::new_with_prefix_bits((*data).clone(), self.layer_id(), self.prefix_bits())
96    }
97
98    /// Bind the variable with index `var_index` to `value`. Note that since
99    /// [MleDescription] is the representation of a multilinear extension function
100    /// sans data, it need not alter its internal MLE evaluations in any way.
101    pub fn fix_variable(&mut self, var_index: usize, value: F) {
102        for mle_index in self.var_indices.iter_mut() {
103            if *mle_index == MleIndex::Indexed(var_index) {
104                mle_index.bind_index(value);
105            }
106        }
107    }
108
109    /// Gets the values of the bound and fixed MLE indices of this MLE,
110    /// panicking if the MLE is not fully bound.
111    pub fn get_claim_point(&self, challenges: &[F]) -> Vec<F> {
112        self.var_indices
113            .iter()
114            .map(|index| match index {
115                MleIndex::Bound(chal, _idx) => *chal,
116                MleIndex::Fixed(chal) => F::from(*chal as u64),
117                MleIndex::Indexed(i) => challenges[*i],
118                _ => panic!("DenseMleRefDesc contained free variables!"),
119            })
120            .collect()
121    }
122
123    /// Convert this MLE into a [VerifierMle], which represents a fully-bound MLE.
124    pub fn into_verifier_mle(
125        &self,
126        point: &[F],
127        transcript_reader: &mut impl VerifierTranscript<F>,
128    ) -> Result<VerifierMle<F>> {
129        let verifier_indices = self
130            .var_indices
131            .iter()
132            .map(|mle_index| match mle_index {
133                MleIndex::Indexed(idx) => Ok(MleIndex::Bound(point[*idx], *idx)),
134                MleIndex::Fixed(val) => Ok(MleIndex::Fixed(*val)),
135                _ => Err(anyhow!(ExpressionError::SelectorBitNotBoundError)),
136            })
137            .collect::<Result<Vec<MleIndex<F>>>>()?;
138
139        let eval = transcript_reader.consume_element("Fully bound MLE evaluation")?;
140
141        Ok(VerifierMle::new(self.layer_id, verifier_indices, eval))
142    }
143}