frontend/layouter/nodes/circuit_inputs/
compile_inputs.rs

1use ark_std::log2;
2use itertools::Itertools;
3use shared_types::Field;
4
5use crate::layouter::{builder::CircuitMap, nodes::CircuitNode};
6use remainder::{
7    circuit_layout::CircuitLocation, input_layer::InputLayerDescription, utils::mle::argsort,
8};
9
10use super::InputLayerNode;
11
12use anyhow::Result;
13
14/// Function which returns a vector of the values for prefix bits according to which
15/// position we are in the range from 0 to `total_num_bits` - `num_free_bits`.
16fn get_prefix_bits_from_capacity(
17    capacity: u32,
18    total_num_bits: usize,
19    num_free_bits: usize,
20) -> Vec<bool> {
21    (0..total_num_bits - num_free_bits)
22        .map(|bit_position| {
23            // Divide capacity by 2**(total_num_bits - bit_position - 1) and see whether the last bit is 1
24            let bit_val = (capacity >> (total_num_bits - bit_position - 1)) & 1;
25            bit_val == 1
26        })
27        .collect()
28}
29
30fn index_input_mles(input_mle_num_vars: &[usize]) -> (Vec<Vec<bool>>, Vec<usize>, usize) {
31    // Add input-output MLE length if needed
32    let mle_combine_indices = argsort(input_mle_num_vars, true);
33
34    // Get the total needed capacity by rounding the raw capacity up to the nearest power of 2
35    let raw_needed_capacity = input_mle_num_vars
36        .iter()
37        .fold(0, |prev, input_mle_num_vars| {
38            prev + 2_usize.pow(*input_mle_num_vars as u32)
39        });
40    let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
41    let total_num_vars = log2(padded_needed_capacity) as usize;
42
43    // Go through individual MLEs and collect the prefix bits that need to be added to each one
44    let mut current_padded_usage: u32 = 0;
45    let res = mle_combine_indices
46        .iter()
47        .map(|input_mle_idx| {
48            let input_mle_bits = input_mle_num_vars[*input_mle_idx];
49
50            // Collect the prefix bits for each MLE
51            let prefix_bits: Vec<_> =
52                get_prefix_bits_from_capacity(current_padded_usage, total_num_vars, input_mle_bits);
53            current_padded_usage += 2_u32.pow(input_mle_bits as u32);
54            prefix_bits
55        })
56        .collect();
57    (res, mle_combine_indices, total_num_vars)
58}
59
60impl InputLayerNode {
61    /// From the circuit description map and a starting layer id, create the circuit description of
62    /// an input layer, adding the input shreds to the circuit map.
63    pub fn generate_input_layer_description<F: Field>(
64        &self,
65        circuit_map: &mut CircuitMap,
66    ) -> Result<InputLayerDescription> {
67        let Self {
68            id: _,
69            input_layer_id,
70            input_shreds,
71        } = &self;
72        let input_mle_num_vars = input_shreds
73            .iter()
74            .map(|node| node.get_num_vars())
75            .collect_vec();
76
77        let (prefix_bits, input_shred_indices, num_vars_combined_mle) =
78            index_input_mles(&input_mle_num_vars);
79        debug_assert_eq!(input_shred_indices.len(), input_shreds.len());
80
81        let input_layer_description = InputLayerDescription {
82            layer_id: *input_layer_id,
83            num_vars: num_vars_combined_mle,
84        };
85
86        input_shred_indices
87            .iter()
88            .zip(prefix_bits)
89            .for_each(|(input_shred_index, prefix_bits)| {
90                let input_shred = &input_shreds[*input_shred_index];
91                circuit_map.add_node_id_and_location_num_vars(
92                    input_shred.id,
93                    (
94                        CircuitLocation::new(*input_layer_id, prefix_bits),
95                        input_shred.get_num_vars(),
96                    ),
97                );
98            });
99
100        Ok(input_layer_description)
101    }
102}