frontend/layouter/nodes/
gate.rs

1//! A Module for adding `Gate` Layers to components
2
3use ark_std::log2;
4use itertools::{repeat_n, Itertools};
5use shared_types::Field;
6
7use remainder::{
8    circuit_layout::CircuitLocation,
9    layer::{
10        gate::{BinaryOperation, GateLayerDescription},
11        layer_enum::LayerDescriptionEnum,
12        LayerId,
13    },
14    mle::{mle_description::MleDescription, MleIndex},
15};
16
17use crate::layouter::builder::CircuitMap;
18
19use super::{CircuitNode, CompilableNode, NodeId};
20
21use anyhow::Result;
22
23/// A Node that represents a `Gate` layer
24#[derive(Clone, Debug)]
25pub struct GateNode {
26    id: NodeId,
27    num_dataparallel_bits: Option<usize>,
28    nonzero_gates: Vec<(u32, u32, u32)>,
29    lhs: NodeId,
30    rhs: NodeId,
31    gate_operation: BinaryOperation,
32    num_vars: usize,
33}
34
35impl CircuitNode for GateNode {
36    fn id(&self) -> NodeId {
37        self.id
38    }
39
40    fn sources(&self) -> Vec<NodeId> {
41        vec![self.lhs, self.rhs]
42    }
43
44    fn get_num_vars(&self) -> usize {
45        self.num_vars
46    }
47}
48
49impl GateNode {
50    /// Constructs a new GateNode and computes the data it generates
51    pub fn new(
52        lhs: &dyn CircuitNode,
53        rhs: &dyn CircuitNode,
54        nonzero_gates: Vec<(u32, u32, u32)>,
55        gate_operation: BinaryOperation,
56        num_dataparallel_bits: Option<usize>,
57    ) -> Self {
58        let max_gate_val = nonzero_gates
59            .clone()
60            .into_iter()
61            .fold(0, |acc, (z, _, _)| std::cmp::max(acc, z));
62
63        // number of entries in the resulting table is the max gate z value * 2 to the power of the number of dataparallel bits, as we are
64        // evaluating over all values in the boolean hypercube which includes dataparallel bits
65        let num_dataparallel_vals = 1 << (num_dataparallel_bits.unwrap_or(0));
66        let res_table_num_entries = (max_gate_val + 1) * num_dataparallel_vals;
67
68        let num_vars = log2(res_table_num_entries as usize) as usize;
69
70        Self {
71            id: NodeId::new(),
72            num_dataparallel_bits,
73            nonzero_gates,
74            gate_operation,
75            lhs: lhs.id(),
76            rhs: rhs.id(),
77            num_vars,
78        }
79    }
80}
81
82impl<F: Field> CompilableNode<F> for GateNode {
83    fn generate_circuit_description(
84        &self,
85        circuit_map: &mut CircuitMap,
86    ) -> Result<Vec<LayerDescriptionEnum<F>>> {
87        let (lhs_location, lhs_num_vars) =
88            circuit_map.get_location_num_vars_from_node_id(&self.lhs)?;
89        let total_indices = lhs_location
90            .prefix_bits
91            .iter()
92            .map(|bit| MleIndex::Fixed(*bit))
93            .chain(repeat_n(MleIndex::Free, *lhs_num_vars))
94            .collect_vec();
95        let lhs_circuit_mle = MleDescription::new(lhs_location.layer_id, &total_indices);
96
97        let (rhs_location, rhs_num_vars) =
98            circuit_map.get_location_num_vars_from_node_id(&self.rhs)?;
99        let total_indices = rhs_location
100            .prefix_bits
101            .iter()
102            .map(|bit| MleIndex::Fixed(*bit))
103            .chain(repeat_n(MleIndex::Free, *rhs_num_vars))
104            .collect_vec();
105        let rhs_circuit_mle = MleDescription::new(rhs_location.layer_id, &total_indices);
106
107        let gate_layer_id = LayerId::next_layer_id();
108        let gate_circuit_description = GateLayerDescription::new(
109            self.num_dataparallel_bits,
110            self.nonzero_gates.clone(),
111            lhs_circuit_mle,
112            rhs_circuit_mle,
113            gate_layer_id,
114            self.gate_operation,
115        );
116        circuit_map.add_node_id_and_location_num_vars(
117            self.id,
118            (
119                CircuitLocation::new(gate_layer_id, vec![]),
120                self.get_num_vars(),
121            ),
122        );
123
124        Ok(vec![LayerDescriptionEnum::Gate(gate_circuit_description)])
125    }
126}
127
128#[cfg(test)]
129mod test {
130
131    use ark_std::{rand::Rng, test_rng};
132    use itertools::Itertools;
133    use shared_types::Fr;
134
135    use crate::layouter::builder::{CircuitBuilder, LayerVisibility};
136    use remainder::{
137        layer::gate::BinaryOperation, mle::evals::MultilinearExtension,
138        prover::helpers::test_circuit_with_runtime_optimized_config,
139    };
140
141    #[test]
142    fn test_gate_node_in_circuit() {
143        let mut builder = CircuitBuilder::<Fr>::new();
144
145        const NUM_FREE_VARS: usize = 4;
146
147        let mut rng = test_rng();
148        let size = 1 << NUM_FREE_VARS;
149
150        let mle =
151            MultilinearExtension::new((0..size).map(|_| Fr::from(rng.gen::<u64>())).collect());
152
153        let neg_mle = MultilinearExtension::new(mle.iter().map(|elem| -elem).collect_vec());
154
155        let mut nonzero_gates = vec![];
156
157        (0..size).for_each(|idx| {
158            nonzero_gates.push((idx, idx, idx));
159        });
160
161        let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
162
163        let input_shred_pos =
164            builder.add_input_shred("Positive Input", NUM_FREE_VARS, &input_layer);
165
166        let input_shred_neg =
167            builder.add_input_shred("Negative Input", NUM_FREE_VARS, &input_layer);
168
169        let gate_sector = builder.add_gate_node(
170            &input_shred_pos,
171            &input_shred_neg,
172            nonzero_gates,
173            BinaryOperation::Add,
174            None,
175        );
176
177        builder.set_output(&gate_sector);
178
179        let mut circuit = builder.build_with_layer_combination().unwrap();
180
181        circuit.set_input("Positive Input", mle);
182        circuit.set_input("Negative Input", neg_mle);
183
184        let provable_circuit = circuit.gen_provable_circuit().unwrap();
185
186        test_circuit_with_runtime_optimized_config(&provable_circuit);
187    }
188
189    #[test]
190    fn test_data_parallel_gate_node_in_circuit() {
191        let mut builder = CircuitBuilder::<Fr>::new();
192
193        const NUM_DATAPARALLEL_VARS: usize = 3;
194        const NUM_FREE_VARS: usize = 4;
195
196        let mut rng = test_rng();
197        let size = 1 << (NUM_DATAPARALLEL_VARS + NUM_FREE_VARS);
198
199        let mle =
200            MultilinearExtension::new((0..size).map(|_| Fr::from(rng.gen::<u64>())).collect());
201
202        let neg_mle = MultilinearExtension::new(mle.iter().map(|elem| -elem).collect_vec());
203
204        let mut nonzero_gates = vec![];
205        let table_size = 1 << NUM_FREE_VARS;
206
207        (0..table_size).for_each(|idx| {
208            nonzero_gates.push((idx, idx, idx));
209        });
210
211        let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
212
213        let input_shred_pos = builder.add_input_shred(
214            "Positive Input",
215            NUM_DATAPARALLEL_VARS + NUM_FREE_VARS,
216            &input_layer,
217        );
218
219        let input_shred_neg = builder.add_input_shred(
220            "Negative Input",
221            NUM_DATAPARALLEL_VARS + NUM_FREE_VARS,
222            &input_layer,
223        );
224
225        let gate_sector = builder.add_gate_node(
226            &input_shred_pos,
227            &input_shred_neg,
228            nonzero_gates,
229            BinaryOperation::Add,
230            Some(NUM_DATAPARALLEL_VARS),
231        );
232
233        let _output = builder.set_output(&gate_sector);
234
235        let mut circuit = builder.build_with_layer_combination().unwrap();
236
237        circuit.set_input("Positive Input", mle);
238        circuit.set_input("Negative Input", neg_mle);
239
240        let provable_circuit = circuit.gen_provable_circuit().unwrap();
241
242        test_circuit_with_runtime_optimized_config(&provable_circuit);
243    }
244}