frontend/layouter/nodes/
sector.rs

1//! The basic building block of a regular gkr circuit. The Sector node
2
3use std::collections::{BTreeSet, HashMap};
4
5use shared_types::Field;
6
7use remainder::{
8    circuit_layout::CircuitLocation,
9    layer::{layer_enum::LayerDescriptionEnum, regular_layer::RegularLayerDescription, LayerId},
10    utils::arithmetic::log2_ceil,
11};
12
13use crate::{
14    abstract_expr::AbstractExpression,
15    layouter::{builder::CircuitMap, nodes::CompilableNode},
16};
17
18use super::{CircuitNode, NodeId};
19
20use anyhow::Result;
21#[cfg(test)]
22mod tests;
23
24#[derive(Debug, Clone)]
25/// A sector node in the circuit DAG, can have multiple inputs, and a single
26/// output
27pub struct Sector<F: Field> {
28    id: NodeId,
29    expr: AbstractExpression<F>,
30    num_vars: usize,
31}
32
33impl<F: Field> Sector<F> {
34    /// creates a new sector node
35    pub fn new(expr: AbstractExpression<F>, num_vars: usize) -> Self {
36        Self {
37            id: NodeId::new(),
38            expr,
39            num_vars,
40        }
41    }
42}
43
44impl<F: Field> CircuitNode for Sector<F> {
45    fn id(&self) -> NodeId {
46        self.id
47    }
48
49    fn sources(&self) -> Vec<NodeId> {
50        self.expr.get_sources()
51    }
52
53    fn get_num_vars(&self) -> usize {
54        self.num_vars
55    }
56}
57
58impl<F: Field> CompilableNode<F> for Sector<F> {
59    fn generate_circuit_description(
60        &self,
61        circuit_map: &mut CircuitMap,
62    ) -> Result<Vec<LayerDescriptionEnum<F>>> {
63        Ok(generate_sector_circuit_description(
64            &[self],
65            circuit_map,
66            None,
67        ))
68    }
69}
70
71/// Generate a circuit description for a vector of sectors that are to be
72/// combined. I.e., the sectors passed into this function do not have any
73/// dependencies between each other. The expected behavior of this function is
74/// to return a single layer which has merged the expressions of each of the
75/// individual sectors into a single expression.
76pub fn generate_sector_circuit_description<F: Field>(
77    sectors: &[&Sector<F>],
78    circuit_map: &mut CircuitMap,
79    maybe_maximum_log_layer_size: Option<usize>,
80) -> Vec<LayerDescriptionEnum<F>> {
81    compile_sectors_into_layer_descriptions(sectors, circuit_map, maybe_maximum_log_layer_size)
82        .unwrap()
83        .into_iter()
84        .map(|regular_layer| LayerDescriptionEnum::Regular(regular_layer))
85        .collect()
86}
87
88/// Takes some sectors that all belong in a single layer and builds the
89/// layer/adds their locations to the circuit map
90fn compile_sectors_into_layer_descriptions<F: Field>(
91    children: &[&Sector<F>],
92    circuit_map: &mut CircuitMap,
93    maybe_maximum_log_layer_size: Option<usize>,
94) -> Result<Vec<RegularLayerDescription<F>>> {
95    // Compute the total number of coefficients required to fully merge this
96    // expression.
97    let mut total_num_coeff: usize = 0;
98    // This will store all the expression yet to be merged, along with the
99    // sector's ID as well as the number of variables.
100    let mut expression_vec = children
101        .iter()
102        .map(|sector| {
103            total_num_coeff += 1 << sector.get_num_vars();
104            Ok((
105                vec![sector.id()],
106                sector.expr.clone(),
107                sector.get_num_vars(),
108            ))
109        })
110        .collect::<Result<Vec<_>>>()?;
111
112    // If the max is specified by the circuit builder, we set it to that.
113    // Otherwise, we allow all of the expressions to be combined so we set it to
114    // what would be the total number of coefficients of the fully merged
115    // expression.
116    let maximum_log_layer_size = maybe_maximum_log_layer_size
117        .unwrap_or(log2_ceil(total_num_coeff.next_power_of_two()) as usize);
118
119    // These prefix bits will be stored in reverse order for each NodeID. Store
120    // the number of variables existing in the sector.
121    let mut prefix_bits_map: HashMap<NodeId, (Vec<bool>, usize)> = HashMap::new();
122    for sector in children.iter() {
123        prefix_bits_map.insert(sector.id(), (vec![], sector.get_num_vars()));
124    }
125
126    // We loop until all the expressions are merged, or the smallest merged
127    // expression exceeds the maximum layer size specified. This means that we
128    // cannot merge any more expressions, and we should compile the rest of the
129    // expressions as is.
130    let new_expr_vec = loop {
131        // Either we have merged all the expressions into one, in this case we
132        // are done combining the expressions.
133        if expression_vec.len() == 1 {
134            break expression_vec;
135        }
136        // We merge the two smallest expressions first.
137        expression_vec.sort_by(|rhs, lhs| rhs.2.cmp(&lhs.2).reverse());
138
139        // Or the two smallest expressions exceed the maximum layer size,
140        // meaning none of the other expressions can be combined. We are done
141        // combining if this is true.
142        //
143        // The total number of coefficients of the merged expression is one
144        // power of two more than the second smallest expression, as we are
145        // combining the two smallest expressions.
146        let total_num_coeff_merged_expr = 1 << (expression_vec[expression_vec.len() - 2].2 + 1);
147        if total_num_coeff_merged_expr > (1 << maximum_log_layer_size) {
148            break expression_vec;
149        }
150
151        let (smallest_ids, mut smallest, smallest_num_vars) = expression_vec.pop().unwrap();
152        let (next_ids, next, next_num_vars) = expression_vec.pop().unwrap();
153
154        // The number of selector variables that need to be added to the smaller
155        // expression to make both expressions the same size.
156        let padding_selector_vars = next_num_vars - smallest_num_vars;
157        // Add any new selector nodes that are needed for padding.
158        for _ in 0..padding_selector_vars {
159            // This results in padding the MLE of the smallest expression with
160            // 0s.
161            smallest = AbstractExpression::constant(F::ZERO).select(smallest);
162            for node_id in &smallest_ids {
163                let (prefix_bits, _) = prefix_bits_map.get_mut(node_id).unwrap();
164                prefix_bits.push(true);
165            }
166        }
167
168        // Merge the two expressions. Now the smallest expression and next
169        // expression are the same size.
170        smallest = next.select(smallest);
171
172        // Track the prefix bits we're creating so they can be added to the
173        // circuit_map; each concat operation pushes a new prefix_bit.
174        for node_id in &smallest_ids {
175            let (prefix_bits, _) = prefix_bits_map.get_mut(node_id).unwrap();
176            prefix_bits.push(true);
177        }
178
179        for node_id in &next_ids {
180            let (prefix_bits, _) = prefix_bits_map.get_mut(node_id).unwrap();
181            prefix_bits.push(false);
182        }
183
184        expression_vec.push((
185            [smallest_ids, next_ids].concat(),
186            smallest,
187            next_num_vars + 1,
188        ));
189    };
190
191    // Keep track of which node IDs are getting added to the circuit map.
192    let mut node_ids_added_to_circuit_map: BTreeSet<NodeId> = BTreeSet::new();
193    // Go through all of the expressions in the vector, which have either been
194    // merged or stayed the same, and compile them into
195    // [RegularLayerDescription]s.
196    let layer_vec: Vec<RegularLayerDescription<F>> = new_expr_vec
197        .into_iter()
198        .map(|(sector_nodes, expression, _expr_num_vars)| {
199            let expr = expression.build_circuit_expr(circuit_map).unwrap();
200            let regular_layer_id = LayerId::next_layer_id();
201            let layer = RegularLayerDescription::new_raw(regular_layer_id, expr);
202            prefix_bits_map
203                .iter_mut()
204                .for_each(|(node_id, (prefix_bits, num_vars))| {
205                    if sector_nodes.contains(node_id) {
206                        node_ids_added_to_circuit_map.insert(*node_id);
207                        prefix_bits.reverse();
208                        circuit_map.add_node_id_and_location_num_vars(
209                            *node_id,
210                            (
211                                CircuitLocation::new(regular_layer_id, prefix_bits.to_vec()),
212                                *num_vars,
213                            ),
214                        );
215                    }
216                });
217            layer
218        })
219        .collect();
220    // Assert that all of the node ids have been added to the circuit map from
221    // those originally populated into the `prefix_bits_map`.
222    assert_eq!(node_ids_added_to_circuit_map.len(), prefix_bits_map.len());
223
224    Ok(layer_vec)
225}