1use ark_std::log2;
4use itertools::{repeat_n, Itertools};
5use shared_types::Field;
6
7use remainder::{
8 circuit_layout::CircuitLocation,
9 layer::{
10 gate::{BinaryOperation, GateLayerDescription},
11 layer_enum::LayerDescriptionEnum,
12 LayerId,
13 },
14 mle::{mle_description::MleDescription, MleIndex},
15};
16
17use crate::layouter::builder::CircuitMap;
18
19use super::{CircuitNode, CompilableNode, NodeId};
20
21use anyhow::Result;
22
23#[derive(Clone, Debug)]
25pub struct GateNode {
26 id: NodeId,
27 num_dataparallel_bits: Option<usize>,
28 nonzero_gates: Vec<(u32, u32, u32)>,
29 lhs: NodeId,
30 rhs: NodeId,
31 gate_operation: BinaryOperation,
32 num_vars: usize,
33}
34
35impl CircuitNode for GateNode {
36 fn id(&self) -> NodeId {
37 self.id
38 }
39
40 fn sources(&self) -> Vec<NodeId> {
41 vec![self.lhs, self.rhs]
42 }
43
44 fn get_num_vars(&self) -> usize {
45 self.num_vars
46 }
47}
48
49impl GateNode {
50 pub fn new(
52 lhs: &dyn CircuitNode,
53 rhs: &dyn CircuitNode,
54 nonzero_gates: Vec<(u32, u32, u32)>,
55 gate_operation: BinaryOperation,
56 num_dataparallel_bits: Option<usize>,
57 ) -> Self {
58 let max_gate_val = nonzero_gates
59 .clone()
60 .into_iter()
61 .fold(0, |acc, (z, _, _)| std::cmp::max(acc, z));
62
63 let num_dataparallel_vals = 1 << (num_dataparallel_bits.unwrap_or(0));
66 let res_table_num_entries = (max_gate_val + 1) * num_dataparallel_vals;
67
68 let num_vars = log2(res_table_num_entries as usize) as usize;
69
70 Self {
71 id: NodeId::new(),
72 num_dataparallel_bits,
73 nonzero_gates,
74 gate_operation,
75 lhs: lhs.id(),
76 rhs: rhs.id(),
77 num_vars,
78 }
79 }
80}
81
82impl<F: Field> CompilableNode<F> for GateNode {
83 fn generate_circuit_description(
84 &self,
85 circuit_map: &mut CircuitMap,
86 ) -> Result<Vec<LayerDescriptionEnum<F>>> {
87 let (lhs_location, lhs_num_vars) =
88 circuit_map.get_location_num_vars_from_node_id(&self.lhs)?;
89 let total_indices = lhs_location
90 .prefix_bits
91 .iter()
92 .map(|bit| MleIndex::Fixed(*bit))
93 .chain(repeat_n(MleIndex::Free, *lhs_num_vars))
94 .collect_vec();
95 let lhs_circuit_mle = MleDescription::new(lhs_location.layer_id, &total_indices);
96
97 let (rhs_location, rhs_num_vars) =
98 circuit_map.get_location_num_vars_from_node_id(&self.rhs)?;
99 let total_indices = rhs_location
100 .prefix_bits
101 .iter()
102 .map(|bit| MleIndex::Fixed(*bit))
103 .chain(repeat_n(MleIndex::Free, *rhs_num_vars))
104 .collect_vec();
105 let rhs_circuit_mle = MleDescription::new(rhs_location.layer_id, &total_indices);
106
107 let gate_layer_id = LayerId::next_layer_id();
108 let gate_circuit_description = GateLayerDescription::new(
109 self.num_dataparallel_bits,
110 self.nonzero_gates.clone(),
111 lhs_circuit_mle,
112 rhs_circuit_mle,
113 gate_layer_id,
114 self.gate_operation,
115 );
116 circuit_map.add_node_id_and_location_num_vars(
117 self.id,
118 (
119 CircuitLocation::new(gate_layer_id, vec![]),
120 self.get_num_vars(),
121 ),
122 );
123
124 Ok(vec![LayerDescriptionEnum::Gate(gate_circuit_description)])
125 }
126}
127
128#[cfg(test)]
129mod test {
130
131 use ark_std::{rand::Rng, test_rng};
132 use itertools::Itertools;
133 use shared_types::Fr;
134
135 use crate::layouter::builder::{CircuitBuilder, LayerVisibility};
136 use remainder::{
137 layer::gate::BinaryOperation, mle::evals::MultilinearExtension,
138 prover::helpers::test_circuit_with_runtime_optimized_config,
139 };
140
141 #[test]
142 fn test_gate_node_in_circuit() {
143 let mut builder = CircuitBuilder::<Fr>::new();
144
145 const NUM_FREE_VARS: usize = 4;
146
147 let mut rng = test_rng();
148 let size = 1 << NUM_FREE_VARS;
149
150 let mle =
151 MultilinearExtension::new((0..size).map(|_| Fr::from(rng.gen::<u64>())).collect());
152
153 let neg_mle = MultilinearExtension::new(mle.iter().map(|elem| -elem).collect_vec());
154
155 let mut nonzero_gates = vec![];
156
157 (0..size).for_each(|idx| {
158 nonzero_gates.push((idx, idx, idx));
159 });
160
161 let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
162
163 let input_shred_pos =
164 builder.add_input_shred("Positive Input", NUM_FREE_VARS, &input_layer);
165
166 let input_shred_neg =
167 builder.add_input_shred("Negative Input", NUM_FREE_VARS, &input_layer);
168
169 let gate_sector = builder.add_gate_node(
170 &input_shred_pos,
171 &input_shred_neg,
172 nonzero_gates,
173 BinaryOperation::Add,
174 None,
175 );
176
177 builder.set_output(&gate_sector);
178
179 let mut circuit = builder.build_with_layer_combination().unwrap();
180
181 circuit.set_input("Positive Input", mle);
182 circuit.set_input("Negative Input", neg_mle);
183
184 let provable_circuit = circuit.gen_provable_circuit().unwrap();
185
186 test_circuit_with_runtime_optimized_config(&provable_circuit);
187 }
188
189 #[test]
190 fn test_data_parallel_gate_node_in_circuit() {
191 let mut builder = CircuitBuilder::<Fr>::new();
192
193 const NUM_DATAPARALLEL_VARS: usize = 3;
194 const NUM_FREE_VARS: usize = 4;
195
196 let mut rng = test_rng();
197 let size = 1 << (NUM_DATAPARALLEL_VARS + NUM_FREE_VARS);
198
199 let mle =
200 MultilinearExtension::new((0..size).map(|_| Fr::from(rng.gen::<u64>())).collect());
201
202 let neg_mle = MultilinearExtension::new(mle.iter().map(|elem| -elem).collect_vec());
203
204 let mut nonzero_gates = vec![];
205 let table_size = 1 << NUM_FREE_VARS;
206
207 (0..table_size).for_each(|idx| {
208 nonzero_gates.push((idx, idx, idx));
209 });
210
211 let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
212
213 let input_shred_pos = builder.add_input_shred(
214 "Positive Input",
215 NUM_DATAPARALLEL_VARS + NUM_FREE_VARS,
216 &input_layer,
217 );
218
219 let input_shred_neg = builder.add_input_shred(
220 "Negative Input",
221 NUM_DATAPARALLEL_VARS + NUM_FREE_VARS,
222 &input_layer,
223 );
224
225 let gate_sector = builder.add_gate_node(
226 &input_shred_pos,
227 &input_shred_neg,
228 nonzero_gates,
229 BinaryOperation::Add,
230 Some(NUM_DATAPARALLEL_VARS),
231 );
232
233 let _output = builder.set_output(&gate_sector);
234
235 let mut circuit = builder.build_with_layer_combination().unwrap();
236
237 circuit.set_input("Positive Input", mle);
238 circuit.set_input("Negative Input", neg_mle);
239
240 let provable_circuit = circuit.gen_provable_circuit().unwrap();
241
242 test_circuit_with_runtime_optimized_config(&provable_circuit);
243 }
244}