remainder/mle/
mle_description.rs1use serde::{Deserialize, Serialize};
2use shared_types::{transcript::VerifierTranscript, Field};
3
4use crate::{
5 circuit_layout::CircuitEvalMap, expression::expr_errors::ExpressionError, layer::LayerId,
6};
7
8use super::{dense::DenseMle, verifier_mle::VerifierMle, MleIndex};
9
10use anyhow::{anyhow, Result};
11
12#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
16#[serde(bound = "F: Field")]
17pub struct MleDescription<F: Field> {
18 layer_id: LayerId,
20
21 var_indices: Vec<MleIndex<F>>,
23}
24
25impl<F: Field> MleDescription<F> {
26 pub fn new(layer_id: LayerId, var_indices: &[MleIndex<F>]) -> Self {
29 Self {
30 layer_id,
31 var_indices: var_indices.to_vec(),
32 }
33 }
34
35 pub fn set_mle_indices(&mut self, new_mle_indices: Vec<MleIndex<F>>) {
38 self.var_indices = new_mle_indices;
39 }
40
41 pub fn index_mle_indices(&mut self, start_index: usize) {
44 let mut index_counter = start_index;
45 self.var_indices
46 .iter_mut()
47 .for_each(|mle_index| match mle_index {
48 MleIndex::Free => {
49 let indexed_mle_index = MleIndex::Indexed(index_counter);
50 index_counter += 1;
51 *mle_index = indexed_mle_index;
52 }
53 MleIndex::Fixed(_bit) => {}
54 _ => panic!("We should not have indexed or bound bits at this point!"),
55 });
56 }
57
58 pub fn layer_id(&self) -> LayerId {
60 self.layer_id
61 }
62
63 pub fn var_indices(&self) -> &[MleIndex<F>] {
65 &self.var_indices
66 }
67
68 pub fn num_free_vars(&self) -> usize {
70 self.var_indices.iter().fold(0, |acc, idx| {
71 acc + match idx {
72 MleIndex::Free => 1,
73 MleIndex::Indexed(_) => 1,
74 _ => 0,
75 }
76 })
77 }
78
79 pub fn prefix_bits(&self) -> Vec<bool> {
81 self.var_indices
82 .iter()
83 .filter_map(|idx| match idx {
84 MleIndex::Fixed(bit) => Some(*bit),
85 _ => None,
86 })
87 .collect()
88 }
89
90 pub fn into_dense_mle(&self, circuit_map: &CircuitEvalMap<F>) -> DenseMle<F> {
94 let data = circuit_map.get_data_from_circuit_mle(self).unwrap();
95 DenseMle::new_with_prefix_bits((*data).clone(), self.layer_id(), self.prefix_bits())
96 }
97
98 pub fn fix_variable(&mut self, var_index: usize, value: F) {
102 for mle_index in self.var_indices.iter_mut() {
103 if *mle_index == MleIndex::Indexed(var_index) {
104 mle_index.bind_index(value);
105 }
106 }
107 }
108
109 pub fn get_claim_point(&self, challenges: &[F]) -> Vec<F> {
112 self.var_indices
113 .iter()
114 .map(|index| match index {
115 MleIndex::Bound(chal, _idx) => *chal,
116 MleIndex::Fixed(chal) => F::from(*chal as u64),
117 MleIndex::Indexed(i) => challenges[*i],
118 _ => panic!("DenseMleRefDesc contained free variables!"),
119 })
120 .collect()
121 }
122
123 pub fn into_verifier_mle(
125 &self,
126 point: &[F],
127 transcript_reader: &mut impl VerifierTranscript<F>,
128 ) -> Result<VerifierMle<F>> {
129 let verifier_indices = self
130 .var_indices
131 .iter()
132 .map(|mle_index| match mle_index {
133 MleIndex::Indexed(idx) => Ok(MleIndex::Bound(point[*idx], *idx)),
134 MleIndex::Fixed(val) => Ok(MleIndex::Fixed(*val)),
135 _ => Err(anyhow!(ExpressionError::SelectorBitNotBoundError)),
136 })
137 .collect::<Result<Vec<MleIndex<F>>>>()?;
138
139 let eval = transcript_reader.consume_element("Fully bound MLE evaluation")?;
140
141 Ok(VerifierMle::new(self.layer_id, verifier_indices, eval))
142 }
143}