frontend/layouter/nodes/
lookup.rs

1//! Nodes that implement LogUp.
2
3use crate::abstract_expr::AbstractExpression;
4use remainder::{
5    expression::{circuit_expr::ExprDescription, generic_expr::Expression},
6    layer::{layer_enum::LayerDescriptionEnum, regular_layer::RegularLayerDescription, LayerId},
7    mle::{mle_description::MleDescription, MleIndex},
8    output_layer::OutputLayerDescription,
9    utils::mle::get_total_mle_indices,
10};
11
12use crate::layouter::builder::CircuitMap;
13
14use itertools::{repeat_n, Itertools};
15use shared_types::Field;
16
17use super::fiat_shamir_challenge::FiatShamirChallengeNode;
18use super::{CircuitNode, NodeId};
19
20use anyhow::Result;
21
22/// Represents the use of a lookup into a particular table (represented by a [LookupTable]).
23#[derive(Clone, Debug)]
24pub struct LookupConstraint {
25    id: NodeId,
26    /// The id of the LookupTable (lookup table) that we are a lookup up into.
27    pub table_node_id: NodeId,
28    /// The id of the node that is being constrained by this lookup.
29    constrained_node_id: NodeId,
30    /// The id of the node that provides the multiplicities for the constrained data.
31    multiplicities_node_id: NodeId,
32}
33
34impl LookupConstraint {
35    /// Creates a new [LookupConstraint], constraining the data of `constrained` to form a subset of
36    /// the data in `lookup_table` with multiplicities given by `multiplicities`. Caller is
37    /// responsible for the yielding of all nodes (including `constrained` and `multiplicities`).
38    /// The adding of lookup specific verifier-challenge- and output layers is handled automatically
39    /// by [LookupTable::generate_lookup_circuit_description].
40    ///
41    /// # Requires:
42    ///   if `constrained` has length not a power of two, then `multiplicitites` must also count the
43    ///   implicit padding!
44    pub fn new(
45        lookup_table: &LookupTable,
46        constrained: &dyn CircuitNode,
47        multiplicities: &dyn CircuitNode,
48    ) -> Self {
49        LookupConstraint {
50            id: NodeId::new(),
51            table_node_id: lookup_table.id(),
52            constrained_node_id: constrained.id(),
53            multiplicities_node_id: multiplicities.id(),
54        }
55    }
56}
57
58impl CircuitNode for LookupConstraint {
59    fn id(&self) -> NodeId {
60        self.id
61    }
62
63    fn sources(&self) -> Vec<NodeId> {
64        // NB this function never gets called, since lookup tables and constraints are placed after
65        // the intermediate nodes in the toposort
66        unimplemented!()
67    }
68
69    fn get_num_vars(&self) -> usize {
70        todo!()
71    }
72}
73
74type LookupCircuitDescription<F> = (Vec<LayerDescriptionEnum<F>>, OutputLayerDescription<F>);
75/// Represents a table of data that can be looked up into, e.g. for a range check.
76/// Implements "Improving logarithmic derivative lookups using GKR" (2023) by Papini & Haböck. Note
77/// that (as is usual e.g. in permutation checks) we do not check that the product of the
78/// denominators is nonzero. This means the soundness of logUp is bounded by
79///     `|F| / max{num_constrained_values, num_table_values}`.
80/// To adapt this to a small field setting, consider using Fermat's Little Theorem.
81///
82/// For a more detailed description of the soundness argument for the above, see
83/// <https://www.notion.so/LogUp-ext-f846956acc3640a68bad51f7897fe32f?pvs=4#15b68687d27e807fae3fcdff59791174>
84#[derive(Clone, Debug)]
85pub struct LookupTable {
86    id: NodeId,
87    /// The lookups that are performed on this table (will be automatically populated, via
88    /// [LookupTable::add_lookup_constraint], during layout).
89    constraints: Vec<LookupConstraint>,
90    /// The id of the node providing the table entries.
91    table_node_id: NodeId,
92    /// The ID of the [FiatShamirChallengeNode] for the FS challenge.
93    fiat_shamir_challenge_node_id: NodeId,
94}
95
96impl LookupTable {
97    /// Create a new LookupTable to use for subsequent lookups. (To perform a lookup using this
98    /// table, create a [LookupConstraint].)
99    ///
100    /// # Requires:
101    /// * The length of the table must be a power of two.
102    pub fn new(
103        table: &dyn CircuitNode,
104        fiat_shamir_challenge_node: &FiatShamirChallengeNode,
105    ) -> Self {
106        LookupTable {
107            id: NodeId::new(),
108            constraints: vec![],
109            table_node_id: table.id(),
110            fiat_shamir_challenge_node_id: fiat_shamir_challenge_node.id(),
111        }
112    }
113
114    /// Add a lookup constraint to this node.
115    /// (Will be called by the layouter when laying out the circuit.)
116    pub fn add_lookup_constraint(&mut self, constraint: LookupConstraint) {
117        self.constraints.push(constraint);
118    }
119
120    /// Create the circuit description of a lookup node by returning the corresponding circuit
121    /// descriptions, and output circuit description needed in order to verify the lookup.
122    pub fn generate_lookup_circuit_description<F: Field>(
123        &self,
124        circuit_map: &mut CircuitMap,
125    ) -> Result<LookupCircuitDescription<F>> {
126        type AE<F> = AbstractExpression<F>;
127        type CE<F> = Expression<F, ExprDescription>;
128
129        // Ensure that number of LookupConstraints is a power of two (otherwise when we concat the
130        // constrained nodes, there will be padding, and the padding value is potentially not in the
131        // table
132        assert_eq!(
133            self.constraints.len().count_ones(),
134            1,
135            "Number of LookupConstraints should be a power of two"
136        );
137
138        // Build the LHS of the equation (defined by the constrained values)
139        println!("Build the LHS of the equation (defined by the constrained values)");
140
141        let (fiat_shamir_challenge_location, fiat_shamir_challenge_node_vars) =
142            circuit_map.get_location_num_vars_from_node_id(&self.fiat_shamir_challenge_node_id)?;
143
144        let fiat_shamir_challenge_mle_indices = get_total_mle_indices(
145            &fiat_shamir_challenge_location.prefix_bits,
146            *fiat_shamir_challenge_node_vars,
147        );
148        let fiat_shamir_challenge_mle = MleDescription::new(
149            fiat_shamir_challenge_location.layer_id,
150            &fiat_shamir_challenge_mle_indices,
151        );
152
153        // Build the denominator r - constrained
154        // There may be more than one constraint, so build a selector tree if necessary
155        let constrained_expr = AE::<F>::binary_tree_selector(
156            self.constraints
157                .iter()
158                .map(|constraint| constraint.constrained_node_id.expr())
159                .collect(),
160        );
161        let expr = CE::sum(
162            CE::from_mle_desc(fiat_shamir_challenge_mle),
163            CE::negated(constrained_expr.build_circuit_expr(circuit_map)?),
164        );
165        let expr_num_vars = expr.num_vars();
166
167        let layer_id = LayerId::next_layer_id();
168        let layer = RegularLayerDescription::new_raw(layer_id, expr);
169        let mut intermediate_layers = vec![LayerDescriptionEnum::Regular(layer)];
170        println!("Layer that calcs r - constrained has layer id: {layer_id:?}");
171
172        let lhs_denominator_vars = repeat_n(MleIndex::Free, expr_num_vars).collect_vec();
173        let lhs_denominator_desc = MleDescription::new(layer_id, &lhs_denominator_vars);
174
175        // Super special case: need to create a 0-variable MLE for the numerator which is JUST
176        // derived from an expression producing the constant 1
177        let maybe_lhs_numerator_desc = if lhs_denominator_vars.is_empty() {
178            Some(MleDescription::new(layer_id, &[]))
179        } else {
180            None
181        };
182
183        // Build the numerator and denominator of the sum of the fractions
184        let (lhs_numerator, lhs_denominator) = build_fractional_sum(
185            maybe_lhs_numerator_desc,
186            lhs_denominator_desc,
187            &mut intermediate_layers,
188        );
189
190        // Build the RHS of the equation (defined by the table values and multiplicities)
191        println!("Build the RHS of the equation (defined by the table values and multiplicities)");
192
193        // Build the numerator (the multiplicities, which we aggregate with an extra layer if there is more than one constraint)
194        let (multiplicities_location, multiplicities_num_vars) = circuit_map
195            .get_location_num_vars_from_node_id(&self.constraints[0].multiplicities_node_id)
196            .unwrap();
197        let mut rhs_numerator_desc = MleDescription::new(
198            multiplicities_location.layer_id,
199            &get_total_mle_indices(
200                &multiplicities_location.prefix_bits,
201                *multiplicities_num_vars,
202            ),
203        );
204
205        if self.constraints.len() > 1 {
206            // Insert an extra layer that aggregates the multiplicities
207            let expr = self.constraints.iter().skip(1).fold(
208                CE::from_mle_desc(rhs_numerator_desc),
209                |acc, constraint| {
210                    let (multiplicities_location, multiplicities_num_vars) = &circuit_map
211                        .get_location_num_vars_from_node_id(&constraint.multiplicities_node_id)
212                        .unwrap();
213                    let mult_constraint_mle_desc = MleDescription::new(
214                        multiplicities_location.layer_id,
215                        &get_total_mle_indices(
216                            &multiplicities_location.prefix_bits,
217                            *multiplicities_num_vars,
218                        ),
219                    );
220                    acc + CE::from_mle_desc(mult_constraint_mle_desc)
221                },
222            );
223            let layer_id = LayerId::next_layer_id();
224            let layer = RegularLayerDescription::new_raw(layer_id, expr);
225            intermediate_layers.push(LayerDescriptionEnum::Regular(layer));
226            println!("Layer that aggs the multiplicities has layer id: {layer_id:?}");
227
228            // Note that this is the aggregated version!
229            // It's just the element-wise sum of the elements within the bookkeeping tables
230            // However, because we're only dealing with the circuit description, we can
231            // just take the number of variables within the *first* constraint
232            let (_first_self_constraint_loc, first_self_constraint_num_vars) = circuit_map
233                .get_location_num_vars_from_node_id(&self.constraints[0].multiplicities_node_id)
234                .unwrap()
235                .clone();
236            rhs_numerator_desc = MleDescription::new(
237                layer_id,
238                &get_total_mle_indices(&[], first_self_constraint_num_vars),
239            )
240        }
241
242        // Build the denominator r - table
243
244        // First grab `r` as a `MleDescription` from the `circuit_description_map`
245        let (fiat_shamir_challenge_loc, fiat_shamir_challenge_num_vars) = circuit_map
246            .get_location_num_vars_from_node_id(&self.fiat_shamir_challenge_node_id)
247            .unwrap()
248            .clone();
249        let fiat_shamir_challenge_circuit_mle = MleDescription::new(
250            fiat_shamir_challenge_loc.layer_id,
251            &get_total_mle_indices(
252                &fiat_shamir_challenge_loc.prefix_bits,
253                fiat_shamir_challenge_num_vars,
254            ),
255        );
256
257        // Next grab `table` as a `MleDescription` from the `circuit_description_map`
258        let (table_loc, table_num_vars) = circuit_map
259            .get_location_num_vars_from_node_id(&self.table_node_id)
260            .unwrap()
261            .clone();
262        let table_circuit_mle = MleDescription::new(
263            table_loc.layer_id,
264            &get_total_mle_indices(&table_loc.prefix_bits, table_num_vars),
265        );
266
267        let expr = CE::from_mle_desc(fiat_shamir_challenge_circuit_mle)
268            - CE::from_mle_desc(table_circuit_mle);
269        let r_minus_table_num_vars = expr.num_vars();
270        let layer_id = LayerId::next_layer_id();
271        let layer = RegularLayerDescription::new_raw(layer_id, expr);
272        intermediate_layers.push(LayerDescriptionEnum::Regular(layer));
273        println!("Layer that calculates r - table has layer id: {layer_id:?}");
274
275        let rhs_denominator_desc = MleDescription::new(
276            layer_id,
277            &repeat_n(MleIndex::Free, r_minus_table_num_vars).collect_vec(),
278        );
279
280        // Build the numerator and denominator of the sum of the fractions
281        let (rhs_numerator, rhs_denominator) = build_fractional_sum(
282            Some(rhs_numerator_desc),
283            rhs_denominator_desc,
284            &mut intermediate_layers,
285        );
286
287        // Add a layer that calculates the difference between the fractions on the LHS and RHS
288        assert!(rhs_numerator.is_some());
289        let rhs_numerator = rhs_numerator.unwrap();
290        let expr = if let Some(lhs_numerator) = lhs_numerator {
291            CE::<F>::products(vec![lhs_numerator.clone(), rhs_denominator.clone()])
292                - CE::<F>::products(vec![rhs_numerator.clone(), lhs_denominator.clone()])
293        } else {
294            CE::<F>::products(vec![rhs_denominator.clone()])
295                - CE::<F>::products(vec![rhs_numerator.clone(), lhs_denominator.clone()])
296        };
297
298        let layer_id = LayerId::next_layer_id();
299        let layer = RegularLayerDescription::new_raw(layer_id, expr);
300        intermediate_layers.push(LayerDescriptionEnum::Regular(layer));
301        println!("Layer that checks that fractions are equal has layer id: {layer_id:?}");
302
303        // Add an output layer that checks that the result is zero
304        let output_layer = OutputLayerDescription::new_zero(layer_id, &[]);
305
306        Ok((intermediate_layers, output_layer))
307    }
308}
309
310impl CircuitNode for LookupTable {
311    fn id(&self) -> NodeId {
312        self.id
313    }
314
315    fn sources(&self) -> Vec<NodeId> {
316        // NB this function never gets called, since lookup tables and constraints are placed after
317        // the intermediate nodes in the toposort
318        unimplemented!()
319    }
320
321    fn get_num_vars(&self) -> usize {
322        todo!()
323    }
324}
325
326/// Extract the prefix bits from a DenseMle.
327fn extract_prefix_num_free_bits<F: Field>(mle: &MleDescription<F>) -> (Vec<MleIndex<F>>, usize) {
328    let mut num_free_bits = 0;
329    let prefix_bits = mle
330        .var_indices()
331        .iter()
332        .filter_map(|mle_index| match mle_index {
333            MleIndex::Fixed(_) => Some(mle_index.clone()),
334            MleIndex::Free => {
335                num_free_bits += 1;
336                None
337            }
338            _ => None,
339        })
340        .collect();
341    (prefix_bits, num_free_bits)
342}
343
344/// Split an MLE into two MLEs, with the left half containing the even-indexed elements and
345/// the right half containing the odd-indexed elements, setting the prefix bits accordingly.
346fn split_circuit_mle<F: Field>(
347    mle_desc: &MleDescription<F>,
348) -> (MleDescription<F>, MleDescription<F>) {
349    let (prefix_bits, num_free_bits) = extract_prefix_num_free_bits(mle_desc);
350
351    let left_mle_desc = MleDescription::new(
352        mle_desc.layer_id(),
353        &prefix_bits
354            .iter()
355            .cloned()
356            .chain(vec![MleIndex::Fixed(false)])
357            .chain(repeat_n(MleIndex::Free, num_free_bits - 1))
358            .collect_vec(),
359    );
360    let right_mle_desc = MleDescription::new(
361        mle_desc.layer_id(),
362        &prefix_bits
363            .iter()
364            .cloned()
365            .chain(vec![MleIndex::Fixed(true)])
366            .chain(repeat_n(MleIndex::Free, num_free_bits - 1))
367            .collect_vec(),
368    );
369    (left_mle_desc, right_mle_desc)
370}
371
372/// Given two MLEs of the same length representing the numerators and denominators of a sequence of
373/// fractions, add layers that perform a sum of the fractions, return a new pair of length-1 MLEs
374/// representing the numerator and denominator of the sum.
375///
376/// Setting `maybe_numerator_desc` to `None` indicates that the numerator has the same length as
377/// `denominator_desc` and takes the constant value 1.
378fn build_fractional_sum<F: Field>(
379    maybe_numerator_desc: Option<MleDescription<F>>,
380    denominator_desc: MleDescription<F>,
381    layers: &mut Vec<LayerDescriptionEnum<F>>,
382) -> (Option<MleDescription<F>>, MleDescription<F>) {
383    type CE<F> = Expression<F, ExprDescription>;
384
385    // Sanitycheck number of vars in numerator == number of vars in denominator
386    // EXCEPT when we're working with the fraction with constant 1 in the numerator
387    if let Some(numerator_desc) = maybe_numerator_desc.as_ref() {
388        assert_eq!(
389            numerator_desc.num_free_vars(),
390            denominator_desc.num_free_vars()
391        );
392    }
393
394    let mut maybe_numerator_desc = maybe_numerator_desc;
395    let mut denominator_desc = denominator_desc;
396
397    for i in 0..denominator_desc.num_free_vars() {
398        let denominators = split_circuit_mle(&denominator_desc);
399        let next_numerator_expr = if let Some(numerator_desc) = maybe_numerator_desc {
400            let numerators = split_circuit_mle(&numerator_desc);
401
402            // Calculate the new numerator
403            CE::products(vec![numerators.0.clone(), denominators.1.clone()])
404                + CE::products(vec![numerators.1.clone(), denominators.0.clone()])
405        } else {
406            // If there is no numerator CircuitMLE,
407            CE::from_mle_desc(denominators.1.clone()) + CE::from_mle_desc(denominators.0.clone())
408        };
409
410        // Calculate the new denominator
411        let next_denominator_expr =
412            CE::products(vec![denominators.0.clone(), denominators.1.clone()]);
413
414        // Grab the size of each
415        let next_numerator_num_vars = next_numerator_expr.num_vars();
416        let next_denominator_num_vars = next_denominator_expr.num_vars();
417
418        // Create the circuit layer by combining the two
419        let layer_id = LayerId::next_layer_id();
420
421        let layer = RegularLayerDescription::new_raw(
422            layer_id,
423            next_denominator_expr.select(next_numerator_expr),
424        );
425
426        layers.push(LayerDescriptionEnum::Regular(layer));
427
428        println!("Iteration {i} of build_fractional_sumcheck has layer id: {layer_id:?}");
429
430        denominator_desc = MleDescription::new(
431            layer_id,
432            &std::iter::once(MleIndex::Fixed(false))
433                .chain(repeat_n(MleIndex::Free, next_denominator_num_vars))
434                .collect_vec(),
435        );
436        maybe_numerator_desc = Some(MleDescription::new(
437            layer_id,
438            &std::iter::once(MleIndex::Fixed(true))
439                .chain(repeat_n(MleIndex::Free, next_numerator_num_vars))
440                .collect_vec(),
441        ));
442    }
443    if let Some(numerator_desc) = maybe_numerator_desc.as_ref() {
444        assert_eq!(numerator_desc.num_free_vars(), 0);
445    }
446    assert_eq!(denominator_desc.num_free_vars(), 0);
447    (maybe_numerator_desc, denominator_desc)
448}