frontend/layouter/nodes/
matmult.rs

1//! A Module for adding `Matmult` Layers to components
2
3use shared_types::Field;
4
5use remainder::{
6    circuit_layout::CircuitLocation,
7    layer::{
8        layer_enum::LayerDescriptionEnum,
9        matmult::{MatMultLayerDescription, MatrixDescription},
10        LayerId,
11    },
12    mle::mle_description::MleDescription,
13    utils::mle::get_total_mle_indices,
14};
15
16use crate::layouter::builder::CircuitMap;
17
18use super::{CircuitNode, CompilableNode, NodeId};
19
20use anyhow::Result;
21
22/// A Node that represents a `Gate` layer
23#[derive(Clone, Debug)]
24pub struct MatMultNode {
25    id: NodeId,
26    matrix_a: NodeId,
27    rows_cols_num_vars_a: (usize, usize),
28    matrix_b: NodeId,
29    rows_cols_num_vars_b: (usize, usize),
30    num_vars: usize,
31}
32
33impl CircuitNode for MatMultNode {
34    fn id(&self) -> NodeId {
35        self.id
36    }
37
38    fn sources(&self) -> Vec<NodeId> {
39        vec![self.matrix_a, self.matrix_b]
40    }
41
42    fn get_num_vars(&self) -> usize {
43        self.num_vars
44    }
45}
46
47impl MatMultNode {
48    /// Constructs a new MatMultNode and computes the data it generates
49    pub fn new(
50        matrix_node_a: &dyn CircuitNode,
51        rows_cols_num_vars_a: (usize, usize),
52        matrix_node_b: &dyn CircuitNode,
53        rows_cols_num_vars_b: (usize, usize),
54    ) -> Self {
55        assert_eq!(rows_cols_num_vars_a.1, rows_cols_num_vars_b.0);
56        let num_product_vars = rows_cols_num_vars_a.0 + rows_cols_num_vars_b.1;
57
58        Self {
59            id: NodeId::new(),
60            matrix_a: matrix_node_a.id(),
61            rows_cols_num_vars_a,
62            matrix_b: matrix_node_b.id(),
63            rows_cols_num_vars_b,
64            num_vars: num_product_vars,
65        }
66    }
67}
68
69impl<F: Field> CompilableNode<F> for MatMultNode {
70    fn generate_circuit_description(
71        &self,
72        circuit_map: &mut CircuitMap,
73    ) -> Result<Vec<LayerDescriptionEnum<F>>> {
74        let (matrix_a_location, matrix_a_num_vars) =
75            circuit_map.get_location_num_vars_from_node_id(&self.matrix_a)?;
76
77        let mle_a_indices =
78            get_total_mle_indices(&matrix_a_location.prefix_bits, *matrix_a_num_vars);
79        let circuit_mle_a = MleDescription::new(matrix_a_location.layer_id, &mle_a_indices);
80
81        // Matrix A and matrix B are not padded because the data from the previous layer is only stored as the raw [MultilinearExtension].
82        let matrix_a = MatrixDescription::new(
83            circuit_mle_a,
84            self.rows_cols_num_vars_a.0,
85            self.rows_cols_num_vars_a.1,
86        );
87        let (matrix_b_location, matrix_b_num_vars) =
88            circuit_map.get_location_num_vars_from_node_id(&self.matrix_b)?;
89        let mle_b_indices =
90            get_total_mle_indices(&matrix_b_location.prefix_bits, *matrix_b_num_vars);
91        let circuit_mle_b = MleDescription::new(matrix_b_location.layer_id, &mle_b_indices);
92
93        // should already been padded
94        let matrix_b = MatrixDescription::new(
95            circuit_mle_b,
96            self.rows_cols_num_vars_b.0,
97            self.rows_cols_num_vars_b.1,
98        );
99
100        let matmult_layer_id = LayerId::next_layer_id();
101        let matmult_layer = MatMultLayerDescription::new(matmult_layer_id, matrix_a, matrix_b);
102        circuit_map.add_node_id_and_location_num_vars(
103            self.id,
104            (
105                CircuitLocation::new(matmult_layer_id, vec![]),
106                self.get_num_vars(),
107            ),
108        );
109
110        Ok(vec![LayerDescriptionEnum::MatMult(matmult_layer)])
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use shared_types::{Field, Fr};
117
118    use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
119    use remainder::mle::evals::MultilinearExtension;
120
121    use remainder::prover::helpers::test_circuit_with_runtime_optimized_config;
122
123    /// Creates the [GKRCircuitDescription] and an associated helper input
124    /// function allowing for ease of proving for the matmul test circuit.
125    fn build_matmul_test_circuit_description<F: Field>(
126        matrix_a_num_rows_vars: usize,
127        matrix_a_num_cols_vars: usize, // This is the same as `matrix_b_num_rows_vars`
128        matrix_b_num_cols_vars: usize,
129    ) -> Circuit<F> {
130        let mut builder = CircuitBuilder::<F>::new();
131
132        // All inputs are public inputs
133        let public_input_layer_node =
134            builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
135
136        // Inputs to the circuit include the "matrix A MLE" and the "matrix B MLE"
137        let matrix_a_mle_shred = builder.add_input_shred(
138            "Matrix A MLE",
139            matrix_a_num_rows_vars + matrix_a_num_cols_vars,
140            &public_input_layer_node,
141        );
142        let matrix_b_mle_shred = builder.add_input_shred(
143            "Matrix B MLE",
144            matrix_a_num_cols_vars + matrix_b_num_cols_vars,
145            &public_input_layer_node,
146        );
147        let expected_result_mle_shred = builder.add_input_shred(
148            "Expected Result MLE",
149            matrix_a_num_rows_vars + matrix_b_num_cols_vars,
150            &public_input_layer_node,
151        );
152
153        // Create the circuit components
154        let matmult_sector = builder.add_matmult_node(
155            &matrix_a_mle_shred,
156            (matrix_a_num_rows_vars, matrix_a_num_cols_vars),
157            &matrix_b_mle_shred,
158            (matrix_a_num_cols_vars, matrix_b_num_cols_vars),
159        );
160
161        let difference_sector = builder.add_sector(matmult_sector - expected_result_mle_shred);
162        builder.set_output(&difference_sector);
163
164        builder.build_with_layer_combination().unwrap()
165    }
166
167    #[test]
168    fn test_matmult_node_in_circuit() {
169        // Define data + input sizes first
170        // (4, 2) * (2, 2) --> (2, 2) for real sizes; take log_2 for num vars
171        let matrix_a_num_rows_vars = 2;
172        let matrix_a_num_cols_vars = 1;
173        let matrix_b_num_rows_vars = 1;
174
175        let matrix_a_mle: MultilinearExtension<Fr> = vec![1, 2, 9, 10, 13, 1, 3, 10].into();
176        let matrix_b_mle: MultilinearExtension<Fr> = vec![3, 5, 9, 6].into();
177        let expected_matrix_mle: MultilinearExtension<Fr> = vec![
178            3 + 2 * 9,
179            5 + 2 * 6,
180            9 * 3 + 10 * 9,
181            9 * 5 + 10 * 6,
182            13 * 3 + 9,
183            13 * 5 + 6,
184            3 * 3 + 10 * 9,
185            3 * 5 + 10 * 6,
186        ]
187        .into();
188
189        // Create circuit description + input helper function
190        let mut circuit = build_matmul_test_circuit_description(
191            matrix_a_num_rows_vars,
192            matrix_a_num_cols_vars,
193            matrix_b_num_rows_vars,
194        );
195
196        circuit.set_input("Matrix A MLE", matrix_a_mle);
197        circuit.set_input("Matrix B MLE", matrix_b_mle);
198        circuit.set_input("Expected Result MLE", expected_matrix_mle);
199
200        let provable_circuit = circuit.gen_provable_circuit().unwrap();
201
202        // Prove/verify the circuit
203        test_circuit_with_runtime_optimized_config(&provable_circuit);
204    }
205
206    #[test]
207    fn test_non_power_of_2_matmult_node_in_circuit() {
208        // Define data + input sizes first
209        // (4, 2) * (2, 2) --> (2, 2) for real sizes; take log_2 for num vars
210        let matrix_a_num_rows_vars = 2;
211        let matrix_a_num_cols_vars = 1;
212        let matrix_b_num_rows_vars = 1;
213
214        let matrix_a_mle: MultilinearExtension<Fr> = vec![1, 2, 9, 10, 13, 1].into();
215        let matrix_b_mle: MultilinearExtension<Fr> = vec![3, 5, 9, 6].into();
216        let expected_matrix_mle: MultilinearExtension<Fr> = vec![
217            3 + 2 * 9,
218            5 + 2 * 6,
219            9 * 3 + 10 * 9,
220            9 * 5 + 10 * 6,
221            13 * 3 + 9,
222            13 * 5 + 6,
223        ]
224        .into();
225
226        // Create circuit description + input helper function
227        let mut circuit = build_matmul_test_circuit_description(
228            matrix_a_num_rows_vars,
229            matrix_a_num_cols_vars,
230            matrix_b_num_rows_vars,
231        );
232
233        circuit.set_input("Matrix A MLE", matrix_a_mle);
234        circuit.set_input("Matrix B MLE", matrix_b_mle);
235        circuit.set_input("Expected Result MLE", expected_matrix_mle);
236
237        let provable_circuit = circuit.gen_provable_circuit().unwrap();
238
239        // Prove/verify the circuit
240        test_circuit_with_runtime_optimized_config(&provable_circuit);
241    }
242}