1use shared_types::Field;
5
6use super::LayerId;
7use crate::mle::dense::DenseMle;
8use crate::mle::mle_description::MleDescription;
9use crate::mle::Mle;
10
11#[derive(Debug)]
14pub struct PostSumcheckLayer<F: Field, T>(pub Vec<Product<F, T>>);
15
16impl<F: Field> PostSumcheckLayer<F, F> {
19 pub fn evaluate_scalar(&self) -> F {
21 self.0.iter().fold(F::ZERO, |acc, product| {
22 acc + product.get_result() * product.coefficient
23 })
24 }
25}
26
27#[derive(Debug)]
31pub struct Product<F: Field, T> {
32 pub intermediates: Vec<Intermediate<F, T>>,
35 pub coefficient: F,
37}
38
39impl<F: Field> Product<F, Option<F>> {
40 pub fn new(mles: &[MleDescription<F>], coefficient: F, bindings: &[F]) -> Self {
42 if mles.is_empty() {
43 return Product {
44 intermediates: vec![Intermediate::Composite {
45 value: Some(F::ONE),
46 }],
47 coefficient,
48 };
49 }
50 let mut intermediates = vec![Self::build_atom(&mles[0], bindings)];
51 mles.iter().skip(1).for_each(|mle| {
52 intermediates.push(Self::build_atom(mle, bindings));
53 intermediates.push(Intermediate::Composite { value: None });
54 });
55 Product {
56 intermediates,
57 coefficient,
58 }
59 }
60
61 pub fn new_from_mul_gate(
63 mles: &[MleDescription<F>],
64 coefficient: F,
65 bindings: &[&[F]],
66 ) -> Self {
67 if mles.is_empty() {
68 return Product {
69 intermediates: vec![Intermediate::Composite {
70 value: Some(F::ONE),
71 }],
72 coefficient,
73 };
74 }
75 let mut intermediates = vec![Self::build_atom(&mles[0], bindings[0])];
76 mles.iter().enumerate().skip(1).for_each(|(idx, mle_ref)| {
77 intermediates.push(Self::build_atom(mle_ref, bindings[idx]));
78 intermediates.push(Intermediate::Composite { value: None });
79 });
80 Product {
81 intermediates,
82 coefficient,
83 }
84 }
85
86 fn build_atom(mle: &MleDescription<F>, bindings: &[F]) -> Intermediate<F, Option<F>> {
88 Intermediate::Atom {
89 layer_id: mle.layer_id(),
90 point: mle.get_claim_point(bindings),
91 value: None,
92 }
93 }
94}
95
96impl<F: Field> Product<F, F> {
97 pub fn new(mles: &[DenseMle<F>], coefficient: F) -> Self {
100 assert!(mles.iter().all(|mle| mle.is_fully_bounded()));
102 if mles.is_empty() {
103 return Product {
104 intermediates: vec![Intermediate::Composite { value: F::ONE }],
105 coefficient,
106 };
107 }
108 let mut intermediates = vec![Self::build_atom(&mles[0])];
109 let _ = mles.iter().skip(1).fold(mles[0].value(), |acc, mle| {
110 let prod_val = acc * mle.value();
111 intermediates.push(Self::build_atom(mle));
112 intermediates.push(Intermediate::Composite { value: prod_val });
113 prod_val
114 });
115 Product {
116 intermediates,
117 coefficient,
118 }
119 }
120
121 fn build_atom(mle: &DenseMle<F>) -> Intermediate<F, F> {
123 Intermediate::Atom {
124 layer_id: mle.layer_id,
125 point: mle.get_bound_point(),
126 value: mle.value(),
127 }
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Intermediate<F: Field, T> {
135 Atom {
137 layer_id: LayerId,
139 point: Vec<F>,
141 value: T,
143 },
144 Composite {
146 value: T,
148 },
149}
150
151impl<F: Field, T: Copy> PostSumcheckLayer<F, T> {
152 pub fn get_values(&self) -> Vec<T> {
155 self.0
156 .iter()
157 .flat_map(|product| {
158 product
159 .intermediates
160 .iter()
161 .map(|pp| match pp {
162 Intermediate::Atom { value, .. } => *value,
163 Intermediate::Composite { value } => *value,
164 })
165 .collect::<Vec<T>>()
166 })
167 .collect()
168 }
169
170 pub fn get_coefficients(&self) -> Vec<F> {
172 self.0.iter().map(|product| product.coefficient).collect()
173 }
174}
175
176pub fn new_with_values<F: Field, S, T: Clone>(
179 post_sumcheck_layer: &PostSumcheckLayer<F, S>,
180 values: &[T],
181) -> PostSumcheckLayer<F, T> {
182 let total_len: usize = post_sumcheck_layer
183 .0
184 .iter()
185 .map(|product| product.intermediates.len())
186 .sum();
187 assert_eq!(total_len, values.len());
188 let mut start = 0;
189 PostSumcheckLayer(
190 post_sumcheck_layer
191 .0
192 .iter()
193 .map(|product| {
194 let end = start + product.intermediates.len();
195 let product_values = values[start..end].to_vec();
196 start = end;
197 new_with_values_single(product, product_values)
198 })
199 .collect(),
200 )
201}
202
203fn new_with_values_single<F: Field, S, T>(
206 product: &Product<F, S>,
207 values: Vec<T>,
208) -> Product<F, T> {
209 assert_eq!(product.intermediates.len(), values.len());
210 Product {
211 coefficient: product.coefficient,
212 intermediates: product
213 .intermediates
214 .iter()
215 .zip(values)
216 .map(|(pp, value)| match pp {
217 Intermediate::Atom {
218 layer_id, point, ..
219 } => Intermediate::Atom {
220 layer_id: *layer_id,
221 point: point.clone(),
222 value,
223 },
224 Intermediate::Composite { .. } => Intermediate::Composite { value },
225 })
226 .collect(),
227 }
228}
229
230impl<F: Field, T: Clone> Product<F, T> {
231 pub fn get_result(&self) -> T {
234 let last = &self.intermediates[self.intermediates.len() - 1];
235 match last {
236 Intermediate::Atom { value, .. } => {
237 assert_eq!(self.intermediates.len(), 1);
239 value.clone()
240 }
241 Intermediate::Composite { value, .. } => value.clone(),
242 }
243 }
244
245 pub fn get_product_triples(&self) -> Option<Vec<(T, T, T)>> {
247 if self.intermediates.len() > 1 {
248 assert!(self.intermediates.len() >= 3);
249 let values = self
250 .intermediates
251 .iter()
252 .map(|pp| match pp {
253 Intermediate::Atom { value, .. } => value.clone(),
254 Intermediate::Composite { value, .. } => value.clone(),
255 })
256 .collect::<Vec<_>>();
257 Some(
258 values
259 .windows(3)
260 .map(|window| {
261 if let [x, y, z] = window {
262 (x.clone(), y.clone(), z.clone())
263 } else {
264 unreachable!()
265 }
266 })
267 .collect::<Vec<_>>(),
268 )
269 } else {
270 None
271 }
272 }
273}