remainder/layer/
product.rs

1//! Standardized expression representation for oracle query (GKR verifier) +
2//! determining necessary proof-of-products (Hyrax prover + verifier)
3
4use shared_types::Field;
5
6use super::LayerId;
7use crate::mle::dense::DenseMle;
8use crate::mle::mle_description::MleDescription;
9use crate::mle::Mle;
10
11/// Represents a normal form for a layer expression in which the layer is represented as a linear
12/// combination of products of other layer MLEs, the coefficients of which are public.
13#[derive(Debug)]
14pub struct PostSumcheckLayer<F: Field, T>(pub Vec<Product<F, T>>);
15
16// FIXME can we implement all of these evaluate functions with a single function using trait bounds?
17// needs to have a zero() method (Default?).  need mulassign?
18impl<F: Field> PostSumcheckLayer<F, F> {
19    /// Evaluate the PostSumcheckLayer to a single scalar
20    pub fn evaluate_scalar(&self) -> F {
21        self.0.iter().fold(F::ZERO, |acc, product| {
22            acc + product.get_result() * product.coefficient
23        })
24    }
25}
26
27/// Represents a fully bound product of MLEs, or a single MLE (which we consider a simple product).
28/// Data structure for extracting the values to be commited to, their "mxi" values (i.e. their
29/// public coefficients), and also the claims made on other layers.
30#[derive(Debug)]
31pub struct Product<F: Field, T> {
32    /// the evaluated MLEs that are being multiplied and their intermediates
33    /// in the order a, b, ab, c, abc, d, abcd, ...
34    pub intermediates: Vec<Intermediate<F, T>>,
35    /// the (public) coefficient i.e. the "mxi"
36    pub coefficient: F,
37}
38
39impl<F: Field> Product<F, Option<F>> {
40    /// Creates a new Product from a vector of [`MleDescription<F>`].
41    pub fn new(mles: &[MleDescription<F>], coefficient: F, bindings: &[F]) -> Self {
42        if mles.is_empty() {
43            return Product {
44                intermediates: vec![Intermediate::Composite {
45                    value: Some(F::ONE),
46                }],
47                coefficient,
48            };
49        }
50        let mut intermediates = vec![Self::build_atom(&mles[0], bindings)];
51        mles.iter().skip(1).for_each(|mle| {
52            intermediates.push(Self::build_atom(mle, bindings));
53            intermediates.push(Intermediate::Composite { value: None });
54        });
55        Product {
56            intermediates,
57            coefficient,
58        }
59    }
60
61    /// Creates a new Product from a vector of [`MleDescription<F>`].
62    pub fn new_from_mul_gate(
63        mles: &[MleDescription<F>],
64        coefficient: F,
65        bindings: &[&[F]],
66    ) -> Self {
67        if mles.is_empty() {
68            return Product {
69                intermediates: vec![Intermediate::Composite {
70                    value: Some(F::ONE),
71                }],
72                coefficient,
73            };
74        }
75        let mut intermediates = vec![Self::build_atom(&mles[0], bindings[0])];
76        mles.iter().enumerate().skip(1).for_each(|(idx, mle_ref)| {
77            intermediates.push(Self::build_atom(mle_ref, bindings[idx]));
78            intermediates.push(Intermediate::Composite { value: None });
79        });
80        Product {
81            intermediates,
82            coefficient,
83        }
84    }
85
86    // Helper function for new
87    fn build_atom(mle: &MleDescription<F>, bindings: &[F]) -> Intermediate<F, Option<F>> {
88        Intermediate::Atom {
89            layer_id: mle.layer_id(),
90            point: mle.get_claim_point(bindings),
91            value: None,
92        }
93    }
94}
95
96impl<F: Field> Product<F, F> {
97    /// Creates a new Product from a vector of fully bound MleRefs.
98    /// Panics if any are not fully bound.
99    pub fn new(mles: &[DenseMle<F>], coefficient: F) -> Self {
100        // ensure all MLEs are fully bound
101        assert!(mles.iter().all(|mle| mle.is_fully_bounded()));
102        if mles.is_empty() {
103            return Product {
104                intermediates: vec![Intermediate::Composite { value: F::ONE }],
105                coefficient,
106            };
107        }
108        let mut intermediates = vec![Self::build_atom(&mles[0])];
109        let _ = mles.iter().skip(1).fold(mles[0].value(), |acc, mle| {
110            let prod_val = acc * mle.value();
111            intermediates.push(Self::build_atom(mle));
112            intermediates.push(Intermediate::Composite { value: prod_val });
113            prod_val
114        });
115        Product {
116            intermediates,
117            coefficient,
118        }
119    }
120
121    // Helper function for new
122    fn build_atom(mle: &DenseMle<F>) -> Intermediate<F, F> {
123        Intermediate::Atom {
124            layer_id: mle.layer_id,
125            point: mle.get_bound_point(),
126            value: mle.value(),
127        }
128    }
129}
130
131#[derive(Clone, Debug)]
132/// Represents either an atomic factor of a product (i.e. an evaluation of an MLE), or the result of
133/// an intermediate product of atoms.
134pub enum Intermediate<F: Field, T> {
135    /// A struct representing a single MLE and a commitment to its evaluation.
136    Atom {
137        /// the id of the layer upon which this is a claim
138        layer_id: LayerId,
139        /// the evaluation point
140        point: Vec<F>,
141        /// the value (C::Scalar), commitment to the value (C), or CommittedScalar
142        value: T,
143    },
144    /// A struct representing a commitment to the product of two MLE evaluations.
145    Composite {
146        /// the value, commitment to the value, or CommittedScalar
147        value: T,
148    },
149}
150
151impl<F: Field, T: Copy> PostSumcheckLayer<F, T> {
152    /// Returns a vector of the values of the intermediates
153    /// (in an order compatible with [new_with_values]).
154    pub fn get_values(&self) -> Vec<T> {
155        self.0
156            .iter()
157            .flat_map(|product| {
158                product
159                    .intermediates
160                    .iter()
161                    .map(|pp| match pp {
162                        Intermediate::Atom { value, .. } => *value,
163                        Intermediate::Composite { value } => *value,
164                    })
165                    .collect::<Vec<T>>()
166            })
167            .collect()
168    }
169
170    /// Return a vector of all the coefficients of the products.
171    pub fn get_coefficients(&self) -> Vec<F> {
172        self.0.iter().map(|product| product.coefficient).collect()
173    }
174}
175
176/// Set the values of the PostSumcheckLayer to the given values, panicking if the lengths do not match,
177/// returning a new instance. Counterpart to [PostSumcheckLayer::get_values].
178pub fn new_with_values<F: Field, S, T: Clone>(
179    post_sumcheck_layer: &PostSumcheckLayer<F, S>,
180    values: &[T],
181) -> PostSumcheckLayer<F, T> {
182    let total_len: usize = post_sumcheck_layer
183        .0
184        .iter()
185        .map(|product| product.intermediates.len())
186        .sum();
187    assert_eq!(total_len, values.len());
188    let mut start = 0;
189    PostSumcheckLayer(
190        post_sumcheck_layer
191            .0
192            .iter()
193            .map(|product| {
194                let end = start + product.intermediates.len();
195                let product_values = values[start..end].to_vec();
196                start = end;
197                new_with_values_single(product, product_values)
198            })
199            .collect(),
200    )
201}
202
203/// Helper for [new_with_values].
204/// Set the values of the Product to the given values, panicking if the lengths do not match.
205fn new_with_values_single<F: Field, S, T>(
206    product: &Product<F, S>,
207    values: Vec<T>,
208) -> Product<F, T> {
209    assert_eq!(product.intermediates.len(), values.len());
210    Product {
211        coefficient: product.coefficient,
212        intermediates: product
213            .intermediates
214            .iter()
215            .zip(values)
216            .map(|(pp, value)| match pp {
217                Intermediate::Atom {
218                    layer_id, point, ..
219                } => Intermediate::Atom {
220                    layer_id: *layer_id,
221                    point: point.clone(),
222                    value,
223                },
224                Intermediate::Composite { .. } => Intermediate::Composite { value },
225            })
226            .collect(),
227    }
228}
229
230impl<F: Field, T: Clone> Product<F, T> {
231    /// Return the resulting value of the product.
232    /// Useful to build the commitment to the oracle evaluation.
233    pub fn get_result(&self) -> T {
234        let last = &self.intermediates[self.intermediates.len() - 1];
235        match last {
236            Intermediate::Atom { value, .. } => {
237                // this product had better consist of just one MLE!
238                assert_eq!(self.intermediates.len(), 1);
239                value.clone()
240            }
241            Intermediate::Composite { value, .. } => value.clone(),
242        }
243    }
244
245    /// Return a vector of triples (x, y, z) where z=x*y, or None.
246    pub fn get_product_triples(&self) -> Option<Vec<(T, T, T)>> {
247        if self.intermediates.len() > 1 {
248            assert!(self.intermediates.len() >= 3);
249            let values = self
250                .intermediates
251                .iter()
252                .map(|pp| match pp {
253                    Intermediate::Atom { value, .. } => value.clone(),
254                    Intermediate::Composite { value, .. } => value.clone(),
255                })
256                .collect::<Vec<_>>();
257            Some(
258                values
259                    .windows(3)
260                    .map(|window| {
261                        if let [x, y, z] = window {
262                            (x.clone(), y.clone(), z.clone())
263                        } else {
264                            unreachable!()
265                        }
266                    })
267                    .collect::<Vec<_>>(),
268            )
269        } else {
270            None
271        }
272    }
273}