shared_types/utils/
pippengers.rs

1use itertools::FoldWhile::{Continue, Done};
2use itertools::{concat, Itertools};
3use num::traits::ToBytes;
4
5use crate::{curves::PrimeOrderCurve, HasByteRepresentation};
6
7fn c_bit_scalar_mult<C: PrimeOrderCurve>(c: usize, base_vec: &[C], scalar_vec: &[u64]) -> C {
8    // We create buckets for when the scalar is 1, up to 2^c - 1.
9    let mut buckets: Vec<C> = vec![C::zero(); (1 << c) - 1];
10    scalar_vec.iter().zip(base_vec).for_each(|(scalar, base)| {
11        // Since this is a c-bit scalar, the scalar must fall in one of these buckets.
12        debug_assert!(
13            scalar.to_le_bytes().as_ref().len() * 8 - (scalar.leading_zeros() as usize) <= c
14        );
15        // We skip when the scalar is 0 because it won't contribute to the MSM.
16        if *scalar != 0 {
17            buckets[*scalar as usize - 1] += *base;
18        }
19    });
20    // We perform an optimized addition of the \sum_i{i * value}
21    // by doing a reverse sum and keeping track of running sum
22    // of all the values along with the accumulation.
23    // I.e., 3S_3 + 2S_2 + S_1 =
24    // S_3 + (S_3 + S_2) + (S_3 + S_2 + S_1).
25    let (sum, _) = buckets
26        .iter()
27        .rev()
28        .fold((C::zero(), C::zero()), |(acc, prev), elem| {
29            let t = prev + *elem;
30            (acc + t, t)
31        });
32    sum
33}
34
35fn combine_c_bit_scalar_mults<C: PrimeOrderCurve>(c: usize, bucket_sums: &[C]) -> C {
36    // We combine the c-bit scalar multiplications by going through
37    // the bucket sums, which is in order of most-significant
38    // contribution to the scalar multiplication, and keep adding it
39    // to the accumulator while multiplying by 2^c. So we should get
40    // \sum_i{2^{i*c} * bucket_val[n - i]}.
41    if !bucket_sums.is_empty() {
42        let all_but_last =
43            bucket_sums
44                .iter()
45                .take(bucket_sums.len() - 1)
46                .fold(C::zero(), |mut acc, elem| {
47                    acc += *elem;
48                    (0..c).for_each(|_idx| {
49                        acc = acc.double();
50                    });
51                    acc
52                });
53        all_but_last + *bucket_sums.last().unwrap()
54    } else {
55        C::zero()
56    }
57}
58
59/// Helper function to compute the number of bits in a scalar.
60fn num_bits<C: PrimeOrderCurve>(n: C::Scalar) -> usize {
61    let u64_chunks = n.to_u64s_le();
62    debug_assert_eq!(u64_chunks.len(), 4);
63    u64_chunks
64        .iter()
65        .rev()
66        .fold_while(192_usize, |acc, chunk| {
67            if *chunk == 0 {
68                if acc == 0 {
69                    Done(0)
70                } else {
71                    Continue(acc - 64)
72                }
73            } else {
74                Done(acc + chunk.ilog2() as usize + 1)
75            }
76        })
77        .into_inner()
78}
79
80/// Overall function to compute the MSM of a vector of
81/// group elements to a vector of scalar field elements
82/// using Pippenger's Algorithm.
83///
84/// `c_bucket_size` is the parameter into how we want to
85/// split the larger MSM into smaller MSMs.
86pub fn scalar_mult_pippenger<C: PrimeOrderCurve>(
87    c_bucket_size: usize,
88    base_vec: &[C],
89    scalar_vec: &[C::Scalar],
90) -> C {
91    assert_eq!(scalar_vec.len(), base_vec.len());
92    let n = scalar_vec.len();
93    if n != 0 {
94        let max_input_mle_value = scalar_vec.iter().max().unwrap();
95        let max_num_bits_needed = num_bits::<C>(*max_input_mle_value);
96        // We take the ceiling division to compute the window size.
97        let num_buckets = max_num_bits_needed.div_ceil(c_bucket_size);
98        let mut bucket_groups = vec![vec![0_u64; n]; num_buckets];
99        scalar_vec.iter().enumerate().for_each(|(idx, elem)| {
100            // We compute all the bits of the field elements.
101            let elem_bits_full = elem
102                .to_bytes_le()
103                .iter()
104                .rev()
105                .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) != 0))
106                .collect_vec();
107            let elem_bits_full_len = elem_bits_full.len();
108            // We only use the most significant field elements rounded to include
109            // the minimum number of full windows we can use.
110            let elem_bits = if (num_buckets * c_bucket_size) < elem_bits_full_len {
111                elem_bits_full
112                    .into_iter()
113                    .skip(elem_bits_full_len - (num_buckets * c_bucket_size))
114                    .collect_vec()
115            } else {
116                let padding_len = (num_buckets * c_bucket_size) - elem_bits_full_len;
117                let padding = vec![false; padding_len];
118                concat(vec![padding, elem_bits_full])
119            };
120            // We create the buckets by iterating through each of the scalar
121            // field elements and splitting them by their most significant
122            // to least significant "c-bit chunk."
123            (0..num_buckets).for_each(|bucket_idx| {
124                let bits_in_bucket =
125                    &elem_bits[(c_bucket_size * bucket_idx)..(c_bucket_size * (bucket_idx + 1))];
126                let bucket_value = bits_in_bucket
127                    .iter()
128                    .rev()
129                    .enumerate()
130                    .map(|(i, &b)| (b as usize) << i)
131                    .sum::<usize>();
132                bucket_groups[bucket_idx][idx] = bucket_value as u64;
133            });
134        });
135        // Perform the smaller MSMs.
136        let c_bit_scalar_mults = bucket_groups
137            .iter()
138            .map(|bucket_group| c_bit_scalar_mult(c_bucket_size, base_vec, bucket_group))
139            .collect_vec();
140        // Combine all the smaller MSMs.
141        combine_c_bit_scalar_mults(c_bucket_size, &c_bit_scalar_mults)
142    } else {
143        C::zero()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::halo2curves::bn256::G1 as Bn256;
151    use halo2curves::ff::Field;
152    type Scalar = <Bn256 as PrimeOrderCurve>::Scalar;
153
154    fn naive_msm<C: PrimeOrderCurve>(scalar_vec: &[C::Scalar], base_vec: &[C]) -> C {
155        scalar_vec
156            .iter()
157            .zip(base_vec)
158            .fold(C::zero(), |acc, (scalar, base)| acc + *base * *scalar)
159    }
160
161    #[test]
162    fn test_pippenger_1() {
163        let mut rng = rand::thread_rng();
164        let scalar_vec = vec![6_u64, 15_u64, 13_u64, 12_u64]
165            .into_iter()
166            .map(Scalar::from)
167            .collect_vec();
168        let base_vec = (0..4).map(|_idx| Bn256::random(&mut rng)).collect_vec();
169        let naive_msm = naive_msm(&scalar_vec, &base_vec);
170        let pip_msm = scalar_mult_pippenger(2, &base_vec, &scalar_vec);
171        assert_eq!(naive_msm, pip_msm);
172    }
173
174    #[test]
175    fn test_pippenger_2() {
176        let mut rng = rand::thread_rng();
177        let scalar_vec = vec![289041_u64, 114202_u64, 124023_u64, 858222_u64]
178            .into_iter()
179            .map(Scalar::from)
180            .collect_vec();
181        let base_vec = (0..4).map(|_idx| Bn256::random(&mut rng)).collect_vec();
182        let naive_msm = naive_msm(&scalar_vec, &base_vec);
183        let pip_msm = scalar_mult_pippenger(6, &base_vec, &scalar_vec);
184        assert_eq!(naive_msm, pip_msm);
185    }
186
187    #[test]
188    fn test_pippenger_3() {
189        const NUM_ELEMS: usize = 20;
190        let mut rng = rand::thread_rng();
191        let scalar_vec = (0..NUM_ELEMS)
192            .into_iter()
193            .map(|_idx| Scalar::random(&mut rng))
194            .collect_vec();
195        let base_vec = (0..NUM_ELEMS)
196            .map(|_idx| Bn256::random(&mut rng))
197            .collect_vec();
198        let naive_msm = naive_msm(&scalar_vec, &base_vec);
199        let pip_msm = scalar_mult_pippenger(6, &base_vec, &scalar_vec);
200        assert_eq!(naive_msm, pip_msm);
201    }
202}