1use itertools::{repeat_n, Itertools};
4use serde::{Deserialize, Serialize};
5
6use crate::claims::RawClaim;
7use crate::layer::LayerId;
8use shared_types::Field;
9
10use super::evals::{Evaluations, EvaluationsIterator};
11use super::Mle;
12use super::{mle_enum::MleEnum, MleIndex};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[serde(bound = "F: Field")]
17pub struct ZeroMle<F: Field> {
18 pub(crate) mle_indices: Vec<MleIndex<F>>,
19 pub(crate) num_vars: usize,
22 pub(crate) layer_id: LayerId,
23 pub(crate) zero: [F; 1],
24 pub(crate) zero_eval: Evaluations<F>,
25 pub(crate) indexed: bool,
26}
27
28impl<F: Field> ZeroMle<F> {
29 pub fn new(num_vars: usize, prefix_bits: Option<Vec<MleIndex<F>>>, layer_id: LayerId) -> Self {
32 let mle_indices = prefix_bits
33 .into_iter()
34 .flatten()
35 .chain(repeat_n(MleIndex::Free, num_vars))
36 .collect_vec();
37
38 Self {
39 mle_indices,
40 num_vars,
41 layer_id,
42 zero: [F::ZERO],
43 zero_eval: Evaluations::new(num_vars, vec![F::ZERO]),
44 indexed: false,
45 }
46 }
47}
48
49impl<F: Field> Mle<F> for ZeroMle<F> {
50 fn mle_indices(&self) -> &[MleIndex<F>] {
51 &self.mle_indices
52 }
53
54 fn num_free_vars(&self) -> usize {
55 self.num_vars
56 }
57
58 fn fix_variable(&mut self, round_index: usize, challenge: F) -> Option<RawClaim<F>> {
59 for mle_index in self.mle_indices.iter_mut() {
60 if *mle_index == MleIndex::Indexed(round_index) {
61 mle_index.bind_index(challenge);
62 }
63 }
64
65 self.num_vars -= 1;
67
68 if self.num_vars == 0 {
69 let send_claim = RawClaim::new(
70 self.mle_indices
71 .iter()
72 .map(|index| index.val().unwrap())
73 .collect_vec(),
74 F::ZERO,
75 );
76 Some(send_claim)
77 } else {
78 None
79 }
80 }
81
82 fn fix_variable_at_index(&mut self, indexed_bit_index: usize, point: F) -> Option<RawClaim<F>> {
83 self.fix_variable(indexed_bit_index, point)
84 }
85
86 fn index_mle_indices(&mut self, curr_index: usize) -> usize {
87 let mut new_indices = 0;
88 for mle_index in self.mle_indices.iter_mut() {
89 if *mle_index == MleIndex::Free {
90 *mle_index = MleIndex::Indexed(curr_index + new_indices);
91 new_indices += 1;
92 }
93 }
94
95 curr_index + new_indices
96 }
97
98 fn layer_id(&self) -> LayerId {
99 self.layer_id
100 }
101
102 fn get_enum(self) -> MleEnum<F> {
103 MleEnum::Zero(self)
104 }
105
106 #[doc = " Get the padded set of evaluations over the boolean hypercube; useful for"]
107 #[doc = " constructing the input layer."]
108 fn get_padded_evaluations(&self) -> Vec<F> {
109 todo!()
110 }
111
112 #[doc = " Mutates the MLE in order to set the prefix bits. This is needed when we"]
113 #[doc = " are working with dataparallel circuits and new bits need to be added."]
114 fn add_prefix_bits(&mut self, _new_bits: Vec<MleIndex<F>>) {
115 todo!()
116 }
117
118 fn len(&self) -> usize {
119 1
120 }
121
122 fn iter(&self) -> EvaluationsIterator<'_, F> {
123 self.zero_eval.iter()
124 }
125
126 fn first(&self) -> F {
127 F::ZERO
128 }
129
130 fn value(&self) -> F {
131 assert!(self.is_fully_bounded());
132 F::ZERO
133 }
134
135 fn get(&self, index: usize) -> Option<F> {
136 if index < self.len() {
137 Some(F::ZERO)
138 } else {
139 None
140 }
141 }
142}