frontend/layouter/nodes/
split_node.rs

1//! A node that splits a single MLE into 2^num_vars smaller MLEs using prefix bits.
2
3use itertools::{repeat_n, Itertools};
4use shared_types::Field;
5
6use crate::layouter::builder::CircuitMap;
7use remainder::{circuit_layout::CircuitLocation, layer::layer_enum::LayerDescriptionEnum};
8
9use super::{CircuitNode, CompilableNode, NodeId};
10
11use anyhow::Result;
12/// A node that splits a single MLE into 2^num_vars smaller MLEs using prefix bits.
13/// Works big endian.
14#[derive(Clone, Debug)]
15pub struct SplitNode {
16    id: NodeId,
17    num_vars: usize,
18    source: NodeId,
19    prefix_bits: Vec<bool>,
20}
21
22impl SplitNode {
23    /// Creates 2^num_vars instances of [SplitNode] from a single [CircuitNode] using prefix bits in
24    /// big-endian order. For example, if num_vars is 2, the prefix bits of the returned
25    /// instances will be (in order): 00, 01, 10, 11.
26    pub fn new(node: &dyn CircuitNode, num_vars: usize) -> Vec<Self> {
27        let num_vars_node = node.get_num_vars();
28        let source = node.id();
29        let max_num_vars = num_vars_node - num_vars;
30        bits_iter(num_vars)
31            .map(|prefix_bits| {
32                let prefix_bits = prefix_bits.into_iter().collect();
33                Self {
34                    id: NodeId::new(),
35                    source,
36                    num_vars: max_num_vars,
37                    prefix_bits,
38                }
39            })
40            .collect()
41    }
42}
43
44impl CircuitNode for SplitNode {
45    fn id(&self) -> NodeId {
46        self.id
47    }
48
49    fn sources(&self) -> Vec<NodeId> {
50        vec![self.source]
51    }
52
53    fn get_num_vars(&self) -> usize {
54        self.num_vars
55    }
56}
57
58impl<F: Field> CompilableNode<F> for SplitNode {
59    fn generate_circuit_description(
60        &self,
61        circuit_map: &mut CircuitMap,
62    ) -> Result<Vec<LayerDescriptionEnum<F>>> {
63        let (source_location, _) = circuit_map.get_location_num_vars_from_node_id(&self.source)?;
64
65        let prefix_bits = source_location
66            .prefix_bits
67            .iter()
68            .chain(self.prefix_bits.iter())
69            .copied()
70            .collect();
71
72        let location = CircuitLocation::new(source_location.layer_id, prefix_bits);
73
74        circuit_map.add_node_id_and_location_num_vars(self.id, (location, self.get_num_vars()));
75        Ok(vec![])
76    }
77}
78
79/// Returns an iterator that gives the MSB-first binary representation of the numbers from 0 to
80/// 2^num_bits.
81/// 0,0,0 -> 0,0,1 -> 0,1,0 -> 0,1,1 -> 1,0,0 -> 1,0,1 -> 1,1,0 -> 1,1,1
82/// # Example:
83/// ```
84/// use frontend::layouter::nodes::split_node::bits_iter;
85/// let bits_iter = bits_iter(2);
86/// let bits: Vec<Vec<bool>> = bits_iter.collect();
87/// assert_eq!(bits, vec![
88///   vec![false, false],
89///   vec![false, true],
90///   vec![true, false],
91///   vec![true, true],
92/// ]);
93/// ```
94pub fn bits_iter(num_bits: usize) -> impl Iterator<Item = Vec<bool>> {
95    std::iter::successors(Some(vec![false; num_bits]), move |prev| {
96        let mut prev = prev.clone();
97        let mut removed_bits = 0;
98        for index in (0..num_bits).rev() {
99            let curr = prev.remove(index);
100            if !curr {
101                prev.push(true);
102                break;
103            } else {
104                removed_bits += 1;
105            }
106        }
107        if removed_bits == num_bits {
108            None
109        } else {
110            Some(
111                prev.into_iter()
112                    .chain(repeat_n(false, removed_bits))
113                    .collect_vec(),
114            )
115        }
116    })
117}
118
119#[cfg(test)]
120mod test {
121    use crate::{
122        abstract_expr::AbstractExpression,
123        layouter::builder::{Circuit, CircuitBuilder, LayerVisibility},
124    };
125    use remainder::{
126        mle::evals::MultilinearExtension,
127        prover::helpers::test_circuit_with_runtime_optimized_config,
128    };
129    use shared_types::{Field, Fr};
130
131    // Build a circuit that takes in a single input with 8 values, splits it into four, and
132    // subtracts the first two from one another (this is the output).
133    fn build_basic_split_circuit<F: Field>() -> Circuit<F> {
134        let mut builder = CircuitBuilder::<F>::new();
135
136        let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
137        let input = builder.add_input_shred("Input", 3, &input_layer);
138        let splits = builder.add_split_node(&input, 2);
139        let subtractor = builder.add_sector(&splits[0] - &splits[1]);
140        builder.set_output(&subtractor);
141
142        builder.build_with_layer_combination().unwrap()
143    }
144
145    #[test]
146    #[should_panic]
147    fn test_that_split_node_works_little_endian() {
148        let a = 1;
149        let b = 2;
150        // the following values work if SplitNode is LITTLE endian
151        let values: Vec<u64> = vec![a, a, 111, 1111, b, b, 11111, 111111];
152        let mle: MultilinearExtension<Fr> = values.into();
153        let mut circuit = build_basic_split_circuit::<Fr>();
154        circuit.set_input("Input", mle);
155        let provable_circuit = circuit.gen_provable_circuit().unwrap();
156        test_circuit_with_runtime_optimized_config(&provable_circuit);
157    }
158
159    #[test]
160    fn test_that_split_node_works_big_endian() {
161        // the following values work if SplitNode is BIG endian
162        let values: Vec<u64> = vec![11, 2, 11, 2, 123, 124, 125, 126];
163        let mle: MultilinearExtension<Fr> = values.into();
164        let mut circuit = build_basic_split_circuit::<Fr>();
165        circuit.set_input("Input", mle);
166        let provable_circuit = circuit.gen_provable_circuit().unwrap();
167        test_circuit_with_runtime_optimized_config(&provable_circuit);
168    }
169
170    // Build a circuit that takes 4 MLEs, joins them using selectors, splits them into 4 MLEs using
171    // SplitNode, and checks that this is the noop.
172    fn build_splits_and_selectors_circuit<F: Field>() -> Circuit<F> {
173        let mut builder = CircuitBuilder::<F>::new();
174
175        let num_vars = 2;
176        let input_layer = builder.add_input_layer("Input Layer", LayerVisibility::Public);
177        let input0 = builder.add_input_shred("Input 0", num_vars, &input_layer);
178        let input1 = builder.add_input_shred("Input 1", num_vars, &input_layer);
179        let input2 = builder.add_input_shred("Input 2", num_vars, &input_layer);
180        let input3 = builder.add_input_shred("Input 3", num_vars, &input_layer);
181
182        let concatenator = builder.add_sector(AbstractExpression::binary_tree_selector(vec![
183            &input0, &input1, &input2, &input3,
184        ]));
185
186        let splits = builder.add_split_node(&concatenator, num_vars);
187        assert_eq!(splits.len(), 4);
188
189        let subtractor = builder.add_sector(
190            (&splits[0] - input0)
191                + (&splits[1] - input1)
192                + (&splits[2] - input2 + (&splits[3] - input3)),
193        );
194        builder.set_output(&subtractor);
195
196        builder.build_with_layer_combination().unwrap()
197    }
198
199    #[test]
200    // Test that SplitNode undoes what the work of selector bits.
201    fn test_splits_and_selectors() {
202        let mut circuit = build_splits_and_selectors_circuit::<Fr>();
203
204        circuit.set_input("Input 0", vec![1, 2, 3, 4].into());
205        circuit.set_input("Input 1", vec![5, 6, 7, 8].into());
206        circuit.set_input("Input 2", vec![9, 10, 11, 12].into());
207        circuit.set_input("Input 3", vec![13, 14, 15, 16].into());
208
209        let provable_circuit = circuit.gen_provable_circuit().unwrap();
210
211        test_circuit_with_runtime_optimized_config(&provable_circuit);
212    }
213}