1use crate::mle::{verifier_mle::VerifierMle, MleIndex};
22use serde::{Deserialize, Serialize};
23use std::{
24 collections::{HashMap, HashSet},
25 fmt::Debug,
26};
27
28use shared_types::Field;
29
30use super::{
31 expr_errors::ExpressionError,
32 generic_expr::{Expression, ExpressionNode, ExpressionType},
33};
34
35use anyhow::{anyhow, Ok, Result};
36
37#[derive(Serialize, Deserialize, Clone, Debug)]
40pub struct VerifierExpr;
41
42impl<F: Field> ExpressionType<F> for VerifierExpr {
50 type MLENodeRepr = VerifierMle<F>;
51 type MleVec = ();
52}
53
54impl<F: Field> Expression<F, VerifierExpr> {
55 pub fn mle(mle: VerifierMle<F>) -> Self {
57 let mle_node = ExpressionNode::Mle(mle);
58
59 Expression::new(mle_node, ())
60 }
61
62 pub fn evaluate(&self) -> Result<F> {
64 let constant = |c| Ok(c);
65 let selector_column = |idx: &MleIndex<F>, lhs: Result<F>, rhs: Result<F>| -> Result<F> {
66 if let MleIndex::Bound(val, _) = idx {
68 return Ok(*val * rhs? + (F::ONE - val) * lhs?);
69 }
70 Err(anyhow!(ExpressionError::SelectorBitNotBoundError))
71 };
72 let mle_eval = |verifier_mle: &VerifierMle<F>| -> Result<F> { Ok(verifier_mle.value()) };
73 let sum = |lhs: Result<F>, rhs: Result<F>| Ok(lhs? + rhs?);
74 let product = |verifier_mles: &[VerifierMle<F>]| -> Result<F> {
75 verifier_mles
76 .iter()
77 .try_fold(F::ONE, |acc, verifier_mle| Ok(acc * verifier_mle.value()))
78 };
79 let scaled = |val: Result<F>, scalar: F| Ok(val? * scalar);
80
81 self.expression_node.reduce(
82 &constant,
83 &selector_column,
84 &mle_eval,
85 &sum,
86 &product,
87 &scaled,
88 )
89 }
90
91 pub fn get_all_nonlinear_rounds(&mut self) -> Vec<usize> {
94 let (expression_node, mle_vec) = self.deconstruct_mut();
95 let mut nonlinear_rounds: Vec<usize> =
96 expression_node.get_all_nonlinear_rounds(&mut vec![], mle_vec);
97 nonlinear_rounds.sort();
98 nonlinear_rounds
99 }
100}
101
102impl<F: Field> ExpressionNode<F, VerifierExpr> {
103 #[allow(clippy::too_many_arguments)]
106 pub fn reduce<T>(
107 &self,
108 constant: &impl Fn(F) -> T,
109 selector_column: &impl Fn(&MleIndex<F>, T, T) -> T,
110 mle_eval: &impl Fn(&<VerifierExpr as ExpressionType<F>>::MLENodeRepr) -> T,
111 sum: &impl Fn(T, T) -> T,
112 product: &impl Fn(&[<VerifierExpr as ExpressionType<F>>::MLENodeRepr]) -> T,
113 scaled: &impl Fn(T, F) -> T,
114 ) -> T {
115 match self {
116 ExpressionNode::Constant(scalar) => constant(*scalar),
117 ExpressionNode::Selector(index, a, b) => {
118 let lhs = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
119 let rhs = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
120 selector_column(index, lhs, rhs)
121 }
122 ExpressionNode::Mle(query) => mle_eval(query),
123 ExpressionNode::Sum(a, b) => {
124 let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
125 let b = b.reduce(constant, selector_column, mle_eval, sum, product, scaled);
126 sum(a, b)
127 }
128 ExpressionNode::Product(queries) => product(queries),
129 ExpressionNode::Scaled(a, f) => {
130 let a = a.reduce(constant, selector_column, mle_eval, sum, product, scaled);
131 scaled(a, *f)
132 }
133 }
134 }
135
136 pub fn get_all_nonlinear_rounds(
139 &self,
140 curr_nonlinear_indices: &mut Vec<usize>,
141 _mle_vec: &<VerifierExpr as ExpressionType<F>>::MleVec,
142 ) -> Vec<usize> {
143 let nonlinear_indices_in_node = {
144 match self {
145 ExpressionNode::Product(verifier_mles) => {
149 let mut product_nonlinear_indices: HashSet<usize> = HashSet::new();
150 let mut product_indices_counts: HashMap<MleIndex<F>, usize> = HashMap::new();
151
152 verifier_mles.iter().for_each(|verifier_mle| {
153 verifier_mle.var_indices().iter().for_each(|mle_index| {
154 let curr_count = {
155 if product_indices_counts.contains_key(mle_index) {
156 product_indices_counts.get(mle_index).unwrap()
157 } else {
158 &0
159 }
160 };
161 product_indices_counts.insert(mle_index.clone(), curr_count + 1);
162 })
163 });
164
165 product_indices_counts
166 .into_iter()
167 .for_each(|(mle_index, count)| {
168 if count > 1 {
169 if let MleIndex::Indexed(i) = mle_index {
170 product_nonlinear_indices.insert(i);
171 }
172 }
173 });
174
175 product_nonlinear_indices
176 }
177 ExpressionNode::Selector(_sel_index, a, b) => {
180 let mut sel_nonlinear_indices: HashSet<usize> = HashSet::new();
181 let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
182 let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
183 a_indices
184 .into_iter()
185 .zip(b_indices)
186 .for_each(|(a_mle_idx, b_mle_idx)| {
187 sel_nonlinear_indices.insert(a_mle_idx);
188 sel_nonlinear_indices.insert(b_mle_idx);
189 });
190 sel_nonlinear_indices
191 }
192 ExpressionNode::Sum(a, b) => {
193 let mut sum_nonlinear_indices: HashSet<usize> = HashSet::new();
194 let a_indices = a.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
195 let b_indices = b.get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec);
196 a_indices
197 .into_iter()
198 .zip(b_indices)
199 .for_each(|(a_mle_idx, b_mle_idx)| {
200 sum_nonlinear_indices.insert(a_mle_idx);
201 sum_nonlinear_indices.insert(b_mle_idx);
202 });
203 sum_nonlinear_indices
204 }
205 ExpressionNode::Scaled(a, _) => a
206 .get_all_nonlinear_rounds(curr_nonlinear_indices, _mle_vec)
207 .into_iter()
208 .collect(),
209 ExpressionNode::Constant(_) | ExpressionNode::Mle(_) => HashSet::new(),
210 }
211 };
212 nonlinear_indices_in_node.into_iter().for_each(|index| {
214 if !curr_nonlinear_indices.contains(&index) {
215 curr_nonlinear_indices.push(index);
216 }
217 });
218 curr_nonlinear_indices.clone()
219 }
220}
221
222impl<F: std::fmt::Debug + Field> std::fmt::Debug for Expression<F, VerifierExpr> {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 f.debug_struct("Circuit Expression")
225 .field("Expression_Node", &self.expression_node)
226 .finish()
227 }
228}
229
230impl<F: std::fmt::Debug + Field> std::fmt::Debug for ExpressionNode<F, VerifierExpr> {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 match self {
233 ExpressionNode::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(),
234 ExpressionNode::Selector(index, a, b) => f
235 .debug_tuple("Selector")
236 .field(index)
237 .field(a)
238 .field(b)
239 .finish(),
240 ExpressionNode::Mle(mle) => f.debug_struct("Circuit Mle").field("mle", mle).finish(),
242 ExpressionNode::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
243 ExpressionNode::Product(a) => f.debug_tuple("Product").field(a).finish(),
244 ExpressionNode::Scaled(poly, scalar) => {
245 f.debug_tuple("Scaled").field(poly).field(scalar).finish()
246 }
247 }
248 }
249}