frontend/layouter/nodes/
identity_gate.rs1use shared_types::Field;
5
6use crate::layouter::builder::CircuitMap;
7use remainder::{
8 circuit_layout::CircuitLocation,
9 layer::{
10 identity_gate::IdentityGateLayerDescription, layer_enum::LayerDescriptionEnum, LayerId,
11 },
12 mle::mle_description::MleDescription,
13 utils::mle::get_total_mle_indices,
14};
15
16use super::{CircuitNode, CompilableNode, NodeId};
17
18use anyhow::Result;
19
20#[derive(Clone, Debug)]
23pub struct IdentityGateNode {
24 id: NodeId,
25 num_vars: usize,
26 num_dataparallel_vars: Option<usize>,
27 nonzero_gates: Vec<(u32, u32)>,
28 pre_routed_data: NodeId,
29}
30
31impl CircuitNode for IdentityGateNode {
32 fn id(&self) -> NodeId {
33 self.id
34 }
35
36 fn sources(&self) -> Vec<NodeId> {
37 vec![self.pre_routed_data]
38 }
39
40 fn get_num_vars(&self) -> usize {
41 self.num_vars
42 }
43}
44
45impl IdentityGateNode {
46 pub fn new(
53 pre_routed_data: &dyn CircuitNode,
54 nonzero_gates: Vec<(u32, u32)>,
55 num_vars: usize,
56 num_dataparallel_vars: Option<usize>,
57 ) -> Self {
58 let gate_idx_bound = 1 << (num_vars - num_dataparallel_vars.unwrap_or(0));
59 nonzero_gates.iter().for_each(|(dest_idx, _)| {
60 assert!(
61 *dest_idx < gate_idx_bound,
62 "Gate index {dest_idx} too large for layer with {num_vars} variables",
63 )
64 });
65 Self {
66 id: NodeId::new(),
67 num_vars,
68 num_dataparallel_vars,
69 nonzero_gates,
70 pre_routed_data: pre_routed_data.id(),
71 }
72 }
73}
74
75impl<F: Field> CompilableNode<F> for IdentityGateNode {
76 fn generate_circuit_description(
77 &self,
78 circuit_map: &mut CircuitMap,
79 ) -> Result<Vec<LayerDescriptionEnum<F>>> {
80 let (pre_routed_data_location, pre_routed_num_vars) =
81 circuit_map.get_location_num_vars_from_node_id(&self.pre_routed_data)?;
82 let total_mle_indices =
83 get_total_mle_indices(&pre_routed_data_location.prefix_bits, *pre_routed_num_vars);
84 let pre_routed_mle =
85 MleDescription::new(pre_routed_data_location.layer_id, &total_mle_indices);
86
87 let id_gate_layer_id = LayerId::next_layer_id();
88 let id_gate_layer = IdentityGateLayerDescription::new(
89 id_gate_layer_id,
90 self.nonzero_gates.clone(),
91 pre_routed_mle,
92 self.num_vars,
93 self.num_dataparallel_vars,
94 );
95 circuit_map.add_node_id_and_location_num_vars(
96 self.id,
97 (
98 CircuitLocation::new(id_gate_layer_id, vec![]),
99 self.get_num_vars(),
100 ),
101 );
102
103 Ok(vec![LayerDescriptionEnum::IdentityGate(id_gate_layer)])
104 }
105}
106
107#[cfg(test)]
108mod test {
109
110 use ark_std::{rand::Rng, test_rng};
111 use shared_types::{Field, Fr};
112
113 use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
114
115 use remainder::{
116 mle::evals::MultilinearExtension,
117 prover::helpers::test_circuit_with_runtime_optimized_config,
118 };
119
120 fn build_identity_gate_test_circuit_description<F: Field>(
123 mle_and_shifted_mle_num_vars: usize,
124 ) -> Circuit<F> {
125 let mut builder = CircuitBuilder::<F>::new();
126
127 let mut nonzero_gates = vec![];
129 (1..(1 << mle_and_shifted_mle_num_vars)).for_each(|idx| {
130 nonzero_gates.push((idx, idx - 1));
131 });
132
133 let public_input_layer_node =
135 builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
136
137 let mle_shred = builder.add_input_shred(
139 "Input MLE",
140 mle_and_shifted_mle_num_vars,
141 &public_input_layer_node,
142 );
143 let shifted_mle_shred = builder.add_input_shred(
144 "Shifter Input MLE",
145 mle_and_shifted_mle_num_vars,
146 &public_input_layer_node,
147 );
148
149 let gate_sector = builder.add_identity_gate_node(
151 &mle_shred,
152 nonzero_gates,
153 mle_and_shifted_mle_num_vars,
154 None,
155 );
156 let diff_sector = builder.add_sector(gate_sector - shifted_mle_shred);
157
158 builder.set_output(&diff_sector);
159
160 builder.build_with_layer_combination().unwrap()
161 }
162
163 #[test]
164 fn test_identity_gate_node_in_circuit() {
165 const NUM_FREE_BITS: usize = 1;
166 let size = 1 << NUM_FREE_BITS;
167
168 let mut rng = test_rng();
169
170 let mle_vec: Vec<Fr> = (0..size).map(|_| Fr::from(rng.gen::<u64>())).collect();
172 let mle = MultilinearExtension::new(mle_vec.clone());
173 let shifted_mle_vec = std::iter::once(Fr::zero())
174 .chain(mle_vec.into_iter().take(size - 1))
175 .collect();
176 let shifted_mle = MultilinearExtension::new(shifted_mle_vec);
177
178 let mut circuit = build_identity_gate_test_circuit_description(NUM_FREE_BITS);
180
181 circuit.set_input("Input MLE", mle);
182 circuit.set_input("Shifter Input MLE", shifted_mle);
183
184 let provable_circuit = circuit.gen_provable_circuit().unwrap();
185
186 test_circuit_with_runtime_optimized_config(&provable_circuit);
188 }
189}