frontend/layouter/nodes/
split_node.rs1use itertools::{repeat_n, Itertools};
4use shared_types::Field;
5
6use crate::layouter::builder::CircuitMap;
7use remainder::{circuit_layout::CircuitLocation, layer::layer_enum::LayerDescriptionEnum};
8
9use super::{CircuitNode, CompilableNode, NodeId};
10
11use anyhow::Result;
12#[derive(Clone, Debug)]
15pub struct SplitNode {
16 id: NodeId,
17 num_vars: usize,
18 source: NodeId,
19 prefix_bits: Vec<bool>,
20}
21
22impl SplitNode {
23 pub fn new(node: &dyn CircuitNode, num_vars: usize) -> Vec<Self> {
27 let num_vars_node = node.get_num_vars();
28 let source = node.id();
29 let max_num_vars = num_vars_node - num_vars;
30 bits_iter(num_vars)
31 .map(|prefix_bits| {
32 let prefix_bits = prefix_bits.into_iter().collect();
33 Self {
34 id: NodeId::new(),
35 source,
36 num_vars: max_num_vars,
37 prefix_bits,
38 }
39 })
40 .collect()
41 }
42}
43
44impl CircuitNode for SplitNode {
45 fn id(&self) -> NodeId {
46 self.id
47 }
48
49 fn sources(&self) -> Vec<NodeId> {
50 vec![self.source]
51 }
52
53 fn get_num_vars(&self) -> usize {
54 self.num_vars
55 }
56}
57
58impl<F: Field> CompilableNode<F> for SplitNode {
59 fn generate_circuit_description(
60 &self,
61 circuit_map: &mut CircuitMap,
62 ) -> Result<Vec<LayerDescriptionEnum<F>>> {
63 let (source_location, _) = circuit_map.get_location_num_vars_from_node_id(&self.source)?;
64
65 let prefix_bits = source_location
66 .prefix_bits
67 .iter()
68 .chain(self.prefix_bits.iter())
69 .copied()
70 .collect();
71
72 let location = CircuitLocation::new(source_location.layer_id, prefix_bits);
73
74 circuit_map.add_node_id_and_location_num_vars(self.id, (location, self.get_num_vars()));
75 Ok(vec![])
76 }
77}
78
79pub fn bits_iter(num_bits: usize) -> impl Iterator<Item = Vec<bool>> {
95 std::iter::successors(Some(vec![false; num_bits]), move |prev| {
96 let mut prev = prev.clone();
97 let mut removed_bits = 0;
98 for index in (0..num_bits).rev() {
99 let curr = prev.remove(index);
100 if !curr {
101 prev.push(true);
102 break;
103 } else {
104 removed_bits += 1;
105 }
106 }
107 if removed_bits == num_bits {
108 None
109 } else {
110 Some(
111 prev.into_iter()
112 .chain(repeat_n(false, removed_bits))
113 .collect_vec(),
114 )
115 }
116 })
117}
118
119#[cfg(test)]
120mod test {
121 use crate::{
122 abstract_expr::AbstractExpression,
123 layouter::builder::{Circuit, CircuitBuilder, LayerVisibility},
124 };
125 use remainder::{
126 mle::evals::MultilinearExtension,
127 prover::helpers::test_circuit_with_runtime_optimized_config,
128 };
129 use shared_types::{Field, Fr};
130
131 fn build_basic_split_circuit<F: Field>() -> Circuit<F> {
134 let mut builder = CircuitBuilder::<F>::new();
135
136 let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
137 let input = builder.add_input_shred("Input", 3, &input_layer);
138 let splits = builder.add_split_node(&input, 2);
139 let subtractor = builder.add_sector(&splits[0] - &splits[1]);
140 builder.set_output(&subtractor);
141
142 builder.build_with_layer_combination().unwrap()
143 }
144
145 #[test]
146 #[should_panic]
147 fn test_that_split_node_works_little_endian() {
148 let a = 1;
149 let b = 2;
150 let values: Vec<u64> = vec![a, a, 111, 1111, b, b, 11111, 111111];
152 let mle: MultilinearExtension<Fr> = values.into();
153 let mut circuit = build_basic_split_circuit::<Fr>();
154 circuit.set_input("Input", mle);
155 let provable_circuit = circuit.gen_provable_circuit().unwrap();
156 test_circuit_with_runtime_optimized_config(&provable_circuit);
157 }
158
159 #[test]
160 fn test_that_split_node_works_big_endian() {
161 let values: Vec<u64> = vec![11, 2, 11, 2, 123, 124, 125, 126];
163 let mle: MultilinearExtension<Fr> = values.into();
164 let mut circuit = build_basic_split_circuit::<Fr>();
165 circuit.set_input("Input", mle);
166 let provable_circuit = circuit.gen_provable_circuit().unwrap();
167 test_circuit_with_runtime_optimized_config(&provable_circuit);
168 }
169
170 fn build_splits_and_selectors_circuit<F: Field>() -> Circuit<F> {
173 let mut builder = CircuitBuilder::<F>::new();
174
175 let num_vars = 2;
176 let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
177 let input0 = builder.add_input_shred("Input 0", num_vars, &input_layer);
178 let input1 = builder.add_input_shred("Input 1", num_vars, &input_layer);
179 let input2 = builder.add_input_shred("Input 2", num_vars, &input_layer);
180 let input3 = builder.add_input_shred("Input 3", num_vars, &input_layer);
181
182 let concatenator = builder.add_sector(AbstractExpression::binary_tree_selector(vec![
183 &input0, &input1, &input2, &input3,
184 ]));
185
186 let splits = builder.add_split_node(&concatenator, num_vars);
187 assert_eq!(splits.len(), 4);
188
189 let subtractor = builder.add_sector(
190 (&splits[0] - input0)
191 + (&splits[1] - input1)
192 + (&splits[2] - input2 + (&splits[3] - input3)),
193 );
194 builder.set_output(&subtractor);
195
196 builder.build_with_layer_combination().unwrap()
197 }
198
199 #[test]
200 fn test_splits_and_selectors() {
202 let mut circuit = build_splits_and_selectors_circuit::<Fr>();
203
204 circuit.set_input("Input 0", vec![1, 2, 3, 4].into());
205 circuit.set_input("Input 1", vec![5, 6, 7, 8].into());
206 circuit.set_input("Input 2", vec![9, 10, 11, 12].into());
207 circuit.set_input("Input 3", vec![13, 14, 15, 16].into());
208
209 let provable_circuit = circuit.gen_provable_circuit().unwrap();
210
211 test_circuit_with_runtime_optimized_config(&provable_circuit);
212 }
213}