shared_types/utils/
pippengers.rs1use itertools::FoldWhile::{Continue, Done};
2use itertools::{concat, Itertools};
3use num::traits::ToBytes;
4
5use crate::{curves::PrimeOrderCurve, HasByteRepresentation};
6
7fn c_bit_scalar_mult<C: PrimeOrderCurve>(c: usize, base_vec: &[C], scalar_vec: &[u64]) -> C {
8 let mut buckets: Vec<C> = vec![C::zero(); (1 << c) - 1];
10 scalar_vec.iter().zip(base_vec).for_each(|(scalar, base)| {
11 debug_assert!(
13 scalar.to_le_bytes().as_ref().len() * 8 - (scalar.leading_zeros() as usize) <= c
14 );
15 if *scalar != 0 {
17 buckets[*scalar as usize - 1] += *base;
18 }
19 });
20 let (sum, _) = buckets
26 .iter()
27 .rev()
28 .fold((C::zero(), C::zero()), |(acc, prev), elem| {
29 let t = prev + *elem;
30 (acc + t, t)
31 });
32 sum
33}
34
35fn combine_c_bit_scalar_mults<C: PrimeOrderCurve>(c: usize, bucket_sums: &[C]) -> C {
36 if !bucket_sums.is_empty() {
42 let all_but_last =
43 bucket_sums
44 .iter()
45 .take(bucket_sums.len() - 1)
46 .fold(C::zero(), |mut acc, elem| {
47 acc += *elem;
48 (0..c).for_each(|_idx| {
49 acc = acc.double();
50 });
51 acc
52 });
53 all_but_last + *bucket_sums.last().unwrap()
54 } else {
55 C::zero()
56 }
57}
58
59fn num_bits<C: PrimeOrderCurve>(n: C::Scalar) -> usize {
61 let u64_chunks = n.to_u64s_le();
62 debug_assert_eq!(u64_chunks.len(), 4);
63 u64_chunks
64 .iter()
65 .rev()
66 .fold_while(192_usize, |acc, chunk| {
67 if *chunk == 0 {
68 if acc == 0 {
69 Done(0)
70 } else {
71 Continue(acc - 64)
72 }
73 } else {
74 Done(acc + chunk.ilog2() as usize + 1)
75 }
76 })
77 .into_inner()
78}
79
80pub fn scalar_mult_pippenger<C: PrimeOrderCurve>(
87 c_bucket_size: usize,
88 base_vec: &[C],
89 scalar_vec: &[C::Scalar],
90) -> C {
91 assert_eq!(scalar_vec.len(), base_vec.len());
92 let n = scalar_vec.len();
93 if n != 0 {
94 let max_input_mle_value = scalar_vec.iter().max().unwrap();
95 let max_num_bits_needed = num_bits::<C>(*max_input_mle_value);
96 let num_buckets = max_num_bits_needed.div_ceil(c_bucket_size);
98 let mut bucket_groups = vec![vec![0_u64; n]; num_buckets];
99 scalar_vec.iter().enumerate().for_each(|(idx, elem)| {
100 let elem_bits_full = elem
102 .to_bytes_le()
103 .iter()
104 .rev()
105 .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) != 0))
106 .collect_vec();
107 let elem_bits_full_len = elem_bits_full.len();
108 let elem_bits = if (num_buckets * c_bucket_size) < elem_bits_full_len {
111 elem_bits_full
112 .into_iter()
113 .skip(elem_bits_full_len - (num_buckets * c_bucket_size))
114 .collect_vec()
115 } else {
116 let padding_len = (num_buckets * c_bucket_size) - elem_bits_full_len;
117 let padding = vec![false; padding_len];
118 concat(vec![padding, elem_bits_full])
119 };
120 (0..num_buckets).for_each(|bucket_idx| {
124 let bits_in_bucket =
125 &elem_bits[(c_bucket_size * bucket_idx)..(c_bucket_size * (bucket_idx + 1))];
126 let bucket_value = bits_in_bucket
127 .iter()
128 .rev()
129 .enumerate()
130 .map(|(i, &b)| (b as usize) << i)
131 .sum::<usize>();
132 bucket_groups[bucket_idx][idx] = bucket_value as u64;
133 });
134 });
135 let c_bit_scalar_mults = bucket_groups
137 .iter()
138 .map(|bucket_group| c_bit_scalar_mult(c_bucket_size, base_vec, bucket_group))
139 .collect_vec();
140 combine_c_bit_scalar_mults(c_bucket_size, &c_bit_scalar_mults)
142 } else {
143 C::zero()
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::halo2curves::bn256::G1 as Bn256;
151 use halo2curves::ff::Field;
152 type Scalar = <Bn256 as PrimeOrderCurve>::Scalar;
153
154 fn naive_msm<C: PrimeOrderCurve>(scalar_vec: &[C::Scalar], base_vec: &[C]) -> C {
155 scalar_vec
156 .iter()
157 .zip(base_vec)
158 .fold(C::zero(), |acc, (scalar, base)| acc + *base * *scalar)
159 }
160
161 #[test]
162 fn test_pippenger_1() {
163 let mut rng = rand::thread_rng();
164 let scalar_vec = vec![6_u64, 15_u64, 13_u64, 12_u64]
165 .into_iter()
166 .map(Scalar::from)
167 .collect_vec();
168 let base_vec = (0..4).map(|_idx| Bn256::random(&mut rng)).collect_vec();
169 let naive_msm = naive_msm(&scalar_vec, &base_vec);
170 let pip_msm = scalar_mult_pippenger(2, &base_vec, &scalar_vec);
171 assert_eq!(naive_msm, pip_msm);
172 }
173
174 #[test]
175 fn test_pippenger_2() {
176 let mut rng = rand::thread_rng();
177 let scalar_vec = vec![289041_u64, 114202_u64, 124023_u64, 858222_u64]
178 .into_iter()
179 .map(Scalar::from)
180 .collect_vec();
181 let base_vec = (0..4).map(|_idx| Bn256::random(&mut rng)).collect_vec();
182 let naive_msm = naive_msm(&scalar_vec, &base_vec);
183 let pip_msm = scalar_mult_pippenger(6, &base_vec, &scalar_vec);
184 assert_eq!(naive_msm, pip_msm);
185 }
186
187 #[test]
188 fn test_pippenger_3() {
189 const NUM_ELEMS: usize = 20;
190 let mut rng = rand::thread_rng();
191 let scalar_vec = (0..NUM_ELEMS)
192 .into_iter()
193 .map(|_idx| Scalar::random(&mut rng))
194 .collect_vec();
195 let base_vec = (0..NUM_ELEMS)
196 .map(|_idx| Bn256::random(&mut rng))
197 .collect_vec();
198 let naive_msm = naive_msm(&scalar_vec, &base_vec);
199 let pip_msm = scalar_mult_pippenger(6, &base_vec, &scalar_vec);
200 assert_eq!(naive_msm, pip_msm);
201 }
202}