frontend/layouter/layouting.rs
1//! Defines the utilities for taking a list of nodes and turning it into a
2//! layedout circuit
3#[cfg(test)]
4mod tests;
5
6use std::collections::HashMap;
7use std::collections::HashSet;
8use std::fmt::Debug;
9use std::hash::Hash;
10
11use itertools::Itertools;
12use shared_types::Field;
13use thiserror::Error;
14
15use crate::layouter::nodes::sector::Sector;
16
17use super::nodes::{
18 circuit_inputs::{InputLayerNode, InputShred},
19 circuit_outputs::OutputNode,
20 fiat_shamir_challenge::FiatShamirChallengeNode,
21 gate::GateNode,
22 identity_gate::IdentityGateNode,
23 lookup::{LookupConstraint, LookupTable},
24 matmult::MatMultNode,
25 split_node::SplitNode,
26 CircuitNode, CompilableNode, NodeId,
27};
28
29use anyhow::{anyhow, Result};
30
31/// Possible errors when topologically ordering the dependency graph that arises
32/// from circuit creation and then categorizing them into layers.
33#[derive(Error, Debug, Clone)]
34pub enum LayoutingError {
35 /// There is a cycle in the node dependencies.
36 #[error("There exists a cycle in the dependency of the nodes, and therefore no topological ordering.")]
37 CircularDependency,
38 /// There exists a node which the circuit builder created, but no other node
39 /// references or depends on.
40 #[error("There exists a node which the circuit builder created, but no other node references or depends on: Id = {0:?}")]
41 DanglingNodeId(NodeId),
42 /// We have gotten to a layer whose parts of the expression have not been
43 /// generated.
44 #[error("This circuit location does not exist, or has not been compiled yet")]
45 NoCircuitLocation,
46}
47/// A directed graph represented with an adjacency list.
48///
49/// `repr` maps each vertex `u` to a vector of all vertices `v` such that `(u, v)` is an edge in the
50/// graph.
51///
52/// In the context of this module, vertices correspond to Circuit Nodes, and edges represent
53/// precedence constraints: there is an edge `(u, v)` if `v`'s input depends on `u`'s output.
54///
55/// Type `N` is typically a node-identifying type, such as `NodeId`.
56#[derive(Clone, Debug)]
57pub struct Graph<N: Hash + Eq + Clone + Debug> {
58 repr: HashMap<N, Vec<N>>,
59}
60
61impl<N: Hash + Eq + Clone + Debug> Graph<N> {
62 /// Constructor given the map of dependencies.
63 fn new_from_map(map: HashMap<N, Vec<N>>) -> Self {
64 Self { repr: map }
65 }
66
67 /// Constructor specifically for a [`Graph<NodeId>`], which will convert an
68 /// array of [CompilableNode], each of which reference their sources, and
69 /// convert that into the graph representation.
70 ///
71 /// Note: we only topologically sort intermediate nodes, therefore we
72 /// provide `input_shred_ids` to exclude them from the graph.
73 fn new_from_circuit_nodes<F: Field>(
74 intermediate_circuit_nodes: &[Box<dyn CompilableNode<F>>],
75 input_shred_ids: &HashSet<NodeId>,
76 ) -> Graph<NodeId> {
77 let mut children_to_parent_map: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
78 intermediate_circuit_nodes.iter().for_each(|circuit_node| {
79 // Disregard the nodes which are input shreds.
80 if !input_shred_ids.contains(&circuit_node.id()) {
81 children_to_parent_map.insert(
82 circuit_node.id(),
83 circuit_node
84 .sources()
85 .iter()
86 .filter_map(|circuit_node_source| {
87 // If any of the sources are input shreds, we don't
88 // add them to the graph.
89 if input_shred_ids.contains(circuit_node_source) {
90 None
91 } else {
92 Some(*circuit_node_source)
93 }
94 })
95 .collect(),
96 );
97 }
98 });
99
100 Graph::<NodeId>::new_from_map(children_to_parent_map)
101 }
102
103 /// Given the graph, a starting node (`node`), the set of visited/"marked"
104 /// nodes (`exploring_nodes`) and nodes that have been fully explored
105 /// (`terminated_nodes`), traverse the graph recursively using the DFS
106 /// algorithm from the starting node.
107 ///
108 /// Returns nothing, but mutates the `exploring_nodes` and
109 /// `terminated_nodes` accordingly.
110 fn visit_node_dfs(
111 &self,
112 node: &N,
113 exploring_nodes: &mut HashSet<N>,
114 terminated_nodes: &mut HashSet<N>,
115 topological_order: &mut Vec<N>,
116 ) -> Result<()> {
117 if terminated_nodes.contains(node) {
118 return Ok(());
119 }
120 if exploring_nodes.contains(node) {
121 return Err(anyhow!(LayoutingError::CircularDependency));
122 }
123 let neighbors = self.repr.get(node).unwrap();
124
125 exploring_nodes.insert(node.clone());
126 for neighbor_not_terminated in neighbors {
127 self.visit_node_dfs(
128 neighbor_not_terminated,
129 exploring_nodes,
130 terminated_nodes,
131 topological_order,
132 )?;
133 }
134 exploring_nodes.remove(node);
135
136 terminated_nodes.insert(node.clone());
137 topological_order.push(node.clone());
138
139 Ok(())
140 }
141
142 /// Topologically orders the graph by visiting each node via DFS. Returns
143 /// the terminated nodes in order of termination.
144 ///
145 /// Note: normally, the topological sort is the "reverse" order of
146 /// termination via DFS. However, since we wish to order from sources ->
147 /// sinks, and our dependency graph is structured as children -> parent, we
148 /// do not reverse the termination order of the DFS search.
149 fn topo_sort(&self) -> Result<Vec<N>> {
150 let mut terminated_nodes: HashSet<N> = HashSet::new();
151 let mut topological_order: Vec<N> = Vec::with_capacity(self.repr.len());
152 let mut exploring_nodes: HashSet<N> = HashSet::new();
153
154 for node in self.repr.keys() {
155 self.visit_node_dfs(
156 node,
157 &mut exploring_nodes,
158 &mut terminated_nodes,
159 &mut topological_order,
160 )?;
161 }
162
163 assert_eq!(topological_order.len(), self.repr.len());
164
165 Ok(topological_order)
166 }
167
168 /// Given a list of nodes that are already topologically ordered, we check
169 /// all nodes that come before it to return the index in the topological
170 /// sort of the latest node that it was dependent on.
171 fn get_index_of_latest_dependency(
172 &self,
173 topological_order: &[N],
174 idx_to_check: usize,
175 ) -> Option<usize> {
176 let node_to_check = &topological_order[idx_to_check];
177 let neighbors = self.repr.get(node_to_check).unwrap();
178 let mut latest_dependency_idx: Option<usize> = None;
179
180 for idx in (0..idx_to_check).rev() {
181 let node = &topological_order[idx];
182 if neighbors.contains(node) {
183 latest_dependency_idx = Some(idx);
184 break;
185 }
186 }
187 latest_dependency_idx
188 }
189
190 /// Given a _valid_ topological ordering `topological_order == [v_0, v_1, ..., v_{n-1}]`, this
191 /// method returns a vector `latest_dependency_indices` such that `latest_dependency_indices[i]
192 /// == Some(j)` iff:
193 /// (a) there is a edge/dependency (v_i, v_j) in the graph, and
194 /// (b) j is the maximum index for which such an edge exists.
195 /// If there are no edges coming out of `v_i`, then `latest_dependency_indices[i] == None`.
196 ///
197 /// # Complexity
198 /// Linear in the size of the graph: `O(|V| + |E|)`.
199 fn gen_latest_dependecy_indices(&self, topological_order: &[N]) -> Vec<Option<usize>> {
200 assert_eq!(topological_order.len(), self.repr.len());
201
202 // Invert `topological_order` to map nodes to their index in the topological ordering.
203 let mut node_idx = HashMap::<N, usize>::new();
204 topological_order
205 .iter()
206 .enumerate()
207 .for_each(|(idx, node)| {
208 debug_assert!(self.repr.contains_key(node));
209 node_idx.insert(node.clone(), idx);
210 });
211
212 topological_order
213 .iter()
214 .map(|u| {
215 let deps = self.repr.get(u).unwrap();
216 deps.iter().map(|u| node_idx[u]).max()
217 })
218 .collect_vec()
219 }
220
221 /// Inefficient variant of `gen_latest_dependency_indices` used for correctness tests.
222 pub fn naive_gen_latest_dependecy_indices(
223 &self,
224 topological_order: &[N],
225 ) -> Vec<Option<usize>> {
226 let n = self.repr.len();
227 assert_eq!(topological_order.len(), n);
228
229 (0..n)
230 .map(|i| self.get_index_of_latest_dependency(topological_order, i))
231 .collect()
232 }
233}
234
235/// The type returned by the [layout] function. Categorizes the nodes into their
236/// respective layers.
237type LayouterNodes<F> = (
238 Vec<InputLayerNode>,
239 Vec<FiatShamirChallengeNode>,
240 // The inner vector represents nodes to be combined into the same layer. The
241 // outer vector is in the order of the layers to be compiled in terms of
242 // dependency.
243 Vec<Vec<Box<dyn CompilableNode<F>>>>,
244 Vec<LookupTable>,
245 Vec<OutputNode>,
246);
247
248/// Given the nodes provided by the circuit builder, this function returns a
249/// tuple of type `LayouterNodes`.
250///
251/// This function categorizes nodes into their respective layers, by doing the
252/// following:
253/// * Assigning `input_shred_nodes` to their respective `input_layer_nodes`
254/// parent.
255/// * The `fiat_shamir_challenge_nodes` are to be compiled next, so they are
256/// returned as is.
257/// * `sector_nodes`, `gate_nodes`, `identity_gate_nodes`, `matmult_nodes`, and
258/// `split_nodes` are considered intermediate nodes. `sector_nodes` are the
259/// only nodes which can be combined with each other via a selector.
260/// Therefore, first we topologically sort the intermediate nodes by creating
261/// a dependency graph using their specified sources. Then, we do a forward
262/// pass through these sorted nodes, identify which ones are the sectors, and
263/// combine them greedily (if there is no dependency between them, we
264/// combine). We then return a [`Vec<Vec<Box<dyn CompilableNode<F>>>>`] for
265/// which each inner vector represents nodes that can be combined into a
266/// single layer.
267/// * `lookup_constraint_nodes` are added to their respective
268/// `lookup_table_nodes`. Because no nodes are dependent on lookups (their
269/// results are always outputs), we compile them after the intermediate
270/// nodes.
271/// * `output_nodes` are compiled last, so they are returned as is.
272///
273/// The ordering in which the nodes are returned as `LayouterNodes` is the order
274/// in which the nodes are expected to be compiled into layers.
275#[allow(clippy::too_many_arguments)]
276pub fn layout<F: Field>(
277 mut input_layer_nodes: Vec<InputLayerNode>,
278 input_shred_nodes: Vec<InputShred>,
279 fiat_shamir_challenge_nodes: Vec<FiatShamirChallengeNode>,
280 output_nodes: Vec<OutputNode>,
281 sector_nodes: Vec<Sector<F>>,
282 gate_nodes: Vec<GateNode>,
283 identity_gate_nodes: Vec<IdentityGateNode>,
284 split_nodes: Vec<SplitNode>,
285 matmult_nodes: Vec<MatMultNode>,
286 lookup_constraint_nodes: Vec<LookupConstraint>,
287 mut lookup_table_nodes: Vec<LookupTable>,
288 should_combine: bool,
289) -> Result<LayouterNodes<F>> {
290 let mut input_layer_map: HashMap<NodeId, &mut InputLayerNode> = HashMap::new();
291 let sector_node_ids: HashSet<NodeId> = sector_nodes.iter().map(|sector| sector.id()).collect();
292
293 for layer in input_layer_nodes.iter_mut() {
294 input_layer_map.insert(layer.id(), layer);
295 }
296
297 // Step 1: Add `input_shred_nodes` to their specified `input_layer_nodes`
298 // parent.
299 let mut input_shred_ids = HashSet::<NodeId>::new();
300 for input_shred in input_shred_nodes {
301 let input_layer_id = input_shred.get_parent();
302 input_shred_ids.insert(input_shred.id());
303 let input_layer = input_layer_map
304 .get_mut(&input_layer_id)
305 .ok_or(LayoutingError::DanglingNodeId(input_layer_id))?;
306 input_layer.add_shred(input_shred);
307 }
308 input_shred_ids.extend(fiat_shamir_challenge_nodes.iter().map(|node| node.id()));
309
310 // We cast all intermediate nodes into their generic trait implementation
311 // type.
312 let intermediate_nodes = gate_nodes
313 .into_iter()
314 .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>)
315 .chain(
316 identity_gate_nodes
317 .into_iter()
318 .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
319 )
320 .chain(
321 split_nodes
322 .into_iter()
323 .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
324 )
325 .chain(
326 matmult_nodes
327 .into_iter()
328 .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
329 )
330 .chain(
331 sector_nodes
332 .into_iter()
333 .map(|node| Box::new(node) as Box<dyn CompilableNode<F>>),
334 )
335 .collect_vec();
336
337 // Step 2a: Determine the topological ordering of intermediate nodes, given
338 // their sources.
339 let circuit_node_graph =
340 Graph::<NodeId>::new_from_circuit_nodes(&intermediate_nodes, &input_shred_ids);
341 let topo_sorted_intermediate_node_ids = &circuit_node_graph.topo_sort()?;
342
343 // Maintain a `NodeId` to `CompilableNode` mapping for later use.
344 let mut id_to_node_mapping: HashMap<NodeId, Box<dyn CompilableNode<F>>> = intermediate_nodes
345 .into_iter()
346 .map(|node| (node.id(), node))
347 .collect();
348
349 let mut intermediate_layers: Vec<Vec<Box<dyn CompilableNode<F>>>> = Vec::new();
350
351 // In this case, we do not wish to combine any sectors together into the same layer.
352 if !should_combine {
353 // We take the topologically sorted nodes, and each one of them forms their own layer.
354 topo_sorted_intermediate_node_ids
355 .iter()
356 .for_each(|node_id| {
357 intermediate_layers.push(vec![id_to_node_mapping.remove(node_id).unwrap()])
358 });
359 assert!(id_to_node_mapping.is_empty())
360 }
361 // Otherwise, combine the sectors according to the max layer size (if provided), otherwise optimizing
362 // to create the least number of layers as possible.
363 else {
364 // For each node, compute the maximum index (in the topological ordering) of all of its
365 // dependencies.
366 let latest_dependency_indices =
367 circuit_node_graph.gen_latest_dependecy_indices(topo_sorted_intermediate_node_ids);
368
369 // Step 2b: Re-order the topological ordering such that non-sector nodes appear as early as
370 // possible.
371 // This is done by sorting the nodes according to the tuple `(adjusted_priority, node_type)`.
372 // See comments below for definitions of those quantities.
373 let mut adjusted_priority: Vec<usize> = vec![0; topo_sorted_intermediate_node_ids.len()];
374 let mut idx_of_latest_sector_node: Option<usize> = None;
375
376 // A vector containing `(node, sorting_key)`, where `sorting_key = (adjusted_priority, node_type)`.
377 let mut nodes_with_sorting_key = topo_sorted_intermediate_node_ids
378 .iter()
379 .enumerate()
380 .map(|(idx, node_id)| {
381 let node = id_to_node_mapping.remove(node_id).unwrap();
382
383 if sector_node_ids.contains(node_id) {
384 // For a sector node, its adjusted priority is defined as its 1-based index among
385 // only the other sector nodes in the original topological ordering.
386 // A 1-based index is used as way of handling non-sector nodes with no (direct or
387 // indirect) dependencies on sector nodes. See comment on the "else" case of this if
388 // statement.
389
390 // The adjusted priority of this sector node is one more than the adjusted priority
391 // of the last seen sector node, or `1` if no other sector has been seen yet.
392 adjusted_priority[idx] = idx_of_latest_sector_node
393 .map(|last_sector_idx| adjusted_priority[last_sector_idx] + 1)
394 .unwrap_or(1);
395
396 idx_of_latest_sector_node = Some(idx);
397
398 // Sector nodes have `node_type == 0` to give them priority in the new ordering.
399 (node, (adjusted_priority[idx], 0))
400 } else {
401 // A non-sector node's adjusted priority is the adjusted priority of its latest
402 // dependencly.
403 // If this node has no dependencies, set its priority to zero, which is equivalent
404 // to having a dummy sector node on index `0` on which all other nodes depend on.
405 adjusted_priority[idx] = latest_dependency_indices[idx]
406 .map(|latest_idx| adjusted_priority[latest_idx])
407 .unwrap_or(0);
408
409 // Non-sector nodes have `node_type == 1`, which in the new ordering places them
410 // right _after_ the sector nodes they directly depend on.
411 (node, (adjusted_priority[idx], 1))
412 }
413 })
414 .collect_vec();
415
416 // Sort in increasing order according to the adjusted `sorting_key`.
417 nodes_with_sorting_key.sort_by_key(|&(_, priority)| priority);
418
419 // Remove the sorting key.
420 let topo_sorted_nodes = nodes_with_sorting_key
421 .into_iter()
422 .map(|(val, _)| val)
423 .collect_vec();
424
425 // Step 2c: Determine which nodes can be combined into one.
426 let mut node_to_layer_map: HashMap<NodeId, usize> = HashMap::new();
427 // The first layer that stores sectors.
428 let mut first_sector_layer_idx = 0;
429
430 // For index i in the vector, keep track of the layer number corresponding to the next
431 // sector layer, None if it does not exist.
432 //
433 // For example, if we have the layers [non-sector, sector, non-sector, non-sector, sector, non-sector],
434 // the list would be [1, 4, 4, 4, 0, 0]
435 let mut next_sector_layer_idx_list = Vec::new();
436
437 topo_sorted_nodes.into_iter().for_each(|node| {
438 // If it is a non-sector node, insert it as its own layer
439 // Note that non-sector nodes are already re-ordered to the earliest possible location,
440 // so their layers are also the earliest possible layer.
441 let layer_idx = if !(sector_node_ids.contains(&node.id())) {
442 // Insert a new layer.
443 intermediate_layers.push(Vec::new());
444 // Next sector layer is currently unknown:
445 next_sector_layer_idx_list.push(0);
446 let layer_idx = intermediate_layers.len() - 1;
447 // If every layer so far are non-sector, then the first sector layer hasn't appeared.
448 if layer_idx == first_sector_layer_idx {
449 first_sector_layer_idx += 1;
450 }
451 layer_idx
452 }
453 // If it is a sector node:
454 else {
455 let maybe_latest_layer_dependency = node
456 .sources()
457 .iter()
458 .filter_map(|node_source| {
459 if input_shred_ids.contains(node_source) {
460 None
461 } else {
462 Some(node_to_layer_map.get(node_source).unwrap())
463 }
464 })
465 .max();
466 // If it is dependent on some previous node:
467 if let Some(layer_of_node) = maybe_latest_layer_dependency {
468 // There is no sector layer after the layer it is dependent on.
469 if next_sector_layer_idx_list[*layer_of_node] == 0 {
470 // If the dependency is at the last sector layer, create a new layer.
471 intermediate_layers.push(Vec::new());
472 // The next sector layer is currently unknown.
473 next_sector_layer_idx_list.push(0);
474 let layer_to_insert = intermediate_layers.len() - 1;
475 // All the previous layers with unknown next sector must all be the
476 // newest layers. Find them and update their next sector layer.
477 for i in (0..intermediate_layers.len() - 1).rev() {
478 if next_sector_layer_idx_list[i] == 0 {
479 next_sector_layer_idx_list[i] = layer_to_insert;
480 } else {
481 break;
482 }
483 }
484 layer_to_insert
485 }
486 // There exists a sector layer after the layer it is dependent on.
487 else {
488 next_sector_layer_idx_list[*layer_of_node]
489 }
490 }
491 // Otherwise it is not dependent on any node:
492 else {
493 // Add it to the first layer.
494 if intermediate_layers.len() == first_sector_layer_idx {
495 intermediate_layers.push(Vec::new());
496 // The next sector layer is currently unknown.
497 next_sector_layer_idx_list.push(0);
498 first_sector_layer_idx = intermediate_layers.len() - 1
499 }
500 first_sector_layer_idx
501 }
502 };
503 node_to_layer_map.insert(node.id(), layer_idx);
504 intermediate_layers[layer_idx].push(node);
505 });
506 }
507
508 // Step 3: Add LookupConstraints to their respective LookupTables. Build a
509 // map node id -> LookupTable
510 let mut lookup_table_map: HashMap<NodeId, &mut LookupTable> = HashMap::new();
511 for lookup_table in lookup_table_nodes.iter_mut() {
512 lookup_table_map.insert(lookup_table.id(), lookup_table);
513 }
514 for lookup_constraint in lookup_constraint_nodes {
515 let lookup_table_id = lookup_constraint.table_node_id;
516 let lookup_table = lookup_table_map
517 .get_mut(&lookup_table_id)
518 .ok_or(LayoutingError::DanglingNodeId(lookup_table_id))?;
519 lookup_table.add_lookup_constraint(lookup_constraint);
520 }
521
522 Ok((
523 input_layer_nodes,
524 fiat_shamir_challenge_nodes,
525 intermediate_layers,
526 lookup_table_nodes,
527 output_nodes,
528 ))
529}