frontend/layouter/
layouting.rs

1//! Defines the utilities for taking a list of nodes and turning it into a
2//! layedout circuit
3#[cfg(test)]
4mod tests;
5
6use std::collections::HashMap;
7use std::collections::HashSet;
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use itertools::Itertools;
12use shared_types::Field;
13use thiserror::Error;
14
15use crate::layouter::nodes::sector::Sector;
16
17use super::nodes::{
18    circuit_inputs::{InputLayerNode, InputShred},
19    circuit_outputs::OutputNode,
20    fiat_shamir_challenge::FiatShamirChallengeNode,
21    gate::GateNode,
22    identity_gate::IdentityGateNode,
23    lookup::{LookupConstraint, LookupTable},
24    matmult::MatMultNode,
25    split_node::SplitNode,
26    CircuitNode, CompilableNode, NodeId,
27};
28
29use anyhow::{anyhow, Result};
30
31/// Possible errors when topologically ordering the dependency graph that arises
32/// from circuit creation and then categorizing them into layers.
33#[derive(Error, Debug, Clone)]
34pub enum LayoutingError {
35    /// There is a cycle in the node dependencies.
36    #[error("There exists a cycle in the dependency of the nodes, and therefore no topological ordering.")]
37    CircularDependency,
38    /// There exists a node which the circuit builder created, but no other node
39    /// references or depends on.
40    #[error("There exists a node which the circuit builder created, but no other node references or depends on: Id = {0:?}")]
41    DanglingNodeId(NodeId),
42    /// We have gotten to a layer whose parts of the expression have not been
43    /// generated.
44    #[error("This circuit location does not exist, or has not been compiled yet")]
45    NoCircuitLocation,
46}
47/// A directed graph represented with an adjacency list.
48///
49/// `repr` maps each vertex `u` to a vector of all vertices `v` such that `(u, v)` is an edge in the
50/// graph.
51///
52/// In the context of this module, vertices correspond to Circuit Nodes, and edges represent
53/// precedence constraints: there is an edge `(u, v)` if `v`'s input depends on `u`'s output.
54///
55/// Type `N` is typically a node-identifying type, such as `NodeId`.
56#[derive(Clone, Debug)]
57pub struct Graph<N: Hash + Eq + Clone + Debug> {
58    repr: HashMap<N, Vec<N>>,
59}
60
61impl<N: Hash + Eq + Clone + Debug> Graph<N> {
62    /// Constructor given the map of dependencies.
63    fn new_from_map(map: HashMap<N, Vec<N>>) -> Self {
64        Self { repr: map }
65    }
66
67    /// Constructor specifically for a [`Graph<NodeId>`], which will convert an
68    /// array of [CompilableNode], each of which reference their sources, and
69    /// convert that into the graph representation.
70    ///
71    /// Note: we only topologically sort intermediate nodes, therefore we
72    /// provide `input_shred_ids` to exclude them from the graph.
73    fn new_from_circuit_nodes<F: Field>(
74        intermediate_circuit_nodes: &[Box<dyn CompilableNode<F>>],
75        input_shred_ids: &HashSet<NodeId>,
76    ) -> Graph<NodeId> {
77        let mut children_to_parent_map: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
78        intermediate_circuit_nodes.iter().for_each(|circuit_node| {
79            // Disregard the nodes which are input shreds.
80            if !input_shred_ids.contains(&circuit_node.id()) {
81                children_to_parent_map.insert(
82                    circuit_node.id(),
83                    circuit_node
84                        .sources()
85                        .iter()
86                        .filter_map(|circuit_node_source| {
87                            // If any of the sources are input shreds, we don't
88                            // add them to the graph.
89                            if input_shred_ids.contains(circuit_node_source) {
90                                None
91                            } else {
92                                Some(*circuit_node_source)
93                            }
94                        })
95                        .collect(),
96                );
97            }
98        });
99
100        Graph::<NodeId>::new_from_map(children_to_parent_map)
101    }
102
103    /// Given the graph, a starting node (`node`), the set of visited/"marked"
104    /// nodes (`exploring_nodes`) and nodes that have been fully explored
105    /// (`terminated_nodes`), traverse the graph recursively using the DFS
106    /// algorithm from the starting node.
107    ///
108    /// Returns nothing, but mutates the `exploring_nodes` and
109    /// `terminated_nodes` accordingly.
110    fn visit_node_dfs(
111        &self,
112        node: &N,
113        exploring_nodes: &mut HashSet<N>,
114        terminated_nodes: &mut HashSet<N>,
115        topological_order: &mut Vec<N>,
116    ) -> Result<()> {
117        if terminated_nodes.contains(node) {
118            return Ok(());
119        }
120        if exploring_nodes.contains(node) {
121            return Err(anyhow!(LayoutingError::CircularDependency));
122        }
123        let neighbors = self.repr.get(node).unwrap();
124
125        exploring_nodes.insert(node.clone());
126        for neighbor_not_terminated in neighbors {
127            self.visit_node_dfs(
128                neighbor_not_terminated,
129                exploring_nodes,
130                terminated_nodes,
131                topological_order,
132            )?;
133        }
134        exploring_nodes.remove(node);
135
136        terminated_nodes.insert(node.clone());
137        topological_order.push(node.clone());
138
139        Ok(())
140    }
141
142    /// Topologically orders the graph by visiting each node via DFS. Returns
143    /// the terminated nodes in order of termination.
144    ///
145    /// Note: normally, the topological sort is the "reverse" order of
146    /// termination via DFS. However, since we wish to order from sources ->
147    /// sinks, and our dependency graph is structured as children -> parent, we
148    /// do not reverse the termination order of the DFS search.
149    fn topo_sort(&self) -> Result<Vec<N>> {
150        let mut terminated_nodes: HashSet<N> = HashSet::new();
151        let mut topological_order: Vec<N> = Vec::with_capacity(self.repr.len());
152        let mut exploring_nodes: HashSet<N> = HashSet::new();
153
154        for node in self.repr.keys() {
155            self.visit_node_dfs(
156                node,
157                &mut exploring_nodes,
158                &mut terminated_nodes,
159                &mut topological_order,
160            )?;
161        }
162
163        assert_eq!(topological_order.len(), self.repr.len());
164
165        Ok(topological_order)
166    }
167
168    /// Given a list of nodes that are already topologically ordered, we check
169    /// all nodes that come before it to return the index in the topological
170    /// sort of the latest node that it was dependent on.
171    fn get_index_of_latest_dependency(
172        &self,
173        topological_order: &[N],
174        idx_to_check: usize,
175    ) -> Option<usize> {
176        let node_to_check = &topological_order[idx_to_check];
177        let neighbors = self.repr.get(node_to_check).unwrap();
178        let mut latest_dependency_idx: Option<usize> = None;
179
180        for idx in (0..idx_to_check).rev() {
181            let node = &topological_order[idx];
182            if neighbors.contains(node) {
183                latest_dependency_idx = Some(idx);
184                break;
185            }
186        }
187        latest_dependency_idx
188    }
189
190    /// Given a _valid_ topological ordering `topological_order == [v_0, v_1, ..., v_{n-1}]`, this
191    /// method returns a vector `latest_dependency_indices` such that `latest_dependency_indices[i]
192    /// == Some(j)` iff:
193    /// (a) there is a edge/dependency (v_i, v_j) in the graph, and
194    /// (b) j is the maximum index for which such an edge exists.
195    /// If there are no edges coming out of `v_i`, then `latest_dependency_indices[i] == None`.
196    ///
197    /// # Complexity
198    /// Linear in the size of the graph: `O(|V| + |E|)`.
199    fn gen_latest_dependecy_indices(&self, topological_order: &[N]) -> Vec<Option<usize>> {
200        assert_eq!(topological_order.len(), self.repr.len());
201
202        // Invert `topological_order` to map nodes to their index in the topological ordering.
203        let mut node_idx = HashMap::<N, usize>::new();
204        topological_order
205            .iter()
206            .enumerate()
207            .for_each(|(idx, node)| {
208                debug_assert!(self.repr.contains_key(node));
209                node_idx.insert(node.clone(), idx);
210            });
211
212        topological_order
213            .iter()
214            .map(|u| {
215                let deps = self.repr.get(u).unwrap();
216                deps.iter().map(|u| node_idx[u]).max()
217            })
218            .collect_vec()
219    }
220
221    /// Inefficient variant of `gen_latest_dependency_indices` used for correctness tests.
222    pub fn naive_gen_latest_dependecy_indices(
223        &self,
224        topological_order: &[N],
225    ) -> Vec<Option<usize>> {
226        let n = self.repr.len();
227        assert_eq!(topological_order.len(), n);
228
229        (0..n)
230            .map(|i| self.get_index_of_latest_dependency(topological_order, i))
231            .collect()
232    }
233}
234
235/// The type returned by the [layout] function. Categorizes the nodes into their
236/// respective layers.
237type LayouterNodes<F> = (
238    Vec<InputLayerNode>,
239    Vec<FiatShamirChallengeNode>,
240    // The inner vector represents nodes to be combined into the same layer. The
241    // outer vector is in the order of the layers to be compiled in terms of
242    // dependency.
243    Vec<Vec<Box<dyn CompilableNode<F>>>>,
244    Vec<LookupTable>,
245    Vec<OutputNode>,
246);
247
248/// Given the nodes provided by the circuit builder, this function returns a
249/// tuple of type `LayouterNodes`.
250///
251/// This function categorizes nodes into their respective layers, by doing the
252/// following:
253/// * Assigning `input_shred_nodes` to their respective `input_layer_nodes`
254///   parent.
255/// * The `fiat_shamir_challenge_nodes` are to be compiled next, so they are
256///   returned as is.
257/// * `sector_nodes`, `gate_nodes`, `identity_gate_nodes`, `matmult_nodes`, and
258///   `split_nodes` are considered intermediate nodes. `sector_nodes` are the
259///   only nodes which can be combined with each other via a selector.
260///   Therefore, first we topologically sort the intermediate nodes by creating
261///   a dependency graph using their specified sources. Then, we do a forward
262///   pass through these sorted nodes, identify which ones are the sectors, and
263///   combine them greedily (if there is no dependency between them, we
264///   combine). We then return a [`Vec<Vec<Box<dyn CompilableNode<F>>>>`] for
265///   which each inner vector represents nodes that can be combined into a
266///   single layer.
267/// * `lookup_constraint_nodes` are added to their respective
268///   `lookup_table_nodes`. Because no nodes are dependent on lookups (their
269///   results are always outputs), we compile them after the intermediate
270///   nodes.
271/// * `output_nodes` are compiled last, so they are returned as is.
272///
273/// The ordering in which the nodes are returned as `LayouterNodes` is the order
274/// in which the nodes are expected to be compiled into layers.
275#[allow(clippy::too_many_arguments)]
276pub fn layout<F: Field>(
277    mut input_layer_nodes: Vec<InputLayerNode>,
278    input_shred_nodes: Vec<InputShred>,
279    fiat_shamir_challenge_nodes: Vec<FiatShamirChallengeNode>,
280    output_nodes: Vec<OutputNode>,
281    sector_nodes: Vec<Sector<F>>,
282    gate_nodes: Vec<GateNode>,
283    identity_gate_nodes: Vec<IdentityGateNode>,
284    split_nodes: Vec<SplitNode>,
285    matmult_nodes: Vec<MatMultNode>,
286    lookup_constraint_nodes: Vec<LookupConstraint>,
287    mut lookup_table_nodes: Vec<LookupTable>,
288    should_combine: bool,
289) -> Result<LayouterNodes<F>> {
290    let mut input_layer_map: HashMap<NodeId, &mut InputLayerNode> = HashMap::new();
291    let sector_node_ids: HashSet<NodeId> = sector_nodes.iter().map(|sector| sector.id()).collect();
292
293    for layer in input_layer_nodes.iter_mut() {
294        input_layer_map.insert(layer.id(), layer);
295    }
296
297    // Step 1: Add `input_shred_nodes` to their specified `input_layer_nodes`
298    // parent.
299    let mut input_shred_ids = HashSet::<NodeId>::new();
300    for input_shred in input_shred_nodes {
301        let input_layer_id = input_shred.get_parent();
302        input_shred_ids.insert(input_shred.id());
303        let input_layer = input_layer_map
304            .get_mut(&input_layer_id)
305            .ok_or(LayoutingError::DanglingNodeId(input_layer_id))?;
306        input_layer.add_shred(input_shred);
307    }
308    input_shred_ids.extend(fiat_shamir_challenge_nodes.iter().map(|node| node.id()));
309
310    // We cast all intermediate nodes into their generic trait implementation
311    // type.
312    let intermediate_nodes = gate_nodes
313        .into_iter()
314        .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>)
315        .chain(
316            identity_gate_nodes
317                .into_iter()
318                .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
319        )
320        .chain(
321            split_nodes
322                .into_iter()
323                .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
324        )
325        .chain(
326            matmult_nodes
327                .into_iter()
328                .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
329        )
330        .chain(
331            sector_nodes
332                .into_iter()
333                .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
334        )
335        .collect_vec();
336
337    // Step 2a: Determine the topological ordering of intermediate nodes, given
338    // their sources.
339    let circuit_node_graph =
340        Graph::<NodeId>::new_from_circuit_nodes(&intermediate_nodes, &input_shred_ids);
341    let topo_sorted_intermediate_node_ids = &circuit_node_graph.topo_sort()?;
342
343    // Maintain a `NodeId` to `CompilableNode` mapping for later use.
344    let mut id_to_node_mapping: HashMap<NodeId, Box<dyn CompilableNode<F>>> = intermediate_nodes
345        .into_iter()
346        .map(|node| (node.id(), node))
347        .collect();
348
349    let mut intermediate_layers: Vec<Vec<Box<dyn CompilableNode<F>>>> = Vec::new();
350
351    // In this case, we do not wish to combine any sectors together into the same layer.
352    if !should_combine {
353        // We take the topologically sorted nodes, and each one of them forms their own layer.
354        topo_sorted_intermediate_node_ids
355            .iter()
356            .for_each(|node_id| {
357                intermediate_layers.push(vec![id_to_node_mapping.remove(node_id).unwrap()])
358            });
359        assert!(id_to_node_mapping.is_empty())
360    }
361    // Otherwise, combine the sectors according to the max layer size (if provided), otherwise optimizing
362    // to create the least number of layers as possible.
363    else {
364        // For each node, compute the maximum index (in the topological ordering) of all of its
365        // dependencies.
366        let latest_dependency_indices =
367            circuit_node_graph.gen_latest_dependecy_indices(topo_sorted_intermediate_node_ids);
368
369        // Step 2b: Re-order the topological ordering such that non-sector nodes appear as early as
370        // possible.
371        // This is done by sorting the nodes according to the tuple `(adjusted_priority, node_type)`.
372        // See comments below for definitions of those quantities.
373        let mut adjusted_priority: Vec<usize> = vec![0; topo_sorted_intermediate_node_ids.len()];
374        let mut idx_of_latest_sector_node: Option<usize> = None;
375
376        // A vector containing `(node, sorting_key)`, where `sorting_key = (adjusted_priority, node_type)`.
377        let mut nodes_with_sorting_key = topo_sorted_intermediate_node_ids
378            .iter()
379            .enumerate()
380            .map(|(idx, node_id)| {
381                let node = id_to_node_mapping.remove(node_id).unwrap();
382
383                if sector_node_ids.contains(node_id) {
384                    // For a sector node, its adjusted priority is defined as its 1-based index among
385                    // only the other sector nodes in the original topological ordering.
386                    // A 1-based index is used as way of handling non-sector nodes with no (direct or
387                    // indirect) dependencies on sector nodes. See comment on the "else" case of this if
388                    // statement.
389
390                    // The adjusted priority of this sector node is one more than the adjusted priority
391                    // of the last seen sector node, or `1` if no other sector has been seen yet.
392                    adjusted_priority[idx] = idx_of_latest_sector_node
393                        .map(|last_sector_idx| adjusted_priority[last_sector_idx] + 1)
394                        .unwrap_or(1);
395
396                    idx_of_latest_sector_node = Some(idx);
397
398                    // Sector nodes have `node_type == 0` to give them priority in the new ordering.
399                    (node, (adjusted_priority[idx], 0))
400                } else {
401                    // A non-sector node's adjusted priority is the adjusted priority of its latest
402                    // dependencly.
403                    // If this node has no dependencies, set its priority to zero, which is equivalent
404                    // to having a dummy sector node on index `0` on which all other nodes depend on.
405                    adjusted_priority[idx] = latest_dependency_indices[idx]
406                        .map(|latest_idx| adjusted_priority[latest_idx])
407                        .unwrap_or(0);
408
409                    // Non-sector nodes have `node_type == 1`, which in the new ordering places them
410                    // right _after_ the sector nodes they directly depend on.
411                    (node, (adjusted_priority[idx], 1))
412                }
413            })
414            .collect_vec();
415
416        // Sort in increasing order according to the adjusted `sorting_key`.
417        nodes_with_sorting_key.sort_by_key(|&(_, priority)| priority);
418
419        // Remove the sorting key.
420        let topo_sorted_nodes = nodes_with_sorting_key
421            .into_iter()
422            .map(|(val, _)| val)
423            .collect_vec();
424
425        // Step 2c: Determine which nodes can be combined into one.
426        let mut node_to_layer_map: HashMap<NodeId, usize> = HashMap::new();
427        // The first layer that stores sectors.
428        let mut first_sector_layer_idx = 0;
429
430        // For index i in the vector, keep track of the layer number corresponding to the next
431        // sector layer, None if it does not exist.
432        //
433        // For example, if we have the layers [non-sector, sector, non-sector, non-sector, sector, non-sector],
434        // the list would be [1, 4, 4, 4, 0, 0]
435        let mut next_sector_layer_idx_list = Vec::new();
436
437        topo_sorted_nodes.into_iter().for_each(|node| {
438            // If it is a non-sector node, insert it as its own layer
439            // Note that non-sector nodes are already re-ordered to the earliest possible location,
440            // so their layers are also the earliest possible layer.
441            let layer_idx = if !(sector_node_ids.contains(&node.id())) {
442                // Insert a new layer.
443                intermediate_layers.push(Vec::new());
444                // Next sector layer is currently unknown:
445                next_sector_layer_idx_list.push(0);
446                let layer_idx = intermediate_layers.len() - 1;
447                // If every layer so far are non-sector, then the first sector layer hasn't appeared.
448                if layer_idx == first_sector_layer_idx {
449                    first_sector_layer_idx += 1;
450                }
451                layer_idx
452            }
453            // If it is a sector node:
454            else {
455                let maybe_latest_layer_dependency = node
456                    .sources()
457                    .iter()
458                    .filter_map(|node_source| {
459                        if input_shred_ids.contains(node_source) {
460                            None
461                        } else {
462                            Some(node_to_layer_map.get(node_source).unwrap())
463                        }
464                    })
465                    .max();
466                // If it is dependent on some previous node:
467                if let Some(layer_of_node) = maybe_latest_layer_dependency {
468                    // There is no sector layer after the layer it is dependent on.
469                    if next_sector_layer_idx_list[*layer_of_node] == 0 {
470                        // If the dependency is at the last sector layer, create a new layer.
471                        intermediate_layers.push(Vec::new());
472                        // The next sector layer is currently unknown.
473                        next_sector_layer_idx_list.push(0);
474                        let layer_to_insert = intermediate_layers.len() - 1;
475                        // All the previous layers with unknown next sector must all be the
476                        // newest layers. Find them and update their next sector layer.
477                        for i in (0..intermediate_layers.len() - 1).rev() {
478                            if next_sector_layer_idx_list[i] == 0 {
479                                next_sector_layer_idx_list[i] = layer_to_insert;
480                            } else {
481                                break;
482                            }
483                        }
484                        layer_to_insert
485                    }
486                    // There exists a sector layer after the layer it is dependent on.
487                    else {
488                        next_sector_layer_idx_list[*layer_of_node]
489                    }
490                }
491                // Otherwise it is not dependent on any node:
492                else {
493                    // Add it to the first layer.
494                    if intermediate_layers.len() == first_sector_layer_idx {
495                        intermediate_layers.push(Vec::new());
496                        // The next sector layer is currently unknown.
497                        next_sector_layer_idx_list.push(0);
498                        first_sector_layer_idx = intermediate_layers.len() - 1
499                    }
500                    first_sector_layer_idx
501                }
502            };
503            node_to_layer_map.insert(node.id(), layer_idx);
504            intermediate_layers[layer_idx].push(node);
505        });
506    }
507
508    // Step 3: Add LookupConstraints to their respective LookupTables. Build a
509    // map node id -> LookupTable
510    let mut lookup_table_map: HashMap<NodeId, &mut LookupTable> = HashMap::new();
511    for lookup_table in lookup_table_nodes.iter_mut() {
512        lookup_table_map.insert(lookup_table.id(), lookup_table);
513    }
514    for lookup_constraint in lookup_constraint_nodes {
515        let lookup_table_id = lookup_constraint.table_node_id;
516        let lookup_table = lookup_table_map
517            .get_mut(&lookup_table_id)
518            .ok_or(LayoutingError::DanglingNodeId(lookup_table_id))?;
519        lookup_table.add_lookup_constraint(lookup_constraint);
520    }
521
522    Ok((
523        input_layer_nodes,
524        fiat_shamir_challenge_nodes,
525        intermediate_layers,
526        lookup_table_nodes,
527        output_nodes,
528    ))
529}