1use core::fmt;
5use std::{
6 collections::{HashMap, HashSet},
7 marker::PhantomData,
8 ops::{Add, BitXor, Mul, Sub},
9 rc::{Rc, Weak},
10};
11
12use ark_std::log2;
13use hyrax::{
14 gkr::input_layer::HyraxInputLayerDescription, provable_circuit::HyraxProvableCircuit,
15 verifiable_circuit::HyraxVerifiableCircuit,
16};
17use itertools::Itertools;
18use ligero::ligero_structs::LigeroAuxInfo;
19use serde::{Deserialize, Serialize};
20use shared_types::{curves::PrimeOrderCurve, Field, Halo2FFTFriendlyField};
21
22use crate::{
23 abstract_expr::AbstractExpression,
24 layouter::{
25 layouting::LayoutingError,
26 nodes::{
27 circuit_inputs::{InputLayerNode, InputShred},
28 circuit_outputs::OutputNode,
29 fiat_shamir_challenge::FiatShamirChallengeNode,
30 gate::GateNode,
31 identity_gate::IdentityGateNode,
32 lookup::{LookupConstraint, LookupTable},
33 matmult::MatMultNode,
34 sector::{generate_sector_circuit_description, Sector},
35 split_node::SplitNode,
36 CircuitNode, NodeId,
37 },
38 },
39};
40use remainder::{
41 circuit_layout::CircuitLocation,
42 input_layer::{ligero_input_layer::LigeroInputLayerDescription, InputLayerDescription},
43 layer::{gate::BinaryOperation, layer_enum::LayerDescriptionEnum, LayerId},
44 mle::evals::MultilinearExtension,
45 output_layer::OutputLayerDescription,
46 provable_circuit::ProvableCircuit,
47 prover::{GKRCircuitDescription, GKRError},
48 utils::mle::build_composite_mle,
49 verifiable_circuit::VerifiableCircuit,
50};
51
52use anyhow::{anyhow, bail, Result};
53
54use tracing::debug;
55
56#[derive(Clone, Debug)]
63pub struct NodeRef<F: Field> {
64 ptr: Weak<dyn CircuitNode>,
65 _phantom: PhantomData<F>,
66}
67
68impl<F: Field> NodeRef<F> {
69 fn new(ptr: Weak<dyn CircuitNode>) -> Self {
70 Self {
71 ptr,
72 _phantom: PhantomData,
73 }
74 }
75
76 pub fn expr(&self) -> AbstractExpression<F> {
79 self.ptr.upgrade().unwrap().id().expr()
80 }
81
82 pub fn id(&self) -> NodeId {
84 self.ptr.upgrade().unwrap().id()
85 }
86
87 pub fn get_num_vars(&self) -> usize {
89 self.ptr.upgrade().unwrap().get_num_vars()
90 }
91}
92
93impl<F: Field> From<NodeRef<F>> for AbstractExpression<F> {
94 fn from(value: NodeRef<F>) -> Self {
95 value.expr()
96 }
97}
98
99impl<F: Field> From<&NodeRef<F>> for AbstractExpression<F> {
100 fn from(value: &NodeRef<F>) -> Self {
101 value.expr()
102 }
103}
104
105#[derive(Clone, Debug)]
108pub struct InputLayerNodeRef<F: Field> {
109 ptr: Weak<InputLayerNode>,
110 _phantom: PhantomData<F>,
111}
112
113impl<F: Field> InputLayerNodeRef<F> {
114 fn new(ptr: Weak<InputLayerNode>) -> Self {
115 Self {
116 ptr,
117 _phantom: PhantomData,
118 }
119 }
120
121 pub fn expr(&self) -> AbstractExpression<F> {
124 self.ptr.upgrade().unwrap().id().expr()
125 }
126}
127
128impl<F: Field> From<InputLayerNodeRef<F>> for AbstractExpression<F> {
129 fn from(value: InputLayerNodeRef<F>) -> Self {
130 value.expr()
131 }
132}
133
134impl<F: Field> From<&InputLayerNodeRef<F>> for AbstractExpression<F> {
135 fn from(value: &InputLayerNodeRef<F>) -> Self {
136 value.expr()
137 }
138}
139
140#[derive(Clone, Debug)]
143pub struct FSNodeRef<F: Field> {
144 ptr: Weak<FiatShamirChallengeNode>,
145 _phantom: PhantomData<F>,
146}
147
148impl<F: Field> FSNodeRef<F> {
149 fn new(ptr: Weak<FiatShamirChallengeNode>) -> Self {
150 Self {
151 ptr,
152 _phantom: PhantomData,
153 }
154 }
155
156 pub fn expr(&self) -> AbstractExpression<F> {
159 self.ptr.upgrade().unwrap().id().expr()
160 }
161}
162
163impl<F: Field> From<FSNodeRef<F>> for NodeRef<F> {
164 fn from(value: FSNodeRef<F>) -> Self {
165 NodeRef::new(value.ptr)
166 }
167}
168
169impl<F: Field> From<FSNodeRef<F>> for AbstractExpression<F> {
170 fn from(value: FSNodeRef<F>) -> Self {
171 value.expr()
172 }
173}
174
175impl<F: Field> From<&FSNodeRef<F>> for AbstractExpression<F> {
176 fn from(value: &FSNodeRef<F>) -> Self {
177 value.expr()
178 }
179}
180
181#[derive(Clone, Debug)]
184pub struct LookupTableNodeRef {
185 ptr: Weak<LookupTable>,
186}
187
188impl LookupTableNodeRef {
189 fn new(ptr: Weak<LookupTable>) -> Self {
190 Self { ptr }
191 }
192
193 pub fn expr<F: Field>(&self) -> AbstractExpression<F> {
196 self.ptr.upgrade().unwrap().id().expr()
197 }
198}
199
200#[derive(Clone, Debug)]
203pub struct LookupConstraintNodeRef {
204 ptr: Weak<LookupConstraint>,
205}
206
207impl LookupConstraintNodeRef {
208 fn new(ptr: Weak<LookupConstraint>) -> Self {
209 Self { ptr }
210 }
211
212 pub fn expr<F: Field>(&self) -> AbstractExpression<F> {
215 self.ptr.upgrade().unwrap().id().expr()
216 }
217}
218
219pub struct CircuitBuilder<F: Field> {
222 input_layer_nodes: Vec<Rc<InputLayerNode>>,
223 input_shred_nodes: Vec<Rc<InputShred>>,
224 fiat_shamir_challenge_nodes: Vec<Rc<FiatShamirChallengeNode>>,
225 output_nodes: Vec<Rc<OutputNode>>,
226 sector_nodes: Vec<Rc<Sector<F>>>,
227 gate_nodes: Vec<Rc<GateNode>>,
228 identity_gate_nodes: Vec<Rc<IdentityGateNode>>,
229 split_nodes: Vec<Rc<SplitNode>>,
230 matmult_nodes: Vec<Rc<MatMultNode>>,
231 lookup_constraint_nodes: Vec<Rc<LookupConstraint>>,
232 lookup_table_nodes: Vec<Rc<LookupTable>>,
233 node_to_ptr: HashMap<NodeId, NodeRef<F>>,
234 circuit_map: CircuitMap,
235}
236
237impl<F: Field> CircuitBuilder<F> {
238 pub fn new() -> Self {
240 Self {
241 input_layer_nodes: vec![],
242 input_shred_nodes: vec![],
243 fiat_shamir_challenge_nodes: vec![],
244 output_nodes: vec![],
245 sector_nodes: vec![],
246 gate_nodes: vec![],
247 identity_gate_nodes: vec![],
248 split_nodes: vec![],
249 matmult_nodes: vec![],
250 lookup_constraint_nodes: vec![],
251 lookup_table_nodes: vec![],
252 node_to_ptr: HashMap::new(),
253 circuit_map: CircuitMap::new(),
254 }
255 }
256
257 fn into_owned_helper<T: CircuitNode + fmt::Debug>(xs: Vec<Rc<T>>) -> Vec<T> {
258 xs.into_iter()
259 .map(|ptr| Rc::try_unwrap(ptr).unwrap())
260 .collect()
261 }
262
263 pub fn build_with_max_layer_size(
268 mut self,
269 maybe_maximum_log_layer_size: Option<usize>,
270 ) -> Result<Circuit<F>> {
271 let input_layer_nodes = Self::into_owned_helper(self.input_layer_nodes);
272 let input_shred_nodes = Self::into_owned_helper(self.input_shred_nodes);
273 let fiat_shamir_challenge_nodes = Self::into_owned_helper(self.fiat_shamir_challenge_nodes);
274 let output_nodes = Self::into_owned_helper(self.output_nodes);
275 let sector_nodes = Self::into_owned_helper(self.sector_nodes);
276 let gate_nodes = Self::into_owned_helper(self.gate_nodes);
277 let identity_gate_nodes = Self::into_owned_helper(self.identity_gate_nodes);
278 let split_nodes = Self::into_owned_helper(self.split_nodes);
279 let matmult_nodes = Self::into_owned_helper(self.matmult_nodes);
280 let lookup_constraint_nodes = Self::into_owned_helper(self.lookup_constraint_nodes);
281 let lookup_table_nodes = Self::into_owned_helper(self.lookup_table_nodes);
282 let id_to_sector_nodes_map: HashMap<NodeId, Sector<F>> = sector_nodes
283 .iter()
284 .cloned()
285 .map(|node| (node.id(), node))
286 .collect();
287
288 let should_combine = maybe_maximum_log_layer_size != Some(0);
290
291 let (
292 input_layer_nodes,
293 fiat_shamir_challenge_nodes,
294 intermediate_node_layers,
295 lookup_nodes,
296 output_nodes,
297 ) = super::layouting::layout(
298 input_layer_nodes,
299 input_shred_nodes,
300 fiat_shamir_challenge_nodes,
301 output_nodes,
302 sector_nodes,
303 gate_nodes,
304 identity_gate_nodes,
305 split_nodes,
306 matmult_nodes,
307 lookup_constraint_nodes,
308 lookup_table_nodes,
309 should_combine,
310 )
311 .unwrap();
312
313 let mut intermediate_layers = Vec::<LayerDescriptionEnum<F>>::new();
314 let mut output_layers = Vec::<OutputLayerDescription<F>>::new();
315
316 let input_layers = input_layer_nodes
317 .iter()
318 .map(|input_layer_node| {
319 let input_layer_description = input_layer_node
320 .generate_input_layer_description::<F>(&mut self.circuit_map)
321 .unwrap();
322 self.circuit_map.insert_shreds_into_input_layer(
323 input_layer_description.layer_id,
324 input_layer_node
325 .input_shreds
326 .iter()
327 .map(CircuitNode::id)
328 .collect(),
329 );
330 input_layer_description
331 })
332 .collect_vec();
333
334 let fiat_shamir_challenges = fiat_shamir_challenge_nodes
335 .iter()
336 .map(|fiat_shamir_challenge_node| {
337 fiat_shamir_challenge_node.generate_circuit_description::<F>(&mut self.circuit_map)
338 })
339 .collect_vec();
340
341 for layer in &intermediate_node_layers {
342 if layer.len() == 1 {
345 intermediate_layers.extend(
346 layer
347 .first()
348 .unwrap()
349 .generate_circuit_description(&mut self.circuit_map)?,
350 );
351 } else {
352 let sectors = layer
358 .iter()
359 .map(|sector| {
360 assert!(id_to_sector_nodes_map.contains_key(§or.id()));
361 id_to_sector_nodes_map.get(§or.id()).unwrap()
362 })
363 .collect_vec();
364 intermediate_layers.extend(generate_sector_circuit_description(
365 §ors,
366 &mut self.circuit_map,
367 maybe_maximum_log_layer_size,
368 ));
369 }
370 }
371
372 (intermediate_layers, output_layers) = lookup_nodes.iter().fold(
374 (intermediate_layers, output_layers),
375 |(mut lookup_intermediate_acc, mut lookup_output_acc), lookup_node| {
376 let (intermediate_layers, output_layer) = lookup_node
377 .generate_lookup_circuit_description(&mut self.circuit_map)
378 .unwrap();
379 lookup_intermediate_acc.extend(intermediate_layers);
380 lookup_output_acc.push(output_layer);
381 (lookup_intermediate_acc, lookup_output_acc)
382 },
383 );
384 output_layers =
385 output_nodes
386 .iter()
387 .fold(output_layers, |mut output_layer_acc, output_node| {
388 output_layer_acc
389 .extend(output_node.generate_circuit_description(&mut self.circuit_map));
390 output_layer_acc
391 });
392
393 let mut circuit_description = GKRCircuitDescription {
394 input_layers,
395 fiat_shamir_challenges,
396 intermediate_layers,
397 output_layers,
398 };
399 circuit_description.index_mle_indices(0);
400
401 Ok(Circuit::new(circuit_description, self.circuit_map))
402 }
403
404 pub fn build_with_layer_combination(self) -> Result<Circuit<F>> {
407 self.build_with_max_layer_size(None)
408 }
409
410 pub fn build_without_layer_combination(self) -> Result<Circuit<F>> {
412 self.build_with_max_layer_size(Some(0))
413 }
414
415 pub fn build(self) -> Result<Circuit<F>> {
418 self.build_without_layer_combination()
419 }
420}
421
422impl<F: Field> CircuitBuilder<F> {
423 pub fn add_input_layer(
435 &mut self,
436 layer_label: &str,
437 layer_visibility: LayerVisibility,
438 ) -> InputLayerNodeRef<F> {
439 let node = Rc::new(InputLayerNode::new(None));
440 let node_weak_ref = Rc::downgrade(&node);
441
442 let layer_id = node.input_layer_id();
443
444 self.circuit_map
445 .add_input_layer(layer_id, layer_label, layer_visibility);
446
447 self.node_to_ptr
448 .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
449
450 self.input_layer_nodes.push(node);
451
452 InputLayerNodeRef::new(node_weak_ref)
453 }
454
455 pub fn add_input_shred(
465 &mut self,
466 label: &str,
467 num_vars: usize,
468 source: &InputLayerNodeRef<F>,
469 ) -> NodeRef<F> {
470 let source = source
471 .ptr
472 .upgrade()
473 .expect("InputShred's source data has already been dropped");
474 let node = Rc::new(InputShred::new(num_vars, &source));
475 let node_weak_ref = Rc::downgrade(&node);
476
477 let node_id = node.id();
478
479 self.node_to_ptr
480 .insert(node_id, NodeRef::new(node_weak_ref.clone()));
481
482 self.circuit_map.add_input_shred(label, node_id);
484
485 self.input_shred_nodes.push(node);
486
487 NodeRef::new(node_weak_ref)
488 }
489
490 pub fn set_output(&mut self, source: &NodeRef<F>) {
495 let source = source
496 .ptr
497 .upgrade()
498 .expect("Sector source has already been dropped");
499 let node = Rc::new(OutputNode::new_zero(source.as_ref()));
500 self.output_nodes.push(node);
501 }
502
503 pub fn add_fiat_shamir_challenge_node(&mut self, num_challenges: usize) -> FSNodeRef<F> {
506 let node = Rc::new(FiatShamirChallengeNode::new(num_challenges));
507 let node_weak_ref = Rc::downgrade(&node);
508 self.node_to_ptr
509 .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
510 self.fiat_shamir_challenge_nodes.push(node);
511 FSNodeRef::new(node_weak_ref)
512 }
513
514 pub fn add_sector(&mut self, expr: AbstractExpression<F>) -> NodeRef<F> {
517 let node_ids_in_use: HashSet<NodeId> = expr.get_sources().into_iter().collect();
518
519 let num_vars_map: HashMap<NodeId, usize> = node_ids_in_use
520 .into_iter()
521 .map(|id| (id, self.get_ptr_from_node_id(id).get_num_vars()))
522 .collect();
523
524 let num_vars = expr
525 .get_num_vars(&num_vars_map)
526 .expect("Internal error duing 'num_vars' computation of an AbstractExpression");
527
528 let node = Rc::new(Sector::<F>::new(expr, num_vars));
529 let node_weak_ref = Rc::downgrade(&node);
530
531 self.node_to_ptr
532 .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
533
534 self.sector_nodes.push(node);
535 NodeRef::new(node_weak_ref)
536 }
537
538 pub fn add_identity_gate_node(
545 &mut self,
546 pre_routed_data: &NodeRef<F>,
547 non_zero_gates: Vec<(u32, u32)>,
548 num_vars: usize,
549 num_dataparallel_vars: Option<usize>,
550 ) -> NodeRef<F> {
551 let pre_routed_data = pre_routed_data
552 .ptr
553 .upgrade()
554 .expect("`pre_routed_data` reference given to identity gate has been dropped");
555 let node = Rc::new(IdentityGateNode::new(
556 pre_routed_data.as_ref(),
557 non_zero_gates,
558 num_vars,
559 num_dataparallel_vars,
560 ));
561 let node_weak_ref = Rc::downgrade(&node);
562 self.node_to_ptr
563 .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
564 self.identity_gate_nodes.push(node);
565 NodeRef::new(node_weak_ref)
566 }
567
568 pub fn add_gate_node(
575 &mut self,
576 lhs: &NodeRef<F>,
577 rhs: &NodeRef<F>,
578 nonzero_gates: Vec<(u32, u32, u32)>,
579 gate_operation: BinaryOperation,
580 num_dataparallel_bits: Option<usize>,
581 ) -> NodeRef<F> {
582 let lhs = lhs
583 .ptr
584 .upgrade()
585 .expect("lhs give to GateNode has already been dropped");
586 let rhs = rhs
587 .ptr
588 .upgrade()
589 .expect("rhs give to GateNode has already been dropped");
590 let node = Rc::new(GateNode::new(
591 lhs.as_ref(),
592 rhs.as_ref(),
593 nonzero_gates,
594 gate_operation,
595 num_dataparallel_bits,
596 ));
597 let node_weak_ref = Rc::downgrade(&node);
598 self.node_to_ptr
599 .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
600 self.gate_nodes.push(node);
601 NodeRef::new(node_weak_ref)
602 }
603
604 pub fn add_matmult_node(
611 &mut self,
612 matrix_a_node: &NodeRef<F>,
613 rows_cols_num_vars_a: (usize, usize),
614 matrix_b_node: &NodeRef<F>,
615 rows_cols_num_vars_b: (usize, usize),
616 ) -> NodeRef<F> {
617 let matrix_a_node = matrix_a_node
618 .ptr
619 .upgrade()
620 .expect("Matrix A input to MatMultNode has been dropped");
621 let matrix_b_node = matrix_b_node
622 .ptr
623 .upgrade()
624 .expect("Matrix B input to MatMultNode has been dropped");
625 let node = Rc::new(MatMultNode::new(
626 matrix_a_node.as_ref(),
627 rows_cols_num_vars_a,
628 matrix_b_node.as_ref(),
629 rows_cols_num_vars_b,
630 ));
631 let node_weak_ref = Rc::downgrade(&node);
632 self.node_to_ptr
633 .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
634 self.matmult_nodes.push(node);
635 NodeRef::new(node_weak_ref)
636 }
637
638 pub fn add_lookup_table(
641 &mut self,
642 table: &NodeRef<F>,
643 fiat_shamir_challenge_node: &FSNodeRef<F>,
644 ) -> LookupTableNodeRef {
645 let table = table
646 .ptr
647 .upgrade()
648 .expect("Table input to LookupTable has already been dropped");
649 let fiat_shamir_challenge_node = fiat_shamir_challenge_node
650 .ptr
651 .upgrade()
652 .expect("FiatShamirChallegeNode input to LookupTable has already been dropped");
653 let node = Rc::new(LookupTable::new(
654 table.as_ref(),
655 fiat_shamir_challenge_node.as_ref(),
656 ));
657 let node_ref = Rc::downgrade(&node);
658 self.node_to_ptr
659 .insert(node.id(), NodeRef::new(node_ref.clone()));
660 self.lookup_table_nodes.push(node);
661 LookupTableNodeRef::new(node_ref)
662 }
663
664 pub fn add_lookup_constraint(
667 &mut self,
668 lookup_table: &LookupTableNodeRef,
669 constrained: &NodeRef<F>,
670 multiplicities: &NodeRef<F>,
671 ) -> LookupConstraintNodeRef {
672 let lookup_table = lookup_table
673 .ptr
674 .upgrade()
675 .expect("LookupTable input to LookupConstraint has already been dropped");
676 let constrained = constrained
677 .ptr
678 .upgrade()
679 .expect("constrained input to LookupConstrained has already been dropped");
680 let multiplicities = multiplicities
681 .ptr
682 .upgrade()
683 .expect("multiplicites input to LookupConstrained has already been dropped");
684 let node = Rc::new(LookupConstraint::new(
685 lookup_table.as_ref(),
686 constrained.as_ref(),
687 multiplicities.as_ref(),
688 ));
689 let node_ref = Rc::downgrade(&node);
690 self.node_to_ptr
691 .insert(node.id(), NodeRef::new(node_ref.clone()));
692 self.lookup_constraint_nodes.push(node);
693 LookupConstraintNodeRef::new(node_ref)
694 }
695
696 pub fn add_split_node(&mut self, input_node: &NodeRef<F>, num_vars: usize) -> Vec<NodeRef<F>> {
703 let input_node = input_node
704 .ptr
705 .upgrade()
706 .expect("input_node to SplitNode has already been dropped");
707 let nodes = SplitNode::new(input_node.as_ref(), num_vars)
708 .into_iter()
709 .map(Rc::new)
710 .collect_vec();
711 debug_assert_eq!(nodes.len(), 1 << num_vars);
712 let node_refs = nodes
713 .iter()
714 .map(|node| NodeRef::new(Rc::downgrade(node) as Weak<dyn CircuitNode>))
715 .collect_vec();
716 nodes
717 .iter()
718 .zip(node_refs.iter())
719 .for_each(|(node, node_ref)| {
720 self.node_to_ptr.insert(node.id(), node_ref.clone());
721 });
722 self.split_nodes.extend(nodes);
723 node_refs
724 }
725
726 fn get_ptr_from_node_id(&self, id: NodeId) -> Rc<dyn CircuitNode> {
727 self.node_to_ptr[&id].ptr.upgrade().unwrap()
728 }
729}
730
731impl<F: Field> Default for CircuitBuilder<F> {
732 fn default() -> Self {
733 Self::new()
734 }
735}
736
737#[cfg(test)]
738mod test {
739 use shared_types::Fr;
740
741 use super::*;
742
743 #[test]
744 #[should_panic]
745 pub fn test_unique_input_layer_label() {
746 let mut builder = CircuitBuilder::<Fr>::new();
747
748 let _input_layer1 = builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
749 let _input_layer2 = builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
750 }
751
752 #[test]
753 pub fn test_scope_mixing() {
754 let mut builder = CircuitBuilder::<Fr>::new();
755
756 let input_layer = builder.add_input_layer("label", LayerVisibility::Public);
757
758 builder.add_input_shred("label", 1, &input_layer);
759 }
760
761 #[test]
762 #[should_panic]
763 pub fn test_unique_input_shred_label() {
764 let mut builder = CircuitBuilder::<Fr>::new();
765
766 let input_layer = builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
767 builder.add_input_shred("shred1", 1, &input_layer);
768 builder.add_input_shred("shred1", 1, &input_layer);
769 }
770
771 #[test]
772 #[should_panic]
773 pub fn test_unique_input_shred_label2() {
774 let mut builder = CircuitBuilder::<Fr>::new();
775
776 let input_layer1 = builder.add_input_layer("Input Layer 1", LayerVisibility::Public);
777 let input_layer2 = builder.add_input_layer("Input Layer 2", LayerVisibility::Committed);
778
779 builder.add_input_shred("shred1", 1, &input_layer1);
780 builder.add_input_shred("shred1", 1, &input_layer2);
781 }
782}
783
784#[derive(Clone, Debug, PartialEq, Copy, Serialize, Deserialize)]
786pub enum LayerVisibility {
787 Public,
789
790 Committed,
794}
795
796#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
798enum CircuitMapState {
799 UnderConstruction,
802
803 Ready,
806}
807
808#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
812pub struct CircuitMap {
813 state: CircuitMapState,
814 shreds_in_layer: HashMap<LayerId, Vec<NodeId>>,
815 label_to_shred_id: HashMap<String, NodeId>,
816 layer_label_to_layer_id: HashMap<String, LayerId>,
817 layer_visibility: HashMap<LayerId, LayerVisibility>,
818 node_location: HashMap<NodeId, (CircuitLocation, usize)>,
819}
820
821impl CircuitMap {
822 pub fn new() -> Self {
826 Self {
827 state: CircuitMapState::UnderConstruction,
828 shreds_in_layer: HashMap::new(),
829 label_to_shred_id: HashMap::new(),
830 layer_label_to_layer_id: HashMap::new(),
831 layer_visibility: HashMap::new(),
832 node_location: HashMap::new(),
833 }
834 }
835
836 pub fn add_node_id_and_location_num_vars(
843 &mut self,
844 node_id: NodeId,
845 value: (CircuitLocation, usize),
846 ) {
847 assert_eq!(self.state, CircuitMapState::UnderConstruction);
848 assert!(!self.node_location.contains_key(&node_id));
849 self.node_location.insert(node_id, value);
850 }
851
852 pub fn insert_shreds_into_input_layer(&mut self, input_layer_id: LayerId, shreds: Vec<NodeId>) {
858 assert_eq!(self.state, CircuitMapState::UnderConstruction);
859 assert!(!self.shreds_in_layer.contains_key(&input_layer_id));
860 self.shreds_in_layer.insert(input_layer_id, shreds);
861 }
862
863 pub fn get_location_num_vars_from_node_id(
867 &self,
868 node_id: &NodeId,
869 ) -> Result<&(CircuitLocation, usize)> {
870 self.node_location
871 .get(node_id)
872 .ok_or(anyhow!(LayoutingError::DanglingNodeId(*node_id)))
873 }
874
875 pub fn add_input_layer(
882 &mut self,
883 layer_id: LayerId,
884 layer_label: &str,
885 layer_kind: LayerVisibility,
886 ) {
887 assert_eq!(self.state, CircuitMapState::UnderConstruction);
888 assert!(!self.layer_visibility.contains_key(&layer_id));
889 assert!(!self.layer_label_to_layer_id.contains_key(layer_label));
890 self.layer_label_to_layer_id
891 .insert(String::from(layer_label), layer_id);
892 self.layer_visibility.insert(layer_id, layer_kind);
893 }
894
895 pub fn add_input_shred(&mut self, label: &str, shred_id: NodeId) {
906 assert_eq!(self.state, CircuitMapState::UnderConstruction);
907 assert!(!self.label_to_shred_id.contains_key(label));
908 self.label_to_shred_id.insert(String::from(label), shred_id);
909 }
910
911 pub fn freeze(mut self) -> Self {
919 assert_eq!(self.state, CircuitMapState::UnderConstruction);
920
921 assert_eq!(
924 self.shreds_in_layer.keys().sorted().collect_vec(),
925 self.layer_visibility.keys().sorted().collect_vec()
926 );
927
928 let input_shred_location: HashMap<NodeId, (CircuitLocation, usize)> = self
930 .label_to_shred_id
931 .values()
932 .map(|shred_id| (*shred_id, self.node_location[shred_id].clone()))
933 .collect();
934 self.node_location = input_shred_location;
935
936 assert_eq!(
939 self.shreds_in_layer
940 .values()
941 .flatten()
942 .sorted()
943 .collect_vec(),
944 self.node_location.keys().sorted().collect_vec()
945 );
946
947 assert_eq!(
950 self.label_to_shred_id.values().sorted().collect_vec(),
951 self.node_location.keys().sorted().collect_vec()
952 );
953
954 self.state = CircuitMapState::Ready;
958
959 self
960 }
961
962 pub fn get_node_kind(&self, shred_label: &str) -> Result<LayerVisibility> {
965 let shred_id = self.get_node_id(shred_label)?;
970
971 let layer_id = self.node_location[&shred_id].0.layer_id;
973 let layer_visibility = self.layer_visibility[&layer_id];
974
975 Ok(layer_visibility)
976 }
977
978 pub fn get_node_id(&self, shred_label: &str) -> Result<NodeId> {
981 let shred_id = self
984 .label_to_shred_id
985 .get(shred_label)
986 .ok_or(anyhow!("Unrecognized Shred Label '{shred_label}'"))?;
987
988 Ok(*shred_id)
989 }
990
991 pub fn get_all_input_shred_ids(&self) -> Vec<NodeId> {
996 assert_eq!(self.state, CircuitMapState::Ready);
997
998 self.node_location.keys().cloned().collect_vec()
999 }
1000
1001 pub fn get_shred_label_from_id(&self, shred_id: NodeId) -> Result<String> {
1009 assert_eq!(self.state, CircuitMapState::Ready);
1010
1011 let labels = self
1013 .label_to_shred_id
1014 .iter()
1015 .filter_map(|(label, node_id)| {
1016 if *node_id == shred_id {
1017 Some(label)
1018 } else {
1019 None
1020 }
1021 })
1022 .collect_vec();
1023
1024 if labels.is_empty() {
1025 bail!("Unrecognized Input Shred ID '{shred_id}'");
1026 } else {
1027 assert_eq!(labels.len(), 1);
1030
1031 Ok(labels[0].clone())
1032 }
1033 }
1034
1035 pub fn get_all_input_layer_ids(&self) -> Vec<LayerId> {
1042 assert_eq!(self.state, CircuitMapState::Ready);
1043
1044 self.layer_visibility.keys().cloned().collect_vec()
1045 }
1046
1047 pub fn get_all_public_input_layer_ids(&self) -> Vec<LayerId> {
1054 assert_eq!(self.state, CircuitMapState::Ready);
1055
1056 self.layer_visibility
1057 .iter()
1058 .filter_map(|(layer_id, visibility)| match *visibility {
1059 LayerVisibility::Public => Some(layer_id),
1060 LayerVisibility::Committed => None,
1061 })
1062 .cloned()
1063 .collect_vec()
1064 }
1065
1066 pub fn get_input_shreds_from_layer_id(&self, layer_id: LayerId) -> Result<Vec<NodeId>> {
1072 assert_eq!(self.state, CircuitMapState::Ready);
1073
1074 Ok(self
1075 .shreds_in_layer
1076 .get(&layer_id)
1077 .ok_or(anyhow!("Unrecognized Input Layer ID '{layer_id}'"))?
1078 .clone())
1079 }
1080
1081 pub fn get_shred_location(&self, shred_id: NodeId) -> Result<(CircuitLocation, usize)> {
1087 assert_eq!(self.state, CircuitMapState::Ready);
1088
1089 self.node_location
1090 .get(&shred_id)
1091 .ok_or(anyhow!("Unrecognized Shred ID '{shred_id}'."))
1092 .cloned()
1093 }
1094
1095 pub fn get_all_committed_layers(&self) -> Vec<LayerId> {
1101 assert_eq!(self.state, CircuitMapState::Ready);
1102
1103 self.layer_visibility
1104 .iter()
1105 .filter_map(|(layer_id, layer_visibility)| {
1106 if *layer_visibility == LayerVisibility::Committed {
1107 Some(layer_id)
1108 } else {
1109 None
1110 }
1111 })
1112 .cloned()
1113 .collect_vec()
1114 }
1115
1116 pub fn get_layer_id_from_label(&self, layer_label: &str) -> Result<LayerId> {
1119 self.layer_label_to_layer_id
1120 .get(layer_label)
1121 .cloned()
1122 .ok_or(anyhow!("No Input Layer with label {layer_label}."))
1123 }
1124}
1125
1126impl Default for CircuitMap {
1127 fn default() -> Self {
1128 Self::new()
1129 }
1130}
1131
1132#[derive(Clone, Debug, Serialize, Deserialize)]
1137#[serde(bound = "F: Field")]
1138pub struct Circuit<F: Field> {
1139 circuit_description: GKRCircuitDescription<F>,
1140 pub circuit_map: CircuitMap,
1141 partial_inputs: HashMap<NodeId, MultilinearExtension<F>>,
1142}
1143
1144impl<F: Field> Circuit<F> {
1145 fn new(circuit_description: GKRCircuitDescription<F>, circuit_map: CircuitMap) -> Self {
1147 assert_eq!(circuit_map.state, CircuitMapState::UnderConstruction);
1148
1149 Self {
1150 circuit_description,
1151 circuit_map: circuit_map.freeze(),
1152 partial_inputs: HashMap::new(),
1153 }
1154 }
1155
1156 pub fn get_circuit_description(&self) -> &GKRCircuitDescription<F> {
1158 &self.circuit_description
1159 }
1160
1161 pub fn set_input(&mut self, shred_label: &str, data: MultilinearExtension<F>) {
1167 let node_id = self.circuit_map.get_node_id(shred_label).unwrap();
1168
1169 if self.partial_inputs.contains_key(&node_id) {
1170 panic!("Input Shred with label '{shred_label}' has already been assigned data.");
1171 }
1172
1173 self.partial_inputs.insert(node_id, data);
1174 }
1175
1176 pub fn update_input(&mut self, shred_label: &str, data: MultilinearExtension<F>) {
1182 let node_id = self.circuit_map.get_node_id(shred_label).unwrap();
1183 self.partial_inputs.insert(node_id, data);
1184 }
1185
1186 pub fn contains_layer(&self, label: &str) -> bool {
1188 self.circuit_map.get_layer_id_from_label(label).is_ok()
1189 }
1190
1191 fn input_shred_contains_data(&self, shred_id: NodeId) -> bool {
1192 self.partial_inputs.contains_key(&shred_id)
1193 }
1194
1195 pub fn input_layer_contains_data(&self, label: &str) -> Result<bool> {
1198 let layer_id = self.circuit_map.get_layer_id_from_label(label)?;
1199
1200 Ok(self
1201 .circuit_map
1202 .get_input_shreds_from_layer_id(layer_id)?
1203 .iter()
1204 .all(|shred_id| self.input_shred_contains_data(*shred_id)))
1205 }
1206
1207 pub fn get_input_layer_description_ref(&self, layer_label: &str) -> &InputLayerDescription {
1212 let layer_id = self
1213 .circuit_map
1214 .get_layer_id_from_label(layer_label)
1215 .unwrap();
1216
1217 let x: Vec<&InputLayerDescription> = self
1218 .circuit_description
1219 .input_layers
1220 .iter()
1221 .filter(|input_layer| input_layer.layer_id == layer_id)
1222 .collect();
1223
1224 assert_eq!(x.len(), 1);
1225
1226 x[0]
1227 }
1228
1229 fn build_input_layer_data(&self, layer_id: LayerId) -> Result<MultilinearExtension<F>> {
1233 let input_shred_ids = self.circuit_map.get_input_shreds_from_layer_id(layer_id)?;
1234
1235 let mut shred_mles_and_prefix_bits = vec![];
1236 for input_shred_id in input_shred_ids {
1237 let mle = self.partial_inputs.get(&input_shred_id).ok_or(anyhow!(
1238 "Input shred {input_shred_id} does not contain any data!"
1239 ))?;
1240
1241 let (circuit_location, num_vars) =
1242 self.circuit_map.get_shred_location(input_shred_id).unwrap();
1243
1244 if num_vars != mle.num_vars() {
1245 return Err(anyhow!(GKRError::InputShredLengthMismatch(
1246 input_shred_id.get_id(),
1247 num_vars,
1248 mle.num_vars(),
1249 )));
1250 }
1251 shred_mles_and_prefix_bits.push((mle, circuit_location.prefix_bits))
1252 }
1253
1254 Ok(build_composite_mle(&shred_mles_and_prefix_bits))
1255 }
1256
1257 fn build_public_input_layer_data(
1258 &self,
1259 verifier_optional_inputs: bool,
1260 ) -> Result<HashMap<LayerId, MultilinearExtension<F>>> {
1261 let mut public_inputs: HashMap<LayerId, MultilinearExtension<F>> = HashMap::new();
1262
1263 for input_layer_id in self.circuit_map.get_all_public_input_layer_ids() {
1264 let maybe_layer_mle = self.build_input_layer_data(input_layer_id);
1266 match maybe_layer_mle {
1267 Ok(layer_mle) => {
1268 public_inputs.insert(input_layer_id, layer_mle);
1269 }
1270 Err(err) => {
1271 if !verifier_optional_inputs {
1274 return Result::Err(err);
1275 }
1276 }
1277 }
1278 }
1279
1280 Ok(public_inputs)
1281 }
1282
1283 fn build_all_input_layer_data(&self) -> Result<HashMap<LayerId, MultilinearExtension<F>>> {
1284 let mut inputs: HashMap<LayerId, MultilinearExtension<F>> = HashMap::new();
1303
1304 for input_layer_id in self.circuit_map.get_all_input_layer_ids() {
1305 let layer_mle = self.build_input_layer_data(input_layer_id)?;
1306 inputs.insert(input_layer_id, layer_mle);
1307 }
1308
1309 Ok(inputs)
1310 }
1311
1312 fn get_all_committed_input_layer_descriptions_to_ligero(
1324 &self,
1325 ) -> Vec<LigeroInputLayerDescription<F>> {
1326 self.circuit_map
1327 .get_all_committed_layers()
1328 .into_iter()
1329 .map(|layer_id| {
1330 let raw_needed_capacity = self
1331 .circuit_map
1332 .get_input_shreds_from_layer_id(layer_id)
1333 .unwrap()
1334 .into_iter()
1335 .fold(0, |acc, shred_id| {
1336 let (_, num_vars) = self.circuit_map.get_shred_location(shred_id).unwrap();
1337 acc + (1_usize << num_vars)
1338 });
1339 let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
1340 let total_num_vars = log2(padded_needed_capacity) as usize;
1341
1342 LigeroInputLayerDescription {
1343 layer_id,
1344 num_vars: total_num_vars,
1345 aux: LigeroAuxInfo::<F>::new(1 << (total_num_vars), 4, 1.0, None),
1346 }
1347 })
1348 .collect()
1349 }
1350
1351 pub fn gen_hyrax_verifiable_circuit<C>(&self) -> Result<HyraxVerifiableCircuit<C>>
1354 where
1355 C: PrimeOrderCurve<Scalar = F>,
1356 {
1357 let public_inputs = self.build_public_input_layer_data(true)?;
1358
1359 debug!("Public inputs available: {:#?}", public_inputs.keys());
1360 debug!(
1361 "Layer Labels to Layer ID map: {:#?}",
1362 self.circuit_map.layer_label_to_layer_id
1363 );
1364
1365 let hyrax_private_inputs = self
1366 .circuit_map
1367 .get_all_committed_layers()
1368 .into_iter()
1369 .map(|layer_id| {
1370 let raw_needed_capacity = self
1371 .circuit_map
1372 .get_input_shreds_from_layer_id(layer_id)
1373 .unwrap()
1374 .into_iter()
1375 .fold(0, |acc, shred_id| {
1376 let (_, num_vars) = self.circuit_map.get_shred_location(shred_id).unwrap();
1377 acc + (1_usize << num_vars)
1378 });
1379 let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
1380 let total_num_vars = log2(padded_needed_capacity) as usize;
1381
1382 Ok((
1383 layer_id,
1384 (
1385 HyraxInputLayerDescription::new(layer_id, total_num_vars),
1386 None,
1387 ),
1388 ))
1389 })
1390 .collect::<Result<HashMap<_, _>>>()?;
1391
1392 Ok(HyraxVerifiableCircuit::new(
1393 self.circuit_description.clone(),
1394 public_inputs,
1395 hyrax_private_inputs,
1396 self.circuit_map.layer_label_to_layer_id.clone(),
1397 ))
1398 }
1399
1400 pub fn gen_hyrax_provable_circuit<C>(&self) -> Result<HyraxProvableCircuit<C>>
1407 where
1408 C: PrimeOrderCurve<Scalar = F>,
1409 {
1410 let inputs = self.build_all_input_layer_data()?;
1411
1412 let hyrax_private_inputs = self
1413 .circuit_map
1414 .get_all_committed_layers()
1415 .into_iter()
1416 .map(|layer_id| {
1417 let raw_needed_capacity = self
1418 .circuit_map
1419 .get_input_shreds_from_layer_id(layer_id)
1420 .unwrap()
1421 .into_iter()
1422 .fold(0, |acc, shred_id| {
1423 let (_, num_vars) = self.circuit_map.get_shred_location(shred_id).unwrap();
1424 acc + (1_usize << num_vars)
1425 });
1426 let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
1427 let total_num_vars = log2(padded_needed_capacity) as usize;
1428
1429 Ok((
1430 layer_id,
1431 (
1432 HyraxInputLayerDescription::new(layer_id, total_num_vars),
1433 None,
1434 ),
1435 ))
1436 })
1437 .collect::<Result<HashMap<_, _>>>()?;
1438
1439 Ok(HyraxProvableCircuit::new(
1440 self.circuit_description.clone(),
1441 inputs,
1442 hyrax_private_inputs,
1443 self.circuit_map.layer_label_to_layer_id.clone(),
1444 ))
1445 }
1446}
1447
1448impl<F: Halo2FFTFriendlyField> Circuit<F> {
1449 pub fn gen_provable_circuit(&self) -> Result<ProvableCircuit<F>> {
1457 let inputs = self.build_all_input_layer_data()?;
1458
1459 let ligero_committed_inputs = self
1460 .get_all_committed_input_layer_descriptions_to_ligero()
1461 .into_iter()
1462 .map(|ligero_input_layer_description| {
1463 (
1464 ligero_input_layer_description.layer_id,
1465 (ligero_input_layer_description, None),
1466 )
1467 })
1468 .collect();
1469
1470 Ok(ProvableCircuit::new(
1471 self.circuit_description.clone(),
1472 inputs,
1473 ligero_committed_inputs,
1474 self.circuit_map.layer_label_to_layer_id.clone(),
1475 ))
1476 }
1477
1478 #[allow(clippy::type_complexity)]
1482 pub fn gen_verifiable_circuit(&self) -> Result<VerifiableCircuit<F>> {
1483 let verifier_predetermined_public_inputs = self.build_public_input_layer_data(true)?;
1487
1488 let ligero_committed_inputs = self
1490 .get_all_committed_input_layer_descriptions_to_ligero()
1491 .into_iter()
1492 .map(|ligero_input_layer_description| {
1493 (
1494 ligero_input_layer_description.layer_id,
1495 (ligero_input_layer_description, None),
1496 )
1497 })
1498 .collect();
1499
1500 Ok(VerifiableCircuit::new(
1501 self.circuit_description.clone(),
1502 verifier_predetermined_public_inputs,
1503 ligero_committed_inputs,
1504 self.circuit_map.layer_label_to_layer_id.clone(),
1505 ))
1506 }
1507}
1508
1509macro_rules! impl_add {
1511 ($Lhs:ty) => {
1512 impl<F: Field, Rhs: Into<AbstractExpression<F>>> Add<Rhs> for $Lhs {
1513 type Output = AbstractExpression<F>;
1514
1515 fn add(self, rhs: Rhs) -> Self::Output {
1516 self.expr() + rhs.into()
1517 }
1518 }
1519 };
1520}
1521impl_add!(NodeRef<F>);
1522impl_add!(&NodeRef<F>);
1523impl_add!(InputLayerNodeRef<F>);
1524impl_add!(&InputLayerNodeRef<F>);
1525impl_add!(FSNodeRef<F>);
1526impl_add!(&FSNodeRef<F>);
1527
1528macro_rules! impl_sub {
1529 ($Lhs:ty) => {
1530 impl<F: Field, Rhs: Into<AbstractExpression<F>>> Sub<Rhs> for $Lhs {
1531 type Output = AbstractExpression<F>;
1532
1533 fn sub(self, rhs: Rhs) -> Self::Output {
1534 self.expr() - rhs.into()
1535 }
1536 }
1537 };
1538}
1539impl_sub!(NodeRef<F>);
1540impl_sub!(&NodeRef<F>);
1541impl_sub!(InputLayerNodeRef<F>);
1542impl_sub!(&InputLayerNodeRef<F>);
1543impl_sub!(FSNodeRef<F>);
1544impl_sub!(&FSNodeRef<F>);
1545
1546macro_rules! impl_mul {
1547 ($Lhs:ty) => {
1548 impl<F: Field, Rhs: Into<AbstractExpression<F>>> Mul<Rhs> for $Lhs {
1549 type Output = AbstractExpression<F>;
1550
1551 fn mul(self, rhs: Rhs) -> Self::Output {
1552 self.expr() * rhs.into()
1553 }
1554 }
1555 };
1556}
1557impl_mul!(NodeRef<F>);
1558impl_mul!(&NodeRef<F>);
1559impl_mul!(InputLayerNodeRef<F>);
1560impl_mul!(&InputLayerNodeRef<F>);
1561impl_mul!(FSNodeRef<F>);
1562impl_mul!(&FSNodeRef<F>);
1563
1564macro_rules! impl_xor {
1565 ($Lhs:ty) => {
1566 impl<F: Field, Rhs: Into<AbstractExpression<F>>> BitXor<Rhs> for $Lhs {
1567 type Output = AbstractExpression<F>;
1568
1569 fn bitxor(self, rhs: Rhs) -> Self::Output {
1570 self.expr() ^ rhs.into()
1571 }
1572 }
1573 };
1574}
1575impl_xor!(NodeRef<F>);
1576impl_xor!(&NodeRef<F>);
1577impl_xor!(InputLayerNodeRef<F>);
1578impl_xor!(&InputLayerNodeRef<F>);
1579impl_xor!(FSNodeRef<F>);
1580impl_xor!(&FSNodeRef<F>);