1use std::collections::HashSet;
4
5use ::serde::{Deserialize, Serialize};
6use itertools::Itertools;
7use shared_types::{
8 transcript::{ProverTranscript, VerifierTranscript},
9 Field,
10};
11
12use super::{
13 gate::compute_sumcheck_message_no_beta_table,
14 layer_enum::{LayerEnum, VerifierLayerEnum},
15 product::{PostSumcheckLayer, Product},
16 Layer, LayerDescription, LayerError, LayerId, VerifierLayer,
17};
18use crate::{
19 circuit_layout::{CircuitEvalMap, CircuitLocation},
20 claims::{Claim, ClaimError, RawClaim},
21 layer::VerificationError,
22 mle::{
23 dense::DenseMle, evals::MultilinearExtension, mle_description::MleDescription,
24 verifier_mle::VerifierMle, Mle, MleIndex,
25 },
26 sumcheck::evaluate_at_a_point,
27};
28
29use anyhow::{anyhow, Ok, Result};
30
31#[derive(Debug, Serialize, Deserialize, Clone)]
46#[serde(bound = "F: Field")]
47pub struct Matrix<F: Field> {
48 pub mle: DenseMle<F>,
50 rows_num_vars: usize,
51 cols_num_vars: usize,
52}
53
54impl<F: Field> Matrix<F> {
55 pub fn new(mle: DenseMle<F>, rows_num_vars: usize, cols_num_vars: usize) -> Matrix<F> {
57 assert_eq!(mle.len(), (1 << rows_num_vars) * (1 << cols_num_vars));
58
59 Matrix {
60 mle,
61 rows_num_vars,
62 cols_num_vars,
63 }
64 }
65
66 pub fn rows_cols_num_vars(&self) -> (usize, usize) {
68 (self.rows_num_vars, self.cols_num_vars)
69 }
70}
71
72#[derive(Debug, Serialize, Deserialize, Clone)]
79#[serde(bound = "F: Field")]
80pub struct MatMult<F: Field> {
81 layer_id: LayerId,
82 matrix_a: Matrix<F>,
83 matrix_b: Matrix<F>,
84 num_vars_middle_ab: usize,
85}
86
87impl<F: Field> MatMult<F> {
88 pub fn new(layer_id: LayerId, matrix_a: Matrix<F>, matrix_b: Matrix<F>) -> MatMult<F> {
90 assert_eq!(matrix_a.cols_num_vars, matrix_b.rows_num_vars);
95 let num_vars_middle_ab = matrix_a.cols_num_vars;
96 MatMult {
97 layer_id,
98 matrix_a,
99 matrix_b,
100 num_vars_middle_ab,
101 }
102 }
103
104 fn pre_processing_step(&mut self, claim_a: Vec<F>, claim_b: Vec<F>) {
119 let matrix_a_mle = &mut self.matrix_a.mle;
120 let matrix_b_mle = &mut self.matrix_b.mle;
121
122 assert_eq!(
125 (1 << self.matrix_a.cols_num_vars) * (1 << self.matrix_a.rows_num_vars),
126 matrix_a_mle.len()
127 );
128 assert_eq!(
129 (1 << self.matrix_b.cols_num_vars) * (1 << self.matrix_b.rows_num_vars),
130 matrix_b_mle.len()
131 );
132
133 matrix_a_mle.index_mle_indices(0);
134 matrix_b_mle.index_mle_indices(0);
135
136 claim_a.into_iter().enumerate().for_each(|(idx, chal)| {
138 matrix_a_mle.fix_variable(idx, chal);
139 });
140
141 claim_b.into_iter().enumerate().for_each(|(idx, chal)| {
143 matrix_b_mle.fix_variable_at_index(idx + self.matrix_b.rows_num_vars, chal);
144 });
145 let new_a_indices = matrix_a_mle
149 .clone()
150 .mle_indices
151 .into_iter()
152 .map(|index| {
153 if let MleIndex::Indexed(_) = index {
154 MleIndex::Free
155 } else {
156 index
157 }
158 })
159 .collect_vec();
160 matrix_a_mle.mle_indices = new_a_indices;
161 matrix_a_mle.index_mle_indices(0);
162 }
163
164 fn append_leaf_mles_to_transcript(&self, transcript_writer: &mut impl ProverTranscript<F>) {
165 transcript_writer.append_elements(
166 "Fully bound MLE evaluation",
167 &[self.matrix_a.mle.value(), self.matrix_b.mle.value()],
168 );
169 }
170}
171
172impl<F: Field> Layer<F> for MatMult<F> {
173 fn prove(
179 &mut self,
180 claims: &[&RawClaim<F>],
181 transcript_writer: &mut impl ProverTranscript<F>,
182 ) -> Result<()> {
183 println!(
184 "MatMul::prove_rounds() for a product ({} x {}) * ({} x {}) matrix.",
185 self.matrix_a.rows_num_vars,
186 self.matrix_a.cols_num_vars,
187 self.matrix_b.rows_num_vars,
188 self.matrix_b.cols_num_vars
189 );
190
191 assert_eq!(claims.len(), 1);
196 self.initialize(claims[0].get_point())?;
197
198 let num_vars_middle = self.num_vars_middle_ab;
199
200 for round in 0..num_vars_middle {
201 let message = self.compute_round_sumcheck_message(round, &[F::ONE])?;
203 transcript_writer
206 .append_elements("Sumcheck round univariate evaluations", &message[1..]);
207 let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
209 self.bind_round_variable(round, challenge)?;
211 }
212
213 assert!(self.matrix_a.mle.is_fully_bounded());
215 assert!(self.matrix_b.mle.is_fully_bounded());
216
217 self.append_leaf_mles_to_transcript(transcript_writer);
218 Ok(())
219 }
220
221 fn layer_id(&self) -> LayerId {
222 self.layer_id
223 }
224
225 fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
226 assert_eq!(
231 claim_point.len(),
232 self.matrix_a.rows_num_vars + self.matrix_b.cols_num_vars
233 );
234 let mut claim_a = claim_point.to_vec();
235 let claim_b = claim_a.split_off(self.matrix_a.rows_num_vars);
236 self.pre_processing_step(claim_a, claim_b);
237 Ok(())
238 }
239
240 fn initialize_rlc(&mut self, _random_coefficients: &[F], _claims: &[&RawClaim<F>]) {
241 unimplemented!()
245 }
246
247 fn compute_round_sumcheck_message(
248 &mut self,
249 round_index: usize,
250 _random_coefficients: &[F],
251 ) -> Result<Vec<F>> {
252 let mles = vec![&self.matrix_a.mle, &self.matrix_b.mle];
253 let sumcheck_message =
254 compute_sumcheck_message_no_beta_table(&mles, round_index, 2).unwrap();
255 Ok(sumcheck_message)
256 }
257
258 fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
259 self.matrix_a.mle.fix_variable(round_index, challenge);
260 self.matrix_b.mle.fix_variable(round_index, challenge);
261
262 Ok(())
263 }
264
265 fn sumcheck_round_indices(&self) -> Vec<usize> {
266 (0..self.num_vars_middle_ab).collect_vec()
267 }
268
269 fn max_degree(&self) -> usize {
270 2
271 }
272
273 fn get_post_sumcheck_layer(
276 &self,
277 _round_challenges: &[F],
278 _claim_challenges: &[&[F]],
279 _random_coefficients: &[F],
280 ) -> PostSumcheckLayer<F, F> {
281 let mles = vec![self.matrix_a.mle.clone(), self.matrix_b.mle.clone()];
282 PostSumcheckLayer(vec![Product::<F, F>::new(&mles, F::ONE)])
283 }
284 fn get_claims(&self) -> Result<Vec<Claim<F>>> {
286 let claims = vec![&self.matrix_a.mle, &self.matrix_b.mle]
287 .into_iter()
288 .map(|matrix_mle| {
289 let matrix_fixed_indices = matrix_mle
290 .mle_indices()
291 .iter()
292 .map(|index| {
293 index
294 .val()
295 .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))
296 .unwrap()
297 })
298 .collect_vec();
299
300 let matrix_val = matrix_mle.value();
301 let claim: Claim<F> = Claim::new(
302 matrix_fixed_indices,
303 matrix_val,
304 self.layer_id,
305 matrix_mle.layer_id,
306 );
307 claim
308 })
309 .collect_vec();
310
311 Ok(claims)
312 }
313}
314
315#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
317#[serde(bound = "F: Field")]
318pub struct MatrixDescription<F: Field> {
319 mle: MleDescription<F>,
320 rows_num_vars: usize,
321 cols_num_vars: usize,
322}
323
324impl<F: Field> MatrixDescription<F> {
325 pub fn new(mle: MleDescription<F>, rows_num_vars: usize, cols_num_vars: usize) -> Self {
330 Self {
331 mle,
332 rows_num_vars,
333 cols_num_vars,
334 }
335 }
336
337 pub fn into_matrix(&self, circuit_map: &CircuitEvalMap<F>) -> Matrix<F> {
340 let dense_mle = self.mle.into_dense_mle(circuit_map);
341 Matrix {
342 mle: dense_mle,
343 rows_num_vars: self.rows_num_vars,
344 cols_num_vars: self.cols_num_vars,
345 }
346 }
347}
348#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
350#[serde(bound = "F: Field")]
351pub struct MatMultLayerDescription<F: Field> {
352 layer_id: LayerId,
354
355 matrix_a: MatrixDescription<F>,
357
358 matrix_b: MatrixDescription<F>,
360}
361
362impl<F: Field> MatMultLayerDescription<F> {
363 pub fn new(
366 layer_id: LayerId,
367 matrix_a: MatrixDescription<F>,
368 matrix_b: MatrixDescription<F>,
369 ) -> Self {
370 Self {
371 layer_id,
372 matrix_a,
373 matrix_b,
374 }
375 }
376}
377
378impl<F: Field> LayerDescription<F> for MatMultLayerDescription<F> {
379 type VerifierLayer = VerifierMatMultLayer<F>;
380
381 fn layer_id(&self) -> LayerId {
383 self.layer_id
384 }
385
386 fn verify_rounds(
387 &self,
388 claims: &[&RawClaim<F>],
389 transcript_reader: &mut impl VerifierTranscript<F>,
390 ) -> Result<VerifierLayerEnum<F>> {
391 let mut challenges = vec![];
393
394 assert_eq!(claims.len(), 1);
396 let claim = claims[0];
397
398 let mut g_prev_round = vec![claim.get_eval()];
402
403 let mut prev_challenge = F::ZERO;
405
406 assert_eq!(self.matrix_a.cols_num_vars, self.matrix_b.rows_num_vars);
408 let num_rounds = self.matrix_a.cols_num_vars;
409
410 for _round in 0..num_rounds {
412 let degree = 2;
413
414 let mut g_cur_round: Vec<_> = [Ok(F::from(0))]
416 .into_iter()
417 .chain((0..degree).map(|_| {
418 transcript_reader.consume_element("Sumcheck round univariate evaluations")
419 }))
420 .collect::<Result<_, _>>()?;
421
422 let challenge = transcript_reader.get_challenge("Sumcheck round challenge")?;
424
425 let g_prev_r_prev = evaluate_at_a_point(&g_prev_round, prev_challenge).unwrap();
428 let g_i_one = evaluate_at_a_point(&g_cur_round, F::ONE).unwrap();
429 g_cur_round[0] = g_prev_r_prev - g_i_one;
430
431 g_prev_round = g_cur_round;
432 prev_challenge = challenge;
433 challenges.push(challenge);
434 }
435
436 let g_final_r_final = evaluate_at_a_point(&g_prev_round, prev_challenge)?;
440
441 let verifier_layer: VerifierMatMultLayer<F> = self
442 .convert_into_verifier_layer(&challenges, &[claim.get_point()], transcript_reader)
443 .unwrap();
444
445 let matrix_product = verifier_layer.evaluate();
446
447 if g_final_r_final != matrix_product {
448 return Err(anyhow!(VerificationError::FinalSumcheckFailed));
449 }
450
451 Ok(VerifierLayerEnum::MatMult(verifier_layer))
452 }
453
454 fn sumcheck_round_indices(&self) -> Vec<usize> {
457 (0..self.matrix_a.cols_num_vars).collect_vec()
458 }
459
460 fn compute_data_outputs(
464 &self,
465 mle_outputs_necessary: &HashSet<&MleDescription<F>>,
466 circuit_map: &mut CircuitEvalMap<F>,
467 ) {
468 assert_eq!(mle_outputs_necessary.len(), 1);
469 let mle_output_necessary = mle_outputs_necessary.iter().next().unwrap();
470
471 let matrix_a_data = circuit_map
472 .get_data_from_circuit_mle(&self.matrix_a.mle)
473 .unwrap();
474 assert_eq!(
475 matrix_a_data.num_vars(),
476 self.matrix_a.rows_num_vars + self.matrix_a.cols_num_vars
477 );
478
479 let matrix_b_data = circuit_map
480 .get_data_from_circuit_mle(&self.matrix_b.mle)
481 .unwrap();
482 assert_eq!(
483 matrix_b_data.num_vars(),
484 self.matrix_b.rows_num_vars + self.matrix_b.cols_num_vars
485 );
486
487 let product = product_two_matrices_from_flattened_vectors(
488 &matrix_a_data.to_vec(),
489 &matrix_b_data.to_vec(),
490 1 << self.matrix_a.rows_num_vars,
491 1 << self.matrix_a.cols_num_vars,
492 1 << self.matrix_b.rows_num_vars,
493 1 << self.matrix_b.cols_num_vars,
494 );
495
496 let output_data = MultilinearExtension::new(product);
497 assert_eq!(
498 output_data.num_vars(),
499 mle_output_necessary.var_indices().len()
500 );
501
502 circuit_map.add_node(CircuitLocation::new(self.layer_id(), vec![]), output_data);
503 }
504
505 fn convert_into_verifier_layer(
506 &self,
507 sumcheck_bindings: &[F],
508 claim_points: &[&[F]],
509 transcript_reader: &mut impl VerifierTranscript<F>,
510 ) -> Result<Self::VerifierLayer> {
511 assert_eq!(claim_points.len(), 1);
513 let claim_point = claim_points[0];
514
515 let mut claim_a = claim_point.to_vec();
517 let claim_b = claim_a.split_off(self.matrix_a.rows_num_vars);
518
519 let full_claim_chals_a = claim_a
521 .into_iter()
522 .chain(sumcheck_bindings.to_vec())
523 .collect_vec();
524
525 let full_claim_chals_b = sumcheck_bindings
527 .iter()
528 .copied()
529 .chain(claim_b)
530 .collect_vec();
531
532 assert_eq!(
534 full_claim_chals_a.len(),
535 self.matrix_a.rows_num_vars + self.matrix_a.cols_num_vars
536 );
537 assert_eq!(
538 full_claim_chals_b.len(),
539 self.matrix_b.rows_num_vars + self.matrix_b.cols_num_vars
540 );
541
542 let matrix_a = VerifierMatrix {
544 mle: self
545 .matrix_a
546 .mle
547 .into_verifier_mle(&full_claim_chals_a, transcript_reader)
548 .unwrap(),
549 rows_num_vars: self.matrix_a.rows_num_vars,
550 cols_num_vars: self.matrix_a.cols_num_vars,
551 };
552 let matrix_b = VerifierMatrix {
553 mle: self
554 .matrix_b
555 .mle
556 .into_verifier_mle(&full_claim_chals_b, transcript_reader)
557 .unwrap(),
558 rows_num_vars: self.matrix_b.rows_num_vars,
559 cols_num_vars: self.matrix_b.cols_num_vars,
560 };
561
562 Ok(VerifierMatMultLayer {
563 layer_id: self.layer_id,
564 matrix_a,
565 matrix_b,
566 })
567 }
568
569 fn get_post_sumcheck_layer(
571 &self,
572 round_challenges: &[F],
573 claim_challenges: &[&[F]],
574 _random_coefficients: &[F],
575 ) -> PostSumcheckLayer<F, Option<F>> {
576 assert_eq!(claim_challenges.len(), 1);
578 let claim_challenge = claim_challenges[0];
579 let mut pre_bound_matrix_a_mle = self.matrix_a.mle.clone();
580 let claim_chals_matrix_a = claim_challenge[..self.matrix_a.rows_num_vars].to_vec();
581 let mut indexed_index_counter = 0;
582 let mut bound_index_counter = 0;
583
584 let matrix_a_new_indices = self
591 .matrix_a
592 .mle
593 .var_indices()
594 .iter()
595 .map(|mle_idx| match mle_idx {
596 &MleIndex::Indexed(_) => {
597 if bound_index_counter < self.matrix_a.rows_num_vars {
598 let ret = MleIndex::Bound(
599 claim_chals_matrix_a[bound_index_counter],
600 bound_index_counter,
601 );
602 bound_index_counter += 1;
603 ret
604 } else {
605 let ret = MleIndex::Indexed(indexed_index_counter);
606 indexed_index_counter += 1;
607 ret
608 }
609 }
610 MleIndex::Fixed(_) => mle_idx.clone(),
611 MleIndex::Free => panic!("should not have any free indices"),
612 MleIndex::Bound(_, _) => panic!("should not have any bound indices"),
613 })
614 .collect_vec();
615 pre_bound_matrix_a_mle.set_mle_indices(matrix_a_new_indices);
616
617 let mut pre_bound_matrix_b_mle = self.matrix_b.mle.clone();
621 let claim_chals_matrix_b = claim_challenge[self.matrix_a.rows_num_vars..].to_vec();
622 let mut bound_index_counter = 0;
623 let mut indexed_index_counter = 0;
624 let matrix_b_new_indices = self
625 .matrix_b
626 .mle
627 .var_indices()
628 .iter()
629 .map(|mle_idx| match mle_idx {
630 &MleIndex::Indexed(_) => {
631 if indexed_index_counter < self.matrix_b.rows_num_vars {
632 let ret = MleIndex::Indexed(indexed_index_counter);
633 indexed_index_counter += 1;
634 ret
635 } else {
636 let ret = MleIndex::Bound(
637 claim_chals_matrix_b[bound_index_counter],
638 bound_index_counter,
639 );
640 bound_index_counter += 1;
641 ret
642 }
643 }
644 MleIndex::Fixed(_) => mle_idx.clone(),
645 MleIndex::Free => panic!("should not have any free indices"),
646 MleIndex::Bound(_, _) => panic!("should not have any bound indices"),
647 })
648 .collect_vec();
649 pre_bound_matrix_b_mle.set_mle_indices(matrix_b_new_indices);
650 let mles = vec![pre_bound_matrix_a_mle, pre_bound_matrix_b_mle];
651
652 PostSumcheckLayer(vec![Product::<F, Option<F>>::new(
653 &mles,
654 F::ONE,
655 round_challenges,
656 )])
657 }
658
659 fn max_degree(&self) -> usize {
660 2
661 }
662
663 fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
664 vec![&self.matrix_a.mle, &self.matrix_b.mle]
665 }
666
667 fn convert_into_prover_layer<'a>(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
668 let prover_matrix_a = self.matrix_a.into_matrix(circuit_map);
669 let prover_matrix_b = self.matrix_b.into_matrix(circuit_map);
670 let matmult_layer = MatMult::new(self.layer_id, prover_matrix_a, prover_matrix_b);
671 matmult_layer.into()
672 }
673
674 fn index_mle_indices(&mut self, start_index: usize) {
675 self.matrix_a.mle.index_mle_indices(start_index);
676 self.matrix_b.mle.index_mle_indices(start_index);
677 }
678}
679
680#[derive(Serialize, Deserialize, Clone, Debug)]
682#[serde(bound = "F: Field")]
683pub struct VerifierMatrix<F: Field> {
684 mle: VerifierMle<F>,
685 rows_num_vars: usize,
686 cols_num_vars: usize,
687}
688
689#[derive(Serialize, Deserialize, Clone, Debug)]
691#[serde(bound = "F: Field")]
692pub struct VerifierMatMultLayer<F: Field> {
693 layer_id: LayerId,
695
696 matrix_a: VerifierMatrix<F>,
698
699 matrix_b: VerifierMatrix<F>,
701}
702
703impl<F: Field> VerifierLayer<F> for VerifierMatMultLayer<F> {
704 fn layer_id(&self) -> LayerId {
705 self.layer_id
706 }
707
708 fn get_claims(&self) -> Result<Vec<Claim<F>>> {
709 let claims = vec![&self.matrix_a, &self.matrix_b]
710 .into_iter()
711 .map(|matrix| {
712 let matrix_fixed_indices = matrix
713 .mle
714 .var_indices()
715 .iter()
716 .map(|index| {
717 index
718 .val()
719 .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))
720 .unwrap()
721 })
722 .collect_vec();
723
724 let matrix_claimed_val = matrix.mle.value();
725
726 let claim: Claim<F> = Claim::new(
727 matrix_fixed_indices,
728 matrix_claimed_val,
729 self.layer_id,
730 matrix.mle.layer_id(),
731 );
732 claim
733 })
734 .collect_vec();
735
736 Ok(claims)
737 }
738}
739
740impl<F: Field> VerifierMatMultLayer<F> {
741 fn evaluate(&self) -> F {
742 self.matrix_a.mle.value() * self.matrix_b.mle.value()
743 }
744}
745
746pub fn product_two_matrices_from_flattened_vectors<F: Field>(
749 matrix_a_vec: &[F],
750 matrix_b_vec: &[F],
751 matrix_a_num_rows: usize,
752 matrix_a_num_cols: usize,
753 matrix_b_num_rows: usize,
754 matrix_b_num_cols: usize,
755) -> Vec<F> {
756 assert_eq!(
757 matrix_a_num_cols, matrix_b_num_rows,
758 "Matrix dimensions are not compatible for multiplication"
759 );
760
761 let mut result = vec![F::ZERO; matrix_a_num_rows * matrix_b_num_cols];
762
763 for i in 0..matrix_a_num_rows {
764 for j in 0..matrix_b_num_cols {
765 for k in 0..matrix_a_num_cols {
766 result[i * matrix_b_num_cols + j] += matrix_a_vec[i * matrix_a_num_cols + k]
767 * matrix_b_vec[k * matrix_b_num_cols + j];
768 }
769 }
770 }
771
772 result
773}
774
775#[cfg(test)]
776mod test {
777
778 use shared_types::Fr;
779
780 use crate::layer::matmult::product_two_matrices_from_flattened_vectors;
781
782 #[test]
783 fn test_product_two_matrices() {
784 let mle_vec_a = vec![
785 Fr::from(1),
786 Fr::from(2),
787 Fr::from(9),
788 Fr::from(10),
789 Fr::from(13),
790 Fr::from(1),
791 Fr::from(3),
792 Fr::from(10),
793 ];
794 let mle_vec_b = vec![Fr::from(3), Fr::from(5), Fr::from(9), Fr::from(6)];
795
796 let res_product =
797 product_two_matrices_from_flattened_vectors(&mle_vec_a, &mle_vec_b, 4, 2, 2, 2);
798
799 let exp_product = vec![
800 Fr::from(3 + 2 * 9),
801 Fr::from(5 + 2 * 6),
802 Fr::from(9 * 3 + 10 * 9),
803 Fr::from(9 * 5 + 10 * 6),
804 Fr::from(13 * 3 + 9),
805 Fr::from(13 * 5 + 6),
806 Fr::from(3 * 3 + 10 * 9),
807 Fr::from(3 * 5 + 10 * 6),
808 ];
809
810 assert_eq!(res_product, exp_product);
811 }
812
813 #[test]
814 fn test_product_two_matrices_2() {
815 let mle_vec_a = vec![
816 Fr::from(3),
817 Fr::from(4),
818 Fr::from(1),
819 Fr::from(6),
820 Fr::from(2),
821 Fr::from(9),
822 Fr::from(0),
823 Fr::from(1),
824 Fr::from(4),
825 Fr::from(5),
826 Fr::from(4),
827 Fr::from(2),
828 Fr::from(4),
829 Fr::from(2),
830 Fr::from(6),
831 Fr::from(7),
832 Fr::from(3),
833 Fr::from(4),
834 Fr::from(1),
835 Fr::from(6),
836 Fr::from(2),
837 Fr::from(9),
838 Fr::from(0),
839 Fr::from(1),
840 Fr::from(4),
841 Fr::from(5),
842 Fr::from(4),
843 Fr::from(2),
844 Fr::from(4),
845 Fr::from(2),
846 Fr::from(6),
847 Fr::from(7),
848 ];
849 let mle_vec_b = vec![
850 Fr::from(3),
851 Fr::from(2),
852 Fr::from(1),
853 Fr::from(5),
854 Fr::from(3),
855 Fr::from(6),
856 Fr::from(7),
857 Fr::from(4),
858 ];
859
860 let res_product =
861 product_two_matrices_from_flattened_vectors(&mle_vec_a, &mle_vec_b, 8, 4, 4, 2);
862
863 let exp_product = vec![
864 Fr::from(58),
865 Fr::from(56),
866 Fr::from(22),
867 Fr::from(53),
868 Fr::from(43),
869 Fr::from(65),
870 Fr::from(81),
871 Fr::from(82),
872 Fr::from(58),
873 Fr::from(56),
874 Fr::from(22),
875 Fr::from(53),
876 Fr::from(43),
877 Fr::from(65),
878 Fr::from(81),
879 Fr::from(82),
880 ];
881
882 assert_eq!(res_product, exp_product);
883 }
884}