remainder/expression/generic_expr.rs
1//! Functionality which is common to all "expression"s (see documentation within
2//! [crate::expression]). See documentation in [Expression] for high-level
3//! summary.
4
5use crate::mle::MleIndex;
6use serde::{Deserialize, Serialize};
7use shared_types::Field;
8use std::hash::Hash;
9
10use anyhow::{Ok, Result};
11
12/// An [ExpressionType] defines two fields -- the type of MLE representation
13/// at the leaf of the expression node tree, and the "global" unique copies
14/// of each of the MLEs (this is so that if an expression references the
15/// same MLE multiple times, the data stored therein is not duplicated)
16pub trait ExpressionType<F: Field>: Serialize + for<'de> Deserialize<'de> {
17 /// The type of thing representing an MLE within the leaves of an
18 /// expression. Note that for most expression types, this is the
19 /// intuitive thing (e.g. for [crate::expression::circuit_expr::ExprDescription]
20 /// this is an [`crate::mle::mle_description::MleDescription<F>`]),
21 /// but for [crate::expression::prover_expr::ProverExpr] specifically this
22 /// is an [crate::expression::prover_expr::MleVecIndex], i.e. the
23 /// index within the `MleVec` which contains the unique representation
24 /// of the prover's view of each MLE.
25 type MLENodeRepr: Clone + Serialize + for<'de> Deserialize<'de> + Hash;
26
27 /// The idea here is that an expression may have many MLEs (or things
28 /// representing MLEs) in its description, including duplicates, but
29 /// we only wish to store one copy for each instance of a thing
30 /// representing an MLE. The `MleVec` represents that list of unique
31 /// copies.
32 /// For example, this is `Vec<DenseMle>` for
33 /// [crate::expression::prover_expr::ProverExpr].
34 type MleVec: Serialize + for<'de> Deserialize<'de>;
35}
36
37/// [ExpressionNode] can be made up of the following:
38/// * [ExpressionNode::Constant], i.e. + c for c \in \mathbb{F}
39/// * [ExpressionNode::Mle], i.e. \widetilde{V}_{j > i}(b_1, ..., b_{m \leq n})
40/// * [ExpressionNode::Product], i.e. \prod_j \widetilde{V}_{j > i}(b_1, ..., b_{m \leq n})
41/// * [ExpressionNode::Selector], i.e. (1 - b_0) * Expr(b_1, ..., b_{m \leq n}) + b_0 * Expr(b_1, ..., b_{m \leq n})
42/// * [ExpressionNode::Sum], i.e. \widetilde{V}_{j_1 > i}(b_1, ..., b_{m_1 \leq n}) + \widetilde{V}_{j_2 > i}(b_1, ..., b_{m_2 \leq n})
43/// * [ExpressionNode::Scaled], i.e. c * Expr(b_1, ..., b_{m \leq n}) for c \in mathbb{F}
44#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq)]
45#[serde(bound = "F: Field")]
46pub enum ExpressionNode<F: Field, E: ExpressionType<F>> {
47 /// See documentation for [ExpressionNode]. Note that
48 /// [ExpressionNode::Constant] can be an expression tree's leaf.
49 Constant(F),
50 /// See documentation for [ExpressionNode].
51 Selector(
52 MleIndex<F>,
53 Box<ExpressionNode<F, E>>,
54 Box<ExpressionNode<F, E>>,
55 ),
56 /// An [ExpressionNode] representing the leaf of an expression tree which
57 /// is actually mathematically defined as a multilinear extension.
58 Mle(E::MLENodeRepr),
59 /// See documentation for [ExpressionNode].
60 Sum(Box<ExpressionNode<F, E>>, Box<ExpressionNode<F, E>>),
61 /// The product of several multilinear extension functions. This is also
62 /// an expression tree's leaf.
63 Product(Vec<E::MLENodeRepr>),
64 /// See documentation for [ExpressionNode].
65 Scaled(Box<ExpressionNode<F, E>>, F),
66}
67
68/// The high-level idea is that an [Expression] is generic over [ExpressionType]
69/// , and contains within it a single parent [ExpressionNode] as well as an
70/// [ExpressionType::MleVec] containing the unique leaf representations for the
71/// leaves of the [ExpressionNode] tree.
72#[derive(Serialize, Deserialize, Clone, Hash)]
73#[serde(bound = "F: Field")]
74pub struct Expression<F: Field, E: ExpressionType<F>> {
75 /// The root of the expression "tree".
76 pub expression_node: ExpressionNode<F, E>,
77 /// The unique owned copies of all MLEs which are "leaves" within the
78 /// expression "tree".
79 pub mle_vec: E::MleVec,
80}
81
82/// generic methods shared across all types of expressions
83impl<F: Field, E: ExpressionType<F>> Expression<F, E> {
84 /// Create a new expression
85 pub fn new(expression_node: ExpressionNode<F, E>, mle_vec: E::MleVec) -> Self {
86 Self {
87 expression_node,
88 mle_vec,
89 }
90 }
91
92 /// Returns a reference to the internal `expression_node` and `mle_vec` fields.
93 pub fn deconstruct_ref(&self) -> (&ExpressionNode<F, E>, &E::MleVec) {
94 (&self.expression_node, &self.mle_vec)
95 }
96
97 /// Returns a mutable reference to the `expression_node` and `mle_vec`
98 /// present within the given [Expression].
99 pub fn deconstruct_mut(&mut self) -> (&mut ExpressionNode<F, E>, &mut E::MleVec) {
100 (&mut self.expression_node, &mut self.mle_vec)
101 }
102
103 /// Takes ownership of the [Expression] and returns the owned values to its
104 /// internal `expression_node` and `mle_vec`.
105 pub fn deconstruct(self) -> (ExpressionNode<F, E>, E::MleVec) {
106 (self.expression_node, self.mle_vec)
107 }
108
109 /// traverse the expression tree, and applies the observer_fn to all child node
110 /// because the expression node has the recursive structure, the traverse_node
111 /// helper function is implemented on it, with the mle_vec reference passed in
112 pub fn traverse(
113 &self,
114 observer_fn: &mut impl FnMut(&ExpressionNode<F, E>, &E::MleVec) -> Result<()>,
115 ) -> Result<()> {
116 self.expression_node
117 .traverse_node(observer_fn, &self.mle_vec)
118 }
119
120 /// similar to traverse, but allows mutation of self (expression node and mle_vec)
121 pub fn traverse_mut(
122 &mut self,
123 observer_fn: &mut impl FnMut(&mut ExpressionNode<F, E>, &mut E::MleVec) -> Result<()>,
124 ) -> Result<()> {
125 self.expression_node
126 .traverse_node_mut(observer_fn, &mut self.mle_vec)
127 }
128}
129
130/// Generic helper methods shared across all types of [ExpressionNode]s.
131impl<F: Field, E: ExpressionType<F>> ExpressionNode<F, E> {
132 /// traverse the expression tree, and applies the observer_fn to all child node / the mle_vec reference
133 pub fn traverse_node(
134 &self,
135 observer_fn: &mut impl FnMut(&ExpressionNode<F, E>, &E::MleVec) -> Result<()>,
136 mle_vec: &E::MleVec,
137 ) -> Result<()> {
138 observer_fn(self, mle_vec)?;
139 match self {
140 ExpressionNode::Constant(_) | ExpressionNode::Mle(_) | ExpressionNode::Product(_) => {
141 Ok(())
142 }
143 ExpressionNode::Scaled(exp, _) => exp.traverse_node(observer_fn, mle_vec),
144 ExpressionNode::Selector(_, lhs, rhs) => {
145 lhs.traverse_node(observer_fn, mle_vec)?;
146 rhs.traverse_node(observer_fn, mle_vec)
147 }
148 ExpressionNode::Sum(lhs, rhs) => {
149 lhs.traverse_node(observer_fn, mle_vec)?;
150 rhs.traverse_node(observer_fn, mle_vec)
151 }
152 }
153 }
154
155 /// similar to traverse, but allows mutation of self (expression node and mle_vec)
156 pub fn traverse_node_mut(
157 &mut self,
158 observer_fn: &mut impl FnMut(&mut ExpressionNode<F, E>, &mut E::MleVec) -> Result<()>,
159 mle_vec: &mut E::MleVec,
160 ) -> Result<()> {
161 observer_fn(self, mle_vec)?;
162 match self {
163 ExpressionNode::Constant(_) | ExpressionNode::Mle(_) | ExpressionNode::Product(_) => {
164 Ok(())
165 }
166 ExpressionNode::Scaled(exp, _) => exp.traverse_node_mut(observer_fn, mle_vec),
167 ExpressionNode::Selector(_, lhs, rhs) => {
168 lhs.traverse_node_mut(observer_fn, mle_vec)?;
169 rhs.traverse_node_mut(observer_fn, mle_vec)
170 }
171 ExpressionNode::Sum(lhs, rhs) => {
172 lhs.traverse_node_mut(observer_fn, mle_vec)?;
173 rhs.traverse_node_mut(observer_fn, mle_vec)
174 }
175 }
176 }
177}