frontend/layouter/
nodes.rs

1//! Module for nodes that can be added to a circuit DAG
2pub use itertools::Either;
3use serde::{Deserialize, Serialize};
4pub use shared_types::{Field, Fr};
5
6use crate::abstract_expr::AbstractExpression;
7use crate::layouter::builder::CircuitMap;
8use remainder::layer::layer_enum::LayerDescriptionEnum;
9
10use remainder::circuit_building_context::CircuitBuildingContext;
11
12use anyhow::Result;
13
14pub mod circuit_inputs;
15pub mod circuit_outputs;
16pub mod fiat_shamir_challenge;
17pub mod gate;
18pub mod identity_gate;
19pub mod lookup;
20pub mod matmult;
21pub mod node_enum;
22pub mod sector;
23pub mod split_node;
24
25/// The circuit-unique ID for each node
26#[derive(Clone, Debug, Hash, PartialEq, Eq, Copy, Ord, PartialOrd, Serialize, Deserialize)]
27pub struct NodeId(usize);
28
29impl Default for NodeId {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl NodeId {
36    /// Creates a new NodeId from the global circuit context.
37    pub fn new() -> Self {
38        Self(CircuitBuildingContext::next_node_id())
39    }
40
41    /// Creates a new NodeId from a usize, for testing only
42    #[cfg(test)]
43    pub fn new_unsafe(id: usize) -> Self {
44        Self(id)
45    }
46
47    /// creates an [`AbstractExpression<F>`] from this NodeId
48    pub fn expr<F: Field>(self) -> AbstractExpression<F> {
49        AbstractExpression::<F>::mle(self)
50    }
51
52    /// Obtain the integer value, for printing GkrError messages
53    pub fn get_id(self) -> usize {
54        self.0
55    }
56}
57
58/// Implement Display for NodeId, so that we can use it in error messages
59impl std::fmt::Display for NodeId {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "{:#x}", self.0)
62    }
63}
64
65/// A Node in the directed acyclic graph (DAG). The directed edges in the DAG model the dependencies
66/// between nodes, with the dependent node being the target of the edge and the dependency being the
67/// source.
68///
69/// All "contiguous" MLEs, e.g. sector, gate, matmul, output, input shred
70pub trait CircuitNode {
71    /// The unique ID of this node
72    fn id(&self) -> NodeId;
73
74    /// Return the ids of the nodes that this node depends upon, i.e. nodes whose values must be
75    /// known before the values of this node can be node.  These are the source nodes of the
76    /// directed edges of the DAG that terminate at this node.
77    fn sources(&self) -> Vec<NodeId>;
78
79    /// Get the number of variables used to represent the data in this node.
80    fn get_num_vars(&self) -> usize;
81}
82
83/// A Node that contains the information neccessary to Compile itself
84///
85/// Implement this for any node that does not need additional Layingout before compilation
86///
87/// TODO: Merge this with `circuitnode`
88pub trait CompilableNode<F: Field>: CircuitNode {
89    /// Generate the circuit description of a node, which represents the
90    /// shape of a certain layer.
91    fn generate_circuit_description(
92        &self,
93        circuit_map: &mut CircuitMap,
94    ) -> Result<Vec<LayerDescriptionEnum<F>>>;
95}
96
97#[macro_export]
98///This macro generates a layer enum that represents all the possible layers
99/// Every layer variant of the enum needs to implement Layer, and the enum will also implement Layer and pass methods to it's variants
100///
101/// Usage:
102///
103/// layer_enum(EnumName, (FirstVariant: LayerType), (SecondVariant: SecondLayerType), ..)
104macro_rules! node_enum {
105    ($type_name:ident: $bound:tt, $(($var_name:ident: $variant:ty)),+) => {
106        #[derive(Clone, Debug)]
107        #[doc = r"Remainder generated trait enum"]
108        pub enum $type_name<F: $bound> {
109            $(
110                #[doc = "Remainder generated node variant"]
111                $var_name($variant),
112            )*
113        }
114
115
116        impl<F: $bound> $crate::layouter::nodes::CircuitNode for $type_name<F> {
117            fn id(&self) -> $crate::layouter::nodes::NodeId {
118                match self {
119                    $(
120                        Self::$var_name(node) => node.id(),
121                    )*
122                }
123            }
124
125            fn sources(&self) -> Vec<$crate::layouter::nodes::NodeId> {
126                match self {
127                    $(
128                        Self::$var_name(node) => node.sources(),
129                    )*
130                }
131            }
132
133            fn get_num_vars(&self) -> usize {
134                match self {
135                    $(
136                        Self::$var_name(node) => node.get_num_vars(),
137                    )*
138                }
139            }
140        }
141
142        $(
143            impl<F: $bound> From<$variant> for $type_name<F> {
144                fn from(var: $variant) -> $type_name<F> {
145                    Self::$var_name(var)
146                }
147            }
148        )*
149    }
150}