frontend/
abstract_expr.rs

1//! The "circuit-builder's" view of an expression. In particular, it represents
2//! a template of polynomial relationships between an output computational
3//! graph node and outputs from source computational graph nodes.
4//!
5//! WARNING: because [AbstractExpression] does *not* contain any semblance of MLE
6//! sizes nor indices, it can thus represent an entire class of polynomial
7//! relationships, depending on its circuit-time instantiation. For example,
8//! the simple relationship of
9//!
10//! ```ignore
11//!     node_id_1 + node_id_2
12//! ```
13//!
14//! can refer to \widetilde{V}_{i}(x_1, ..., x_m) + \widetilde{V}_{j}(x_1, ..., x_n)
15//! where:
16//! * m > n, i.e. the second MLE's "data" is "wrapped around" via repetition
17//! * m = n, i.e. the resulting bookkeeping table is the element-wise sum of the two
18//! * m < n, i.e. the first MLE's "data" is "wrapped around" via repetition
19
20use itertools::Itertools;
21use serde::{Deserialize, Serialize};
22use std::{
23    cmp::max,
24    collections::HashMap,
25    ops::{Add, AddAssign, BitXor, Mul, MulAssign, Neg, Sub, SubAssign},
26};
27
28use shared_types::Field;
29
30use crate::layouter::{builder::CircuitMap, layouting::LayoutingError, nodes::NodeId};
31use remainder::{circuit_layout::CircuitLocation, expression::generic_expr::ExpressionNode};
32
33use remainder::{
34    expression::{circuit_expr::ExprDescription, generic_expr::Expression},
35    mle::{mle_description::MleDescription, MleIndex},
36    utils::mle::get_total_mle_indices,
37};
38
39use anyhow::{Ok, Result};
40
41/// See [ExpressionNode] for more details. Note that this implementation is
42/// somewhat redundant with [Expression] and [ExpressionNode], but the
43/// separation allows for more flexibility with respect to this particular
44/// frontend being able to create polynomial relationships in any way it
45/// chooses, so long as those representations are compile-able into [Expression].
46#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq)]
47#[serde(bound = "F: Field")]
48pub enum AbstractExpression<F: Field> {
49    Constant(F),
50    Selector(
51        MleIndex<F>,
52        Box<AbstractExpression<F>>,
53        Box<AbstractExpression<F>>,
54    ),
55    Mle(NodeId),
56    Sum(Box<AbstractExpression<F>>, Box<AbstractExpression<F>>),
57    Product(Vec<NodeId>),
58    Scaled(Box<AbstractExpression<F>>, F),
59}
60
61//  comments for Phase II:
62//  This will be the the circuit "pre-data" stage
63//  will take care of building a prover expression
64//  building the most memory efficient denseMleRefs dictionaries, etc.
65impl<F: Field> AbstractExpression<F> {
66    /// Traverses the expression and applies the observer function to all nodes.
67    pub fn traverse(
68        &self,
69        observer_fn: &mut impl FnMut(&AbstractExpression<F>) -> Result<()>,
70    ) -> Result<()> {
71        observer_fn(self)?;
72        match self {
73            AbstractExpression::Constant(_)
74            | AbstractExpression::Mle(_)
75            | AbstractExpression::Product(_) => Ok(()),
76            AbstractExpression::Scaled(exp, _) => exp.traverse(observer_fn),
77            AbstractExpression::Selector(_, lhs, rhs) => {
78                lhs.traverse(observer_fn)?;
79                rhs.traverse(observer_fn)
80            }
81            AbstractExpression::Sum(lhs, rhs) => {
82                lhs.traverse(observer_fn)?;
83                rhs.traverse(observer_fn)
84            }
85        }
86    }
87
88    /// find all the sources this expression depend on
89    pub fn get_sources(&self) -> Vec<NodeId> {
90        let mut sources = vec![];
91        let mut get_sources_closure = |expr_node: &AbstractExpression<F>| -> Result<()> {
92            if let AbstractExpression::Product(node_id_vec) = expr_node {
93                sources.extend(node_id_vec.iter());
94            } else if let AbstractExpression::Mle(node_id) = expr_node {
95                sources.push(*node_id);
96            }
97            Ok(())
98        };
99        self.traverse(&mut get_sources_closure).unwrap();
100        sources
101    }
102
103    /// Computes the num_vars of this expression (how many rounds of sumcheck it would take to prove)
104    pub fn get_num_vars(&self, num_vars_map: &HashMap<NodeId, usize>) -> Result<usize> {
105        match self {
106            AbstractExpression::Constant(_) => Ok(0),
107            AbstractExpression::Selector(_, lhs, rhs) => Ok(max(
108                lhs.get_num_vars(num_vars_map)? + 1,
109                rhs.get_num_vars(num_vars_map)? + 1,
110            )),
111            AbstractExpression::Mle(node_id) => Ok(*num_vars_map.get(node_id).unwrap()),
112            AbstractExpression::Sum(lhs, rhs) => Ok(max(
113                lhs.get_num_vars(num_vars_map)?,
114                rhs.get_num_vars(num_vars_map)?,
115            )),
116            AbstractExpression::Product(nodes) => Ok(nodes
117                .iter()
118                .map(|node_id| Ok(Some(*num_vars_map.get(node_id).unwrap())))
119                .fold_ok(None, max)?
120                .unwrap_or(0)),
121            AbstractExpression::Scaled(expr, _) => expr.get_num_vars(num_vars_map),
122        }
123    }
124
125    /// Convert the abstract expression into a circuit expression, which
126    /// stores information on the shape of the expression, using the
127    /// [CircuitMap].
128    pub fn build_circuit_expr(
129        self,
130        circuit_map: &CircuitMap,
131    ) -> Result<Expression<F, ExprDescription>> {
132        // First we get all the mles that this expression will need to store
133        let mut nodes = self.get_node_ids(vec![]);
134        nodes.sort();
135        nodes.dedup();
136
137        let mut node_map = HashMap::<NodeId, (usize, &CircuitLocation)>::new();
138
139        nodes.into_iter().for_each(|node_id| {
140            let (location, num_vars) = circuit_map
141                .get_location_num_vars_from_node_id(&node_id)
142                .unwrap();
143            node_map.insert(node_id, (*num_vars, location));
144        });
145
146        // Then we replace the NodeIds in the AbstractExpr w/ indices of our stored MLEs
147
148        let expression_node = self.build_circuit_node(&node_map)?;
149
150        Ok(Expression::new(expression_node, ()))
151    }
152
153    /// See documentation for `select()` function within [remainder::expression::circuit_expr::ExprDescription]
154    pub fn select(self, rhs: Self) -> Self {
155        Self::Selector(MleIndex::Free, Box::new(self), Box::new(rhs))
156    }
157
158    /// Call [Self::select] sequentially
159    pub fn select_seq<E: Clone + Into<AbstractExpression<F>>>(expressions: Vec<E>) -> Self {
160        let mut base = expressions[0].clone().into();
161        for e in expressions.into_iter().skip(1) {
162            base = Self::select(base, e.into());
163        }
164        base
165    }
166
167    /// Create a nested selector Expression that selects between 2^k Expressions
168    /// by creating a binary tree of Selector Expressions.
169    /// The order of the leaves is the order of the input expressions.
170    /// (Note that this is very different from [Self::select_seq].)
171    pub fn binary_tree_selector<E: Into<AbstractExpression<F>>>(expressions: Vec<E>) -> Self {
172        // Ensure length is a power of two
173        assert!(expressions.len().is_power_of_two());
174        let mut expressions = expressions
175            .into_iter()
176            .map(|e| e.into())
177            .collect::<Vec<_>>();
178        while expressions.len() > 1 {
179            // Iterate over consecutive pairs of expressions, creating a new expression that selects between them
180            expressions = expressions
181                .into_iter()
182                .tuples()
183                .map(|(lhs, rhs)| Self::Selector(MleIndex::Free, Box::new(lhs), Box::new(rhs)))
184                .collect();
185        }
186        expressions[0].clone()
187    }
188
189    /// Create a product Expression that raises one expression to a given power
190    pub fn pow(pow: usize, node_id: Self) -> Self {
191        // lazily construct a linear-depth expression tree
192        let base = node_id;
193        let mut result = base.clone();
194        for _ in 1..pow {
195            result *= base.clone();
196        }
197        result
198    }
199
200    /// Create a product Expression that multiplies many MLEs together
201    pub fn products(node_ids: Vec<NodeId>) -> Self {
202        Self::Product(node_ids)
203    }
204
205    /// Multiplication for expressions, DO NOT USE ON SELECTORS
206    pub fn mult(lhs: Self, rhs: Self) -> Self {
207        let switch = |lhs, rhs| Self::mult(rhs, lhs);
208
209        // Simplify the expression into scaled and products
210        match (&lhs, &rhs) {
211            // Case 1: const() * X => scaled(X, const()) 
212            (AbstractExpression::Constant(f), _) => AbstractExpression::Scaled(Box::new(rhs), *f),
213            (_, AbstractExpression::Constant(_)) => switch(lhs, rhs),
214
215            // Case 2: sel() * X => KILL (not allowed)
216            (AbstractExpression::Selector(..), _) => panic!("Multiplying a non-constant with a selector is not allowed! Create a separate sector or fold the operand into each branch!"),
217            (_, AbstractExpression::Selector(..)) => switch(lhs, rhs),
218
219            // Case 3: add(X, Y) * Z => add(X * Z, Y * Z)
220            (AbstractExpression::Sum(x, y), _) => {
221                let xr = Self::mult(*x.clone(), rhs.clone());
222                let yr = Self::mult(*y.clone(), rhs);
223                AbstractExpression::Sum(Box::new(xr), Box::new(yr))
224            }
225            (_, AbstractExpression::Sum(..)) => switch(lhs, rhs),
226
227            // Case 4: scaled(X, c1) * scaled(Y, c2) => scaled(X * Y, c1 * c2); scaled(X, c) * Z => scaled(X * Z, c)
228            (AbstractExpression::Scaled(x, c1), AbstractExpression::Scaled(y, c2)) => {
229                let xy = Self::mult(*x.clone(), *y.clone());
230                let c = *c1 * *c2;
231                AbstractExpression::Scaled(Box::new(xy), c)
232            }
233            (AbstractExpression::Scaled(x, c), _) => {
234                let xz = Self::mult(*x.clone(), rhs);
235                AbstractExpression::Scaled(Box::new(xz), *c)
236            }
237            (_, AbstractExpression::Scaled(..)) => switch(lhs, rhs),
238
239            // Case 5: mle() * mle(); prod() * prod()
240            (l, r) => {
241                let l_ids = match l {
242                    AbstractExpression::Mle(id) => vec![*id],
243                    AbstractExpression::Product(ids) => ids.clone(),
244                    _ => unreachable!()
245                };
246                let r_ids = match r {
247                    AbstractExpression::Mle(id) => vec![*id],
248                    AbstractExpression::Product(ids) => ids.clone(),
249                    _ => unreachable!()
250                };
251                let ids = [l_ids, r_ids].concat();
252                AbstractExpression::Product(ids)
253            }
254        }
255    }
256
257    /// Create a mle Expression that contains one MLE
258    pub fn mle(node_id: NodeId) -> Self {
259        AbstractExpression::Mle(node_id)
260    }
261
262    /// Create a constant Expression that contains one field element
263    pub fn constant(constant: F) -> Self {
264        AbstractExpression::Constant(constant)
265    }
266
267    /// negates an Expression
268    pub fn negated(expression: Self) -> Self {
269        AbstractExpression::Scaled(Box::new(expression), F::from(1).neg())
270    }
271
272    /// Create a Sum Expression that contains two MLEs
273    pub fn sum(lhs: Self, rhs: Self) -> Self {
274        AbstractExpression::Sum(Box::new(lhs), Box::new(rhs))
275    }
276
277    /// scales an Expression by a field element
278    pub fn scaled(expression: AbstractExpression<F>, scale: F) -> Self {
279        AbstractExpression::Scaled(Box::new(expression), scale)
280    }
281}
282
283impl<F: Field> AbstractExpression<F> {
284    fn build_circuit_node(
285        self,
286        node_map: &HashMap<NodeId, (usize, &CircuitLocation)>,
287    ) -> Result<ExpressionNode<F, ExprDescription>> {
288        // Note that the node_map is the map of node_ids to the internal vec of MLEs, not the circuit_map
289        match self {
290            AbstractExpression::Constant(val) => Ok(ExpressionNode::Constant(val)),
291            AbstractExpression::Selector(mle_index, lhs, rhs) => {
292                let lhs = lhs.build_circuit_node(node_map)?;
293                let rhs = rhs.build_circuit_node(node_map)?;
294                Ok(ExpressionNode::Selector(
295                    mle_index,
296                    Box::new(lhs),
297                    Box::new(rhs),
298                ))
299            }
300            AbstractExpression::Mle(node_id) => {
301                let (
302                    num_vars,
303                    CircuitLocation {
304                        prefix_bits,
305                        layer_id,
306                    },
307                ) = node_map
308                    .get(&node_id)
309                    .ok_or(LayoutingError::DanglingNodeId(node_id))?;
310                let total_indices = get_total_mle_indices(prefix_bits, *num_vars);
311                let circuit_mle = MleDescription::new(*layer_id, &total_indices);
312                Ok(ExpressionNode::Mle(circuit_mle))
313            }
314            AbstractExpression::Sum(lhs, rhs) => {
315                let lhs = lhs.build_circuit_node(node_map)?;
316                let rhs = rhs.build_circuit_node(node_map)?;
317                Ok(ExpressionNode::Sum(Box::new(lhs), Box::new(rhs)))
318            }
319            AbstractExpression::Product(nodes) => {
320                let circuit_mles = nodes
321                    .into_iter()
322                    .map(|node_id| {
323                        let (
324                            num_vars,
325                            CircuitLocation {
326                                prefix_bits,
327                                layer_id,
328                            },
329                        ) = node_map
330                            .get(&node_id)
331                            .ok_or(LayoutingError::DanglingNodeId(node_id))
332                            .unwrap();
333                        let total_indices = get_total_mle_indices::<F>(prefix_bits, *num_vars);
334                        MleDescription::new(*layer_id, &total_indices)
335                    })
336                    .collect::<Vec<MleDescription<F>>>();
337                Ok(ExpressionNode::Product(circuit_mles))
338            }
339            AbstractExpression::Scaled(expr, scalar) => {
340                let expr = expr.build_circuit_node(node_map)?;
341                Ok(ExpressionNode::Scaled(Box::new(expr), scalar))
342            }
343        }
344    }
345
346    fn get_node_ids(&self, mut node_ids: Vec<NodeId>) -> Vec<NodeId> {
347        match self {
348            AbstractExpression::Constant(_) => node_ids,
349            AbstractExpression::Selector(_, lhs, rhs) => {
350                let node_ids = rhs.get_node_ids(node_ids);
351                lhs.get_node_ids(node_ids)
352            }
353            AbstractExpression::Mle(node_id) => {
354                node_ids.push(*node_id);
355                node_ids
356            }
357            AbstractExpression::Sum(lhs, rhs) => {
358                let node_ids = lhs.get_node_ids(node_ids);
359                rhs.get_node_ids(node_ids)
360            }
361            AbstractExpression::Product(nodes) => {
362                node_ids.extend(nodes.iter());
363                node_ids
364            }
365            AbstractExpression::Scaled(expr, _) => expr.get_node_ids(node_ids),
366        }
367    }
368}
369
370// Additional operators
371impl<F: Field> Neg for AbstractExpression<F> {
372    type Output = AbstractExpression<F>;
373    fn neg(self) -> Self::Output {
374        AbstractExpression::<F>::negated(self)
375    }
376}
377
378impl<F: Field> Neg for &AbstractExpression<F> {
379    type Output = AbstractExpression<F>;
380    fn neg(self) -> Self::Output {
381        AbstractExpression::<F>::negated(self.clone())
382    }
383}
384
385impl<F: Field> From<F> for AbstractExpression<F> {
386    fn from(f: F) -> Self {
387        AbstractExpression::<F>::constant(f)
388    }
389}
390
391/// implement the Add, Sub, and Mul traits for the Expression
392impl<F: Field, Rhs: Into<AbstractExpression<F>>> Add<Rhs> for AbstractExpression<F> {
393    type Output = AbstractExpression<F>;
394    fn add(self, rhs: Rhs) -> Self::Output {
395        AbstractExpression::sum(self, rhs.into())
396    }
397}
398impl<F: Field, Rhs: Into<AbstractExpression<F>>> Add<Rhs> for &AbstractExpression<F> {
399    type Output = AbstractExpression<F>;
400    fn add(self, rhs: Rhs) -> Self::Output {
401        AbstractExpression::sum(self.clone(), rhs.into())
402    }
403}
404
405impl<F: Field, Rhs: Into<AbstractExpression<F>>> AddAssign<Rhs> for AbstractExpression<F> {
406    fn add_assign(&mut self, rhs: Rhs) {
407        *self = self.clone() + rhs;
408    }
409}
410
411impl<F: Field, Rhs: Into<AbstractExpression<F>>> Sub<Rhs> for AbstractExpression<F> {
412    type Output = AbstractExpression<F>;
413    fn sub(self, rhs: Rhs) -> Self::Output {
414        AbstractExpression::sum(self, rhs.into().neg())
415    }
416}
417impl<F: Field, Rhs: Into<AbstractExpression<F>>> Sub<Rhs> for &AbstractExpression<F> {
418    type Output = AbstractExpression<F>;
419    fn sub(self, rhs: Rhs) -> Self::Output {
420        AbstractExpression::sum(self.clone(), rhs.into().neg())
421    }
422}
423impl<F: Field, Rhs: Into<AbstractExpression<F>>> SubAssign<Rhs> for AbstractExpression<F> {
424    fn sub_assign(&mut self, rhs: Rhs) {
425        *self = self.clone() - rhs;
426    }
427}
428
429impl<F: Field, Rhs: Into<AbstractExpression<F>>> Mul<Rhs> for AbstractExpression<F> {
430    type Output = AbstractExpression<F>;
431    fn mul(self, rhs: Rhs) -> Self::Output {
432        AbstractExpression::mult(self, rhs.into())
433    }
434}
435impl<F: Field, Rhs: Into<AbstractExpression<F>>> Mul<Rhs> for &AbstractExpression<F> {
436    type Output = AbstractExpression<F>;
437    fn mul(self, rhs: Rhs) -> Self::Output {
438        AbstractExpression::mult(self.clone(), rhs.into())
439    }
440}
441impl<F: Field, Rhs: Into<AbstractExpression<F>>> MulAssign<Rhs> for AbstractExpression<F> {
442    fn mul_assign(&mut self, rhs: Rhs) {
443        *self = self.clone() * rhs;
444    }
445}
446
447impl<F: Field, Rhs: Into<AbstractExpression<F>>> BitXor<Rhs> for AbstractExpression<F> {
448    type Output = AbstractExpression<F>;
449    fn bitxor(self, rhs: Rhs) -> Self::Output {
450        let rhs_expr: AbstractExpression<F> = rhs.into();
451        self.clone() + rhs_expr.clone() - self * rhs_expr * F::from(2)
452    }
453}
454impl<F: Field, Rhs: Into<AbstractExpression<F>>> BitXor<Rhs> for &AbstractExpression<F> {
455    type Output = AbstractExpression<F>;
456    fn bitxor(self, rhs: Rhs) -> Self::Output {
457        let rhs_expr: &AbstractExpression<F> = &rhs.into();
458        self.clone() + rhs_expr.clone() - self.clone() * rhs_expr * F::from(2)
459    }
460}
461
462impl<F: Field> From<&AbstractExpression<F>> for AbstractExpression<F> {
463    fn from(val: &AbstractExpression<F>) -> Self {
464        val.clone()
465    }
466}
467
468/// constant
469#[macro_export]
470macro_rules! const_expr {
471    ($val:expr) => {{
472        use frontend::abstract_expr::AbstractExpression;
473        AbstractExpression::Constant($val)
474    }};
475}
476
477/// selector
478/// equivalent of calling `AbstractExpression::<F>::select_seq(vec![<INPUTS>])`
479/// but allows the entries to be of different type
480#[macro_export]
481macro_rules! sel_expr {
482    ($($expr:expr),+ $(,)?) => {{
483        use frontend::abstract_expr::{AbstractExpression};
484        let v = vec![$(Into::<AbstractExpression<F>>::into($expr)),+];
485        AbstractExpression::<F>::select_seq(v)
486    }};
487}
488
489// defines how the AbstractExpression are printed and displayed
490impl<F: std::fmt::Debug + Field> std::fmt::Debug for AbstractExpression<F> {
491    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492        match self {
493            AbstractExpression::Constant(scalar) => {
494                f.debug_tuple("Constant").field(scalar).finish()
495            }
496            AbstractExpression::Selector(index, a, b) => f
497                .debug_tuple("Selector")
498                .field(index)
499                .field(a)
500                .field(b)
501                .finish(),
502            // Skip enum variant and print query struct directly to maintain backwards compatibility.
503            AbstractExpression::Mle(mle) => f.debug_struct("Mle").field("mle", mle).finish(),
504            AbstractExpression::Sum(a, b) => f.debug_tuple("Sum").field(a).field(b).finish(),
505            AbstractExpression::Product(a) => f.debug_tuple("Product").field(a).finish(),
506            AbstractExpression::Scaled(poly, scalar) => {
507                f.debug_tuple("Scaled").field(poly).field(scalar).finish()
508            }
509        }
510    }
511}