frontend/layouter/
builder.rs

1//! A Circuit Builder struct that owns the [super::nodes::CircuitNode]s used
2//! during circuit creation.
3
4use core::fmt;
5use std::{
6    collections::{HashMap, HashSet},
7    marker::PhantomData,
8    ops::{Add, BitXor, Mul, Sub},
9    rc::{Rc, Weak},
10};
11
12use ark_std::log2;
13use hyrax::{
14    gkr::input_layer::HyraxInputLayerDescription, provable_circuit::HyraxProvableCircuit,
15    verifiable_circuit::HyraxVerifiableCircuit,
16};
17use itertools::Itertools;
18use ligero::ligero_structs::LigeroAuxInfo;
19use serde::{Deserialize, Serialize};
20use shared_types::{curves::PrimeOrderCurve, Field, Halo2FFTFriendlyField};
21
22use crate::{
23    abstract_expr::AbstractExpression,
24    layouter::{
25        layouting::LayoutingError,
26        nodes::{
27            circuit_inputs::{InputLayerNode, InputShred},
28            circuit_outputs::OutputNode,
29            fiat_shamir_challenge::FiatShamirChallengeNode,
30            gate::GateNode,
31            identity_gate::IdentityGateNode,
32            lookup::{LookupConstraint, LookupTable},
33            matmult::MatMultNode,
34            sector::{generate_sector_circuit_description, Sector},
35            split_node::SplitNode,
36            CircuitNode, NodeId,
37        },
38    },
39};
40use remainder::{
41    circuit_layout::CircuitLocation,
42    input_layer::{ligero_input_layer::LigeroInputLayerDescription, InputLayerDescription},
43    layer::{gate::BinaryOperation, layer_enum::LayerDescriptionEnum, LayerId},
44    mle::evals::MultilinearExtension,
45    output_layer::OutputLayerDescription,
46    provable_circuit::ProvableCircuit,
47    prover::{GKRCircuitDescription, GKRError},
48    utils::mle::build_composite_mle,
49    verifiable_circuit::VerifiableCircuit,
50};
51
52use anyhow::{anyhow, bail, Result};
53
54use tracing::debug;
55
56/// A dynamically-typed reference to a [CircuitNode].
57/// Used only in the front-end during the circuit-building phase.
58///
59/// A developer building a circuit can use these references to
60/// 1. indicate specific nodes when calling methods of [CircuitBuilder],
61/// 2. generate [AbstractExpression]s, typically to be used in defining [Sector] nodes.
62#[derive(Clone, Debug)]
63pub struct NodeRef<F: Field> {
64    ptr: Weak<dyn CircuitNode>,
65    _phantom: PhantomData<F>,
66}
67
68impl<F: Field> NodeRef<F> {
69    fn new(ptr: Weak<dyn CircuitNode>) -> Self {
70        Self {
71            ptr,
72            _phantom: PhantomData,
73        }
74    }
75
76    /// Generates an abstract expression containing a single MLE with the data in the node
77    /// referenced to by [Self].
78    pub fn expr(&self) -> AbstractExpression<F> {
79        self.ptr.upgrade().unwrap().id().expr()
80    }
81
82    /// Returns the [NodeId] of the node references by [Self].
83    pub fn id(&self) -> NodeId {
84        self.ptr.upgrade().unwrap().id()
85    }
86
87    /// Returns the number of variables of the MLE of this node.
88    pub fn get_num_vars(&self) -> usize {
89        self.ptr.upgrade().unwrap().get_num_vars()
90    }
91}
92
93impl<F: Field> From<NodeRef<F>> for AbstractExpression<F> {
94    fn from(value: NodeRef<F>) -> Self {
95        value.expr()
96    }
97}
98
99impl<F: Field> From<&NodeRef<F>> for AbstractExpression<F> {
100    fn from(value: &NodeRef<F>) -> Self {
101        value.expr()
102    }
103}
104
105/// A reference to a [InputLayerNode]; a specialized version of [NodeRef].
106/// Used only in the front-end during the circuit-building phase.
107#[derive(Clone, Debug)]
108pub struct InputLayerNodeRef<F: Field> {
109    ptr: Weak<InputLayerNode>,
110    _phantom: PhantomData<F>,
111}
112
113impl<F: Field> InputLayerNodeRef<F> {
114    fn new(ptr: Weak<InputLayerNode>) -> Self {
115        Self {
116            ptr,
117            _phantom: PhantomData,
118        }
119    }
120
121    /// Generates an abstract expression containing a single MLE with the data in the node
122    /// referenced to by [Self].
123    pub fn expr(&self) -> AbstractExpression<F> {
124        self.ptr.upgrade().unwrap().id().expr()
125    }
126}
127
128impl<F: Field> From<InputLayerNodeRef<F>> for AbstractExpression<F> {
129    fn from(value: InputLayerNodeRef<F>) -> Self {
130        value.expr()
131    }
132}
133
134impl<F: Field> From<&InputLayerNodeRef<F>> for AbstractExpression<F> {
135    fn from(value: &InputLayerNodeRef<F>) -> Self {
136        value.expr()
137    }
138}
139
140/// A reference to a [FiatShamirChallengeNode]; a specialized version of [NodeRef].
141/// Used only in the front-end during the circuit-building phase.
142#[derive(Clone, Debug)]
143pub struct FSNodeRef<F: Field> {
144    ptr: Weak<FiatShamirChallengeNode>,
145    _phantom: PhantomData<F>,
146}
147
148impl<F: Field> FSNodeRef<F> {
149    fn new(ptr: Weak<FiatShamirChallengeNode>) -> Self {
150        Self {
151            ptr,
152            _phantom: PhantomData,
153        }
154    }
155
156    /// Generates an abstract expression containing a single MLE with the data in the node
157    /// referenced to by [Self].
158    pub fn expr(&self) -> AbstractExpression<F> {
159        self.ptr.upgrade().unwrap().id().expr()
160    }
161}
162
163impl<F: Field> From<FSNodeRef<F>> for NodeRef<F> {
164    fn from(value: FSNodeRef<F>) -> Self {
165        NodeRef::new(value.ptr)
166    }
167}
168
169impl<F: Field> From<FSNodeRef<F>> for AbstractExpression<F> {
170    fn from(value: FSNodeRef<F>) -> Self {
171        value.expr()
172    }
173}
174
175impl<F: Field> From<&FSNodeRef<F>> for AbstractExpression<F> {
176    fn from(value: &FSNodeRef<F>) -> Self {
177        value.expr()
178    }
179}
180
181/// A reference to a [LookupTable]; a specialized version of [NodeRef].
182/// Used only in the front-end during the circuit-building phase.
183#[derive(Clone, Debug)]
184pub struct LookupTableNodeRef {
185    ptr: Weak<LookupTable>,
186}
187
188impl LookupTableNodeRef {
189    fn new(ptr: Weak<LookupTable>) -> Self {
190        Self { ptr }
191    }
192
193    /// Generates an abstract expression containing a single MLE with the data in the node
194    /// referenced to by [Self].
195    pub fn expr<F: Field>(&self) -> AbstractExpression<F> {
196        self.ptr.upgrade().unwrap().id().expr()
197    }
198}
199
200/// A reference to a [LookupConstraint]; a specialized version of [NodeRef].
201/// Used only in the front-end during the circuit-building phase.
202#[derive(Clone, Debug)]
203pub struct LookupConstraintNodeRef {
204    ptr: Weak<LookupConstraint>,
205}
206
207impl LookupConstraintNodeRef {
208    fn new(ptr: Weak<LookupConstraint>) -> Self {
209        Self { ptr }
210    }
211
212    /// Generates an abstract expression containing a single MLE with the data in the node
213    /// referenced to by [Self].
214    pub fn expr<F: Field>(&self) -> AbstractExpression<F> {
215        self.ptr.upgrade().unwrap().id().expr()
216    }
217}
218
219/// A struct that owns and manages [super::nodes::CircuitNode]s during
220/// circuit creation.
221pub struct CircuitBuilder<F: Field> {
222    input_layer_nodes: Vec<Rc<InputLayerNode>>,
223    input_shred_nodes: Vec<Rc<InputShred>>,
224    fiat_shamir_challenge_nodes: Vec<Rc<FiatShamirChallengeNode>>,
225    output_nodes: Vec<Rc<OutputNode>>,
226    sector_nodes: Vec<Rc<Sector<F>>>,
227    gate_nodes: Vec<Rc<GateNode>>,
228    identity_gate_nodes: Vec<Rc<IdentityGateNode>>,
229    split_nodes: Vec<Rc<SplitNode>>,
230    matmult_nodes: Vec<Rc<MatMultNode>>,
231    lookup_constraint_nodes: Vec<Rc<LookupConstraint>>,
232    lookup_table_nodes: Vec<Rc<LookupTable>>,
233    node_to_ptr: HashMap<NodeId, NodeRef<F>>,
234    circuit_map: CircuitMap,
235}
236
237impl<F: Field> CircuitBuilder<F> {
238    /// Constructs an empty [CircuitBuilder].
239    pub fn new() -> Self {
240        Self {
241            input_layer_nodes: vec![],
242            input_shred_nodes: vec![],
243            fiat_shamir_challenge_nodes: vec![],
244            output_nodes: vec![],
245            sector_nodes: vec![],
246            gate_nodes: vec![],
247            identity_gate_nodes: vec![],
248            split_nodes: vec![],
249            matmult_nodes: vec![],
250            lookup_constraint_nodes: vec![],
251            lookup_table_nodes: vec![],
252            node_to_ptr: HashMap::new(),
253            circuit_map: CircuitMap::new(),
254        }
255    }
256
257    fn into_owned_helper<T: CircuitNode + fmt::Debug>(xs: Vec<Rc<T>>) -> Vec<T> {
258        xs.into_iter()
259            .map(|ptr| Rc::try_unwrap(ptr).unwrap())
260            .collect()
261    }
262
263    /// Generates a circuit description of all the nodes added so far.
264    ///
265    /// Returns a [Circuit] struct containing the circuit description and all necessary metadata for
266    /// attaching inputs.
267    pub fn build_with_max_layer_size(
268        mut self,
269        maybe_maximum_log_layer_size: Option<usize>,
270    ) -> Result<Circuit<F>> {
271        let input_layer_nodes = Self::into_owned_helper(self.input_layer_nodes);
272        let input_shred_nodes = Self::into_owned_helper(self.input_shred_nodes);
273        let fiat_shamir_challenge_nodes = Self::into_owned_helper(self.fiat_shamir_challenge_nodes);
274        let output_nodes = Self::into_owned_helper(self.output_nodes);
275        let sector_nodes = Self::into_owned_helper(self.sector_nodes);
276        let gate_nodes = Self::into_owned_helper(self.gate_nodes);
277        let identity_gate_nodes = Self::into_owned_helper(self.identity_gate_nodes);
278        let split_nodes = Self::into_owned_helper(self.split_nodes);
279        let matmult_nodes = Self::into_owned_helper(self.matmult_nodes);
280        let lookup_constraint_nodes = Self::into_owned_helper(self.lookup_constraint_nodes);
281        let lookup_table_nodes = Self::into_owned_helper(self.lookup_table_nodes);
282        let id_to_sector_nodes_map: HashMap<NodeId, Sector<F>> = sector_nodes
283            .iter()
284            .cloned()
285            .map(|node| (node.id(), node))
286            .collect();
287
288        // If the specified maximum layer size is 0, then this means we do not want to combine any layers.
289        let should_combine = maybe_maximum_log_layer_size != Some(0);
290
291        let (
292            input_layer_nodes,
293            fiat_shamir_challenge_nodes,
294            intermediate_node_layers,
295            lookup_nodes,
296            output_nodes,
297        ) = super::layouting::layout(
298            input_layer_nodes,
299            input_shred_nodes,
300            fiat_shamir_challenge_nodes,
301            output_nodes,
302            sector_nodes,
303            gate_nodes,
304            identity_gate_nodes,
305            split_nodes,
306            matmult_nodes,
307            lookup_constraint_nodes,
308            lookup_table_nodes,
309            should_combine,
310        )
311        .unwrap();
312
313        let mut intermediate_layers = Vec::<LayerDescriptionEnum<F>>::new();
314        let mut output_layers = Vec::<OutputLayerDescription<F>>::new();
315
316        let input_layers = input_layer_nodes
317            .iter()
318            .map(|input_layer_node| {
319                let input_layer_description = input_layer_node
320                    .generate_input_layer_description::<F>(&mut self.circuit_map)
321                    .unwrap();
322                self.circuit_map.insert_shreds_into_input_layer(
323                    input_layer_description.layer_id,
324                    input_layer_node
325                        .input_shreds
326                        .iter()
327                        .map(CircuitNode::id)
328                        .collect(),
329                );
330                input_layer_description
331            })
332            .collect_vec();
333
334        let fiat_shamir_challenges = fiat_shamir_challenge_nodes
335            .iter()
336            .map(|fiat_shamir_challenge_node| {
337                fiat_shamir_challenge_node.generate_circuit_description::<F>(&mut self.circuit_map)
338            })
339            .collect_vec();
340
341        for layer in &intermediate_node_layers {
342            // We have no nodes to combine in this layer. Therefore we can directly
343            // compile it and add it to the layer circuit descriptions.
344            if layer.len() == 1 {
345                intermediate_layers.extend(
346                    layer
347                        .first()
348                        .unwrap()
349                        .generate_circuit_description(&mut self.circuit_map)?,
350                );
351            } else {
352                // If there are nodes to combine, they must all be sectors. We first
353                // check whether they are sectors and grab their associated node as
354                // a Vec<&Sector<F>>.
355                //
356                // From this, we can generate their circuit description.
357                let sectors = layer
358                    .iter()
359                    .map(|sector| {
360                        assert!(id_to_sector_nodes_map.contains_key(&sector.id()));
361                        id_to_sector_nodes_map.get(&sector.id()).unwrap()
362                    })
363                    .collect_vec();
364                intermediate_layers.extend(generate_sector_circuit_description(
365                    &sectors,
366                    &mut self.circuit_map,
367                    maybe_maximum_log_layer_size,
368                ));
369            }
370        }
371
372        // Get the contributions of each LookupTable to the circuit description.
373        (intermediate_layers, output_layers) = lookup_nodes.iter().fold(
374            (intermediate_layers, output_layers),
375            |(mut lookup_intermediate_acc, mut lookup_output_acc), lookup_node| {
376                let (intermediate_layers, output_layer) = lookup_node
377                    .generate_lookup_circuit_description(&mut self.circuit_map)
378                    .unwrap();
379                lookup_intermediate_acc.extend(intermediate_layers);
380                lookup_output_acc.push(output_layer);
381                (lookup_intermediate_acc, lookup_output_acc)
382            },
383        );
384        output_layers =
385            output_nodes
386                .iter()
387                .fold(output_layers, |mut output_layer_acc, output_node| {
388                    output_layer_acc
389                        .extend(output_node.generate_circuit_description(&mut self.circuit_map));
390                    output_layer_acc
391                });
392
393        let mut circuit_description = GKRCircuitDescription {
394            input_layers,
395            fiat_shamir_challenges,
396            intermediate_layers,
397            output_layers,
398        };
399        circuit_description.index_mle_indices(0);
400
401        Ok(Circuit::new(circuit_description, self.circuit_map))
402    }
403
404    /// A build function that combines layers greedily such that the circuit is optimized for having
405    /// the smallest number of layers possible.
406    pub fn build_with_layer_combination(self) -> Result<Circuit<F>> {
407        self.build_with_max_layer_size(None)
408    }
409
410    /// A build function that does not combine any layers.
411    pub fn build_without_layer_combination(self) -> Result<Circuit<F>> {
412        self.build_with_max_layer_size(Some(0))
413    }
414
415    /// A default build function which does _not_ combine layers.
416    /// Equivalent to `build_without_layer_combination`.
417    pub fn build(self) -> Result<Circuit<F>> {
418        self.build_without_layer_combination()
419    }
420}
421
422impl<F: Field> CircuitBuilder<F> {
423    /// Adds an [InputLayerNode] labeled `layer_label` to the builder's node collection, intented to
424    /// become a `layer_kind` input later during circuit instantiation.
425    ///
426    /// Returns a weak pointer to the newly created layer node.
427    ///
428    /// Note that Input Layers and Input Shred have disjoint label scopes. A label has to be unique
429    /// only in its respective scope, regardless of the inclusive relation between shreds and input
430    /// layers.
431    ///
432    /// # Panics
433    /// If `layer_label` has already been used for an existing Input Layer.
434    pub fn add_input_layer(
435        &mut self,
436        layer_label: &str,
437        layer_visibility: LayerVisibility,
438    ) -> InputLayerNodeRef<F> {
439        let node = Rc::new(InputLayerNode::new(None));
440        let node_weak_ref = Rc::downgrade(&node);
441
442        let layer_id = node.input_layer_id();
443
444        self.circuit_map
445            .add_input_layer(layer_id, layer_label, layer_visibility);
446
447        self.node_to_ptr
448            .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
449
450        self.input_layer_nodes.push(node);
451
452        InputLayerNodeRef::new(node_weak_ref)
453    }
454
455    // Adds an [InputShred] labeled `label` to the builder's node collection.
456    /// Returns a reference to the newly created node.
457    ///
458    /// Note that no method in [Self] requires to differentiate between a reference to an input
459    /// shred as opposed to a generic [NodeRef], so there is no need to retain the specific type
460    /// information in the returned type.
461    ///
462    /// # Panics
463    /// If `label` has already been used for an existing Input Shred.
464    pub fn add_input_shred(
465        &mut self,
466        label: &str,
467        num_vars: usize,
468        source: &InputLayerNodeRef<F>,
469    ) -> NodeRef<F> {
470        let source = source
471            .ptr
472            .upgrade()
473            .expect("InputShred's source data has already been dropped");
474        let node = Rc::new(InputShred::new(num_vars, &source));
475        let node_weak_ref = Rc::downgrade(&node);
476
477        let node_id = node.id();
478
479        self.node_to_ptr
480            .insert(node_id, NodeRef::new(node_weak_ref.clone()));
481
482        // Associate `label` with the `NodeId` of the newly created node.
483        self.circuit_map.add_input_shred(label, node_id);
484
485        self.input_shred_nodes.push(node);
486
487        NodeRef::new(node_weak_ref)
488    }
489
490    /// Adds an _zero_ [OutputNode] (using `OutputNode::new_zero()`) to the builder's node
491    /// collection.
492    ///
493    /// TODO(Makis): Add a check for ensuring each node can be set as output at most once.
494    pub fn set_output(&mut self, source: &NodeRef<F>) {
495        let source = source
496            .ptr
497            .upgrade()
498            .expect("Sector source has already been dropped");
499        let node = Rc::new(OutputNode::new_zero(source.as_ref()));
500        self.output_nodes.push(node);
501    }
502
503    /// Adds a [FiatShamirChallengeNode] to the builder's node colllection.
504    /// Returns a typed reference to the newly created node.
505    pub fn add_fiat_shamir_challenge_node(&mut self, num_challenges: usize) -> FSNodeRef<F> {
506        let node = Rc::new(FiatShamirChallengeNode::new(num_challenges));
507        let node_weak_ref = Rc::downgrade(&node);
508        self.node_to_ptr
509            .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
510        self.fiat_shamir_challenge_nodes.push(node);
511        FSNodeRef::new(node_weak_ref)
512    }
513
514    /// Adds a [Sector] to the builder's node collection.
515    /// Returns a typed reference to the newly created node.
516    pub fn add_sector(&mut self, expr: AbstractExpression<F>) -> NodeRef<F> {
517        let node_ids_in_use: HashSet<NodeId> = expr.get_sources().into_iter().collect();
518
519        let num_vars_map: HashMap<NodeId, usize> = node_ids_in_use
520            .into_iter()
521            .map(|id| (id, self.get_ptr_from_node_id(id).get_num_vars()))
522            .collect();
523
524        let num_vars = expr
525            .get_num_vars(&num_vars_map)
526            .expect("Internal error duing 'num_vars' computation of an AbstractExpression");
527
528        let node = Rc::new(Sector::<F>::new(expr, num_vars));
529        let node_weak_ref = Rc::downgrade(&node);
530
531        self.node_to_ptr
532            .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
533
534        self.sector_nodes.push(node);
535        NodeRef::new(node_weak_ref)
536    }
537
538    /// Adds an [IdentityGateNode] to the builder's node collection.
539    /// Returns a reference to the newly created node.
540    ///
541    /// Note that no method in [Self] requires to differentiate between a reference to an identity
542    /// gate node as opposed to a generic [NodeRef], so there is no need to retain the specific type
543    /// information in the returned type.
544    pub fn add_identity_gate_node(
545        &mut self,
546        pre_routed_data: &NodeRef<F>,
547        non_zero_gates: Vec<(u32, u32)>,
548        num_vars: usize,
549        num_dataparallel_vars: Option<usize>,
550    ) -> NodeRef<F> {
551        let pre_routed_data = pre_routed_data
552            .ptr
553            .upgrade()
554            .expect("`pre_routed_data` reference given to identity gate has been dropped");
555        let node = Rc::new(IdentityGateNode::new(
556            pre_routed_data.as_ref(),
557            non_zero_gates,
558            num_vars,
559            num_dataparallel_vars,
560        ));
561        let node_weak_ref = Rc::downgrade(&node);
562        self.node_to_ptr
563            .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
564        self.identity_gate_nodes.push(node);
565        NodeRef::new(node_weak_ref)
566    }
567
568    /// Adds an [GateNode] to the builder's node collection.
569    /// Returns a reference to the newly created node.
570    ///
571    /// Note that no method in [Self] requires to differentiate between a reference to a gate node
572    /// as opposed to a generic [NodeRef], so there is no need to retain the specific type
573    /// information in the returned type.
574    pub fn add_gate_node(
575        &mut self,
576        lhs: &NodeRef<F>,
577        rhs: &NodeRef<F>,
578        nonzero_gates: Vec<(u32, u32, u32)>,
579        gate_operation: BinaryOperation,
580        num_dataparallel_bits: Option<usize>,
581    ) -> NodeRef<F> {
582        let lhs = lhs
583            .ptr
584            .upgrade()
585            .expect("lhs give to GateNode has already been dropped");
586        let rhs = rhs
587            .ptr
588            .upgrade()
589            .expect("rhs give to GateNode has already been dropped");
590        let node = Rc::new(GateNode::new(
591            lhs.as_ref(),
592            rhs.as_ref(),
593            nonzero_gates,
594            gate_operation,
595            num_dataparallel_bits,
596        ));
597        let node_weak_ref = Rc::downgrade(&node);
598        self.node_to_ptr
599            .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
600        self.gate_nodes.push(node);
601        NodeRef::new(node_weak_ref)
602    }
603
604    /// Adds an [MatMultNode] to the builder's node collection.
605    /// Returns a reference to the newly created node.
606    ///
607    /// Note that no method in [Self] requires to differentiate between a reference to a matmult
608    /// node as opposed to a generic [NodeRef], so there is no need to retain the specific type
609    /// information in the returned type.
610    pub fn add_matmult_node(
611        &mut self,
612        matrix_a_node: &NodeRef<F>,
613        rows_cols_num_vars_a: (usize, usize),
614        matrix_b_node: &NodeRef<F>,
615        rows_cols_num_vars_b: (usize, usize),
616    ) -> NodeRef<F> {
617        let matrix_a_node = matrix_a_node
618            .ptr
619            .upgrade()
620            .expect("Matrix A input to MatMultNode has been dropped");
621        let matrix_b_node = matrix_b_node
622            .ptr
623            .upgrade()
624            .expect("Matrix B input to MatMultNode has been dropped");
625        let node = Rc::new(MatMultNode::new(
626            matrix_a_node.as_ref(),
627            rows_cols_num_vars_a,
628            matrix_b_node.as_ref(),
629            rows_cols_num_vars_b,
630        ));
631        let node_weak_ref = Rc::downgrade(&node);
632        self.node_to_ptr
633            .insert(node.id(), NodeRef::new(node_weak_ref.clone()));
634        self.matmult_nodes.push(node);
635        NodeRef::new(node_weak_ref)
636    }
637
638    /// Adds an [LookupTable] to the builder's node collection.
639    /// Returns a typed reference to the newly created node.
640    pub fn add_lookup_table(
641        &mut self,
642        table: &NodeRef<F>,
643        fiat_shamir_challenge_node: &FSNodeRef<F>,
644    ) -> LookupTableNodeRef {
645        let table = table
646            .ptr
647            .upgrade()
648            .expect("Table input to LookupTable has already been dropped");
649        let fiat_shamir_challenge_node = fiat_shamir_challenge_node
650            .ptr
651            .upgrade()
652            .expect("FiatShamirChallegeNode input to LookupTable has already been dropped");
653        let node = Rc::new(LookupTable::new(
654            table.as_ref(),
655            fiat_shamir_challenge_node.as_ref(),
656        ));
657        let node_ref = Rc::downgrade(&node);
658        self.node_to_ptr
659            .insert(node.id(), NodeRef::new(node_ref.clone()));
660        self.lookup_table_nodes.push(node);
661        LookupTableNodeRef::new(node_ref)
662    }
663
664    /// Adds an [LookupConstraint] to the builder's node collection.
665    /// Returns a typed reference to the newly created node.
666    pub fn add_lookup_constraint(
667        &mut self,
668        lookup_table: &LookupTableNodeRef,
669        constrained: &NodeRef<F>,
670        multiplicities: &NodeRef<F>,
671    ) -> LookupConstraintNodeRef {
672        let lookup_table = lookup_table
673            .ptr
674            .upgrade()
675            .expect("LookupTable input to LookupConstraint has already been dropped");
676        let constrained = constrained
677            .ptr
678            .upgrade()
679            .expect("constrained input to LookupConstrained has already been dropped");
680        let multiplicities = multiplicities
681            .ptr
682            .upgrade()
683            .expect("multiplicites input to LookupConstrained has already been dropped");
684        let node = Rc::new(LookupConstraint::new(
685            lookup_table.as_ref(),
686            constrained.as_ref(),
687            multiplicities.as_ref(),
688        ));
689        let node_ref = Rc::downgrade(&node);
690        self.node_to_ptr
691            .insert(node.id(), NodeRef::new(node_ref.clone()));
692        self.lookup_constraint_nodes.push(node);
693        LookupConstraintNodeRef::new(node_ref)
694    }
695
696    /// Adds an [SplitNode] to the builder's node collection.
697    /// Returns a vector of reference to the `2^num_vars` newly created nodes.
698    ///
699    /// Note that no method in [Self] requires to differentiate between a reference to a split node
700    /// as opposed to a generic [NodeRef], so there is no need to retain the specific type
701    /// information in the returned type.
702    pub fn add_split_node(&mut self, input_node: &NodeRef<F>, num_vars: usize) -> Vec<NodeRef<F>> {
703        let input_node = input_node
704            .ptr
705            .upgrade()
706            .expect("input_node to SplitNode has already been dropped");
707        let nodes = SplitNode::new(input_node.as_ref(), num_vars)
708            .into_iter()
709            .map(Rc::new)
710            .collect_vec();
711        debug_assert_eq!(nodes.len(), 1 << num_vars);
712        let node_refs = nodes
713            .iter()
714            .map(|node| NodeRef::new(Rc::downgrade(node) as Weak<dyn CircuitNode>))
715            .collect_vec();
716        nodes
717            .iter()
718            .zip(node_refs.iter())
719            .for_each(|(node, node_ref)| {
720                self.node_to_ptr.insert(node.id(), node_ref.clone());
721            });
722        self.split_nodes.extend(nodes);
723        node_refs
724    }
725
726    fn get_ptr_from_node_id(&self, id: NodeId) -> Rc<dyn CircuitNode> {
727        self.node_to_ptr[&id].ptr.upgrade().unwrap()
728    }
729}
730
731impl<F: Field> Default for CircuitBuilder<F> {
732    fn default() -> Self {
733        Self::new()
734    }
735}
736
737#[cfg(test)]
738mod test {
739    use shared_types::Fr;
740
741    use super::*;
742
743    #[test]
744    #[should_panic]
745    pub fn test_unique_input_layer_label() {
746        let mut builder = CircuitBuilder::<Fr>::new();
747
748        let _input_layer1 = builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
749        let _input_layer2 = builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
750    }
751
752    #[test]
753    pub fn test_scope_mixing() {
754        let mut builder = CircuitBuilder::<Fr>::new();
755
756        let input_layer = builder.add_input_layer("label", LayerVisibility::Public);
757
758        builder.add_input_shred("label", 1, &input_layer);
759    }
760
761    #[test]
762    #[should_panic]
763    pub fn test_unique_input_shred_label() {
764        let mut builder = CircuitBuilder::<Fr>::new();
765
766        let input_layer = builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
767        builder.add_input_shred("shred1", 1, &input_layer);
768        builder.add_input_shred("shred1", 1, &input_layer);
769    }
770
771    #[test]
772    #[should_panic]
773    pub fn test_unique_input_shred_label2() {
774        let mut builder = CircuitBuilder::<Fr>::new();
775
776        let input_layer1 = builder.add_input_layer("Input Layer 1", LayerVisibility::Public);
777        let input_layer2 = builder.add_input_layer("Input Layer 2", LayerVisibility::Committed);
778
779        builder.add_input_shred("shred1", 1, &input_layer1);
780        builder.add_input_shred("shred1", 1, &input_layer2);
781    }
782}
783
784/// The Layer kind defines the visibility of an input layer's data.
785#[derive(Clone, Debug, PartialEq, Copy, Serialize, Deserialize)]
786pub enum LayerVisibility {
787    /// Input layers whose data are visible to the verifier.
788    Public,
789
790    /// Input layers whose data are only accessible through their commitments; according to some
791    /// Polynomial Commitment Scheme (PCS). The specific commitment scheme is determined when the
792    /// circuit is finalized.
793    Committed,
794}
795
796/// Used only inside a [CircuitMap] to keep track of its state.
797#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
798enum CircuitMapState {
799    /// The circuit is under construction, meaning that some of the internals mappings might be in
800    /// an incomplete state.
801    UnderConstruction,
802
803    /// The circuit has been built, and all internal mappings must be in a complete and consistent
804    /// state.
805    Ready,
806}
807
808/// Manages the relations between all different kinds of identifiers used to specify nodes during
809/// circuit building and circuit instantiation.
810/// Keeps track of [LayerId]s, [NodeId]s, Labels, [LayerVisibility]s, [CircuitLocation]s.
811#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
812pub struct CircuitMap {
813    state: CircuitMapState,
814    shreds_in_layer: HashMap<LayerId, Vec<NodeId>>,
815    label_to_shred_id: HashMap<String, NodeId>,
816    layer_label_to_layer_id: HashMap<String, LayerId>,
817    layer_visibility: HashMap<LayerId, LayerVisibility>,
818    node_location: HashMap<NodeId, (CircuitLocation, usize)>,
819}
820
821impl CircuitMap {
822    /// Constructs an empty [CircuitMap] in a `CircuitMapState::UnderConstruction` state.
823    /// It can receive new data until the [Self::freeze] method is called which will transition it
824    /// to the `CircuitMapState::Ready` state, at which point it can only be used to answer queries.
825    pub fn new() -> Self {
826        Self {
827            state: CircuitMapState::UnderConstruction,
828            shreds_in_layer: HashMap::new(),
829            label_to_shred_id: HashMap::new(),
830            layer_label_to_layer_id: HashMap::new(),
831            layer_visibility: HashMap::new(),
832            node_location: HashMap::new(),
833        }
834    }
835
836    /// Associates the node with ID `node_id` to its corresponding circuit location as well as
837    /// number of variables.
838    ///
839    /// # Panics
840    /// If [self] is not in state `CircuitMapState::UnderConstruction`,
841    /// or if the node with ID `node_id` has already been assigned a location.
842    pub fn add_node_id_and_location_num_vars(
843        &mut self,
844        node_id: NodeId,
845        value: (CircuitLocation, usize),
846    ) {
847        assert_eq!(self.state, CircuitMapState::UnderConstruction);
848        assert!(!self.node_location.contains_key(&node_id));
849        self.node_location.insert(node_id, value);
850    }
851
852    /// Adds a collection of `shreds` to Input Layer with ID `input_layer_id`.
853    ///
854    /// # Panics
855    /// If [self] is not in state `CircuitMapState::UnderConstruction`,
856    /// or if `input_layer_id` has already been assigned shreds.
857    pub fn insert_shreds_into_input_layer(&mut self, input_layer_id: LayerId, shreds: Vec<NodeId>) {
858        assert_eq!(self.state, CircuitMapState::UnderConstruction);
859        assert!(!self.shreds_in_layer.contains_key(&input_layer_id));
860        self.shreds_in_layer.insert(input_layer_id, shreds);
861    }
862
863    /// Using `node_id`, retrieves the number of variables and location of this
864    /// node in the circuit, or returns an error if `node_id` is missing.
865    /// This method is safe to use in any `CircuitMapState`.
866    pub fn get_location_num_vars_from_node_id(
867        &self,
868        node_id: &NodeId,
869    ) -> Result<&(CircuitLocation, usize)> {
870        self.node_location
871            .get(node_id)
872            .ok_or(anyhow!(LayoutingError::DanglingNodeId(*node_id)))
873    }
874
875    /// Adds a new Input Layer with ID `layer_id`, label `layer_label`, and visibility defined by
876    /// `layer_kind`.
877    ///
878    /// # Panics
879    /// If [self] is _not_ in state `CircuitMapState::UnderConstruction`, or if a layer already
880    /// exists with either an ID equal to `layer_id` or a label equal to `layer_label`.
881    pub fn add_input_layer(
882        &mut self,
883        layer_id: LayerId,
884        layer_label: &str,
885        layer_kind: LayerVisibility,
886    ) {
887        assert_eq!(self.state, CircuitMapState::UnderConstruction);
888        assert!(!self.layer_visibility.contains_key(&layer_id));
889        assert!(!self.layer_label_to_layer_id.contains_key(layer_label));
890        self.layer_label_to_layer_id
891            .insert(String::from(layer_label), layer_id);
892        self.layer_visibility.insert(layer_id, layer_kind);
893    }
894
895    /// Adds a new Input Shred labeled `label` with node ID `shred_id`.
896    ///
897    /// # Panics
898    /// If [self] is _not_ in a `CircuitMapState::UnderConstruction`, or if `label` is already in
899    /// use.
900    ///
901    /// # Note
902    /// While this method does _not_ panic if `shred_id` has already been given a different label,
903    /// it is considered a semantic error to associate two different labels with the same node, and
904    /// [Self::freeze] may detect it and panic .
905    pub fn add_input_shred(&mut self, label: &str, shred_id: NodeId) {
906        assert_eq!(self.state, CircuitMapState::UnderConstruction);
907        assert!(!self.label_to_shred_id.contains_key(label));
908        self.label_to_shred_id.insert(String::from(label), shred_id);
909    }
910
911    /// Desctructively mutates [self] to transition it from a `CircuitMapState::UnderConstruction`
912    /// state to a `CircuitMapState::Ready` state. As part of the transition, consistency checks are
913    /// performed to ensure all internal mappings have all the expected properties.
914    ///
915    /// # Panics
916    /// If [self] is already in a `CircuitMapState::Ready` state, or if its internal state is
917    /// inconsistent.
918    pub fn freeze(mut self) -> Self {
919        assert_eq!(self.state, CircuitMapState::UnderConstruction);
920
921        // Ensure consistency between `self.shreds_in_layer` and `self.layer_visibility`: their domains
922        // should conincide.
923        assert_eq!(
924            self.shreds_in_layer.keys().sorted().collect_vec(),
925            self.layer_visibility.keys().sorted().collect_vec()
926        );
927
928        // Keep the shred location of the input shreds only.
929        let input_shred_location: HashMap<NodeId, (CircuitLocation, usize)> = self
930            .label_to_shred_id
931            .values()
932            .map(|shred_id| (*shred_id, self.node_location[shred_id].clone()))
933            .collect();
934        self.node_location = input_shred_location;
935
936        // Ensure consistency between `self.shreds_in_layer` and `self.shred_location`: the
937        // flattened image of the former should equal the domain of the latter.
938        assert_eq!(
939            self.shreds_in_layer
940                .values()
941                .flatten()
942                .sorted()
943                .collect_vec(),
944            self.node_location.keys().sorted().collect_vec()
945        );
946
947        // Ensure consistency between `self.label_to_shred_id` and `self.shred_location`: the image
948        // of the former shoould equal the domain of the latter.
949        assert_eq!(
950            self.label_to_shred_id.values().sorted().collect_vec(),
951            self.node_location.keys().sorted().collect_vec()
952        );
953
954        // TODO: Ensure all circuit locations are covered in `self.shred_location`.
955        // TODO: Ensure mappings are bijective.
956
957        self.state = CircuitMapState::Ready;
958
959        self
960    }
961
962    /// Returns the [LayerVisibility] of the Input Layer that the shred with label `shred_label` is in,
963    /// or an error if the `shred_label` does not correspond to any input shred.
964    pub fn get_node_kind(&self, shred_label: &str) -> Result<LayerVisibility> {
965        // The call to `self.get_node_id` will check ensure the state is `Ready`.
966
967        // This lookup may fail because the caller provided an invalid label.
968        // In this case, return an `Error`.
969        let shred_id = self.get_node_id(shred_label)?;
970
971        // Subsequent lookups should never fail, assuming `Self` maintains a consistent state.
972        let layer_id = self.node_location[&shred_id].0.layer_id;
973        let layer_visibility = self.layer_visibility[&layer_id];
974
975        Ok(layer_visibility)
976    }
977
978    /// Returns the [NodeId] of the Input Shred labeled `shred_label`, or an error if `shred_label`
979    /// does not correspond to any input shred.
980    pub fn get_node_id(&self, shred_label: &str) -> Result<NodeId> {
981        // This lookup may fail because the caller provided an invalid label.
982        // In this case, return an error.
983        let shred_id = self
984            .label_to_shred_id
985            .get(shred_label)
986            .ok_or(anyhow!("Unrecognized Shred Label '{shred_label}'"))?;
987
988        Ok(*shred_id)
989    }
990
991    /// Returns a vector of all [NodeId]s of Input Shreds.
992    ///
993    /// # Panics
994    /// If [self] is _not_ in `CircuitMapState::Ready` state.
995    pub fn get_all_input_shred_ids(&self) -> Vec<NodeId> {
996        assert_eq!(self.state, CircuitMapState::Ready);
997
998        self.node_location.keys().cloned().collect_vec()
999    }
1000
1001    /// Returns the label of the input shred with ID `shred_id`, or an error if there is no input
1002    /// shred with ID `shred_id`.
1003    ///
1004    /// Intended to be used only for error-reporting; current implementation is inefficient.
1005    ///
1006    /// # Panics
1007    /// If [self] is _not_ in state `CircuitMapState::Ready`.
1008    pub fn get_shred_label_from_id(&self, shred_id: NodeId) -> Result<String> {
1009        assert_eq!(self.state, CircuitMapState::Ready);
1010
1011        // Reverse lookup `shred_id` in `self.label_to_shred_id`.
1012        let labels = self
1013            .label_to_shred_id
1014            .iter()
1015            .filter_map(|(label, node_id)| {
1016                if *node_id == shred_id {
1017                    Some(label)
1018                } else {
1019                    None
1020                }
1021            })
1022            .collect_vec();
1023
1024        if labels.is_empty() {
1025            bail!("Unrecognized Input Shred ID '{shred_id}'");
1026        } else {
1027            // Panic if more than one label maps to this input shred ID as this indicates an
1028            // inconsistent internal state.
1029            assert_eq!(labels.len(), 1);
1030
1031            Ok(labels[0].clone())
1032        }
1033    }
1034
1035    /// Returns a vector of all input layer IDs.
1036    ///
1037    /// TODO: Consider returning an iterator instead of `Vec`.
1038    ///
1039    /// # Panics
1040    /// If [self] is _not_ in `CircuitMapState::Ready` state.
1041    pub fn get_all_input_layer_ids(&self) -> Vec<LayerId> {
1042        assert_eq!(self.state, CircuitMapState::Ready);
1043
1044        self.layer_visibility.keys().cloned().collect_vec()
1045    }
1046
1047    /// Returns a vector of all _public_ input layer IDs.
1048    ///
1049    /// TODO: Consider returning an iterator instead of `Vec`.
1050    ///
1051    /// # Panics
1052    /// If [self] is _not_ in `CircuitMapState::Ready` state.
1053    pub fn get_all_public_input_layer_ids(&self) -> Vec<LayerId> {
1054        assert_eq!(self.state, CircuitMapState::Ready);
1055
1056        self.layer_visibility
1057            .iter()
1058            .filter_map(|(layer_id, visibility)| match *visibility {
1059                LayerVisibility::Public => Some(layer_id),
1060                LayerVisibility::Committed => None,
1061            })
1062            .cloned()
1063            .collect_vec()
1064    }
1065
1066    /// Returns a vector of all Input Shred IDs in the Input Layer with ID `layer_id`, or an error
1067    /// if there is no input layer with that ID.
1068    ///
1069    /// # Panics
1070    /// If [self] is _not_ in `CircuitMapState::Ready`.
1071    pub fn get_input_shreds_from_layer_id(&self, layer_id: LayerId) -> Result<Vec<NodeId>> {
1072        assert_eq!(self.state, CircuitMapState::Ready);
1073
1074        Ok(self
1075            .shreds_in_layer
1076            .get(&layer_id)
1077            .ok_or(anyhow!("Unrecognized Input Layer ID '{layer_id}'"))?
1078            .clone())
1079    }
1080
1081    /// Returns the [CircuitLocation] and number of variables of the Input Shred with ID `shred_id`,
1082    /// or an error if no input shred with this ID exists.
1083    ///
1084    /// # Panics
1085    /// If [self] is _not_ in `CircuitMapState::Ready` state.
1086    pub fn get_shred_location(&self, shred_id: NodeId) -> Result<(CircuitLocation, usize)> {
1087        assert_eq!(self.state, CircuitMapState::Ready);
1088
1089        self.node_location
1090            .get(&shred_id)
1091            .ok_or(anyhow!("Unrecognized Shred ID '{shred_id}'."))
1092            .cloned()
1093    }
1094
1095    /// Returns a vector with all [LayerId]s of the Input Layers with [LayerVisibility::Committed]
1096    /// visibility.
1097    ///
1098    /// # Panics
1099    /// If [self] is _not_ in `CircuitMapState::Ready` state.
1100    pub fn get_all_committed_layers(&self) -> Vec<LayerId> {
1101        assert_eq!(self.state, CircuitMapState::Ready);
1102
1103        self.layer_visibility
1104            .iter()
1105            .filter_map(|(layer_id, layer_visibility)| {
1106                if *layer_visibility == LayerVisibility::Committed {
1107                    Some(layer_id)
1108                } else {
1109                    None
1110                }
1111            })
1112            .cloned()
1113            .collect_vec()
1114    }
1115
1116    /// Returns the layer ID of the input layer with label `layer_label`, or error if no such layer
1117    /// exists.
1118    pub fn get_layer_id_from_label(&self, layer_label: &str) -> Result<LayerId> {
1119        self.layer_label_to_layer_id
1120            .get(layer_label)
1121            .cloned()
1122            .ok_or(anyhow!("No Input Layer with label {layer_label}."))
1123    }
1124}
1125
1126impl Default for CircuitMap {
1127    fn default() -> Self {
1128        Self::new()
1129    }
1130}
1131
1132/// A circuit whose structure is fixed, but is not yet ready to be proven or verified because its
1133/// missing all or some of its input data. This structs provides an API for attaching inputs and
1134/// generating a form of the circuit that can be proven or verified respectively, for various
1135/// proving systems (vanilla GKR with Ligero, or Hyrax).
1136#[derive(Clone, Debug, Serialize, Deserialize)]
1137#[serde(bound = "F: Field")]
1138pub struct Circuit<F: Field> {
1139    circuit_description: GKRCircuitDescription<F>,
1140    pub circuit_map: CircuitMap,
1141    partial_inputs: HashMap<NodeId, MultilinearExtension<F>>,
1142}
1143
1144impl<F: Field> Circuit<F> {
1145    /// Constructor to be used by [CircuitBuilder].
1146    fn new(circuit_description: GKRCircuitDescription<F>, circuit_map: CircuitMap) -> Self {
1147        assert_eq!(circuit_map.state, CircuitMapState::UnderConstruction);
1148
1149        Self {
1150            circuit_description,
1151            circuit_map: circuit_map.freeze(),
1152            partial_inputs: HashMap::new(),
1153        }
1154    }
1155
1156    /// Return the [GKRCircuitDescription] inside this [Circuit].
1157    pub fn get_circuit_description(&self) -> &GKRCircuitDescription<F> {
1158        &self.circuit_description
1159    }
1160
1161    /// Assign `data` to the Input Shred with label `shred_label`.
1162    ///
1163    /// # Panics
1164    /// If `shred_label` does not correspond to any Input Shred, or if the this Input Shred has
1165    /// already been assigned data. Use [Self::update_input] for replacing the data of a shred.
1166    pub fn set_input(&mut self, shred_label: &str, data: MultilinearExtension<F>) {
1167        let node_id = self.circuit_map.get_node_id(shred_label).unwrap();
1168
1169        if self.partial_inputs.contains_key(&node_id) {
1170            panic!("Input Shred with label '{shred_label}' has already been assigned data.");
1171        }
1172
1173        self.partial_inputs.insert(node_id, data);
1174    }
1175
1176    /// Assign `data` to the Input Shred with label `shred_label`, discarding any existing data
1177    /// associated with this Input Shred.
1178    ///
1179    /// # Panics
1180    /// If `shred_label` does not correspond to any Input Shred.
1181    pub fn update_input(&mut self, shred_label: &str, data: MultilinearExtension<F>) {
1182        let node_id = self.circuit_map.get_node_id(shred_label).unwrap();
1183        self.partial_inputs.insert(node_id, data);
1184    }
1185
1186    /// Returns whether the circuit contains an Input Layer labeled `label`.
1187    pub fn contains_layer(&self, label: &str) -> bool {
1188        self.circuit_map.get_layer_id_from_label(label).is_ok()
1189    }
1190
1191    fn input_shred_contains_data(&self, shred_id: NodeId) -> bool {
1192        self.partial_inputs.contains_key(&shred_id)
1193    }
1194
1195    /// Returns whether data has already been assigned to the Input Layer labeled `label`, or an
1196    /// error if no such input layer exists.
1197    pub fn input_layer_contains_data(&self, label: &str) -> Result<bool> {
1198        let layer_id = self.circuit_map.get_layer_id_from_label(label)?;
1199
1200        Ok(self
1201            .circuit_map
1202            .get_input_shreds_from_layer_id(layer_id)?
1203            .iter()
1204            .all(|shred_id| self.input_shred_contains_data(*shred_id)))
1205    }
1206
1207    /// Returns the Input Layer Description of the Input Layer with label `layer_label`.
1208    ///
1209    /// # Panics
1210    /// If no such layer exists, or if `self` is in an inconsistent state.
1211    pub fn get_input_layer_description_ref(&self, layer_label: &str) -> &InputLayerDescription {
1212        let layer_id = self
1213            .circuit_map
1214            .get_layer_id_from_label(layer_label)
1215            .unwrap();
1216
1217        let x: Vec<&InputLayerDescription> = self
1218            .circuit_description
1219            .input_layers
1220            .iter()
1221            .filter(|input_layer| input_layer.layer_id == layer_id)
1222            .collect();
1223
1224        assert_eq!(x.len(), 1);
1225
1226        x[0]
1227    }
1228
1229    /// Builds the layer MLE for `layer_id` by combining the data in all the input shreds of that layer.
1230    ///
1231    /// Returns error if `layer_id` is an invalid input layer ID, or if any shred data is missing.
1232    fn build_input_layer_data(&self, layer_id: LayerId) -> Result<MultilinearExtension<F>> {
1233        let input_shred_ids = self.circuit_map.get_input_shreds_from_layer_id(layer_id)?;
1234
1235        let mut shred_mles_and_prefix_bits = vec![];
1236        for input_shred_id in input_shred_ids {
1237            let mle = self.partial_inputs.get(&input_shred_id).ok_or(anyhow!(
1238                "Input shred {input_shred_id} does not contain any data!"
1239            ))?;
1240
1241            let (circuit_location, num_vars) =
1242                self.circuit_map.get_shred_location(input_shred_id).unwrap();
1243
1244            if num_vars != mle.num_vars() {
1245                return Err(anyhow!(GKRError::InputShredLengthMismatch(
1246                    input_shred_id.get_id(),
1247                    num_vars,
1248                    mle.num_vars(),
1249                )));
1250            }
1251            shred_mles_and_prefix_bits.push((mle, circuit_location.prefix_bits))
1252        }
1253
1254        Ok(build_composite_mle(&shred_mles_and_prefix_bits))
1255    }
1256
1257    fn build_public_input_layer_data(
1258        &self,
1259        verifier_optional_inputs: bool,
1260    ) -> Result<HashMap<LayerId, MultilinearExtension<F>>> {
1261        let mut public_inputs: HashMap<LayerId, MultilinearExtension<F>> = HashMap::new();
1262
1263        for input_layer_id in self.circuit_map.get_all_public_input_layer_ids() {
1264            // Attempt to build the input layer's MLE.
1265            let maybe_layer_mle = self.build_input_layer_data(input_layer_id);
1266            match maybe_layer_mle {
1267                Ok(layer_mle) => {
1268                    public_inputs.insert(input_layer_id, layer_mle);
1269                }
1270                Err(err) => {
1271                    // In the verifier case, we skip adding input data to a
1272                    // particular input layer if any of the inputs are missing.
1273                    if !verifier_optional_inputs {
1274                        return Result::Err(err);
1275                    }
1276                }
1277            }
1278        }
1279
1280        Ok(public_inputs)
1281    }
1282
1283    fn build_all_input_layer_data(&self) -> Result<HashMap<LayerId, MultilinearExtension<F>>> {
1284        // Ensure all Input Shreds have been assigned input data.
1285        /*
1286        if let Some(shred_id) = self
1287            .circuit_map
1288            .get_all_input_shred_ids()
1289            .into_iter()
1290            .find(|shred_id| !self.partial_inputs.contains_key(shred_id))
1291        {
1292            // Try to return a readable error message if possible.
1293            if let Ok(shred_label) = self.circuit_map.get_shred_label_from_id(shred_id) {
1294                bail!("Circuit Instantiation Failed: Input Shred '{shred_label}' has not been assigned any data. The label of this shred is not available.");
1295            } else {
1296                bail!("Circuit Instantiation Failed: Input Shred ID '{shred_id}' has not been assigned any data.");
1297            }
1298        }
1299        */
1300
1301        // Build Input Layer data.
1302        let mut inputs: HashMap<LayerId, MultilinearExtension<F>> = HashMap::new();
1303
1304        for input_layer_id in self.circuit_map.get_all_input_layer_ids() {
1305            let layer_mle = self.build_input_layer_data(input_layer_id)?;
1306            inputs.insert(input_layer_id, layer_mle);
1307        }
1308
1309        Ok(inputs)
1310    }
1311
1312    /// Helper function for grabbing all of the committed input layers + descriptions
1313    /// from the circuit map.
1314    ///
1315    /// We do this by first filtering all input layers which are
1316    /// [LayerVisibility::Committed], getting all input "shreds" which correspond to
1317    /// those input layers, and aggregating those to compute the number of variables
1318    /// required to represent each input layer.
1319    ///
1320    /// Finally, we set a default configuration for the Ligero PCS used to commit to
1321    /// each of the committed input layers' MLEs. TODO(tfHARD team): add support for
1322    /// custom settings for the PCS configurations.
1323    fn get_all_committed_input_layer_descriptions_to_ligero(
1324        &self,
1325    ) -> Vec<LigeroInputLayerDescription<F>> {
1326        self.circuit_map
1327            .get_all_committed_layers()
1328            .into_iter()
1329            .map(|layer_id| {
1330                let raw_needed_capacity = self
1331                    .circuit_map
1332                    .get_input_shreds_from_layer_id(layer_id)
1333                    .unwrap()
1334                    .into_iter()
1335                    .fold(0, |acc, shred_id| {
1336                        let (_, num_vars) = self.circuit_map.get_shred_location(shred_id).unwrap();
1337                        acc + (1_usize << num_vars)
1338                    });
1339                let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
1340                let total_num_vars = log2(padded_needed_capacity) as usize;
1341
1342                LigeroInputLayerDescription {
1343                    layer_id,
1344                    num_vars: total_num_vars,
1345                    aux: LigeroAuxInfo::<F>::new(1 << (total_num_vars), 4, 1.0, None),
1346                }
1347            })
1348            .collect()
1349    }
1350
1351    /// Returns a [HyraxVerifiableCircuit] containing the public input layer data that have been
1352    /// added to `self` so far.
1353    pub fn gen_hyrax_verifiable_circuit<C>(&self) -> Result<HyraxVerifiableCircuit<C>>
1354    where
1355        C: PrimeOrderCurve<Scalar = F>,
1356    {
1357        let public_inputs = self.build_public_input_layer_data(true)?;
1358
1359        debug!("Public inputs available: {:#?}", public_inputs.keys());
1360        debug!(
1361            "Layer Labels to Layer ID map: {:#?}",
1362            self.circuit_map.layer_label_to_layer_id
1363        );
1364
1365        let hyrax_private_inputs = self
1366            .circuit_map
1367            .get_all_committed_layers()
1368            .into_iter()
1369            .map(|layer_id| {
1370                let raw_needed_capacity = self
1371                    .circuit_map
1372                    .get_input_shreds_from_layer_id(layer_id)
1373                    .unwrap()
1374                    .into_iter()
1375                    .fold(0, |acc, shred_id| {
1376                        let (_, num_vars) = self.circuit_map.get_shred_location(shred_id).unwrap();
1377                        acc + (1_usize << num_vars)
1378                    });
1379                let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
1380                let total_num_vars = log2(padded_needed_capacity) as usize;
1381
1382                Ok((
1383                    layer_id,
1384                    (
1385                        HyraxInputLayerDescription::new(layer_id, total_num_vars),
1386                        None,
1387                    ),
1388                ))
1389            })
1390            .collect::<Result<HashMap<_, _>>>()?;
1391
1392        Ok(HyraxVerifiableCircuit::new(
1393            self.circuit_description.clone(),
1394            public_inputs,
1395            hyrax_private_inputs,
1396            self.circuit_map.layer_label_to_layer_id.clone(),
1397        ))
1398    }
1399
1400    /// Produces a provable form of this circuit for the Hyrax-GKR proving system which uses Hyrax
1401    /// as a commitment scheme for private input layers, and offers zero-knowledge guarantees.
1402    /// Requires all input data to be populated (use `Self::set_input()` on _all_ input shreds).
1403    ///
1404    /// # Returns
1405    /// The generated provable circuit, or an error if the [self] is missing input data.
1406    pub fn gen_hyrax_provable_circuit<C>(&self) -> Result<HyraxProvableCircuit<C>>
1407    where
1408        C: PrimeOrderCurve<Scalar = F>,
1409    {
1410        let inputs = self.build_all_input_layer_data()?;
1411
1412        let hyrax_private_inputs = self
1413            .circuit_map
1414            .get_all_committed_layers()
1415            .into_iter()
1416            .map(|layer_id| {
1417                let raw_needed_capacity = self
1418                    .circuit_map
1419                    .get_input_shreds_from_layer_id(layer_id)
1420                    .unwrap()
1421                    .into_iter()
1422                    .fold(0, |acc, shred_id| {
1423                        let (_, num_vars) = self.circuit_map.get_shred_location(shred_id).unwrap();
1424                        acc + (1_usize << num_vars)
1425                    });
1426                let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
1427                let total_num_vars = log2(padded_needed_capacity) as usize;
1428
1429                Ok((
1430                    layer_id,
1431                    (
1432                        HyraxInputLayerDescription::new(layer_id, total_num_vars),
1433                        None,
1434                    ),
1435                ))
1436            })
1437            .collect::<Result<HashMap<_, _>>>()?;
1438
1439        Ok(HyraxProvableCircuit::new(
1440            self.circuit_description.clone(),
1441            inputs,
1442            hyrax_private_inputs,
1443            self.circuit_map.layer_label_to_layer_id.clone(),
1444        ))
1445    }
1446}
1447
1448impl<F: Halo2FFTFriendlyField> Circuit<F> {
1449    /// Produces a provable form of this circuit for the vanilla GKR proving system which uses
1450    /// Ligero as a commitment scheme for committed input layers, and does _not_ offer any
1451    /// zero-knowledge guarantees.
1452    /// Requires all input data to be populated (use `Self::set_input()` on _all_ input shreds).
1453    ///
1454    /// # Returns
1455    /// The generated provable circuit, or an error if the [self] is missing input data.
1456    pub fn gen_provable_circuit(&self) -> Result<ProvableCircuit<F>> {
1457        let inputs = self.build_all_input_layer_data()?;
1458
1459        let ligero_committed_inputs = self
1460            .get_all_committed_input_layer_descriptions_to_ligero()
1461            .into_iter()
1462            .map(|ligero_input_layer_description| {
1463                (
1464                    ligero_input_layer_description.layer_id,
1465                    (ligero_input_layer_description, None),
1466                )
1467            })
1468            .collect();
1469
1470        Ok(ProvableCircuit::new(
1471            self.circuit_description.clone(),
1472            inputs,
1473            ligero_committed_inputs,
1474            self.circuit_map.layer_label_to_layer_id.clone(),
1475        ))
1476    }
1477
1478    /// Returns a [VerifiableCircuit] initialized with all input data which is already
1479    /// known to the verifier, but no commitments to the data in the committed input layers
1480    /// yet.
1481    #[allow(clippy::type_complexity)]
1482    pub fn gen_verifiable_circuit(&self) -> Result<VerifiableCircuit<F>> {
1483        // Input data which is known to the verifier ahead of time -- note that
1484        // this data was manually appended using the `circuit.set_input()`
1485        // function.
1486        let verifier_predetermined_public_inputs = self.build_public_input_layer_data(true)?;
1487
1488        // Sets default Ligero parameters for each of the committed input layers.
1489        let ligero_committed_inputs = self
1490            .get_all_committed_input_layer_descriptions_to_ligero()
1491            .into_iter()
1492            .map(|ligero_input_layer_description| {
1493                (
1494                    ligero_input_layer_description.layer_id,
1495                    (ligero_input_layer_description, None),
1496                )
1497            })
1498            .collect();
1499
1500        Ok(VerifiableCircuit::new(
1501            self.circuit_description.clone(),
1502            verifier_predetermined_public_inputs,
1503            ligero_committed_inputs,
1504            self.circuit_map.layer_label_to_layer_id.clone(),
1505        ))
1506    }
1507}
1508
1509/// implement the Add, Sub, and Mul traits for NodeRef and FSNodeRef
1510macro_rules! impl_add {
1511    ($Lhs:ty) => {
1512        impl<F: Field, Rhs: Into<AbstractExpression<F>>> Add<Rhs> for $Lhs {
1513            type Output = AbstractExpression<F>;
1514
1515            fn add(self, rhs: Rhs) -> Self::Output {
1516                self.expr() + rhs.into()
1517            }
1518        }
1519    };
1520}
1521impl_add!(NodeRef<F>);
1522impl_add!(&NodeRef<F>);
1523impl_add!(InputLayerNodeRef<F>);
1524impl_add!(&InputLayerNodeRef<F>);
1525impl_add!(FSNodeRef<F>);
1526impl_add!(&FSNodeRef<F>);
1527
1528macro_rules! impl_sub {
1529    ($Lhs:ty) => {
1530        impl<F: Field, Rhs: Into<AbstractExpression<F>>> Sub<Rhs> for $Lhs {
1531            type Output = AbstractExpression<F>;
1532
1533            fn sub(self, rhs: Rhs) -> Self::Output {
1534                self.expr() - rhs.into()
1535            }
1536        }
1537    };
1538}
1539impl_sub!(NodeRef<F>);
1540impl_sub!(&NodeRef<F>);
1541impl_sub!(InputLayerNodeRef<F>);
1542impl_sub!(&InputLayerNodeRef<F>);
1543impl_sub!(FSNodeRef<F>);
1544impl_sub!(&FSNodeRef<F>);
1545
1546macro_rules! impl_mul {
1547    ($Lhs:ty) => {
1548        impl<F: Field, Rhs: Into<AbstractExpression<F>>> Mul<Rhs> for $Lhs {
1549            type Output = AbstractExpression<F>;
1550
1551            fn mul(self, rhs: Rhs) -> Self::Output {
1552                self.expr() * rhs.into()
1553            }
1554        }
1555    };
1556}
1557impl_mul!(NodeRef<F>);
1558impl_mul!(&NodeRef<F>);
1559impl_mul!(InputLayerNodeRef<F>);
1560impl_mul!(&InputLayerNodeRef<F>);
1561impl_mul!(FSNodeRef<F>);
1562impl_mul!(&FSNodeRef<F>);
1563
1564macro_rules! impl_xor {
1565    ($Lhs:ty) => {
1566        impl<F: Field, Rhs: Into<AbstractExpression<F>>> BitXor<Rhs> for $Lhs {
1567            type Output = AbstractExpression<F>;
1568
1569            fn bitxor(self, rhs: Rhs) -> Self::Output {
1570                self.expr() ^ rhs.into()
1571            }
1572        }
1573    };
1574}
1575impl_xor!(NodeRef<F>);
1576impl_xor!(&NodeRef<F>);
1577impl_xor!(InputLayerNodeRef<F>);
1578impl_xor!(&InputLayerNodeRef<F>);
1579impl_xor!(FSNodeRef<F>);
1580impl_xor!(&FSNodeRef<F>);