1use crate::abstract_expr::AbstractExpression;
4use remainder::{
5 expression::{circuit_expr::ExprDescription, generic_expr::Expression},
6 layer::{layer_enum::LayerDescriptionEnum, regular_layer::RegularLayerDescription, LayerId},
7 mle::{mle_description::MleDescription, MleIndex},
8 output_layer::OutputLayerDescription,
9 utils::mle::get_total_mle_indices,
10};
11
12use crate::layouter::builder::CircuitMap;
13
14use itertools::{repeat_n, Itertools};
15use shared_types::Field;
16
17use super::fiat_shamir_challenge::FiatShamirChallengeNode;
18use super::{CircuitNode, NodeId};
19
20use anyhow::Result;
21
22#[derive(Clone, Debug)]
24pub struct LookupConstraint {
25 id: NodeId,
26 pub table_node_id: NodeId,
28 constrained_node_id: NodeId,
30 multiplicities_node_id: NodeId,
32}
33
34impl LookupConstraint {
35 pub fn new(
45 lookup_table: &LookupTable,
46 constrained: &dyn CircuitNode,
47 multiplicities: &dyn CircuitNode,
48 ) -> Self {
49 LookupConstraint {
50 id: NodeId::new(),
51 table_node_id: lookup_table.id(),
52 constrained_node_id: constrained.id(),
53 multiplicities_node_id: multiplicities.id(),
54 }
55 }
56}
57
58impl CircuitNode for LookupConstraint {
59 fn id(&self) -> NodeId {
60 self.id
61 }
62
63 fn sources(&self) -> Vec<NodeId> {
64 unimplemented!()
67 }
68
69 fn get_num_vars(&self) -> usize {
70 todo!()
71 }
72}
73
74type LookupCircuitDescription<F> = (Vec<LayerDescriptionEnum<F>>, OutputLayerDescription<F>);
75#[derive(Clone, Debug)]
85pub struct LookupTable {
86 id: NodeId,
87 constraints: Vec<LookupConstraint>,
90 table_node_id: NodeId,
92 fiat_shamir_challenge_node_id: NodeId,
94}
95
96impl LookupTable {
97 pub fn new(
103 table: &dyn CircuitNode,
104 fiat_shamir_challenge_node: &FiatShamirChallengeNode,
105 ) -> Self {
106 LookupTable {
107 id: NodeId::new(),
108 constraints: vec![],
109 table_node_id: table.id(),
110 fiat_shamir_challenge_node_id: fiat_shamir_challenge_node.id(),
111 }
112 }
113
114 pub fn add_lookup_constraint(&mut self, constraint: LookupConstraint) {
117 self.constraints.push(constraint);
118 }
119
120 pub fn generate_lookup_circuit_description<F: Field>(
123 &self,
124 circuit_map: &mut CircuitMap,
125 ) -> Result<LookupCircuitDescription<F>> {
126 type AE<F> = AbstractExpression<F>;
127 type CE<F> = Expression<F, ExprDescription>;
128
129 assert_eq!(
133 self.constraints.len().count_ones(),
134 1,
135 "Number of LookupConstraints should be a power of two"
136 );
137
138 println!("Build the LHS of the equation (defined by the constrained values)");
140
141 let (fiat_shamir_challenge_location, fiat_shamir_challenge_node_vars) =
142 circuit_map.get_location_num_vars_from_node_id(&self.fiat_shamir_challenge_node_id)?;
143
144 let fiat_shamir_challenge_mle_indices = get_total_mle_indices(
145 &fiat_shamir_challenge_location.prefix_bits,
146 *fiat_shamir_challenge_node_vars,
147 );
148 let fiat_shamir_challenge_mle = MleDescription::new(
149 fiat_shamir_challenge_location.layer_id,
150 &fiat_shamir_challenge_mle_indices,
151 );
152
153 let constrained_expr = AE::<F>::binary_tree_selector(
156 self.constraints
157 .iter()
158 .map(|constraint| constraint.constrained_node_id.expr())
159 .collect(),
160 );
161 let expr = CE::sum(
162 CE::from_mle_desc(fiat_shamir_challenge_mle),
163 CE::negated(constrained_expr.build_circuit_expr(circuit_map)?),
164 );
165 let expr_num_vars = expr.num_vars();
166
167 let layer_id = LayerId::next_layer_id();
168 let layer = RegularLayerDescription::new_raw(layer_id, expr);
169 let mut intermediate_layers = vec![LayerDescriptionEnum::Regular(layer)];
170 println!("Layer that calcs r - constrained has layer id: {layer_id:?}");
171
172 let lhs_denominator_vars = repeat_n(MleIndex::Free, expr_num_vars).collect_vec();
173 let lhs_denominator_desc = MleDescription::new(layer_id, &lhs_denominator_vars);
174
175 let maybe_lhs_numerator_desc = if lhs_denominator_vars.is_empty() {
178 Some(MleDescription::new(layer_id, &[]))
179 } else {
180 None
181 };
182
183 let (lhs_numerator, lhs_denominator) = build_fractional_sum(
185 maybe_lhs_numerator_desc,
186 lhs_denominator_desc,
187 &mut intermediate_layers,
188 );
189
190 println!("Build the RHS of the equation (defined by the table values and multiplicities)");
192
193 let (multiplicities_location, multiplicities_num_vars) = circuit_map
195 .get_location_num_vars_from_node_id(&self.constraints[0].multiplicities_node_id)
196 .unwrap();
197 let mut rhs_numerator_desc = MleDescription::new(
198 multiplicities_location.layer_id,
199 &get_total_mle_indices(
200 &multiplicities_location.prefix_bits,
201 *multiplicities_num_vars,
202 ),
203 );
204
205 if self.constraints.len() > 1 {
206 let expr = self.constraints.iter().skip(1).fold(
208 CE::from_mle_desc(rhs_numerator_desc),
209 |acc, constraint| {
210 let (multiplicities_location, multiplicities_num_vars) = &circuit_map
211 .get_location_num_vars_from_node_id(&constraint.multiplicities_node_id)
212 .unwrap();
213 let mult_constraint_mle_desc = MleDescription::new(
214 multiplicities_location.layer_id,
215 &get_total_mle_indices(
216 &multiplicities_location.prefix_bits,
217 *multiplicities_num_vars,
218 ),
219 );
220 acc + CE::from_mle_desc(mult_constraint_mle_desc)
221 },
222 );
223 let layer_id = LayerId::next_layer_id();
224 let layer = RegularLayerDescription::new_raw(layer_id, expr);
225 intermediate_layers.push(LayerDescriptionEnum::Regular(layer));
226 println!("Layer that aggs the multiplicities has layer id: {layer_id:?}");
227
228 let (_first_self_constraint_loc, first_self_constraint_num_vars) = circuit_map
233 .get_location_num_vars_from_node_id(&self.constraints[0].multiplicities_node_id)
234 .unwrap()
235 .clone();
236 rhs_numerator_desc = MleDescription::new(
237 layer_id,
238 &get_total_mle_indices(&[], first_self_constraint_num_vars),
239 )
240 }
241
242 let (fiat_shamir_challenge_loc, fiat_shamir_challenge_num_vars) = circuit_map
246 .get_location_num_vars_from_node_id(&self.fiat_shamir_challenge_node_id)
247 .unwrap()
248 .clone();
249 let fiat_shamir_challenge_circuit_mle = MleDescription::new(
250 fiat_shamir_challenge_loc.layer_id,
251 &get_total_mle_indices(
252 &fiat_shamir_challenge_loc.prefix_bits,
253 fiat_shamir_challenge_num_vars,
254 ),
255 );
256
257 let (table_loc, table_num_vars) = circuit_map
259 .get_location_num_vars_from_node_id(&self.table_node_id)
260 .unwrap()
261 .clone();
262 let table_circuit_mle = MleDescription::new(
263 table_loc.layer_id,
264 &get_total_mle_indices(&table_loc.prefix_bits, table_num_vars),
265 );
266
267 let expr = CE::from_mle_desc(fiat_shamir_challenge_circuit_mle)
268 - CE::from_mle_desc(table_circuit_mle);
269 let r_minus_table_num_vars = expr.num_vars();
270 let layer_id = LayerId::next_layer_id();
271 let layer = RegularLayerDescription::new_raw(layer_id, expr);
272 intermediate_layers.push(LayerDescriptionEnum::Regular(layer));
273 println!("Layer that calculates r - table has layer id: {layer_id:?}");
274
275 let rhs_denominator_desc = MleDescription::new(
276 layer_id,
277 &repeat_n(MleIndex::Free, r_minus_table_num_vars).collect_vec(),
278 );
279
280 let (rhs_numerator, rhs_denominator) = build_fractional_sum(
282 Some(rhs_numerator_desc),
283 rhs_denominator_desc,
284 &mut intermediate_layers,
285 );
286
287 assert!(rhs_numerator.is_some());
289 let rhs_numerator = rhs_numerator.unwrap();
290 let expr = if let Some(lhs_numerator) = lhs_numerator {
291 CE::<F>::products(vec![lhs_numerator.clone(), rhs_denominator.clone()])
292 - CE::<F>::products(vec![rhs_numerator.clone(), lhs_denominator.clone()])
293 } else {
294 CE::<F>::products(vec![rhs_denominator.clone()])
295 - CE::<F>::products(vec![rhs_numerator.clone(), lhs_denominator.clone()])
296 };
297
298 let layer_id = LayerId::next_layer_id();
299 let layer = RegularLayerDescription::new_raw(layer_id, expr);
300 intermediate_layers.push(LayerDescriptionEnum::Regular(layer));
301 println!("Layer that checks that fractions are equal has layer id: {layer_id:?}");
302
303 let output_layer = OutputLayerDescription::new_zero(layer_id, &[]);
305
306 Ok((intermediate_layers, output_layer))
307 }
308}
309
310impl CircuitNode for LookupTable {
311 fn id(&self) -> NodeId {
312 self.id
313 }
314
315 fn sources(&self) -> Vec<NodeId> {
316 unimplemented!()
319 }
320
321 fn get_num_vars(&self) -> usize {
322 todo!()
323 }
324}
325
326fn extract_prefix_num_free_bits<F: Field>(mle: &MleDescription<F>) -> (Vec<MleIndex<F>>, usize) {
328 let mut num_free_bits = 0;
329 let prefix_bits = mle
330 .var_indices()
331 .iter()
332 .filter_map(|mle_index| match mle_index {
333 MleIndex::Fixed(_) => Some(mle_index.clone()),
334 MleIndex::Free => {
335 num_free_bits += 1;
336 None
337 }
338 _ => None,
339 })
340 .collect();
341 (prefix_bits, num_free_bits)
342}
343
344fn split_circuit_mle<F: Field>(
347 mle_desc: &MleDescription<F>,
348) -> (MleDescription<F>, MleDescription<F>) {
349 let (prefix_bits, num_free_bits) = extract_prefix_num_free_bits(mle_desc);
350
351 let left_mle_desc = MleDescription::new(
352 mle_desc.layer_id(),
353 &prefix_bits
354 .iter()
355 .cloned()
356 .chain(vec![MleIndex::Fixed(false)])
357 .chain(repeat_n(MleIndex::Free, num_free_bits - 1))
358 .collect_vec(),
359 );
360 let right_mle_desc = MleDescription::new(
361 mle_desc.layer_id(),
362 &prefix_bits
363 .iter()
364 .cloned()
365 .chain(vec![MleIndex::Fixed(true)])
366 .chain(repeat_n(MleIndex::Free, num_free_bits - 1))
367 .collect_vec(),
368 );
369 (left_mle_desc, right_mle_desc)
370}
371
372fn build_fractional_sum<F: Field>(
379 maybe_numerator_desc: Option<MleDescription<F>>,
380 denominator_desc: MleDescription<F>,
381 layers: &mut Vec<LayerDescriptionEnum<F>>,
382) -> (Option<MleDescription<F>>, MleDescription<F>) {
383 type CE<F> = Expression<F, ExprDescription>;
384
385 if let Some(numerator_desc) = maybe_numerator_desc.as_ref() {
388 assert_eq!(
389 numerator_desc.num_free_vars(),
390 denominator_desc.num_free_vars()
391 );
392 }
393
394 let mut maybe_numerator_desc = maybe_numerator_desc;
395 let mut denominator_desc = denominator_desc;
396
397 for i in 0..denominator_desc.num_free_vars() {
398 let denominators = split_circuit_mle(&denominator_desc);
399 let next_numerator_expr = if let Some(numerator_desc) = maybe_numerator_desc {
400 let numerators = split_circuit_mle(&numerator_desc);
401
402 CE::products(vec![numerators.0.clone(), denominators.1.clone()])
404 + CE::products(vec![numerators.1.clone(), denominators.0.clone()])
405 } else {
406 CE::from_mle_desc(denominators.1.clone()) + CE::from_mle_desc(denominators.0.clone())
408 };
409
410 let next_denominator_expr =
412 CE::products(vec![denominators.0.clone(), denominators.1.clone()]);
413
414 let next_numerator_num_vars = next_numerator_expr.num_vars();
416 let next_denominator_num_vars = next_denominator_expr.num_vars();
417
418 let layer_id = LayerId::next_layer_id();
420
421 let layer = RegularLayerDescription::new_raw(
422 layer_id,
423 next_denominator_expr.select(next_numerator_expr),
424 );
425
426 layers.push(LayerDescriptionEnum::Regular(layer));
427
428 println!("Iteration {i} of build_fractional_sumcheck has layer id: {layer_id:?}");
429
430 denominator_desc = MleDescription::new(
431 layer_id,
432 &std::iter::once(MleIndex::Fixed(false))
433 .chain(repeat_n(MleIndex::Free, next_denominator_num_vars))
434 .collect_vec(),
435 );
436 maybe_numerator_desc = Some(MleDescription::new(
437 layer_id,
438 &std::iter::once(MleIndex::Fixed(true))
439 .chain(repeat_n(MleIndex::Free, next_numerator_num_vars))
440 .collect_vec(),
441 ));
442 }
443 if let Some(numerator_desc) = maybe_numerator_desc.as_ref() {
444 assert_eq!(numerator_desc.num_free_vars(), 0);
445 }
446 assert_eq!(denominator_desc.num_free_vars(), 0);
447 (maybe_numerator_desc, denominator_desc)
448}