1use shared_types::Field;
4
5use remainder::{
6 circuit_layout::CircuitLocation,
7 layer::{
8 layer_enum::LayerDescriptionEnum,
9 matmult::{MatMultLayerDescription, MatrixDescription},
10 LayerId,
11 },
12 mle::mle_description::MleDescription,
13 utils::mle::get_total_mle_indices,
14};
15
16use crate::layouter::builder::CircuitMap;
17
18use super::{CircuitNode, CompilableNode, NodeId};
19
20use anyhow::Result;
21
22#[derive(Clone, Debug)]
24pub struct MatMultNode {
25 id: NodeId,
26 matrix_a: NodeId,
27 rows_cols_num_vars_a: (usize, usize),
28 matrix_b: NodeId,
29 rows_cols_num_vars_b: (usize, usize),
30 num_vars: usize,
31}
32
33impl CircuitNode for MatMultNode {
34 fn id(&self) -> NodeId {
35 self.id
36 }
37
38 fn sources(&self) -> Vec<NodeId> {
39 vec![self.matrix_a, self.matrix_b]
40 }
41
42 fn get_num_vars(&self) -> usize {
43 self.num_vars
44 }
45}
46
47impl MatMultNode {
48 pub fn new(
50 matrix_node_a: &dyn CircuitNode,
51 rows_cols_num_vars_a: (usize, usize),
52 matrix_node_b: &dyn CircuitNode,
53 rows_cols_num_vars_b: (usize, usize),
54 ) -> Self {
55 assert_eq!(rows_cols_num_vars_a.1, rows_cols_num_vars_b.0);
56 let num_product_vars = rows_cols_num_vars_a.0 + rows_cols_num_vars_b.1;
57
58 Self {
59 id: NodeId::new(),
60 matrix_a: matrix_node_a.id(),
61 rows_cols_num_vars_a,
62 matrix_b: matrix_node_b.id(),
63 rows_cols_num_vars_b,
64 num_vars: num_product_vars,
65 }
66 }
67}
68
69impl<F: Field> CompilableNode<F> for MatMultNode {
70 fn generate_circuit_description(
71 &self,
72 circuit_map: &mut CircuitMap,
73 ) -> Result<Vec<LayerDescriptionEnum<F>>> {
74 let (matrix_a_location, matrix_a_num_vars) =
75 circuit_map.get_location_num_vars_from_node_id(&self.matrix_a)?;
76
77 let mle_a_indices =
78 get_total_mle_indices(&matrix_a_location.prefix_bits, *matrix_a_num_vars);
79 let circuit_mle_a = MleDescription::new(matrix_a_location.layer_id, &mle_a_indices);
80
81 let matrix_a = MatrixDescription::new(
83 circuit_mle_a,
84 self.rows_cols_num_vars_a.0,
85 self.rows_cols_num_vars_a.1,
86 );
87 let (matrix_b_location, matrix_b_num_vars) =
88 circuit_map.get_location_num_vars_from_node_id(&self.matrix_b)?;
89 let mle_b_indices =
90 get_total_mle_indices(&matrix_b_location.prefix_bits, *matrix_b_num_vars);
91 let circuit_mle_b = MleDescription::new(matrix_b_location.layer_id, &mle_b_indices);
92
93 let matrix_b = MatrixDescription::new(
95 circuit_mle_b,
96 self.rows_cols_num_vars_b.0,
97 self.rows_cols_num_vars_b.1,
98 );
99
100 let matmult_layer_id = LayerId::next_layer_id();
101 let matmult_layer = MatMultLayerDescription::new(matmult_layer_id, matrix_a, matrix_b);
102 circuit_map.add_node_id_and_location_num_vars(
103 self.id,
104 (
105 CircuitLocation::new(matmult_layer_id, vec![]),
106 self.get_num_vars(),
107 ),
108 );
109
110 Ok(vec![LayerDescriptionEnum::MatMult(matmult_layer)])
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use shared_types::{Field, Fr};
117
118 use crate::layouter::builder::{Circuit, CircuitBuilder, LayerVisibility};
119 use remainder::mle::evals::MultilinearExtension;
120
121 use remainder::prover::helpers::test_circuit_with_runtime_optimized_config;
122
123 fn build_matmul_test_circuit_description<F: Field>(
126 matrix_a_num_rows_vars: usize,
127 matrix_a_num_cols_vars: usize, matrix_b_num_cols_vars: usize,
129 ) -> Circuit<F> {
130 let mut builder = CircuitBuilder::<F>::new();
131
132 let public_input_layer_node =
134 builder.add_input_layer("Public Input Layer", LayerVisibility::Public);
135
136 let matrix_a_mle_shred = builder.add_input_shred(
138 "Matrix A MLE",
139 matrix_a_num_rows_vars + matrix_a_num_cols_vars,
140 &public_input_layer_node,
141 );
142 let matrix_b_mle_shred = builder.add_input_shred(
143 "Matrix B MLE",
144 matrix_a_num_cols_vars + matrix_b_num_cols_vars,
145 &public_input_layer_node,
146 );
147 let expected_result_mle_shred = builder.add_input_shred(
148 "Expected Result MLE",
149 matrix_a_num_rows_vars + matrix_b_num_cols_vars,
150 &public_input_layer_node,
151 );
152
153 let matmult_sector = builder.add_matmult_node(
155 &matrix_a_mle_shred,
156 (matrix_a_num_rows_vars, matrix_a_num_cols_vars),
157 &matrix_b_mle_shred,
158 (matrix_a_num_cols_vars, matrix_b_num_cols_vars),
159 );
160
161 let difference_sector = builder.add_sector(matmult_sector - expected_result_mle_shred);
162 builder.set_output(&difference_sector);
163
164 builder.build_with_layer_combination().unwrap()
165 }
166
167 #[test]
168 fn test_matmult_node_in_circuit() {
169 let matrix_a_num_rows_vars = 2;
172 let matrix_a_num_cols_vars = 1;
173 let matrix_b_num_rows_vars = 1;
174
175 let matrix_a_mle: MultilinearExtension<Fr> = vec![1, 2, 9, 10, 13, 1, 3, 10].into();
176 let matrix_b_mle: MultilinearExtension<Fr> = vec![3, 5, 9, 6].into();
177 let expected_matrix_mle: MultilinearExtension<Fr> = vec![
178 3 + 2 * 9,
179 5 + 2 * 6,
180 9 * 3 + 10 * 9,
181 9 * 5 + 10 * 6,
182 13 * 3 + 9,
183 13 * 5 + 6,
184 3 * 3 + 10 * 9,
185 3 * 5 + 10 * 6,
186 ]
187 .into();
188
189 let mut circuit = build_matmul_test_circuit_description(
191 matrix_a_num_rows_vars,
192 matrix_a_num_cols_vars,
193 matrix_b_num_rows_vars,
194 );
195
196 circuit.set_input("Matrix A MLE", matrix_a_mle);
197 circuit.set_input("Matrix B MLE", matrix_b_mle);
198 circuit.set_input("Expected Result MLE", expected_matrix_mle);
199
200 let provable_circuit = circuit.gen_provable_circuit().unwrap();
201
202 test_circuit_with_runtime_optimized_config(&provable_circuit);
204 }
205
206 #[test]
207 fn test_non_power_of_2_matmult_node_in_circuit() {
208 let matrix_a_num_rows_vars = 2;
211 let matrix_a_num_cols_vars = 1;
212 let matrix_b_num_rows_vars = 1;
213
214 let matrix_a_mle: MultilinearExtension<Fr> = vec![1, 2, 9, 10, 13, 1].into();
215 let matrix_b_mle: MultilinearExtension<Fr> = vec![3, 5, 9, 6].into();
216 let expected_matrix_mle: MultilinearExtension<Fr> = vec![
217 3 + 2 * 9,
218 5 + 2 * 6,
219 9 * 3 + 10 * 9,
220 9 * 5 + 10 * 6,
221 13 * 3 + 9,
222 13 * 5 + 6,
223 ]
224 .into();
225
226 let mut circuit = build_matmul_test_circuit_description(
228 matrix_a_num_rows_vars,
229 matrix_a_num_cols_vars,
230 matrix_b_num_rows_vars,
231 );
232
233 circuit.set_input("Matrix A MLE", matrix_a_mle);
234 circuit.set_input("Matrix B MLE", matrix_b_mle);
235 circuit.set_input("Expected Result MLE", expected_matrix_mle);
236
237 let provable_circuit = circuit.gen_provable_circuit().unwrap();
238
239 test_circuit_with_runtime_optimized_config(&provable_circuit);
241 }
242}