frontend/layouter/nodes/
sector.rs1use std::collections::{BTreeSet, HashMap};
4
5use shared_types::Field;
6
7use remainder::{
8 circuit_layout::CircuitLocation,
9 layer::{layer_enum::LayerDescriptionEnum, regular_layer::RegularLayerDescription, LayerId},
10 utils::arithmetic::log2_ceil,
11};
12
13use crate::{
14 abstract_expr::AbstractExpression,
15 layouter::{builder::CircuitMap, nodes::CompilableNode},
16};
17
18use super::{CircuitNode, NodeId};
19
20use anyhow::Result;
21#[cfg(test)]
22mod tests;
23
24#[derive(Debug, Clone)]
25pub struct Sector<F: Field> {
28 id: NodeId,
29 expr: AbstractExpression<F>,
30 num_vars: usize,
31}
32
33impl<F: Field> Sector<F> {
34 pub fn new(expr: AbstractExpression<F>, num_vars: usize) -> Self {
36 Self {
37 id: NodeId::new(),
38 expr,
39 num_vars,
40 }
41 }
42}
43
44impl<F: Field> CircuitNode for Sector<F> {
45 fn id(&self) -> NodeId {
46 self.id
47 }
48
49 fn sources(&self) -> Vec<NodeId> {
50 self.expr.get_sources()
51 }
52
53 fn get_num_vars(&self) -> usize {
54 self.num_vars
55 }
56}
57
58impl<F: Field> CompilableNode<F> for Sector<F> {
59 fn generate_circuit_description(
60 &self,
61 circuit_map: &mut CircuitMap,
62 ) -> Result<Vec<LayerDescriptionEnum<F>>> {
63 Ok(generate_sector_circuit_description(
64 &[self],
65 circuit_map,
66 None,
67 ))
68 }
69}
70
71pub fn generate_sector_circuit_description<F: Field>(
77 sectors: &[&Sector<F>],
78 circuit_map: &mut CircuitMap,
79 maybe_maximum_log_layer_size: Option<usize>,
80) -> Vec<LayerDescriptionEnum<F>> {
81 compile_sectors_into_layer_descriptions(sectors, circuit_map, maybe_maximum_log_layer_size)
82 .unwrap()
83 .into_iter()
84 .map(|regular_layer| LayerDescriptionEnum::Regular(regular_layer))
85 .collect()
86}
87
88fn compile_sectors_into_layer_descriptions<F: Field>(
91 children: &[&Sector<F>],
92 circuit_map: &mut CircuitMap,
93 maybe_maximum_log_layer_size: Option<usize>,
94) -> Result<Vec<RegularLayerDescription<F>>> {
95 let mut total_num_coeff: usize = 0;
98 let mut expression_vec = children
101 .iter()
102 .map(|sector| {
103 total_num_coeff += 1 << sector.get_num_vars();
104 Ok((
105 vec![sector.id()],
106 sector.expr.clone(),
107 sector.get_num_vars(),
108 ))
109 })
110 .collect::<Result<Vec<_>>>()?;
111
112 let maximum_log_layer_size = maybe_maximum_log_layer_size
117 .unwrap_or(log2_ceil(total_num_coeff.next_power_of_two()) as usize);
118
119 let mut prefix_bits_map: HashMap<NodeId, (Vec<bool>, usize)> = HashMap::new();
122 for sector in children.iter() {
123 prefix_bits_map.insert(sector.id(), (vec![], sector.get_num_vars()));
124 }
125
126 let new_expr_vec = loop {
131 if expression_vec.len() == 1 {
134 break expression_vec;
135 }
136 expression_vec.sort_by(|rhs, lhs| rhs.2.cmp(&lhs.2).reverse());
138
139 let total_num_coeff_merged_expr = 1 << (expression_vec[expression_vec.len() - 2].2 + 1);
147 if total_num_coeff_merged_expr > (1 << maximum_log_layer_size) {
148 break expression_vec;
149 }
150
151 let (smallest_ids, mut smallest, smallest_num_vars) = expression_vec.pop().unwrap();
152 let (next_ids, next, next_num_vars) = expression_vec.pop().unwrap();
153
154 let padding_selector_vars = next_num_vars - smallest_num_vars;
157 for _ in 0..padding_selector_vars {
159 smallest = AbstractExpression::constant(F::ZERO).select(smallest);
162 for node_id in &smallest_ids {
163 let (prefix_bits, _) = prefix_bits_map.get_mut(node_id).unwrap();
164 prefix_bits.push(true);
165 }
166 }
167
168 smallest = next.select(smallest);
171
172 for node_id in &smallest_ids {
175 let (prefix_bits, _) = prefix_bits_map.get_mut(node_id).unwrap();
176 prefix_bits.push(true);
177 }
178
179 for node_id in &next_ids {
180 let (prefix_bits, _) = prefix_bits_map.get_mut(node_id).unwrap();
181 prefix_bits.push(false);
182 }
183
184 expression_vec.push((
185 [smallest_ids, next_ids].concat(),
186 smallest,
187 next_num_vars + 1,
188 ));
189 };
190
191 let mut node_ids_added_to_circuit_map: BTreeSet<NodeId> = BTreeSet::new();
193 let layer_vec: Vec<RegularLayerDescription<F>> = new_expr_vec
197 .into_iter()
198 .map(|(sector_nodes, expression, _expr_num_vars)| {
199 let expr = expression.build_circuit_expr(circuit_map).unwrap();
200 let regular_layer_id = LayerId::next_layer_id();
201 let layer = RegularLayerDescription::new_raw(regular_layer_id, expr);
202 prefix_bits_map
203 .iter_mut()
204 .for_each(|(node_id, (prefix_bits, num_vars))| {
205 if sector_nodes.contains(node_id) {
206 node_ids_added_to_circuit_map.insert(*node_id);
207 prefix_bits.reverse();
208 circuit_map.add_node_id_and_location_num_vars(
209 *node_id,
210 (
211 CircuitLocation::new(regular_layer_id, prefix_bits.to_vec()),
212 *num_vars,
213 ),
214 );
215 }
216 });
217 layer
218 })
219 .collect();
220 assert_eq!(node_ids_added_to_circuit_map.len(), prefix_bits_map.len());
223
224 Ok(layer_vec)
225}