1use crate::{
6 circuit_layout::CircuitEvalMap,
7 layer::{
8 gate::BinaryOperation,
9 product::{PostSumcheckLayer, Product},
10 },
11 mle::{
12 evals::MultilinearExtension, mle_description::MleDescription, verifier_mle::VerifierMle,
13 MleIndex,
14 },
15};
16use ark_std::log2;
17use itertools::Itertools;
18use serde::{Deserialize, Serialize};
19use std::{
20 cmp::max,
21 collections::{HashMap, HashSet},
22 fmt::Debug,
23 ops::{Add, Mul, Neg, Sub},
24};
25
26use shared_types::{transcript::VerifierTranscript, Field};
27
28use super::{
29 expr_errors::ExpressionError,
30 generic_expr::{Expression, ExpressionNode, ExpressionType},
31 prover_expr::ProverExpr,
32 verifier_expr::VerifierExpr,
33};
34
35use anyhow::{anyhow, Result};
36
37#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
40pub struct ExprDescription;
41
42impl<F: Field> ExpressionType<F> for ExprDescription {
46 type MLENodeRepr = MleDescription<F>;
47 type MleVec = ();
48}
49
50impl<F: Field> Expression<F, ExprDescription> {
51 pub fn bind(
55 &self,
56 point: &[F],
57 transcript_reader: &mut impl VerifierTranscript<F>,
58 ) -> Result<Expression<F, VerifierExpr>> {
59 Ok(Expression::new(
60 self.expression_node
61 .into_verifier_node(point, transcript_reader)?,
62 (),
63 ))
64 }
65
66 pub fn from_mle_desc(mle_desc: MleDescription<F>) -> Self {
69 Self {
70 expression_node: ExpressionNode::<F, ExprDescription>::Mle(mle_desc),
71 mle_vec: (),
72 }
73 }
74
75 pub fn get_all_nonlinear_rounds(&self) -> Vec<usize> {
78 self.expression_node
79 .get_all_nonlinear_rounds(&mut vec![], &self.mle_vec)
80 .into_iter()
81 .sorted()
82 .collect()
83 }
84
85 pub fn get_all_rounds(&self) -> Vec<usize> {
88 self.expression_node
89 .get_all_rounds(&mut vec![], &self.mle_vec)
90 .into_iter()
91 .sorted()
92 .collect()
93 }
94
95 pub fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
97 let circuit_mles = self.expression_node.get_circuit_mles();
98 circuit_mles
99 }
100
101 pub fn index_mle_vars(&mut self, start_index: usize) {
103 self.expression_node.index_mle_vars(start_index);
104 }
105
106 pub fn into_prover_expression(
109 &self,
110 circuit_map: &CircuitEvalMap<F>,
111 ) -> Expression<F, ProverExpr> {
112 self.expression_node.into_prover_expression(circuit_map)
113 }
114
115 pub fn get_post_sumcheck_layer(
118 &self,
119 multiplier: F,
120 challenges: &[F],
121 ) -> PostSumcheckLayer<F, Option<F>> {
122 self.expression_node
123 .get_post_sumcheck_layer(multiplier, challenges, &self.mle_vec)
124 }
125
126 pub fn get_max_degree(&self) -> usize {
128 self.expression_node.get_max_degree(&self.mle_vec)
129 }
130
131 pub fn get_round_degree(&self, curr_round: usize) -> usize {
134 let mut round_degree = 1;
136
137 let mut get_degree_closure = |expr: &ExpressionNode<F, ExprDescription>,
138 _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec|
139 -> Result<()> {
140 let round_degree = &mut round_degree;
141
142 if let ExpressionNode::Product(circuit_mles) = expr {
144 let mut product_round_degree: usize = 0;
145 for circuit_mle in circuit_mles {
146 let mle_indices = circuit_mle.var_indices();
147 for mle_index in mle_indices {
148 if *mle_index == MleIndex::Indexed(curr_round) {
149 product_round_degree += 1;
150 break;
151 }
152 }
153 }
154 if *round_degree < product_round_degree {
155 *round_degree = product_round_degree;
156 }
157 }
158 Ok(())
159 };
160
161 self.traverse(&mut get_degree_closure).unwrap();
162 round_degree + 1
164 }
165}
166
167impl<F: Field> ExpressionNode<F, ExprDescription> {
168 pub fn into_verifier_node(
171 &self,
172 point: &[F],
173 transcript_reader: &mut impl VerifierTranscript<F>,
174 ) -> Result<ExpressionNode<F, VerifierExpr>> {
175 match self {
176 ExpressionNode::Constant(scalar) => Ok(ExpressionNode::Constant(*scalar)),
177 ExpressionNode::Selector(index, lhs, rhs) => match index {
178 MleIndex::Indexed(idx) => Ok(ExpressionNode::Selector(
179 MleIndex::Bound(point[*idx], *idx),
180 Box::new(lhs.into_verifier_node(point, transcript_reader)?),
181 Box::new(rhs.into_verifier_node(point, transcript_reader)?),
182 )),
183 _ => Err(anyhow!(ExpressionError::SelectorBitNotBoundError)),
184 },
185 ExpressionNode::Mle(circuit_mle) => Ok(ExpressionNode::Mle(
186 circuit_mle.into_verifier_mle(point, transcript_reader)?,
187 )),
188 ExpressionNode::Sum(lhs, rhs) => Ok(ExpressionNode::Sum(
189 Box::new(lhs.into_verifier_node(point, transcript_reader)?),
190 Box::new(rhs.into_verifier_node(point, transcript_reader)?),
191 )),
192 ExpressionNode::Product(circuit_mles) => {
193 let verifier_mles: Vec<VerifierMle<F>> = circuit_mles
194 .iter()
195 .map(|circuit_mle| circuit_mle.into_verifier_mle(point, transcript_reader))
196 .collect::<Result<Vec<VerifierMle<F>>>>()?;
197
198 Ok(ExpressionNode::Product(verifier_mles))
199 }
200 ExpressionNode::Scaled(circuit_mle, scalar) => Ok(ExpressionNode::Scaled(
201 Box::new(circuit_mle.into_verifier_node(point, transcript_reader)?),
202 *scalar,
203 )),
204 }
205 }
206
207 pub fn compute_bookkeeping_table(
211 &self,
212 circuit_map: &CircuitEvalMap<F>,
213 ) -> Option<MultilinearExtension<F>> {
214 let output_data: Option<MultilinearExtension<F>> = match self {
215 ExpressionNode::Mle(circuit_mle) => {
216 let maybe_mle = circuit_map.get_data_from_circuit_mle(circuit_mle);
217 if let Ok(mle) = maybe_mle {
218 Some(mle.clone())
219 } else {
220 return None;
221 }
222 }
223 ExpressionNode::Product(circuit_mles) => {
224 let mle_bookkeeping_tables = circuit_mles
225 .iter()
226 .map(|circuit_mle| {
227 circuit_map
228 .get_data_from_circuit_mle(circuit_mle) .map(|data| data.to_vec()) })
231 .collect::<Result<Vec<Vec<F>>>>() .ok()?;
233 Some(evaluate_bookkeeping_tables_given_operation(
234 &mle_bookkeeping_tables,
235 BinaryOperation::Mul,
236 ))
237 }
238 ExpressionNode::Sum(a, b) => {
239 let a_bookkeeping_table = a.compute_bookkeeping_table(circuit_map)?;
240 let b_bookkeeping_table = b.compute_bookkeeping_table(circuit_map)?;
241 Some(evaluate_bookkeeping_tables_given_operation(
242 &[
243 (a_bookkeeping_table.to_vec()),
244 (b_bookkeeping_table.to_vec()),
245 ],
246 BinaryOperation::Add,
247 ))
248 }
249 ExpressionNode::Scaled(a, scale) => {
250 let a_bookkeeping_table = a.compute_bookkeeping_table(circuit_map)?;
251 Some(MultilinearExtension::new(
252 a_bookkeeping_table
253 .iter()
254 .map(|elem| elem * scale)
255 .collect_vec(),
256 ))
257 }
258 ExpressionNode::Selector(_mle_index, a, b) => {
259 let a_bookkeeping_table = a.compute_bookkeeping_table(circuit_map)?;
260 let b_bookkeeping_table = b.compute_bookkeeping_table(circuit_map)?;
261 assert_eq!(
262 a_bookkeeping_table.num_vars(),
263 b_bookkeeping_table.num_vars()
264 );
265 Some(MultilinearExtension::new(
266 a_bookkeeping_table
267 .iter()
268 .chain(b_bookkeeping_table.iter())
269 .collect_vec(),
270 ))
271 }
272 ExpressionNode::Constant(value) => Some(MultilinearExtension::new(vec![*value])),
273 };
274
275 output_data
276 }
277
278 #[allow(clippy::too_many_arguments)]
281 pub fn reduce<T>(
282 &self,
283 constant: &mut impl FnMut(F) -> T,
284 selector_column: &mut impl FnMut(&MleIndex<F>, T, T) -> T,
285 mle_eval: &mut impl FnMut(&<ExprDescription as ExpressionType<F>>::MLENodeRepr) -> T,
286 sum: &mut impl FnMut(T, T) -> T,
287 product: &mut impl FnMut(&[<ExprDescription as ExpressionType<F>>::MLENodeRepr]) -> T,
288 scaled: &mut impl FnMut(T, F) -> T,
289 ) -> T {
290 match self {
291 ExpressionNode::Constant(scalar) => constant(*scalar),
292 ExpressionNode::Selector(index, a, b) => {
293 let lhs = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
294 let rhs = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
295 selector_column(index, lhs, rhs)
296 }
297 ExpressionNode::Mle(query) => mle_eval(query),
298 ExpressionNode::Sum(a, b) => {
299 let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
300 let b = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
301 sum(a, b)
302 }
303 ExpressionNode::Product(queries) => product(queries),
304 ExpressionNode::Scaled(a, f) => {
305 let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
306 scaled(a, *f)
307 }
308 }
309 }
310
311 pub fn get_all_nonlinear_rounds(
314 &self,
315 curr_nonlinear_indices: &mut Vec<usize>,
316 _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec,
317 ) -> Vec<usize> {
318 let nonlinear_indices_in_node = {
319 match self {
320 ExpressionNode::Product(verifier_mles) => {
324 let mut product_nonlinear_indices: HashSet<usize> = HashSet::new();
325 let mut product_indices_counts: HashMap<MleIndex<F>, usize> = HashMap::new();
326
327 verifier_mles.iter().for_each(|verifier_mle| {
328 verifier_mle.var_indices().iter().for_each(|mle_index| {
329 let curr_count = {
330 if product_indices_counts.contains_key(mle_index) {
331 product_indices_counts.get(mle_index).unwrap()
332 } else {
333 &0
334 }
335 };
336 product_indices_counts.insert(mle_index.clone(), curr_count + 1);
337 })
338 });
339
340 product_indices_counts
341 .into_iter()
342 .for_each(|(mle_index, count)| {
343 if count > 1 {
344 if let MleIndex::Indexed(i) = mle_index {
345 product_nonlinear_indices.insert(i);
346 } else if let MleIndex::Bound(_, i) = mle_index {
347 product_nonlinear_indices.insert(i);
348 }
349 }
350 });
351
352 product_nonlinear_indices
353 }
354 ExpressionNode::Selector(_sel_index, a, b) => {
357 let mut sel_nonlinear_indices: HashSet<usize> = HashSet::new();
358 let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
359 let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
360 a_indices
361 .into_iter()
362 .zip(b_indices)
363 .for_each(|(a_mle_idx, b_mle_idx)| {
364 sel_nonlinear_indices.insert(a_mle_idx);
365 sel_nonlinear_indices.insert(b_mle_idx);
366 });
367 sel_nonlinear_indices
368 }
369 ExpressionNode::Sum(a, b) => {
370 let mut sum_nonlinear_indices: HashSet<usize> = HashSet::new();
371 let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
372 let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
373 a_indices
374 .into_iter()
375 .zip(b_indices)
376 .for_each(|(a_mle_idx, b_mle_idx)| {
377 sum_nonlinear_indices.insert(a_mle_idx);
378 sum_nonlinear_indices.insert(b_mle_idx);
379 });
380 sum_nonlinear_indices
381 }
382 ExpressionNode::Scaled(a, _) => a
383 .get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec)
384 .into_iter()
385 .collect(),
386 ExpressionNode::Constant(_) | ExpressionNode::Mle(_) => HashSet::new(),
387 }
388 };
389 nonlinear_indices_in_node.into_iter().for_each(|index| {
391 if !curr_nonlinear_indices.contains(&index) {
392 curr_nonlinear_indices.push(index);
393 }
394 });
395 curr_nonlinear_indices.clone()
396 }
397
398 pub(crate) fn get_all_rounds(
401 &self,
402 curr_indices: &mut Vec<usize>,
403 _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec,
404 ) -> Vec<usize> {
405 let indices_in_node = {
406 match self {
407 ExpressionNode::Product(verifier_mles) => {
410 let mut product_indices: HashSet<usize> = HashSet::new();
411 verifier_mles.iter().for_each(|mle| {
412 mle.var_indices().iter().for_each(|mle_index| {
413 if let MleIndex::Indexed(i) = mle_index {
414 product_indices.insert(*i);
415 }
416 })
417 });
418 product_indices
419 }
420 ExpressionNode::Mle(verifier_mle) => verifier_mle
422 .var_indices()
423 .iter()
424 .filter_map(|mle_index| match mle_index {
425 MleIndex::Indexed(i) => Some(*i),
426 _ => None,
427 })
428 .collect(),
429 ExpressionNode::Selector(sel_index, a, b) => {
432 let mut sel_indices: HashSet<usize> = HashSet::new();
433 if let MleIndex::Indexed(i) = sel_index {
434 sel_indices.insert(*i);
435 };
436
437 let a_indices = a.get_all_rounds(curr_indices, _mle_vec);
438 let b_indices = b.get_all_rounds(curr_indices, _mle_vec);
439 a_indices
440 .into_iter()
441 .zip(b_indices)
442 .for_each(|(a_mle_idx, b_mle_idx)| {
443 sel_indices.insert(a_mle_idx);
444 sel_indices.insert(b_mle_idx);
445 });
446 sel_indices
447 }
448 ExpressionNode::Sum(a, b) => {
450 let mut sum_indices: HashSet<usize> = HashSet::new();
451 let a_indices = a.get_all_rounds(curr_indices, _mle_vec);
452 let b_indices = b.get_all_rounds(curr_indices, _mle_vec);
453 a_indices
454 .into_iter()
455 .zip(b_indices)
456 .for_each(|(a_mle_idx, b_mle_idx)| {
457 sum_indices.insert(a_mle_idx);
458 sum_indices.insert(b_mle_idx);
459 });
460 sum_indices
461 }
462 ExpressionNode::Scaled(a, _) => a
464 .get_all_rounds(curr_indices, _mle_vec)
465 .into_iter()
466 .collect(),
467 ExpressionNode::Constant(_) => HashSet::new(),
469 }
470 };
471 indices_in_node.into_iter().for_each(|index| {
474 if !curr_indices.contains(&index) {
475 curr_indices.push(index);
476 }
477 });
478 curr_indices.clone()
479 }
480
481 pub fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
483 let mut circuit_mles: Vec<&MleDescription<F>> = vec![];
484 match self {
485 ExpressionNode::Selector(_mle_index, a, b) => {
486 circuit_mles.extend(a.get_circuit_mles());
487 circuit_mles.extend(b.get_circuit_mles());
488 }
489 ExpressionNode::Sum(a, b) => {
490 circuit_mles.extend(a.get_circuit_mles());
491 circuit_mles.extend(b.get_circuit_mles());
492 }
493 ExpressionNode::Mle(mle) => {
494 circuit_mles.push(mle);
495 }
496 ExpressionNode::Product(mles) => mles.iter().for_each(|mle| circuit_mles.push(mle)),
497 ExpressionNode::Scaled(a, _scale_factor) => {
498 circuit_mles.extend(a.get_circuit_mles());
499 }
500 ExpressionNode::Constant(_constant) => {}
501 }
502 circuit_mles
503 }
504
505 pub fn index_mle_vars(&mut self, start_index: usize) {
507 match self {
508 ExpressionNode::Selector(mle_index, a, b) => {
509 match mle_index {
510 MleIndex::Free => *mle_index = MleIndex::Indexed(start_index),
511 MleIndex::Fixed(_bit) => {}
512 _ => panic!("should not have indexed or bound bits at this point!"),
513 };
514 a.index_mle_vars(start_index + 1);
515 b.index_mle_vars(start_index + 1);
516 }
517 ExpressionNode::Sum(a, b) => {
518 a.index_mle_vars(start_index);
519 b.index_mle_vars(start_index);
520 }
521 ExpressionNode::Mle(mle) => {
522 mle.index_mle_indices(start_index);
523 }
524 ExpressionNode::Product(mles) => {
525 mles.iter_mut()
526 .for_each(|mle| mle.index_mle_indices(start_index));
527 }
528 ExpressionNode::Scaled(a, _scale_factor) => {
529 a.index_mle_vars(start_index);
530 }
531 ExpressionNode::Constant(_constant) => {}
532 }
533 }
534
535 pub fn into_prover_expression(
537 &self,
538 circuit_map: &CircuitEvalMap<F>,
539 ) -> Expression<F, ProverExpr> {
540 match self {
541 ExpressionNode::Selector(_mle_index, a, b) => a
542 .into_prover_expression(circuit_map)
543 .select(b.into_prover_expression(circuit_map)),
544 ExpressionNode::Sum(a, b) => {
545 a.into_prover_expression(circuit_map) + b.into_prover_expression(circuit_map)
546 }
547 ExpressionNode::Mle(mle) => {
548 let prover_mle = mle.into_dense_mle(circuit_map);
549 prover_mle.expression()
550 }
551 ExpressionNode::Product(mles) => {
552 let dense_mles = mles
553 .iter()
554 .map(|mle| mle.into_dense_mle(circuit_map))
555 .collect_vec();
556 Expression::<F, ProverExpr>::products(dense_mles)
557 }
558 ExpressionNode::Scaled(a, scale_factor) => {
559 a.into_prover_expression(circuit_map) * *scale_factor
560 }
561 ExpressionNode::Constant(constant) => Expression::<F, ProverExpr>::constant(*constant),
562 }
563 }
564
565 pub fn get_post_sumcheck_layer(
569 &self,
570 multiplier: F,
571 challenges: &[F],
572 _mle_vec: &<VerifierExpr as ExpressionType<F>>::MleVec,
573 ) -> PostSumcheckLayer<F, Option<F>> {
574 let mut products: Vec<Product<F, Option<F>>> = vec![];
575 match self {
576 ExpressionNode::Selector(mle_index, a, b) => {
577 let idx_val = match mle_index {
578 MleIndex::Indexed(idx) => challenges[*idx],
579 MleIndex::Bound(chal, _idx) => *chal,
580 _ => panic!("should not have any other index here"),
583 };
584 let left_side_acc = multiplier * (F::ONE - idx_val);
585 let right_side_acc = multiplier * (idx_val);
586 products.extend(
587 a.get_post_sumcheck_layer(left_side_acc, challenges, _mle_vec)
588 .0,
589 );
590 products.extend(
591 b.get_post_sumcheck_layer(right_side_acc, challenges, _mle_vec)
592 .0,
593 );
594 }
595 ExpressionNode::Sum(a, b) => {
596 products.extend(
597 a.get_post_sumcheck_layer(multiplier, challenges, _mle_vec)
598 .0,
599 );
600 products.extend(
601 b.get_post_sumcheck_layer(multiplier, challenges, _mle_vec)
602 .0,
603 );
604 }
605 ExpressionNode::Mle(mle) => {
606 products.push(Product::<F, Option<F>>::new(
607 std::slice::from_ref(mle),
608 multiplier,
609 challenges,
610 ));
611 }
612 ExpressionNode::Product(mles) => {
613 let product = Product::<F, Option<F>>::new(mles, multiplier, challenges);
614 products.push(product);
615 }
616 ExpressionNode::Scaled(a, scale_factor) => {
617 let acc = multiplier * scale_factor;
618 products.extend(a.get_post_sumcheck_layer(acc, challenges, _mle_vec).0);
619 }
620 ExpressionNode::Constant(constant) => {
621 products.push(Product::<F, Option<F>>::new(
622 &[],
623 *constant * multiplier,
624 challenges,
625 ));
626 }
627 }
628 PostSumcheckLayer(products)
629 }
630
631 fn get_max_degree(&self, _mle_vec: &<ExprDescription as ExpressionType<F>>::MleVec) -> usize {
633 match self {
634 ExpressionNode::Selector(_, a, b) | ExpressionNode::Sum(a, b) => {
635 let a_degree = a.get_max_degree(_mle_vec);
636 let b_degree = b.get_max_degree(_mle_vec);
637 max(a_degree, b_degree)
638 }
639 ExpressionNode::Mle(_) => {
640 1
642 }
643 ExpressionNode::Product(mles) => {
644 mles.len()
646 }
647 ExpressionNode::Scaled(a, _) => a.get_max_degree(_mle_vec),
648 ExpressionNode::Constant(_) => 1,
649 }
650 }
651
652 fn get_num_vars(&self) -> usize {
659 match self {
660 ExpressionNode::Constant(_) => 0,
661 ExpressionNode::Selector(_, lhs, rhs) => {
662 max(lhs.get_num_vars() + 1, rhs.get_num_vars() + 1)
663 }
664 ExpressionNode::Mle(circuit_mle_desc) => circuit_mle_desc.num_free_vars(),
665 ExpressionNode::Sum(lhs, rhs) => max(lhs.get_num_vars(), rhs.get_num_vars()),
666 ExpressionNode::Product(nodes) => nodes.iter().fold(0, |cur_max, circuit_mle_desc| {
667 max(cur_max, circuit_mle_desc.num_free_vars())
668 }),
669 ExpressionNode::Scaled(expr, _) => expr.get_num_vars(),
670 }
671 }
672}
673
674impl<F: Field> Expression<F, ExprDescription> {
675 pub fn num_vars(&self) -> usize {
681 self.expression_node.get_num_vars()
682 }
683
684 pub fn products(circuit_mle_descs: Vec<MleDescription<F>>) -> Self {
688 let product_node = ExpressionNode::Product(circuit_mle_descs);
689
690 Expression::new(product_node, ())
691 }
692
693 pub fn select(self, rhs: Expression<F, ExprDescription>) -> Self {
709 let (lhs_node, _) = self.deconstruct();
710 let (rhs_node, _) = rhs.deconstruct();
711
712 let num_left_selectors = max(0, rhs_node.get_num_vars() - lhs_node.get_num_vars());
714 let num_right_selectors = max(0, lhs_node.get_num_vars() - rhs_node.get_num_vars());
715
716 let lhs_subtree = if num_left_selectors > 0 {
717 (0..num_left_selectors).fold(lhs_node, |cur_subtree, _| {
719 ExpressionNode::Selector(
720 MleIndex::Free,
721 Box::new(cur_subtree),
722 Box::new(ExpressionNode::Constant(F::ZERO)),
723 )
724 })
725 } else {
726 lhs_node
727 };
728
729 let rhs_subtree = if num_right_selectors > 0 {
730 (0..num_right_selectors).fold(rhs_node, |cur_subtree, _| {
732 ExpressionNode::Selector(
733 MleIndex::Free,
734 Box::new(cur_subtree),
735 Box::new(ExpressionNode::Constant(F::ZERO)),
736 )
737 })
738 } else {
739 rhs_node
740 };
741
742 debug_assert_eq!(lhs_subtree.get_num_vars(), rhs_subtree.get_num_vars());
744
745 let concat_node =
747 ExpressionNode::Selector(MleIndex::Free, Box::new(lhs_subtree), Box::new(rhs_subtree));
748
749 Expression::new(concat_node, ())
750 }
751
752 pub fn binary_tree_selector(expressions: Vec<Self>) -> Self {
757 assert!(expressions.len().is_power_of_two());
759 let mut expressions = expressions;
760 while expressions.len() > 1 {
761 expressions = expressions
763 .into_iter()
764 .tuples()
765 .map(|(lhs, rhs)| {
766 let (lhs_node, _) = lhs.deconstruct();
767 let (rhs_node, _) = rhs.deconstruct();
768
769 let selector_node = ExpressionNode::Selector(
770 MleIndex::Free,
771 Box::new(lhs_node),
772 Box::new(rhs_node),
773 );
774
775 Expression::new(selector_node, ())
776 })
777 .collect();
778 }
779 expressions[0].clone()
780 }
781
782 pub fn constant(constant: F) -> Self {
784 let mle_node = ExpressionNode::Constant(constant);
785
786 Expression::new(mle_node, ())
787 }
788
789 pub fn negated(expression: Self) -> Self {
791 let (node, _) = expression.deconstruct();
792
793 let mle_node = ExpressionNode::Scaled(Box::new(node), F::from(1).neg());
794
795 Expression::new(mle_node, ())
796 }
797
798 pub fn sum(lhs: Self, rhs: Self) -> Self {
800 let (lhs_node, _) = lhs.deconstruct();
801 let (rhs_node, _) = rhs.deconstruct();
802
803 let sum_node = ExpressionNode::Sum(Box::new(lhs_node), Box::new(rhs_node));
804
805 Expression::new(sum_node, ())
806 }
807
808 pub fn scaled(expression: Expression<F, ExprDescription>, scale: F) -> Self {
810 let (node, _) = expression.deconstruct();
811
812 Expression::new(ExpressionNode::Scaled(Box::new(node), scale), ())
813 }
814}
815
816pub fn filter_bookkeeping_table<F: Field>(
821 bookkeeping_table: &MultilinearExtension<F>,
822 unfiltered_prefix_bits: &[bool],
823) -> MultilinearExtension<F> {
824 let current_table = bookkeeping_table.to_vec();
825 let mut current_table_len = current_table.len();
826 let filtered_table = unfiltered_prefix_bits
827 .iter()
828 .fold(current_table, |acc, bit| {
829 let acc = if *bit {
830 acc.into_iter().skip(current_table_len / 2).collect_vec()
831 } else {
832 acc.into_iter().take(current_table_len / 2).collect_vec()
833 };
834 current_table_len /= 2;
835 acc
836 });
837 MultilinearExtension::new(filtered_table)
838}
839
840pub(crate) fn evaluate_bookkeeping_tables_given_operation<F: Field>(
843 mle_bookkeeping_tables: &[Vec<F>],
844 binary_operation: BinaryOperation,
845) -> MultilinearExtension<F> {
846 let max_num_vars = mle_bookkeeping_tables
847 .iter()
848 .map(|bookkeeping_table| log2(bookkeeping_table.len()))
849 .max()
850 .unwrap();
851
852 let mut output_table = vec![F::ZERO; 1 << max_num_vars];
853 (0..1 << (max_num_vars)).for_each(|index| {
854 let evaluated_data_point = mle_bookkeeping_tables
855 .iter()
856 .map(|mle_bookkeeping_table| {
857 let zero = F::ZERO;
858 let index = if log2(mle_bookkeeping_table.len()) < max_num_vars {
859 let max = 1 << log2(mle_bookkeeping_table.len());
860 let multiple = (1 << max_num_vars) / max;
861 index / multiple
862 } else {
863 index
864 };
865 let value = *mle_bookkeeping_table.get(index).unwrap_or(&zero);
866 value
867 })
868 .reduce(|acc, value| binary_operation.perform_operation(acc, value))
869 .unwrap();
870 output_table[index] = evaluated_data_point;
871 });
872 MultilinearExtension::new(output_table)
873}
874
875impl<F: Field> Neg for Expression<F, ExprDescription> {
876 type Output = Expression<F, ExprDescription>;
877 fn neg(self) -> Self::Output {
878 Expression::<F, ExprDescription>::negated(self)
879 }
880}
881
882impl<F: Field> Add for Expression<F, ExprDescription> {
884 type Output = Expression<F, ExprDescription>;
885 fn add(self, rhs: Expression<F, ExprDescription>) -> Expression<F, ExprDescription> {
886 Expression::<F, ExprDescription>::sum(self, rhs)
887 }
888}
889
890impl<F: Field> Sub for Expression<F, ExprDescription> {
891 type Output = Expression<F, ExprDescription>;
892 fn sub(self, rhs: Expression<F, ExprDescription>) -> Expression<F, ExprDescription> {
893 self.add(rhs.neg())
894 }
895}
896
897impl<F: Field> Mul<F> for Expression<F, ExprDescription> {
898 type Output = Expression<F, ExprDescription>;
899 fn mul(self, rhs: F) -> Self::Output {
900 Expression::<F, ExprDescription>::scaled(self, rhs)
901 }
902}
903
904impl<F: std::fmt::Debug + Field> std::fmt::Debug for Expression<F, ExprDescription> {
905 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
906 f.debug_struct("Circuit Expression")
907 .field("Expression_Node", &self.expression_node)
908 .finish()
909 }
910}
911
912impl<F: std::fmt::Debug + Field> std::fmt::Debug for ExpressionNode<F, ExprDescription> {
913 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
914 match self {
915 ExpressionNode::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(),
916 ExpressionNode::Selector(index, a, b) => f
917 .debug_tuple("Selector")
918 .field(index)
919 .field(a)
920 .field(b)
921 .finish(),
922 ExpressionNode::Mle(mle) => f.debug_struct("Circuit Mle").field("mle", mle).finish(),
924 ExpressionNode::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
925 ExpressionNode::Product(a) => f.debug_tuple("Product").field(a).finish(),
926 ExpressionNode::Scaled(poly, scalar) => {
927 f.debug_tuple("Scaled").field(poly).field(scalar).finish()
928 }
929 }
930 }
931}