frontend/layouter/nodes/
identity_gate.rs

1//! A module providing identity gate functionality. For rerouting values from one layer to another
2//! in an arbitrary fashion.
3
4use shared_types::Field;
5
6use crate::layouter::builder::CircuitMap;
7use remainder::{
8    circuit_layout::CircuitLocation,
9    layer::{
10        identity_gate::IdentityGateLayerDescription, layer_enum::LayerDescriptionEnum, LayerId,
11    },
12    mle::mle_description::MleDescription,
13    utils::mle::get_total_mle_indices,
14};
15
16use super::{CircuitNode, CompilableNode, NodeId};
17
18use anyhow::Result;
19
20/// A node that represents an identity gate in the circuit i.e. that wires values unmodified from
21/// one layer to another.
22#[derive(Clone, Debug)]
23pub struct IdentityGateNode {
24    id: NodeId,
25    num_vars: usize,
26    num_dataparallel_vars: Option<usize>,
27    nonzero_gates: Vec<(u32, u32)>,
28    pre_routed_data: NodeId,
29}
30
31impl CircuitNode for IdentityGateNode {
32    fn id(&self) -> NodeId {
33        self.id
34    }
35
36    fn sources(&self) -> Vec<NodeId> {
37        vec![self.pre_routed_data]
38    }
39
40    fn get_num_vars(&self) -> usize {
41        self.num_vars
42    }
43}
44
45impl IdentityGateNode {
46    /// Constructs a new IdentityGateNode.
47    /// Arguments:
48    /// * `pre_routed_data`: The Node that is being routed to this layer.
49    /// * `nonzero_gates`: A list of tuples representing the gates that are nonzero, in the form `(dest_idx, src_idx)`.
50    /// * `num_vars`: The total number of variables in the layer.
51    /// * `num_dataparallel_vars`: The number of dataparallel variables to use in this layer.
52    pub fn new(
53        pre_routed_data: &dyn CircuitNode,
54        nonzero_gates: Vec<(u32, u32)>,
55        num_vars: usize,
56        num_dataparallel_vars: Option<usize>,
57    ) -> Self {
58        let gate_idx_bound = 1 << (num_vars - num_dataparallel_vars.unwrap_or(0));
59        nonzero_gates.iter().for_each(|(dest_idx, _)| {
60            assert!(
61                *dest_idx < gate_idx_bound,
62                "Gate index {dest_idx} too large for layer with {num_vars} variables",
63            )
64        });
65        Self {
66            id: NodeId::new(),
67            num_vars,
68            num_dataparallel_vars,
69            nonzero_gates,
70            pre_routed_data: pre_routed_data.id(),
71        }
72    }
73}
74
75impl<F: Field> CompilableNode<F> for IdentityGateNode {
76    fn generate_circuit_description(
77        &self,
78        circuit_map: &mut CircuitMap,
79    ) -> Result<Vec<LayerDescriptionEnum<F>>> {
80        let (pre_routed_data_location, pre_routed_num_vars) =
81            circuit_map.get_location_num_vars_from_node_id(&self.pre_routed_data)?;
82        let total_mle_indices =
83            get_total_mle_indices(&pre_routed_data_location.prefix_bits, *pre_routed_num_vars);
84        let pre_routed_mle =
85            MleDescription::new(pre_routed_data_location.layer_id, &total_mle_indices);
86
87        let id_gate_layer_id = LayerId::next_layer_id();
88        let id_gate_layer = IdentityGateLayerDescription::new(
89            id_gate_layer_id,
90            self.nonzero_gates.clone(),
91            pre_routed_mle,
92            self.num_vars,
93            self.num_dataparallel_vars,
94        );
95        circuit_map.add_node_id_and_location_num_vars(
96            self.id,
97            (
98                CircuitLocation::new(id_gate_layer_id, vec![]),
99                self.get_num_vars(),
100            ),
101        );
102
103        Ok(vec![LayerDescriptionEnum::IdentityGate(id_gate_layer)])
104    }
105}
106
107#[cfg(test)]
108mod test {
109
110    use ark_std::{rand::Rng, test_rng};
111    use shared_types::{Field, Fr};
112
113    use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
114
115    use remainder::{
116        mle::evals::MultilinearExtension,
117        prover::helpers::test_circuit_with_runtime_optimized_config,
118    };
119
120    /// Creates the [GKRCircuitDescription] and an associated helper input
121    /// function allowing for ease of proving for the identity gate circuit.
122    fn build_identity_gate_test_circuit_description<F: Field>(
123        mle_and_shifted_mle_num_vars: usize,
124    ) -> Circuit<F> {
125        let mut builder = CircuitBuilder::<F>::new();
126
127        // Nonzero gates
128        let mut nonzero_gates = vec![];
129        (1..(1 << mle_and_shifted_mle_num_vars)).for_each(|idx| {
130            nonzero_gates.push((idx, idx - 1));
131        });
132
133        // All inputs are public inputs
134        let public_input_layer_node =
135            builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
136
137        // Inputs to the circuit include the "primary MLE" and the "shifted MLE"
138        let mle_shred = builder.add_input_shred(
139            "Input MLE",
140            mle_and_shifted_mle_num_vars,
141            &public_input_layer_node,
142        );
143        let shifted_mle_shred = builder.add_input_shred(
144            "Shifter Input MLE",
145            mle_and_shifted_mle_num_vars,
146            &public_input_layer_node,
147        );
148
149        // Create the circuit components
150        let gate_sector = builder.add_identity_gate_node(
151            &mle_shred,
152            nonzero_gates,
153            mle_and_shifted_mle_num_vars,
154            None,
155        );
156        let diff_sector = builder.add_sector(gate_sector - shifted_mle_shred);
157
158        builder.set_output(&diff_sector);
159
160        builder.build_with_layer_combination().unwrap()
161    }
162
163    #[test]
164    fn test_identity_gate_node_in_circuit() {
165        const NUM_FREE_BITS: usize = 1;
166        let size = 1 << NUM_FREE_BITS;
167
168        let mut rng = test_rng();
169
170        // Define all the input (data) to the circuit
171        let mle_vec: Vec<Fr> = (0..size).map(|_| Fr::from(rng.gen::<u64>())).collect();
172        let mle = MultilinearExtension::new(mle_vec.clone());
173        let shifted_mle_vec = std::iter::once(Fr::zero())
174            .chain(mle_vec.into_iter().take(size - 1))
175            .collect();
176        let shifted_mle = MultilinearExtension::new(shifted_mle_vec);
177
178        // Create circuit description + input helper function
179        let mut circuit = build_identity_gate_test_circuit_description(NUM_FREE_BITS);
180
181        circuit.set_input("Input MLE", mle);
182        circuit.set_input("Shifter Input MLE", shifted_mle);
183
184        let provable_circuit = circuit.gen_provable_circuit().unwrap();
185
186        // Prove/verify the circuit
187        test_circuit_with_runtime_optimized_config(&provable_circuit);
188    }
189}