remainder/layer/
matmult.rs

1//! This module contains the implementation of the matrix multiplication layer
2
3use std::collections::HashSet;
4
5use ::serde::{Deserialize, Serialize};
6use itertools::Itertools;
7use shared_types::{
8    transcript::{ProverTranscript, VerifierTranscript},
9    Field,
10};
11
12use super::{
13    gate::compute_sumcheck_message_no_beta_table,
14    layer_enum::{LayerEnum, VerifierLayerEnum},
15    product::{PostSumcheckLayer, Product},
16    Layer, LayerDescription, LayerError, LayerId, VerifierLayer,
17};
18use crate::{
19    circuit_layout::{CircuitEvalMap, CircuitLocation},
20    claims::{Claim, ClaimError, RawClaim},
21    layer::VerificationError,
22    mle::{
23        dense::DenseMle, evals::MultilinearExtension, mle_description::MleDescription,
24        verifier_mle::VerifierMle, Mle, MleIndex,
25    },
26    sumcheck::evaluate_at_a_point,
27};
28
29use anyhow::{anyhow, Ok, Result};
30
31/// Used to represent a matrix; basically an MLE which is the
32/// flattened version of this matrix along with the log2
33/// num_rows (`rows_num_vars`) and the log2 num_cols `cols_num_vars`.
34///
35/// This ensures that the flattened MLE provided already has
36/// a bookkeeping table where the rows and columns are padded to
37/// the nearest power of 2.
38///
39/// NOTE: the flattened MLE that represents a matrix is in
40/// row major order. Internal bookkeeping tables are
41/// stored in big-endian, so the FIRST "row"
42/// number of variables to represent the rows of the matrix,
43/// and the LAST "column" number of variables to represent the
44/// columns of the matrix.
45#[derive(Debug, Serialize, Deserialize, Clone)]
46#[serde(bound = "F: Field")]
47pub struct Matrix<F: Field> {
48    /// The underlying and padded MLE that represents this matrix.
49    pub mle: DenseMle<F>,
50    rows_num_vars: usize,
51    cols_num_vars: usize,
52}
53
54impl<F: Field> Matrix<F> {
55    /// Create a new matrix.
56    pub fn new(mle: DenseMle<F>, rows_num_vars: usize, cols_num_vars: usize) -> Matrix<F> {
57        assert_eq!(mle.len(), (1 << rows_num_vars) * (1 << cols_num_vars));
58
59        Matrix {
60            mle,
61            rows_num_vars,
62            cols_num_vars,
63        }
64    }
65
66    /// Get the dimensions of this matrix.
67    pub fn rows_cols_num_vars(&self) -> (usize, usize) {
68        (self.rows_num_vars, self.cols_num_vars)
69    }
70}
71
72/// Used to represent a matrix multiplication layer.
73///
74/// #Attributes:
75/// * `layer_id` - the LayerId of this MatMult layer.
76/// * `matrix_a` - the lefthand side matrix in the multiplication.
77/// * `matrix_b` - the righthand side matrix in the multiplication.
78#[derive(Debug, Serialize, Deserialize, Clone)]
79#[serde(bound = "F: Field")]
80pub struct MatMult<F: Field> {
81    layer_id: LayerId,
82    matrix_a: Matrix<F>,
83    matrix_b: Matrix<F>,
84    num_vars_middle_ab: usize,
85}
86
87impl<F: Field> MatMult<F> {
88    /// Create a new matrix multiplication layer.
89    pub fn new(layer_id: LayerId, matrix_a: Matrix<F>, matrix_b: Matrix<F>) -> MatMult<F> {
90        // Check to make sure the inner dimensions of the matrices we are
91        // producting match. I.e., the number of variables representing the
92        // columns of matrix a are the same as the number of variables
93        // representing the rows of matrix b.
94        assert_eq!(matrix_a.cols_num_vars, matrix_b.rows_num_vars);
95        let num_vars_middle_ab = matrix_a.cols_num_vars;
96        MatMult {
97            layer_id,
98            matrix_a,
99            matrix_b,
100            num_vars_middle_ab,
101        }
102    }
103
104    /// The step, according to [Tha13](https://eprint.iacr.org/2013/351.pdf), which
105    /// makes the matmult algorithm super-efficient.
106    ///
107    /// Given the claim on the output of the matrix multiplication, we bind
108    /// the variables representing the "rows" of `matrix_a` to the first
109    /// log(num_rows_a) vars in this claim, and we bind the variables
110    /// representing the "columns" of `matrix_b` to the last log(num_cols_b)
111    /// vars in the claim.
112    ///
113    /// #Arguments
114    /// * `claim_a`: the first log_num_rows variables of the claim made on the
115    ///   MLE representing the output of this layer.
116    /// * `claim_b`: the last log_num_cols variables of the claim made on the
117    ///   MLE representing the output of this layer.
118    fn pre_processing_step(&mut self, claim_a: Vec<F>, claim_b: Vec<F>) {
119        let matrix_a_mle = &mut self.matrix_a.mle;
120        let matrix_b_mle = &mut self.matrix_b.mle;
121
122        // Check that both matrices are padded such that the number of rows
123        // and the number of columns are both powers of 2.
124        assert_eq!(
125            (1 << self.matrix_a.cols_num_vars) * (1 << self.matrix_a.rows_num_vars),
126            matrix_a_mle.len()
127        );
128        assert_eq!(
129            (1 << self.matrix_b.cols_num_vars) * (1 << self.matrix_b.rows_num_vars),
130            matrix_b_mle.len()
131        );
132
133        matrix_a_mle.index_mle_indices(0);
134        matrix_b_mle.index_mle_indices(0);
135
136        // Bind the row indices of matrix A to the relevant claim point.
137        claim_a.into_iter().enumerate().for_each(|(idx, chal)| {
138            matrix_a_mle.fix_variable(idx, chal);
139        });
140
141        // Bind the column indices of matrix B to the relevant claim point.
142        claim_b.into_iter().enumerate().for_each(|(idx, chal)| {
143            matrix_b_mle.fix_variable_at_index(idx + self.matrix_b.rows_num_vars, chal);
144        });
145        // We want to re-index the MLE indices in matrix A such that it
146        // starts from 0 after the pre-processing, so we do that by first
147        // setting them to be free and then re-indexing them.
148        let new_a_indices = matrix_a_mle
149            .clone()
150            .mle_indices
151            .into_iter()
152            .map(|index| {
153                if let MleIndex::Indexed(_) = index {
154                    MleIndex::Free
155                } else {
156                    index
157                }
158            })
159            .collect_vec();
160        matrix_a_mle.mle_indices = new_a_indices;
161        matrix_a_mle.index_mle_indices(0);
162    }
163
164    fn append_leaf_mles_to_transcript(&self, transcript_writer: &mut impl ProverTranscript<F>) {
165        transcript_writer.append_elements(
166            "Fully bound MLE evaluation",
167            &[self.matrix_a.mle.value(), self.matrix_b.mle.value()],
168        );
169    }
170}
171
172impl<F: Field> Layer<F> for MatMult<F> {
173    // Since we pre-process the matrices first, by pre-binding the
174    // row variables of matrix A and the column variables of matrix B,
175    // the number of rounds of sumcheck is simply the number of variables
176    // that represent the singular (same) inner dimension of both of the
177    // matrices in this matrix product.
178    fn prove(
179        &mut self,
180        claims: &[&RawClaim<F>],
181        transcript_writer: &mut impl ProverTranscript<F>,
182    ) -> Result<()> {
183        println!(
184            "MatMul::prove_rounds() for a product ({} x {}) * ({} x {}) matrix.",
185            self.matrix_a.rows_num_vars,
186            self.matrix_a.cols_num_vars,
187            self.matrix_b.rows_num_vars,
188            self.matrix_b.cols_num_vars
189        );
190
191        // We always use interpolative claim aggregation for matmult layers
192        // because the preprocessing step in matmult utilizes the fact that we
193        // have linear variables in the expression, which RLC is unable to
194        // aggregate claims for.
195        assert_eq!(claims.len(), 1);
196        self.initialize(claims[0].get_point())?;
197
198        let num_vars_middle = self.num_vars_middle_ab;
199
200        for round in 0..num_vars_middle {
201            // Compute the round's sumcheck message.
202            let message = self.compute_round_sumcheck_message(round, &[F::ONE])?;
203            // Add to transcript.
204            // Since the verifier can deduce g_i(0) by computing claim - g_i(1), the prover does not send g_i(0)
205            transcript_writer
206                .append_elements("Sumcheck round univariate evaluations", &message[1..]);
207            // Sample the challenge to bind the round's MatMult expression to.
208            let challenge = transcript_writer.get_challenge("Sumcheck round challenge");
209            // Bind the Matrix MLEs to this variable.
210            self.bind_round_variable(round, challenge)?;
211        }
212
213        // Assert that the MLEs have been fully bound.
214        assert!(self.matrix_a.mle.is_fully_bounded());
215        assert!(self.matrix_b.mle.is_fully_bounded());
216
217        self.append_leaf_mles_to_transcript(transcript_writer);
218        Ok(())
219    }
220
221    fn layer_id(&self) -> LayerId {
222        self.layer_id
223    }
224
225    fn initialize(&mut self, claim_point: &[F]) -> Result<()> {
226        // Split the claim on the MLE representing the output of this layer
227        // accordingly.
228        // We need to make sure the number of variables in the claim is the
229        // sum of the outer dimensions of this matrix product.
230        assert_eq!(
231            claim_point.len(),
232            self.matrix_a.rows_num_vars + self.matrix_b.cols_num_vars
233        );
234        let mut claim_a = claim_point.to_vec();
235        let claim_b = claim_a.split_off(self.matrix_a.rows_num_vars);
236        self.pre_processing_step(claim_a, claim_b);
237        Ok(())
238    }
239
240    fn initialize_rlc(&mut self, _random_coefficients: &[F], _claims: &[&RawClaim<F>]) {
241        // This function is not implemented for MatMult layers because we should
242        // never be using RLC claim aggregation for MatMult layers. Instead, we always
243        // use interpolative claim aggregation.
244        unimplemented!()
245    }
246
247    fn compute_round_sumcheck_message(
248        &mut self,
249        round_index: usize,
250        _random_coefficients: &[F],
251    ) -> Result<Vec<F>> {
252        let mles = vec![&self.matrix_a.mle, &self.matrix_b.mle];
253        let sumcheck_message =
254            compute_sumcheck_message_no_beta_table(&mles, round_index, 2).unwrap();
255        Ok(sumcheck_message)
256    }
257
258    fn bind_round_variable(&mut self, round_index: usize, challenge: F) -> Result<()> {
259        self.matrix_a.mle.fix_variable(round_index, challenge);
260        self.matrix_b.mle.fix_variable(round_index, challenge);
261
262        Ok(())
263    }
264
265    fn sumcheck_round_indices(&self) -> Vec<usize> {
266        (0..self.num_vars_middle_ab).collect_vec()
267    }
268
269    fn max_degree(&self) -> usize {
270        2
271    }
272
273    /// Return the [PostSumcheckLayer], panicking if either of the MLE refs is not fully bound.
274    /// Relevant for the Hyrax IP, where we need commitments to fully bound MLEs as well as their intermediate products.
275    fn get_post_sumcheck_layer(
276        &self,
277        _round_challenges: &[F],
278        _claim_challenges: &[&[F]],
279        _random_coefficients: &[F],
280    ) -> PostSumcheckLayer<F, F> {
281        let mles = vec![self.matrix_a.mle.clone(), self.matrix_b.mle.clone()];
282        PostSumcheckLayer(vec![Product::<F, F>::new(&mles, F::ONE)])
283    }
284    /// Get the claims that this layer makes on other layers
285    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
286        let claims = vec![&self.matrix_a.mle, &self.matrix_b.mle]
287            .into_iter()
288            .map(|matrix_mle| {
289                let matrix_fixed_indices = matrix_mle
290                    .mle_indices()
291                    .iter()
292                    .map(|index| {
293                        index
294                            .val()
295                            .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))
296                            .unwrap()
297                    })
298                    .collect_vec();
299
300                let matrix_val = matrix_mle.value();
301                let claim: Claim<F> = Claim::new(
302                    matrix_fixed_indices,
303                    matrix_val,
304                    self.layer_id,
305                    matrix_mle.layer_id,
306                );
307                claim
308            })
309            .collect_vec();
310
311        Ok(claims)
312    }
313}
314
315/// The circuit description counterpart of a [Matrix].
316#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
317#[serde(bound = "F: Field")]
318pub struct MatrixDescription<F: Field> {
319    mle: MleDescription<F>,
320    rows_num_vars: usize,
321    cols_num_vars: usize,
322}
323
324impl<F: Field> MatrixDescription<F> {
325    /// The constructor for a [MatrixDescription], which is the circuit
326    /// description of matrix, only containing shape information
327    /// which is the number of variables in the rows and the number
328    /// of variables in the columns.
329    pub fn new(mle: MleDescription<F>, rows_num_vars: usize, cols_num_vars: usize) -> Self {
330        Self {
331            mle,
332            rows_num_vars,
333            cols_num_vars,
334        }
335    }
336
337    /// Convert the circuit description of a matrix into the prover
338    /// view of a matrix, using the [CircuitEvalMap].
339    pub fn into_matrix(&self, circuit_map: &CircuitEvalMap<F>) -> Matrix<F> {
340        let dense_mle = self.mle.into_dense_mle(circuit_map);
341        Matrix {
342            mle: dense_mle,
343            rows_num_vars: self.rows_num_vars,
344            cols_num_vars: self.cols_num_vars,
345        }
346    }
347}
348/// The circuit description counterpart of a [MatMult] layer.
349#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
350#[serde(bound = "F: Field")]
351pub struct MatMultLayerDescription<F: Field> {
352    /// The layer id associated with this matmult layer.
353    layer_id: LayerId,
354
355    /// The LHS Matrix to be multiplied.
356    matrix_a: MatrixDescription<F>,
357
358    /// The RHS Matrix to be multiplied.
359    matrix_b: MatrixDescription<F>,
360}
361
362impl<F: Field> MatMultLayerDescription<F> {
363    /// Constructor for the [MatMultLayerDescription], using the circuit description
364    /// of the matrices that make up this layer.
365    pub fn new(
366        layer_id: LayerId,
367        matrix_a: MatrixDescription<F>,
368        matrix_b: MatrixDescription<F>,
369    ) -> Self {
370        Self {
371            layer_id,
372            matrix_a,
373            matrix_b,
374        }
375    }
376}
377
378impl<F: Field> LayerDescription<F> for MatMultLayerDescription<F> {
379    type VerifierLayer = VerifierMatMultLayer<F>;
380
381    /// Gets this layer's id.
382    fn layer_id(&self) -> LayerId {
383        self.layer_id
384    }
385
386    fn verify_rounds(
387        &self,
388        claims: &[&RawClaim<F>],
389        transcript_reader: &mut impl VerifierTranscript<F>,
390    ) -> Result<VerifierLayerEnum<F>> {
391        // Keeps track of challenges `r_1, ..., r_n` sent by the verifier.
392        let mut challenges = vec![];
393
394        // For matmult we always use the interpolative claim aggregation method.
395        assert_eq!(claims.len(), 1);
396        let claim = claims[0];
397
398        // Represents `g_{i-1}(x)` of the previous round.
399        // This is initialized to the constant polynomial `g_0(x)` which evaluates
400        // to the claim result for any `x`.
401        let mut g_prev_round = vec![claim.get_eval()];
402
403        // Previous round's challege: r_{i-1}.
404        let mut prev_challenge = F::ZERO;
405
406        // Get the number of rounds, which is exactly the inner dimension of the matrix product.
407        assert_eq!(self.matrix_a.cols_num_vars, self.matrix_b.rows_num_vars);
408        let num_rounds = self.matrix_a.cols_num_vars;
409
410        // For round 1 <= i <= n, perform the check:
411        for _round in 0..num_rounds {
412            let degree = 2;
413
414            // Read g_i(1), ..., g_i(d+1) from the prover, reserve space to compute g_i(0)
415            let mut g_cur_round: Vec<_> = [Ok(F::from(0))]
416                .into_iter()
417                .chain((0..degree).map(|_| {
418                    transcript_reader.consume_element("Sumcheck round univariate evaluations")
419                }))
420                .collect::<Result<_, _>>()?;
421
422            // Sample random challenge `r_i`.
423            let challenge = transcript_reader.get_challenge("Sumcheck round challenge")?;
424
425            // Compute:
426            //       `g_i(0) = g_{i - 1}(r_{i-1}) - g_i(1)`
427            let g_prev_r_prev = evaluate_at_a_point(&g_prev_round, prev_challenge).unwrap();
428            let g_i_one = evaluate_at_a_point(&g_cur_round, F::ONE).unwrap();
429            g_cur_round[0] = g_prev_r_prev - g_i_one;
430
431            g_prev_round = g_cur_round;
432            prev_challenge = challenge;
433            challenges.push(challenge);
434        }
435
436        // Evalute `g_n(r_n)`.
437        // Note: If there were no nonlinear rounds, this value reduces to
438        // `claim.get_result()` due to how we initialized `g_prev_round`.
439        let g_final_r_final = evaluate_at_a_point(&g_prev_round, prev_challenge)?;
440
441        let verifier_layer: VerifierMatMultLayer<F> = self
442            .convert_into_verifier_layer(&challenges, &[claim.get_point()], transcript_reader)
443            .unwrap();
444
445        let matrix_product = verifier_layer.evaluate();
446
447        if g_final_r_final != matrix_product {
448            return Err(anyhow!(VerificationError::FinalSumcheckFailed));
449        }
450
451        Ok(VerifierLayerEnum::MatMult(verifier_layer))
452    }
453
454    /// The number of sumcheck rounds are only those over the inner dimensions
455    /// of the matrix, hence they enumerate from 0 to the inner dimension.
456    fn sumcheck_round_indices(&self) -> Vec<usize> {
457        (0..self.matrix_a.cols_num_vars).collect_vec()
458    }
459
460    /// Compute the evaluations of the MLE that represents the
461    /// product of the two matrices over the boolean hypercube.
462    /// Panics if the MLEs for the two matrices provided by the circuit map are of the wrong size.
463    fn compute_data_outputs(
464        &self,
465        mle_outputs_necessary: &HashSet<&MleDescription<F>>,
466        circuit_map: &mut CircuitEvalMap<F>,
467    ) {
468        assert_eq!(mle_outputs_necessary.len(), 1);
469        let mle_output_necessary = mle_outputs_necessary.iter().next().unwrap();
470
471        let matrix_a_data = circuit_map
472            .get_data_from_circuit_mle(&self.matrix_a.mle)
473            .unwrap();
474        assert_eq!(
475            matrix_a_data.num_vars(),
476            self.matrix_a.rows_num_vars + self.matrix_a.cols_num_vars
477        );
478
479        let matrix_b_data = circuit_map
480            .get_data_from_circuit_mle(&self.matrix_b.mle)
481            .unwrap();
482        assert_eq!(
483            matrix_b_data.num_vars(),
484            self.matrix_b.rows_num_vars + self.matrix_b.cols_num_vars
485        );
486
487        let product = product_two_matrices_from_flattened_vectors(
488            &matrix_a_data.to_vec(),
489            &matrix_b_data.to_vec(),
490            1 << self.matrix_a.rows_num_vars,
491            1 << self.matrix_a.cols_num_vars,
492            1 << self.matrix_b.rows_num_vars,
493            1 << self.matrix_b.cols_num_vars,
494        );
495
496        let output_data = MultilinearExtension::new(product);
497        assert_eq!(
498            output_data.num_vars(),
499            mle_output_necessary.var_indices().len()
500        );
501
502        circuit_map.add_node(CircuitLocation::new(self.layer_id(), vec![]), output_data);
503    }
504
505    fn convert_into_verifier_layer(
506        &self,
507        sumcheck_bindings: &[F],
508        claim_points: &[&[F]],
509        transcript_reader: &mut impl VerifierTranscript<F>,
510    ) -> Result<Self::VerifierLayer> {
511        // For matmult, we only use interpolative claim aggregation.
512        assert_eq!(claim_points.len(), 1);
513        let claim_point = claim_points[0];
514
515        // Split the claim into the claims made on matrix A rows and matrix B cols.
516        let mut claim_a = claim_point.to_vec();
517        let claim_b = claim_a.split_off(self.matrix_a.rows_num_vars);
518
519        // Construct the full claim made on A using the claim made on the layer and the sumcheck bindings.
520        let full_claim_chals_a = claim_a
521            .into_iter()
522            .chain(sumcheck_bindings.to_vec())
523            .collect_vec();
524
525        // Construct the full claim made on B using the claim made on the layer and the sumcheck bindings.
526        let full_claim_chals_b = sumcheck_bindings
527            .iter()
528            .copied()
529            .chain(claim_b)
530            .collect_vec();
531
532        // Shape checks.
533        assert_eq!(
534            full_claim_chals_a.len(),
535            self.matrix_a.rows_num_vars + self.matrix_a.cols_num_vars
536        );
537        assert_eq!(
538            full_claim_chals_b.len(),
539            self.matrix_b.rows_num_vars + self.matrix_b.cols_num_vars
540        );
541
542        // Construct the verifier matrices given these fully bound points.
543        let matrix_a = VerifierMatrix {
544            mle: self
545                .matrix_a
546                .mle
547                .into_verifier_mle(&full_claim_chals_a, transcript_reader)
548                .unwrap(),
549            rows_num_vars: self.matrix_a.rows_num_vars,
550            cols_num_vars: self.matrix_a.cols_num_vars,
551        };
552        let matrix_b = VerifierMatrix {
553            mle: self
554                .matrix_b
555                .mle
556                .into_verifier_mle(&full_claim_chals_b, transcript_reader)
557                .unwrap(),
558            rows_num_vars: self.matrix_b.rows_num_vars,
559            cols_num_vars: self.matrix_b.cols_num_vars,
560        };
561
562        Ok(VerifierMatMultLayer {
563            layer_id: self.layer_id,
564            matrix_a,
565            matrix_b,
566        })
567    }
568
569    /// Return the [PostSumcheckLayer], given challenges that fully bind the expression.
570    fn get_post_sumcheck_layer(
571        &self,
572        round_challenges: &[F],
573        claim_challenges: &[&[F]],
574        _random_coefficients: &[F],
575    ) -> PostSumcheckLayer<F, Option<F>> {
576        // We are always using interpolative claim aggregation for MatMult layers.
577        assert_eq!(claim_challenges.len(), 1);
578        let claim_challenge = claim_challenges[0];
579        let mut pre_bound_matrix_a_mle = self.matrix_a.mle.clone();
580        let claim_chals_matrix_a = claim_challenge[..self.matrix_a.rows_num_vars].to_vec();
581        let mut indexed_index_counter = 0;
582        let mut bound_index_counter = 0;
583
584        // We need to make sure the MLE indices of the post-sumcheck layer
585        // match the MLE indices in proving, since it is pre-processed
586        // when we start proving.
587        // I.e, we keep the first variables representing the columns of matrix
588        // A as Indexed for sumcheck, and keep the rest as bound to their
589        // respective claim point in pre-processing.
590        let matrix_a_new_indices = self
591            .matrix_a
592            .mle
593            .var_indices()
594            .iter()
595            .map(|mle_idx| match mle_idx {
596                &MleIndex::Indexed(_) => {
597                    if bound_index_counter < self.matrix_a.rows_num_vars {
598                        let ret = MleIndex::Bound(
599                            claim_chals_matrix_a[bound_index_counter],
600                            bound_index_counter,
601                        );
602                        bound_index_counter += 1;
603                        ret
604                    } else {
605                        let ret = MleIndex::Indexed(indexed_index_counter);
606                        indexed_index_counter += 1;
607                        ret
608                    }
609                }
610                MleIndex::Fixed(_) => mle_idx.clone(),
611                MleIndex::Free => panic!("should not have any free indices"),
612                MleIndex::Bound(_, _) => panic!("should not have any bound indices"),
613            })
614            .collect_vec();
615        pre_bound_matrix_a_mle.set_mle_indices(matrix_a_new_indices);
616
617        // We keep the last variables representing the rows of matrix B
618        // as Indexed for sumcheck, and keep the rest as bound to their
619        // respective claim point in pre-processing.
620        let mut pre_bound_matrix_b_mle = self.matrix_b.mle.clone();
621        let claim_chals_matrix_b = claim_challenge[self.matrix_a.rows_num_vars..].to_vec();
622        let mut bound_index_counter = 0;
623        let mut indexed_index_counter = 0;
624        let matrix_b_new_indices = self
625            .matrix_b
626            .mle
627            .var_indices()
628            .iter()
629            .map(|mle_idx| match mle_idx {
630                &MleIndex::Indexed(_) => {
631                    if indexed_index_counter < self.matrix_b.rows_num_vars {
632                        let ret = MleIndex::Indexed(indexed_index_counter);
633                        indexed_index_counter += 1;
634                        ret
635                    } else {
636                        let ret = MleIndex::Bound(
637                            claim_chals_matrix_b[bound_index_counter],
638                            bound_index_counter,
639                        );
640                        bound_index_counter += 1;
641                        ret
642                    }
643                }
644                MleIndex::Fixed(_) => mle_idx.clone(),
645                MleIndex::Free => panic!("should not have any free indices"),
646                MleIndex::Bound(_, _) => panic!("should not have any bound indices"),
647            })
648            .collect_vec();
649        pre_bound_matrix_b_mle.set_mle_indices(matrix_b_new_indices);
650        let mles = vec![pre_bound_matrix_a_mle, pre_bound_matrix_b_mle];
651
652        PostSumcheckLayer(vec![Product::<F, Option<F>>::new(
653            &mles,
654            F::ONE,
655            round_challenges,
656        )])
657    }
658
659    fn max_degree(&self) -> usize {
660        2
661    }
662
663    fn get_circuit_mles(&self) -> Vec<&MleDescription<F>> {
664        vec![&self.matrix_a.mle, &self.matrix_b.mle]
665    }
666
667    fn convert_into_prover_layer<'a>(&self, circuit_map: &CircuitEvalMap<F>) -> LayerEnum<F> {
668        let prover_matrix_a = self.matrix_a.into_matrix(circuit_map);
669        let prover_matrix_b = self.matrix_b.into_matrix(circuit_map);
670        let matmult_layer = MatMult::new(self.layer_id, prover_matrix_a, prover_matrix_b);
671        matmult_layer.into()
672    }
673
674    fn index_mle_indices(&mut self, start_index: usize) {
675        self.matrix_a.mle.index_mle_indices(start_index);
676        self.matrix_b.mle.index_mle_indices(start_index);
677    }
678}
679
680/// The verifier's counterpart of a [Matrix].
681#[derive(Serialize, Deserialize, Clone, Debug)]
682#[serde(bound = "F: Field")]
683pub struct VerifierMatrix<F: Field> {
684    mle: VerifierMle<F>,
685    rows_num_vars: usize,
686    cols_num_vars: usize,
687}
688
689/// The verifier's counterpart of a [MatMult] layer.
690#[derive(Serialize, Deserialize, Clone, Debug)]
691#[serde(bound = "F: Field")]
692pub struct VerifierMatMultLayer<F: Field> {
693    /// The layer id associated with this gate layer.
694    layer_id: LayerId,
695
696    /// The LHS Matrix to be multiplied.
697    matrix_a: VerifierMatrix<F>,
698
699    /// The RHS Matrix to be multiplied.
700    matrix_b: VerifierMatrix<F>,
701}
702
703impl<F: Field> VerifierLayer<F> for VerifierMatMultLayer<F> {
704    fn layer_id(&self) -> LayerId {
705        self.layer_id
706    }
707
708    fn get_claims(&self) -> Result<Vec<Claim<F>>> {
709        let claims = vec![&self.matrix_a, &self.matrix_b]
710            .into_iter()
711            .map(|matrix| {
712                let matrix_fixed_indices = matrix
713                    .mle
714                    .var_indices()
715                    .iter()
716                    .map(|index| {
717                        index
718                            .val()
719                            .ok_or(LayerError::ClaimError(ClaimError::ClaimMleIndexError))
720                            .unwrap()
721                    })
722                    .collect_vec();
723
724                let matrix_claimed_val = matrix.mle.value();
725
726                let claim: Claim<F> = Claim::new(
727                    matrix_fixed_indices,
728                    matrix_claimed_val,
729                    self.layer_id,
730                    matrix.mle.layer_id(),
731                );
732                claim
733            })
734            .collect_vec();
735
736        Ok(claims)
737    }
738}
739
740impl<F: Field> VerifierMatMultLayer<F> {
741    fn evaluate(&self) -> F {
742        self.matrix_a.mle.value() * self.matrix_b.mle.value()
743    }
744}
745
746/// Compute the product of two matrices given flattened vectors rather than
747/// matrices.
748pub fn product_two_matrices_from_flattened_vectors<F: Field>(
749    matrix_a_vec: &[F],
750    matrix_b_vec: &[F],
751    matrix_a_num_rows: usize,
752    matrix_a_num_cols: usize,
753    matrix_b_num_rows: usize,
754    matrix_b_num_cols: usize,
755) -> Vec<F> {
756    assert_eq!(
757        matrix_a_num_cols, matrix_b_num_rows,
758        "Matrix dimensions are not compatible for multiplication"
759    );
760
761    let mut result = vec![F::ZERO; matrix_a_num_rows * matrix_b_num_cols];
762
763    for i in 0..matrix_a_num_rows {
764        for j in 0..matrix_b_num_cols {
765            for k in 0..matrix_a_num_cols {
766                result[i * matrix_b_num_cols + j] += matrix_a_vec[i * matrix_a_num_cols + k]
767                    * matrix_b_vec[k * matrix_b_num_cols + j];
768            }
769        }
770    }
771
772    result
773}
774
775#[cfg(test)]
776mod test {
777
778    use shared_types::Fr;
779
780    use crate::layer::matmult::product_two_matrices_from_flattened_vectors;
781
782    #[test]
783    fn test_product_two_matrices() {
784        let mle_vec_a = vec![
785            Fr::from(1),
786            Fr::from(2),
787            Fr::from(9),
788            Fr::from(10),
789            Fr::from(13),
790            Fr::from(1),
791            Fr::from(3),
792            Fr::from(10),
793        ];
794        let mle_vec_b = vec![Fr::from(3), Fr::from(5), Fr::from(9), Fr::from(6)];
795
796        let res_product =
797            product_two_matrices_from_flattened_vectors(&mle_vec_a, &mle_vec_b, 4, 2, 2, 2);
798
799        let exp_product = vec![
800            Fr::from(3 + 2 * 9),
801            Fr::from(5 + 2 * 6),
802            Fr::from(9 * 3 + 10 * 9),
803            Fr::from(9 * 5 + 10 * 6),
804            Fr::from(13 * 3 + 9),
805            Fr::from(13 * 5 + 6),
806            Fr::from(3 * 3 + 10 * 9),
807            Fr::from(3 * 5 + 10 * 6),
808        ];
809
810        assert_eq!(res_product, exp_product);
811    }
812
813    #[test]
814    fn test_product_two_matrices_2() {
815        let mle_vec_a = vec![
816            Fr::from(3),
817            Fr::from(4),
818            Fr::from(1),
819            Fr::from(6),
820            Fr::from(2),
821            Fr::from(9),
822            Fr::from(0),
823            Fr::from(1),
824            Fr::from(4),
825            Fr::from(5),
826            Fr::from(4),
827            Fr::from(2),
828            Fr::from(4),
829            Fr::from(2),
830            Fr::from(6),
831            Fr::from(7),
832            Fr::from(3),
833            Fr::from(4),
834            Fr::from(1),
835            Fr::from(6),
836            Fr::from(2),
837            Fr::from(9),
838            Fr::from(0),
839            Fr::from(1),
840            Fr::from(4),
841            Fr::from(5),
842            Fr::from(4),
843            Fr::from(2),
844            Fr::from(4),
845            Fr::from(2),
846            Fr::from(6),
847            Fr::from(7),
848        ];
849        let mle_vec_b = vec![
850            Fr::from(3),
851            Fr::from(2),
852            Fr::from(1),
853            Fr::from(5),
854            Fr::from(3),
855            Fr::from(6),
856            Fr::from(7),
857            Fr::from(4),
858        ];
859
860        let res_product =
861            product_two_matrices_from_flattened_vectors(&mle_vec_a, &mle_vec_b, 8, 4, 4, 2);
862
863        let exp_product = vec![
864            Fr::from(58),
865            Fr::from(56),
866            Fr::from(22),
867            Fr::from(53),
868            Fr::from(43),
869            Fr::from(65),
870            Fr::from(81),
871            Fr::from(82),
872            Fr::from(58),
873            Fr::from(56),
874            Fr::from(22),
875            Fr::from(53),
876            Fr::from(43),
877            Fr::from(65),
878            Fr::from(81),
879            Fr::from(82),
880        ];
881
882        assert_eq!(res_product, exp_product);
883    }
884}