remainder/mle/evals/bit_packed_vector.rs
1//! Implements `BitPackedVector`, a version of an immutable vector optimized for
2//! storing field elements compactly.
3#![allow(clippy::needless_lifetimes)]
4use ::serde::{Deserialize, Serialize};
5use ark_std::cfg_into_iter;
6use itertools::Itertools;
7use shared_types::{config::global_config::global_prover_enable_bit_packing, Field};
8
9#[cfg(feature = "parallel")]
10use rayon::iter::{IntoParallelIterator, ParallelIterator};
11
12use itertools::FoldWhile::{Continue, Done};
13use zeroize::Zeroize;
14
15// -------------- Helper Functions -----------------
16
17/// Returns the minimum numbers of bits required to represent prime field
18/// elements in the range `[0, n]`. This is equivalent to computing
19/// `ceil(log_2(n+1))`.
20///
21/// # Complexity
22/// Constant in the size of the representation of `n`.
23///
24/// # Example
25/// ```
26/// use shared_types::Fr;
27/// use remainder::mle::evals::bit_packed_vector::num_bits;
28///
29/// assert_eq!(num_bits(Fr::from(0)), 0);
30/// assert_eq!(num_bits(Fr::from(31)), 5);
31/// assert_eq!(num_bits(Fr::from(32)), 6);
32/// ```
33pub fn num_bits<F: Field>(n: F) -> usize {
34 let u64_chunks = n.to_u64s_le();
35 debug_assert_eq!(u64_chunks.len(), 4);
36 u64_chunks
37 .iter()
38 .rev()
39 .fold_while(192_usize, |acc, chunk| {
40 if *chunk == 0 {
41 if acc == 0 {
42 Done(0)
43 } else {
44 Continue(acc - 64)
45 }
46 } else {
47 Done(acc + chunk.ilog2() as usize + 1)
48 }
49 })
50 .into_inner()
51}
52
53// ---------------------------------------------------------
54
55/// A space-efficient representation of an immutable vector of prime field
56/// elements. Particularly useful when all elements have values close to each
57/// other. It provides an interface similar to that of a `Vec`.
58///
59/// # Encoding method
60///
61/// This struct interpretes elements of the prime field `F_p` as integers in the
62/// range `[0, p-1]` and tries to encode each with fewer bits than the default
63/// of `sizeof::<F>() * 8` bits.
64///
65/// In particular, when a new bitpacked vector is created, the
66/// [BitPackedVector::new] method computes the smallest interval `[a, b]`, with
67/// `a < b`, such that all elements `v[i]` in the input vector (interpreted as
68/// integers) belong to `[a, b]`, and then instead of storing `v[i]`, it stores
69/// the value `(v[i] - a) \in [0, b - a]`. If `b - a` is a small integer,
70/// representing `v[i] - a` can be done using `ceil(log_2(b - a + 1))` bits.
71///
72/// It then stores the encoded values compactly by packing together the
73/// representation of many consecutive elements into a single machine word (when
74/// possible). This encoding can store `n` elements using a total size of `(n *
75/// ceil(log_2(b - a + 1)) * word_width) bits`.
76///
77/// # Notes
78/// 1. Currently the implementation uses more storage than the theoretically
79/// optimal mentioned above. This is because:
80/// 1. If `ceil(log_2(b - a + 1)) > 64`, we resort to the standard
81/// representation of using `sizeof::<F>()` bytes. This is because for our
82/// use-case, there are not many instances of vectors needing `c \in [65,
83/// 256]` bits to encode each value.
84/// 2. We round `ceil(log_2(b - a + 1))` up to the nearest divisor of 64.
85/// This is to simplify the implementation by avoiding the situation where
86/// the encoding of an element spans multiple words.
87/// 2. For optimal performance, the buffer used to store the encoded values
88/// should be using machine words (e.g. `buf: Vec<usize>`) instead of always
89/// defaulting to 64-bit entries (`buf: Vec<u64>`) Here we always assume a
90/// 64-bit architecture for the simplicity of the implementation.
91#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
92#[serde(bound = "F: Field")]
93pub(in crate::mle::evals) struct BitPackedVector<F: Field> {
94 /// The buffer for storing the bit-packed representation.
95 /// As noted above, for optimal performance, the type of each element should
96 /// be the machine's word size.
97 /// For now, we're keeping it always to `u64` to make it easier to
98 /// work with `F` chunks.
99 ///
100 /// *Invariant*: For every instance of a `BitPackedVector`, either
101 /// [BitPackedVector::buf] or [BitPackedVector::naive_buf] is populated but
102 /// NEVER both.
103 buf: Vec<u64>,
104
105 /// If during initialization it is deduced that the number of bits
106 /// needed per element is more than 64, we revert back to a standard
107 /// representation. In that case, `Self::buf` is never used but instead
108 /// `Self::naive_buf` is populated.
109 ///
110 /// *Invariant*: For every instance of a `BitPackedVector`, either
111 /// [BitPackedVector::buf] or [BitPackedVector::naive_buf] is populated but
112 /// NEVER both.
113 naive_buf: Vec<F>,
114
115 /// The number of field elements stored in this vector.
116 /// This is generally different from `self.buf.len()`.
117 num_elements: usize,
118
119 /// The value of the smallest element in the original vector.
120 /// This is the value of `a` such that all elements of the original
121 /// vector belong to the interval `[a, b]` as described above.
122 offset: F,
123
124 /// The number of bits required to represent each element optimally.
125 /// This is equal to `ceil(log_2(b - a + 1))` as described above.
126 bits_per_element: usize,
127}
128
129impl<F: Field> Zeroize for BitPackedVector<F> {
130 fn zeroize(&mut self) {
131 self.buf.iter_mut().for_each(|x| x.zeroize());
132 self.naive_buf.iter_mut().for_each(|x| x.zeroize());
133 self.num_elements.zeroize();
134 self.offset.zeroize();
135 self.bits_per_element.zeroize();
136 }
137}
138
139impl<F: Field> BitPackedVector<F> {
140 /// Generates a bit-packed vector initialized with `data`.
141 pub fn new(data: &[F]) -> Self {
142 // TODO(ryancao): Distinguish between prover and verifier here
143 if !global_prover_enable_bit_packing() {
144 return Self {
145 buf: vec![],
146 naive_buf: data.to_vec(),
147 num_elements: data.len(),
148 offset: F::ZERO,
149 bits_per_element: 4 * (u64::BITS as usize),
150 };
151 }
152
153 // Handle empty vectors separately.
154 if data.is_empty() {
155 return Self {
156 buf: vec![],
157 naive_buf: vec![],
158 num_elements: 0,
159 offset: F::ZERO,
160 bits_per_element: 0,
161 };
162 }
163
164 let num_elements = data.len();
165
166 let min_val = *cfg_into_iter!(data).min().unwrap();
167 let max_val = *cfg_into_iter!(data).max().unwrap();
168
169 let range = max_val - min_val;
170
171 // Handle constant values separately.
172 if min_val == max_val {
173 return Self {
174 buf: vec![],
175 naive_buf: vec![],
176 num_elements,
177 offset: min_val,
178 bits_per_element: 0,
179 };
180 }
181
182 // Number of bits required to encode each element in the range.
183 let bits_per_element = num_bits(range);
184
185 // Bits available per buffer entry.
186 let entry_width = u64::BITS as usize;
187 // println!("Buffer entry width: {entry_width}");
188
189 // To simplify the implementation, for now we only support bit-packing
190 // of values whose bit-width evenly divides the available bits per
191 // buffer entry, or their bit-width equals 4*64 = 256 bits. Any other
192 // case is reduced to one of the two by rounding up `bits_per_element`
193 // accordingly.
194 let bits_per_element = if bits_per_element > entry_width {
195 // Resort to storing the raw representation of the field element.
196 4 * entry_width
197 } else {
198 // Round up to next power of two to make sure it evenly divides the
199 // `entry_width`. This assumes that `entry_width` is always a power
200 // of two.
201 assert!(entry_width.is_power_of_two());
202 bits_per_element.next_power_of_two()
203 };
204
205 assert!(
206 bits_per_element == 4 * entry_width
207 || (bits_per_element <= entry_width
208 && entry_width.is_multiple_of(bits_per_element))
209 );
210
211 if bits_per_element > entry_width {
212 let naive_buf = data.to_vec();
213
214 Self {
215 buf: vec![],
216 naive_buf,
217 num_elements,
218 offset: F::ZERO,
219 bits_per_element,
220 }
221 } else {
222 // Compute an upper bound to the number of buffer entries needed.
223 let buf_len = (bits_per_element * num_elements).div_ceil(entry_width);
224
225 let mut buf = vec![0_u64; buf_len];
226
227 for (i, x) in data.iter().enumerate() {
228 let encoded_x = *(*x - min_val).to_u64s_le().first().unwrap();
229 // println!("Encoded value of {:?}: {encoded_x}", x);
230
231 let buffer_idx = i * bits_per_element / entry_width;
232 assert!(buffer_idx < buf_len);
233
234 let word_idx = i * bits_per_element % entry_width;
235 assert!(word_idx < entry_width);
236
237 // println!(
238 // "Placing {i}-th element into buffer_idx: {buffer_idx}, and word_idx: {word_idx}"
239 // );
240
241 let prev_entry = &mut buf[buffer_idx];
242
243 // Set new entry.
244 *prev_entry |= encoded_x << word_idx;
245 }
246
247 Self {
248 buf,
249 naive_buf: vec![],
250 num_elements,
251 offset: min_val,
252 bits_per_element,
253 }
254 }
255 }
256
257 /// Return the `index`-th element stored in the array,
258 /// or `None` if `index` is out of bounds.
259 pub fn get(&self, index: usize) -> Option<F> {
260 // Check for index-out-of-bounds.
261 if index >= self.num_elements {
262 return None;
263 }
264
265 if self.bits_per_element == 0 {
266 return Some(self.offset);
267 }
268
269 // Bits per buffer entry.
270 let entry_width = u64::BITS as usize;
271
272 if self.bits_per_element > entry_width {
273 Some(self.naive_buf[index])
274 } else {
275 let buffer_idx = index * self.bits_per_element / entry_width;
276 assert!(buffer_idx < self.buf.len());
277
278 let word_idx = index * self.bits_per_element % entry_width;
279 // println!("Getting buffer idx: {buffer_idx}, word_idx: {word_idx}");
280 assert!(word_idx < entry_width);
281
282 let entry = &self.buf[buffer_idx];
283 let mask: u64 = if self.bits_per_element == 64 {
284 !0x0
285 } else {
286 ((1_u64 << self.bits_per_element) - 1) << word_idx
287 };
288 // println!("Mask: {:#x}", mask);
289
290 let encoded_value = (entry & mask) >> word_idx;
291 let value = self.offset + F::from(encoded_value);
292
293 Some(value)
294 }
295 }
296
297 pub fn len(&self) -> usize {
298 self.num_elements
299 }
300
301 /// Returns the number of bits used to encode each element.
302 #[allow(unused)]
303 pub fn get_bits_per_element(&self) -> usize {
304 self.bits_per_element
305 }
306
307 pub fn iter(&self) -> BitPackedIterator<'_, F> {
308 BitPackedIterator {
309 vec: self,
310 current_index: 0,
311 }
312 }
313
314 #[cfg(test)]
315 pub fn to_vec(&self) -> Vec<F> {
316 self.iter().collect()
317 }
318}
319
320/// Iterator for a `BitPackedVector`. See `BitPackedVector::iter` for generating
321/// one.
322pub struct BitPackedIterator<'a, F: Field> {
323 vec: &'a BitPackedVector<F>,
324 current_index: usize,
325}
326
327impl<'a, F: Field> Iterator for BitPackedIterator<'a, F> {
328 type Item = F;
329
330 fn next(&mut self) -> Option<Self::Item> {
331 if self.current_index < self.vec.len() {
332 let val = self.vec.get(self.current_index).unwrap();
333 self.current_index += 1;
334
335 Some(val)
336 } else {
337 None
338 }
339 }
340}