1use itertools::Itertools;
21use serde::{Deserialize, Serialize};
22use std::{
23 cmp::max,
24 collections::HashMap,
25 ops::{Add, AddAssign, BitXor, Mul, MulAssign, Neg, Sub, SubAssign},
26};
27
28use shared_types::Field;
29
30use crate::layouter::{builder::CircuitMap, layouting::LayoutingError, nodes::NodeId};
31use remainder::{circuit_layout::CircuitLocation, expression::generic_expr::ExpressionNode};
32
33use remainder::{
34 expression::{circuit_expr::ExprDescription, generic_expr::Expression},
35 mle::{mle_description::MleDescription, MleIndex},
36 utils::mle::get_total_mle_indices,
37};
38
39use anyhow::{Ok, Result};
40
41#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq)]
47#[serde(bound = "F: Field")]
48pub enum AbstractExpression<F: Field> {
49 Constant(F),
50 Selector(
51 MleIndex<F>,
52 Box<AbstractExpression<F>>,
53 Box<AbstractExpression<F>>,
54 ),
55 Mle(NodeId),
56 Sum(Box<AbstractExpression<F>>, Box<AbstractExpression<F>>),
57 Product(Vec<NodeId>),
58 Scaled(Box<AbstractExpression<F>>, F),
59}
60
61impl<F: Field> AbstractExpression<F> {
66 pub fn traverse(
68 &self,
69 observer_fn: &mut impl FnMut(&AbstractExpression<F>) -> Result<()>,
70 ) -> Result<()> {
71 observer_fn(self)?;
72 match self {
73 AbstractExpression::Constant(_)
74 | AbstractExpression::Mle(_)
75 | AbstractExpression::Product(_) => Ok(()),
76 AbstractExpression::Scaled(exp, _) => exp.traverse(observer_fn),
77 AbstractExpression::Selector(_, lhs, rhs) => {
78 lhs.traverse(observer_fn)?;
79 rhs.traverse(observer_fn)
80 }
81 AbstractExpression::Sum(lhs, rhs) => {
82 lhs.traverse(observer_fn)?;
83 rhs.traverse(observer_fn)
84 }
85 }
86 }
87
88 pub fn get_sources(&self) -> Vec<NodeId> {
90 let mut sources = vec![];
91 let mut get_sources_closure = |expr_node: &AbstractExpression<F>| -> Result<()> {
92 if let AbstractExpression::Product(node_id_vec) = expr_node {
93 sources.extend(node_id_vec.iter());
94 } else if let AbstractExpression::Mle(node_id) = expr_node {
95 sources.push(*node_id);
96 }
97 Ok(())
98 };
99 self.traverse(&mut get_sources_closure).unwrap();
100 sources
101 }
102
103 pub fn get_num_vars(&self, num_vars_map: &HashMap<NodeId, usize>) -> Result<usize> {
105 match self {
106 AbstractExpression::Constant(_) => Ok(0),
107 AbstractExpression::Selector(_, lhs, rhs) => Ok(max(
108 lhs.get_num_vars(num_vars_map)? + 1,
109 rhs.get_num_vars(num_vars_map)? + 1,
110 )),
111 AbstractExpression::Mle(node_id) => Ok(*num_vars_map.get(node_id).unwrap()),
112 AbstractExpression::Sum(lhs, rhs) => Ok(max(
113 lhs.get_num_vars(num_vars_map)?,
114 rhs.get_num_vars(num_vars_map)?,
115 )),
116 AbstractExpression::Product(nodes) => Ok(nodes
117 .iter()
118 .map(|node_id| Ok(Some(*num_vars_map.get(node_id).unwrap())))
119 .fold_ok(None, max)?
120 .unwrap_or(0)),
121 AbstractExpression::Scaled(expr, _) => expr.get_num_vars(num_vars_map),
122 }
123 }
124
125 pub fn build_circuit_expr(
129 self,
130 circuit_map: &CircuitMap,
131 ) -> Result<Expression<F, ExprDescription>> {
132 let mut nodes = self.get_node_ids(vec![]);
134 nodes.sort();
135 nodes.dedup();
136
137 let mut node_map = HashMap::<NodeId, (usize, &CircuitLocation)>::new();
138
139 nodes.into_iter().for_each(|node_id| {
140 let (location, num_vars) = circuit_map
141 .get_location_num_vars_from_node_id(&node_id)
142 .unwrap();
143 node_map.insert(node_id, (*num_vars, location));
144 });
145
146 let expression_node = self.build_circuit_node(&node_map)?;
149
150 Ok(Expression::new(expression_node, ()))
151 }
152
153 pub fn select(self, rhs: Self) -> Self {
155 Self::Selector(MleIndex::Free, Box::new(self), Box::new(rhs))
156 }
157
158 pub fn select_seq<E: Clone + Into<AbstractExpression<F>>>(expressions: Vec<E>) -> Self {
160 let mut base = expressions[0].clone().into();
161 for e in expressions.into_iter().skip(1) {
162 base = Self::select(base, e.into());
163 }
164 base
165 }
166
167 pub fn binary_tree_selector<E: Into<AbstractExpression<F>>>(expressions: Vec<E>) -> Self {
172 assert!(expressions.len().is_power_of_two());
174 let mut expressions = expressions
175 .into_iter()
176 .map(|e| e.into())
177 .collect::<Vec<_>>();
178 while expressions.len() > 1 {
179 expressions = expressions
181 .into_iter()
182 .tuples()
183 .map(|(lhs, rhs)| Self::Selector(MleIndex::Free, Box::new(lhs), Box::new(rhs)))
184 .collect();
185 }
186 expressions[0].clone()
187 }
188
189 pub fn pow(pow: usize, node_id: Self) -> Self {
191 let base = node_id;
193 let mut result = base.clone();
194 for _ in 1..pow {
195 result *= base.clone();
196 }
197 result
198 }
199
200 pub fn products(node_ids: Vec<NodeId>) -> Self {
202 Self::Product(node_ids)
203 }
204
205 pub fn mult(lhs: Self, rhs: Self) -> Self {
207 let switch = |lhs, rhs| Self::mult(rhs, lhs);
208
209 match (&lhs, &rhs) {
211 (AbstractExpression::Constant(f), _) => AbstractExpression::Scaled(Box::new(rhs), *f),
213 (_, AbstractExpression::Constant(_)) => switch(lhs, rhs),
214
215 (AbstractExpression::Selector(..), _) => panic!("Multiplying a non-constant with a selector is not allowed! Create a separate sector or fold the operand into each branch!"),
217 (_, AbstractExpression::Selector(..)) => switch(lhs, rhs),
218
219 (AbstractExpression::Sum(x, y), _) => {
221 let xr = Self::mult(*x.clone(), rhs.clone());
222 let yr = Self::mult(*y.clone(), rhs);
223 AbstractExpression::Sum(Box::new(xr), Box::new(yr))
224 }
225 (_, AbstractExpression::Sum(..)) => switch(lhs, rhs),
226
227 (AbstractExpression::Scaled(x, c1), AbstractExpression::Scaled(y, c2)) => {
229 let xy = Self::mult(*x.clone(), *y.clone());
230 let c = *c1 * *c2;
231 AbstractExpression::Scaled(Box::new(xy), c)
232 }
233 (AbstractExpression::Scaled(x, c), _) => {
234 let xz = Self::mult(*x.clone(), rhs);
235 AbstractExpression::Scaled(Box::new(xz), *c)
236 }
237 (_, AbstractExpression::Scaled(..)) => switch(lhs, rhs),
238
239 (l, r) => {
241 let l_ids = match l {
242 AbstractExpression::Mle(id) => vec![*id],
243 AbstractExpression::Product(ids) => ids.clone(),
244 _ => unreachable!()
245 };
246 let r_ids = match r {
247 AbstractExpression::Mle(id) => vec![*id],
248 AbstractExpression::Product(ids) => ids.clone(),
249 _ => unreachable!()
250 };
251 let ids = [l_ids, r_ids].concat();
252 AbstractExpression::Product(ids)
253 }
254 }
255 }
256
257 pub fn mle(node_id: NodeId) -> Self {
259 AbstractExpression::Mle(node_id)
260 }
261
262 pub fn constant(constant: F) -> Self {
264 AbstractExpression::Constant(constant)
265 }
266
267 pub fn negated(expression: Self) -> Self {
269 AbstractExpression::Scaled(Box::new(expression), F::from(1).neg())
270 }
271
272 pub fn sum(lhs: Self, rhs: Self) -> Self {
274 AbstractExpression::Sum(Box::new(lhs), Box::new(rhs))
275 }
276
277 pub fn scaled(expression: AbstractExpression<F>, scale: F) -> Self {
279 AbstractExpression::Scaled(Box::new(expression), scale)
280 }
281}
282
283impl<F: Field> AbstractExpression<F> {
284 fn build_circuit_node(
285 self,
286 node_map: &HashMap<NodeId, (usize, &CircuitLocation)>,
287 ) -> Result<ExpressionNode<F, ExprDescription>> {
288 match self {
290 AbstractExpression::Constant(val) => Ok(ExpressionNode::Constant(val)),
291 AbstractExpression::Selector(mle_index, lhs, rhs) => {
292 let lhs = lhs.build_circuit_node(node_map)?;
293 let rhs = rhs.build_circuit_node(node_map)?;
294 Ok(ExpressionNode::Selector(
295 mle_index,
296 Box::new(lhs),
297 Box::new(rhs),
298 ))
299 }
300 AbstractExpression::Mle(node_id) => {
301 let (
302 num_vars,
303 CircuitLocation {
304 prefix_bits,
305 layer_id,
306 },
307 ) = node_map
308 .get(&node_id)
309 .ok_or(LayoutingError::DanglingNodeId(node_id))?;
310 let total_indices = get_total_mle_indices(prefix_bits, *num_vars);
311 let circuit_mle = MleDescription::new(*layer_id, &total_indices);
312 Ok(ExpressionNode::Mle(circuit_mle))
313 }
314 AbstractExpression::Sum(lhs, rhs) => {
315 let lhs = lhs.build_circuit_node(node_map)?;
316 let rhs = rhs.build_circuit_node(node_map)?;
317 Ok(ExpressionNode::Sum(Box::new(lhs), Box::new(rhs)))
318 }
319 AbstractExpression::Product(nodes) => {
320 let circuit_mles = nodes
321 .into_iter()
322 .map(|node_id| {
323 let (
324 num_vars,
325 CircuitLocation {
326 prefix_bits,
327 layer_id,
328 },
329 ) = node_map
330 .get(&node_id)
331 .ok_or(LayoutingError::DanglingNodeId(node_id))
332 .unwrap();
333 let total_indices = get_total_mle_indices::<F>(prefix_bits, *num_vars);
334 MleDescription::new(*layer_id, &total_indices)
335 })
336 .collect::<Vec<MleDescription<F>>>();
337 Ok(ExpressionNode::Product(circuit_mles))
338 }
339 AbstractExpression::Scaled(expr, scalar) => {
340 let expr = expr.build_circuit_node(node_map)?;
341 Ok(ExpressionNode::Scaled(Box::new(expr), scalar))
342 }
343 }
344 }
345
346 fn get_node_ids(&self, mut node_ids: Vec<NodeId>) -> Vec<NodeId> {
347 match self {
348 AbstractExpression::Constant(_) => node_ids,
349 AbstractExpression::Selector(_, lhs, rhs) => {
350 let node_ids = rhs.get_node_ids(node_ids);
351 lhs.get_node_ids(node_ids)
352 }
353 AbstractExpression::Mle(node_id) => {
354 node_ids.push(*node_id);
355 node_ids
356 }
357 AbstractExpression::Sum(lhs, rhs) => {
358 let node_ids = lhs.get_node_ids(node_ids);
359 rhs.get_node_ids(node_ids)
360 }
361 AbstractExpression::Product(nodes) => {
362 node_ids.extend(nodes.iter());
363 node_ids
364 }
365 AbstractExpression::Scaled(expr, _) => expr.get_node_ids(node_ids),
366 }
367 }
368}
369
370impl<F: Field> Neg for AbstractExpression<F> {
372 type Output = AbstractExpression<F>;
373 fn neg(self) -> Self::Output {
374 AbstractExpression::<F>::negated(self)
375 }
376}
377
378impl<F: Field> Neg for &AbstractExpression<F> {
379 type Output = AbstractExpression<F>;
380 fn neg(self) -> Self::Output {
381 AbstractExpression::<F>::negated(self.clone())
382 }
383}
384
385impl<F: Field> From<F> for AbstractExpression<F> {
386 fn from(f: F) -> Self {
387 AbstractExpression::<F>::constant(f)
388 }
389}
390
391impl<F: Field, Rhs: Into<AbstractExpression<F>>> Add<Rhs> for AbstractExpression<F> {
393 type Output = AbstractExpression<F>;
394 fn add(self, rhs: Rhs) -> Self::Output {
395 AbstractExpression::sum(self, rhs.into())
396 }
397}
398impl<F: Field, Rhs: Into<AbstractExpression<F>>> Add<Rhs> for &AbstractExpression<F> {
399 type Output = AbstractExpression<F>;
400 fn add(self, rhs: Rhs) -> Self::Output {
401 AbstractExpression::sum(self.clone(), rhs.into())
402 }
403}
404
405impl<F: Field, Rhs: Into<AbstractExpression<F>>> AddAssign<Rhs> for AbstractExpression<F> {
406 fn add_assign(&mut self, rhs: Rhs) {
407 *self = self.clone() + rhs;
408 }
409}
410
411impl<F: Field, Rhs: Into<AbstractExpression<F>>> Sub<Rhs> for AbstractExpression<F> {
412 type Output = AbstractExpression<F>;
413 fn sub(self, rhs: Rhs) -> Self::Output {
414 AbstractExpression::sum(self, rhs.into().neg())
415 }
416}
417impl<F: Field, Rhs: Into<AbstractExpression<F>>> Sub<Rhs> for &AbstractExpression<F> {
418 type Output = AbstractExpression<F>;
419 fn sub(self, rhs: Rhs) -> Self::Output {
420 AbstractExpression::sum(self.clone(), rhs.into().neg())
421 }
422}
423impl<F: Field, Rhs: Into<AbstractExpression<F>>> SubAssign<Rhs> for AbstractExpression<F> {
424 fn sub_assign(&mut self, rhs: Rhs) {
425 *self = self.clone() - rhs;
426 }
427}
428
429impl<F: Field, Rhs: Into<AbstractExpression<F>>> Mul<Rhs> for AbstractExpression<F> {
430 type Output = AbstractExpression<F>;
431 fn mul(self, rhs: Rhs) -> Self::Output {
432 AbstractExpression::mult(self, rhs.into())
433 }
434}
435impl<F: Field, Rhs: Into<AbstractExpression<F>>> Mul<Rhs> for &AbstractExpression<F> {
436 type Output = AbstractExpression<F>;
437 fn mul(self, rhs: Rhs) -> Self::Output {
438 AbstractExpression::mult(self.clone(), rhs.into())
439 }
440}
441impl<F: Field, Rhs: Into<AbstractExpression<F>>> MulAssign<Rhs> for AbstractExpression<F> {
442 fn mul_assign(&mut self, rhs: Rhs) {
443 *self = self.clone() * rhs;
444 }
445}
446
447impl<F: Field, Rhs: Into<AbstractExpression<F>>> BitXor<Rhs> for AbstractExpression<F> {
448 type Output = AbstractExpression<F>;
449 fn bitxor(self, rhs: Rhs) -> Self::Output {
450 let rhs_expr: AbstractExpression<F> = rhs.into();
451 self.clone() + rhs_expr.clone() - self * rhs_expr * F::from(2)
452 }
453}
454impl<F: Field, Rhs: Into<AbstractExpression<F>>> BitXor<Rhs> for &AbstractExpression<F> {
455 type Output = AbstractExpression<F>;
456 fn bitxor(self, rhs: Rhs) -> Self::Output {
457 let rhs_expr: &AbstractExpression<F> = &rhs.into();
458 self.clone() + rhs_expr.clone() - self.clone() * rhs_expr * F::from(2)
459 }
460}
461
462impl<F: Field> From<&AbstractExpression<F>> for AbstractExpression<F> {
463 fn from(val: &AbstractExpression<F>) -> Self {
464 val.clone()
465 }
466}
467
468#[macro_export]
470macro_rules! const_expr {
471 ($val:expr) => {{
472 use frontend::abstract_expr::AbstractExpression;
473 AbstractExpression::Constant($val)
474 }};
475}
476
477#[macro_export]
481macro_rules! sel_expr {
482 ($($expr:expr),+ $(,)?) => {{
483 use frontend::abstract_expr::{AbstractExpression};
484 let v = vec![$(Into::<AbstractExpression<F>>::into($expr)),+];
485 AbstractExpression::<F>::select_seq(v)
486 }};
487}
488
489impl<F: std::fmt::Debug + Field> std::fmt::Debug for AbstractExpression<F> {
491 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492 match self {
493 AbstractExpression::Constant(scalar) => {
494 f.debug_tuple("Constant").field(scalar).finish()
495 }
496 AbstractExpression::Selector(index, a, b) => f
497 .debug_tuple("Selector")
498 .field(index)
499 .field(a)
500 .field(b)
501 .finish(),
502 AbstractExpression::Mle(mle) => f.debug_struct("Mle").field("mle", mle).finish(),
504 AbstractExpression::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
505 AbstractExpression::Product(a) => f.debug_tuple("Product").field(a).finish(),
506 AbstractExpression::Scaled(poly, scalar) => {
507 f.debug_tuple("Scaled").field(poly).field(scalar).finish()
508 }
509 }
510 }
511}