frontend/layouter/nodes/circuit_inputs/
compile_inputs.rs1use ark_std::log2;
2use itertools::Itertools;
3use shared_types::Field;
4
5use crate::layouter::{builder::CircuitMap, nodes::CircuitNode};
6use remainder::{
7 circuit_layout::CircuitLocation, input_layer::InputLayerDescription, utils::mle::argsort,
8};
9
10use super::InputLayerNode;
11
12use anyhow::Result;
13
14fn get_prefix_bits_from_capacity(
17 capacity: u32,
18 total_num_bits: usize,
19 num_free_bits: usize,
20) -> Vec<bool> {
21 (0..total_num_bits - num_free_bits)
22 .map(|bit_position| {
23 let bit_val = (capacity >> (total_num_bits - bit_position - 1)) & 1;
25 bit_val == 1
26 })
27 .collect()
28}
29
30fn index_input_mles(input_mle_num_vars: &[usize]) -> (Vec<Vec<bool>>, Vec<usize>, usize) {
31 let mle_combine_indices = argsort(input_mle_num_vars, true);
33
34 let raw_needed_capacity = input_mle_num_vars
36 .iter()
37 .fold(0, |prev, input_mle_num_vars| {
38 prev + 2_usize.pow(*input_mle_num_vars as u32)
39 });
40 let padded_needed_capacity = (1 << log2(raw_needed_capacity)) as usize;
41 let total_num_vars = log2(padded_needed_capacity) as usize;
42
43 let mut current_padded_usage: u32 = 0;
45 let res = mle_combine_indices
46 .iter()
47 .map(|input_mle_idx| {
48 let input_mle_bits = input_mle_num_vars[*input_mle_idx];
49
50 let prefix_bits: Vec<_> =
52 get_prefix_bits_from_capacity(current_padded_usage, total_num_vars, input_mle_bits);
53 current_padded_usage += 2_u32.pow(input_mle_bits as u32);
54 prefix_bits
55 })
56 .collect();
57 (res, mle_combine_indices, total_num_vars)
58}
59
60impl InputLayerNode {
61 pub fn generate_input_layer_description<F: Field>(
64 &self,
65 circuit_map: &mut CircuitMap,
66 ) -> Result<InputLayerDescription> {
67 let Self {
68 id: _,
69 input_layer_id,
70 input_shreds,
71 } = &self;
72 let input_mle_num_vars = input_shreds
73 .iter()
74 .map(|node| node.get_num_vars())
75 .collect_vec();
76
77 let (prefix_bits, input_shred_indices, num_vars_combined_mle) =
78 index_input_mles(&input_mle_num_vars);
79 debug_assert_eq!(input_shred_indices.len(), input_shreds.len());
80
81 let input_layer_description = InputLayerDescription {
82 layer_id: *input_layer_id,
83 num_vars: num_vars_combined_mle,
84 };
85
86 input_shred_indices
87 .iter()
88 .zip(prefix_bits)
89 .for_each(|(input_shred_index, prefix_bits)| {
90 let input_shred = &input_shreds[*input_shred_index];
91 circuit_map.add_node_id_and_location_num_vars(
92 input_shred.id,
93 (
94 CircuitLocation::new(*input_layer_id, prefix_bits),
95 input_shred.get_num_vars(),
96 ),
97 );
98 });
99
100 Ok(input_layer_description)
101 }
102}