remainder/expression/
generic_expr.rs

1//! Functionality which is common to all "expression"s (see documentation within
2//! [crate::expression]). See documentation in [Expression] for high-level
3//! summary.
4
5use crate::mle::MleIndex;
6use serde::{Deserialize, Serialize};
7use shared_types::Field;
8use std::hash::Hash;
9
10use anyhow::{Ok, Result};
11
12/// An [ExpressionType] defines two fields -- the type of MLE representation
13/// at the leaf of the expression node tree, and the "global" unique copies
14/// of each of the MLEs (this is so that if an expression references the
15/// same MLE multiple times, the data stored therein is not duplicated)
16pub trait ExpressionType<F: Field>: Serialize + for<'de> Deserialize<'de> {
17    /// The type of thing representing an MLE within the leaves of an
18    /// expression. Note that for most expression types, this is the
19    /// intuitive thing (e.g. for [crate::expression::circuit_expr::ExprDescription]
20    /// this is an [`crate::mle::mle_description::MleDescription<F>`]),
21    /// but for [crate::expression::prover_expr::ProverExpr] specifically this
22    /// is an [crate::expression::prover_expr::MleVecIndex], i.e. the
23    /// index within the `MleVec` which contains the unique representation
24    /// of the prover's view of each MLE.
25    type MLENodeRepr: Clone + Serialize + for<'de> Deserialize<'de> + Hash;
26
27    /// The idea here is that an expression may have many MLEs (or things
28    /// representing MLEs) in its description, including duplicates, but
29    /// we only wish to store one copy for each instance of a thing
30    /// representing an MLE. The `MleVec` represents that list of unique
31    /// copies.
32    /// For example, this is `Vec<DenseMle>` for
33    /// [crate::expression::prover_expr::ProverExpr].
34    type MleVec: Serialize + for<'de> Deserialize<'de>;
35}
36
37/// [ExpressionNode] can be made up of the following:
38/// * [ExpressionNode::Constant], i.e. + c for c \in \mathbb{F}
39/// * [ExpressionNode::Mle], i.e. \widetilde{V}_{j > i}(b_1, ..., b_{m \leq n})
40/// * [ExpressionNode::Product], i.e. \prod_j \widetilde{V}_{j > i}(b_1, ..., b_{m \leq n})
41/// * [ExpressionNode::Selector], i.e. (1 - b_0) * Expr(b_1, ..., b_{m \leq n}) + b_0 * Expr(b_1, ..., b_{m \leq n})
42/// * [ExpressionNode::Sum], i.e. \widetilde{V}_{j_1 > i}(b_1, ..., b_{m_1 \leq n}) + \widetilde{V}_{j_2 > i}(b_1, ..., b_{m_2 \leq n})
43/// * [ExpressionNode::Scaled], i.e. c * Expr(b_1, ..., b_{m \leq n}) for c \in mathbb{F}
44#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq)]
45#[serde(bound = "F: Field")]
46pub enum ExpressionNode<F: Field, E: ExpressionType<F>> {
47    /// See documentation for [ExpressionNode]. Note that
48    /// [ExpressionNode::Constant] can be an expression tree's leaf.
49    Constant(F),
50    /// See documentation for [ExpressionNode].
51    Selector(
52        MleIndex<F>,
53        Box<ExpressionNode<F, E>>,
54        Box<ExpressionNode<F, E>>,
55    ),
56    /// An [ExpressionNode] representing the leaf of an expression tree which
57    /// is actually mathematically defined as a multilinear extension.
58    Mle(E::MLENodeRepr),
59    /// See documentation for [ExpressionNode].
60    Sum(Box<ExpressionNode<F, E>>, Box<ExpressionNode<F, E>>),
61    /// The product of several multilinear extension functions. This is also
62    /// an expression tree's leaf.
63    Product(Vec<E::MLENodeRepr>),
64    /// See documentation for [ExpressionNode].
65    Scaled(Box<ExpressionNode<F, E>>, F),
66}
67
68/// The high-level idea is that an [Expression] is generic over [ExpressionType]
69/// , and contains within it a single parent [ExpressionNode] as well as an
70/// [ExpressionType::MleVec] containing the unique leaf representations for the
71/// leaves of the [ExpressionNode] tree.
72#[derive(Serialize, Deserialize, Clone, Hash)]
73#[serde(bound = "F: Field")]
74pub struct Expression<F: Field, E: ExpressionType<F>> {
75    /// The root of the expression "tree".
76    pub expression_node: ExpressionNode<F, E>,
77    /// The unique owned copies of all MLEs which are "leaves" within the
78    /// expression "tree".
79    pub mle_vec: E::MleVec,
80}
81
82/// generic methods shared across all types of expressions
83impl<F: Field, E: ExpressionType<F>> Expression<F, E> {
84    /// Create a new expression
85    pub fn new(expression_node: ExpressionNode<F, E>, mle_vec: E::MleVec) -> Self {
86        Self {
87            expression_node,
88            mle_vec,
89        }
90    }
91
92    /// Returns a reference to the internal `expression_node` and `mle_vec` fields.
93    pub fn deconstruct_ref(&self) -> (&ExpressionNode<F, E>, &E::MleVec) {
94        (&self.expression_node, &self.mle_vec)
95    }
96
97    /// Returns a mutable reference to the `expression_node` and `mle_vec`
98    /// present within the given [Expression].
99    pub fn deconstruct_mut(&mut self) -> (&mut ExpressionNode<F, E>, &mut E::MleVec) {
100        (&mut self.expression_node, &mut self.mle_vec)
101    }
102
103    /// Takes ownership of the [Expression] and returns the owned values to its
104    /// internal `expression_node` and `mle_vec`.
105    pub fn deconstruct(self) -> (ExpressionNode<F, E>, E::MleVec) {
106        (self.expression_node, self.mle_vec)
107    }
108
109    /// traverse the expression tree, and applies the observer_fn to all child node
110    /// because the expression node has the recursive structure, the traverse_node
111    /// helper function is implemented on it, with the mle_vec reference passed in
112    pub fn traverse(
113        &self,
114        observer_fn: &mut impl FnMut(&ExpressionNode<F, E>, &E::MleVec) -> Result<()>,
115    ) -> Result<()> {
116        self.expression_node
117            .traverse_node(observer_fn, &self.mle_vec)
118    }
119
120    /// similar to traverse, but allows mutation of self (expression node and mle_vec)
121    pub fn traverse_mut(
122        &mut self,
123        observer_fn: &mut impl FnMut(&mut ExpressionNode<F, E>, &mut E::MleVec) -> Result<()>,
124    ) -> Result<()> {
125        self.expression_node
126            .traverse_node_mut(observer_fn, &mut self.mle_vec)
127    }
128}
129
130/// Generic helper methods shared across all types of [ExpressionNode]s.
131impl<F: Field, E: ExpressionType<F>> ExpressionNode<F, E> {
132    /// traverse the expression tree, and applies the observer_fn to all child node / the mle_vec reference
133    pub fn traverse_node(
134        &self,
135        observer_fn: &mut impl FnMut(&ExpressionNode<F, E>, &E::MleVec) -> Result<()>,
136        mle_vec: &E::MleVec,
137    ) -> Result<()> {
138        observer_fn(self, mle_vec)?;
139        match self {
140            ExpressionNode::Constant(_) | ExpressionNode::Mle(_) | ExpressionNode::Product(_) => {
141                Ok(())
142            }
143            ExpressionNode::Scaled(exp, _) => exp.traverse_node(observer_fn, mle_vec),
144            ExpressionNode::Selector(_, lhs, rhs) => {
145                lhs.traverse_node(observer_fn, mle_vec)?;
146                rhs.traverse_node(observer_fn, mle_vec)
147            }
148            ExpressionNode::Sum(lhs, rhs) => {
149                lhs.traverse_node(observer_fn, mle_vec)?;
150                rhs.traverse_node(observer_fn, mle_vec)
151            }
152        }
153    }
154
155    /// similar to traverse, but allows mutation of self (expression node and mle_vec)
156    pub fn traverse_node_mut(
157        &mut self,
158        observer_fn: &mut impl FnMut(&mut ExpressionNode<F, E>, &mut E::MleVec) -> Result<()>,
159        mle_vec: &mut E::MleVec,
160    ) -> Result<()> {
161        observer_fn(self, mle_vec)?;
162        match self {
163            ExpressionNode::Constant(_) | ExpressionNode::Mle(_) | ExpressionNode::Product(_) => {
164                Ok(())
165            }
166            ExpressionNode::Scaled(exp, _) => exp.traverse_node_mut(observer_fn, mle_vec),
167            ExpressionNode::Selector(_, lhs, rhs) => {
168                lhs.traverse_node_mut(observer_fn, mle_vec)?;
169                rhs.traverse_node_mut(observer_fn, mle_vec)
170            }
171            ExpressionNode::Sum(lhs, rhs) => {
172                lhs.traverse_node_mut(observer_fn, mle_vec)?;
173                rhs.traverse_node_mut(observer_fn, mle_vec)
174            }
175        }
176    }
177}