remainder/mle/evals/
bit_packed_vector.rs

1//! Implements `BitPackedVector`, a version of an immutable vector optimized for
2//! storing field elements compactly.
3#![allow(clippy::needless_lifetimes)]
4use ::serde::{Deserialize, Serialize};
5use ark_std::cfg_into_iter;
6use itertools::Itertools;
7use shared_types::{config::global_config::global_prover_enable_bit_packing, Field};
8
9#[cfg(feature = "parallel")]
10use rayon::iter::{IntoParallelIterator, ParallelIterator};
11
12use itertools::FoldWhile::{Continue, Done};
13use zeroize::Zeroize;
14
15// -------------- Helper Functions -----------------
16
17/// Returns the minimum numbers of bits required to represent prime field
18/// elements in the range `[0, n]`. This is equivalent to computing
19/// `ceil(log_2(n+1))`.
20///
21/// # Complexity
22/// Constant in the size of the representation of `n`.
23///
24/// # Example
25/// ```
26///     use shared_types::Fr;
27///     use remainder::mle::evals::bit_packed_vector::num_bits;
28///
29///     assert_eq!(num_bits(Fr::from(0)), 0);
30///     assert_eq!(num_bits(Fr::from(31)), 5);
31///     assert_eq!(num_bits(Fr::from(32)), 6);
32/// ```
33pub fn num_bits<F: Field>(n: F) -> usize {
34    let u64_chunks = n.to_u64s_le();
35    debug_assert_eq!(u64_chunks.len(), 4);
36    u64_chunks
37        .iter()
38        .rev()
39        .fold_while(192_usize, |acc, chunk| {
40            if *chunk == 0 {
41                if acc == 0 {
42                    Done(0)
43                } else {
44                    Continue(acc - 64)
45                }
46            } else {
47                Done(acc + chunk.ilog2() as usize + 1)
48            }
49        })
50        .into_inner()
51}
52
53// ---------------------------------------------------------
54
55/// A space-efficient representation of an immutable vector of prime field
56/// elements. Particularly useful when all elements have values close to each
57/// other. It provides an interface similar to that of a `Vec`.
58///
59/// # Encoding method
60///
61/// This struct interpretes elements of the prime field `F_p` as integers in the
62/// range `[0, p-1]` and tries to encode each with fewer bits than the default
63/// of `sizeof::<F>() * 8` bits.
64///
65/// In particular, when a new bitpacked vector is created, the
66/// [BitPackedVector::new] method computes the smallest interval `[a, b]`, with
67/// `a < b`, such that all elements `v[i]` in the input vector (interpreted as
68/// integers) belong to `[a, b]`, and then instead of storing `v[i]`, it stores
69/// the value `(v[i] - a) \in [0, b - a]`. If `b - a` is a small integer,
70/// representing `v[i] - a` can be done using `ceil(log_2(b - a + 1))` bits.
71///
72/// It then stores the encoded values compactly by packing together the
73/// representation of many consecutive elements into a single machine word (when
74/// possible). This encoding can store `n` elements using a total size of `(n *
75/// ceil(log_2(b - a + 1)) * word_width) bits`.
76///
77/// # Notes
78/// 1. Currently the implementation uses more storage than the theoretically
79///    optimal mentioned above. This is because:
80///    1. If `ceil(log_2(b - a + 1)) > 64`, we resort to the standard
81///       representation of using `sizeof::<F>()` bytes. This is because for our
82///       use-case, there are not many instances of vectors needing `c \in [65,
83///       256]` bits to encode each value.
84///    2. We round `ceil(log_2(b - a + 1))` up to the nearest divisor of 64.
85///       This is to simplify the implementation by avoiding the situation where
86///       the encoding of an element spans multiple words.
87/// 2. For optimal performance, the buffer used to store the encoded values
88///    should be using machine words (e.g. `buf: Vec<usize>`) instead of always
89///    defaulting to 64-bit entries (`buf: Vec<u64>`) Here we always assume a
90///    64-bit architecture for the simplicity of the implementation.
91#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
92#[serde(bound = "F: Field")]
93pub(in crate::mle::evals) struct BitPackedVector<F: Field> {
94    /// The buffer for storing the bit-packed representation.
95    /// As noted above, for optimal performance, the type of each element should
96    /// be the machine's word size.
97    /// For now, we're keeping it always to `u64` to make it easier to
98    /// work with `F` chunks.
99    ///
100    /// *Invariant*: For every instance of a `BitPackedVector`, either
101    /// [BitPackedVector::buf] or [BitPackedVector::naive_buf] is populated but
102    /// NEVER both.
103    buf: Vec<u64>,
104
105    /// If during initialization it is deduced that the number of bits
106    /// needed per element is more than 64, we revert back to a standard
107    /// representation. In that case, `Self::buf` is never used but instead
108    /// `Self::naive_buf` is populated.
109    ///
110    /// *Invariant*: For every instance of a `BitPackedVector`, either
111    /// [BitPackedVector::buf] or [BitPackedVector::naive_buf] is populated but
112    /// NEVER both.
113    naive_buf: Vec<F>,
114
115    /// The number of field elements stored in this vector.
116    /// This is generally different from `self.buf.len()`.
117    num_elements: usize,
118
119    /// The value of the smallest element in the original vector.
120    /// This is the value of `a` such that all elements of the original
121    /// vector belong to the interval `[a, b]` as described above.
122    offset: F,
123
124    /// The number of bits required to represent each element optimally.
125    /// This is equal to `ceil(log_2(b - a + 1))` as described above.
126    bits_per_element: usize,
127}
128
129impl<F: Field> Zeroize for BitPackedVector<F> {
130    fn zeroize(&mut self) {
131        self.buf.iter_mut().for_each(|x| x.zeroize());
132        self.naive_buf.iter_mut().for_each(|x| x.zeroize());
133        self.num_elements.zeroize();
134        self.offset.zeroize();
135        self.bits_per_element.zeroize();
136    }
137}
138
139impl<F: Field> BitPackedVector<F> {
140    /// Generates a bit-packed vector initialized with `data`.
141    pub fn new(data: &[F]) -> Self {
142        // TODO(ryancao): Distinguish between prover and verifier here
143        if !global_prover_enable_bit_packing() {
144            return Self {
145                buf: vec![],
146                naive_buf: data.to_vec(),
147                num_elements: data.len(),
148                offset: F::ZERO,
149                bits_per_element: 4 * (u64::BITS as usize),
150            };
151        }
152
153        // Handle empty vectors separately.
154        if data.is_empty() {
155            return Self {
156                buf: vec![],
157                naive_buf: vec![],
158                num_elements: 0,
159                offset: F::ZERO,
160                bits_per_element: 0,
161            };
162        }
163
164        let num_elements = data.len();
165
166        let min_val = *cfg_into_iter!(data).min().unwrap();
167        let max_val = *cfg_into_iter!(data).max().unwrap();
168
169        let range = max_val - min_val;
170
171        // Handle constant values separately.
172        if min_val == max_val {
173            return Self {
174                buf: vec![],
175                naive_buf: vec![],
176                num_elements,
177                offset: min_val,
178                bits_per_element: 0,
179            };
180        }
181
182        // Number of bits required to encode each element in the range.
183        let bits_per_element = num_bits(range);
184
185        // Bits available per buffer entry.
186        let entry_width = u64::BITS as usize;
187        // println!("Buffer entry width: {entry_width}");
188
189        // To simplify the implementation, for now we only support bit-packing
190        // of values whose bit-width evenly divides the available bits per
191        // buffer entry, or their bit-width equals 4*64 = 256 bits.  Any other
192        // case is reduced to one of the two by rounding up `bits_per_element`
193        // accordingly.
194        let bits_per_element = if bits_per_element > entry_width {
195            // Resort to storing the raw representation of the field element.
196            4 * entry_width
197        } else {
198            // Round up to next power of two to make sure it evenly divides the
199            // `entry_width`. This assumes that `entry_width` is always a power
200            // of two.
201            assert!(entry_width.is_power_of_two());
202            bits_per_element.next_power_of_two()
203        };
204
205        assert!(
206            bits_per_element == 4 * entry_width
207                || (bits_per_element <= entry_width
208                    && entry_width.is_multiple_of(bits_per_element))
209        );
210
211        if bits_per_element > entry_width {
212            let naive_buf = data.to_vec();
213
214            Self {
215                buf: vec![],
216                naive_buf,
217                num_elements,
218                offset: F::ZERO,
219                bits_per_element,
220            }
221        } else {
222            // Compute an upper bound to the number of buffer entries needed.
223            let buf_len = (bits_per_element * num_elements).div_ceil(entry_width);
224
225            let mut buf = vec![0_u64; buf_len];
226
227            for (i, x) in data.iter().enumerate() {
228                let encoded_x = *(*x - min_val).to_u64s_le().first().unwrap();
229                // println!("Encoded value of {:?}: {encoded_x}", x);
230
231                let buffer_idx = i * bits_per_element / entry_width;
232                assert!(buffer_idx < buf_len);
233
234                let word_idx = i * bits_per_element % entry_width;
235                assert!(word_idx < entry_width);
236
237                // println!(
238                //     "Placing {i}-th element into buffer_idx: {buffer_idx}, and word_idx: {word_idx}"
239                // );
240
241                let prev_entry = &mut buf[buffer_idx];
242
243                // Set new entry.
244                *prev_entry |= encoded_x << word_idx;
245            }
246
247            Self {
248                buf,
249                naive_buf: vec![],
250                num_elements,
251                offset: min_val,
252                bits_per_element,
253            }
254        }
255    }
256
257    /// Return the `index`-th element stored in the array,
258    /// or `None` if `index` is out of bounds.
259    pub fn get(&self, index: usize) -> Option<F> {
260        // Check for index-out-of-bounds.
261        if index >= self.num_elements {
262            return None;
263        }
264
265        if self.bits_per_element == 0 {
266            return Some(self.offset);
267        }
268
269        // Bits per buffer entry.
270        let entry_width = u64::BITS as usize;
271
272        if self.bits_per_element > entry_width {
273            Some(self.naive_buf[index])
274        } else {
275            let buffer_idx = index * self.bits_per_element / entry_width;
276            assert!(buffer_idx < self.buf.len());
277
278            let word_idx = index * self.bits_per_element % entry_width;
279            // println!("Getting buffer idx: {buffer_idx}, word_idx: {word_idx}");
280            assert!(word_idx < entry_width);
281
282            let entry = &self.buf[buffer_idx];
283            let mask: u64 = if self.bits_per_element == 64 {
284                !0x0
285            } else {
286                ((1_u64 << self.bits_per_element) - 1) << word_idx
287            };
288            // println!("Mask: {:#x}", mask);
289
290            let encoded_value = (entry & mask) >> word_idx;
291            let value = self.offset + F::from(encoded_value);
292
293            Some(value)
294        }
295    }
296
297    pub fn len(&self) -> usize {
298        self.num_elements
299    }
300
301    /// Returns the number of bits used to encode each element.
302    #[allow(unused)]
303    pub fn get_bits_per_element(&self) -> usize {
304        self.bits_per_element
305    }
306
307    pub fn iter(&self) -> BitPackedIterator<'_, F> {
308        BitPackedIterator {
309            vec: self,
310            current_index: 0,
311        }
312    }
313
314    #[cfg(test)]
315    pub fn to_vec(&self) -> Vec<F> {
316        self.iter().collect()
317    }
318}
319
320/// Iterator for a `BitPackedVector`. See `BitPackedVector::iter` for generating
321/// one.
322pub struct BitPackedIterator<'a, F: Field> {
323    vec: &'a BitPackedVector<F>,
324    current_index: usize,
325}
326
327impl<'a, F: Field> Iterator for BitPackedIterator<'a, F> {
328    type Item = F;
329
330    fn next(&mut self) -> Option<Self::Item> {
331        if self.current_index < self.vec.len() {
332            let val = self.vec.get(self.current_index).unwrap();
333            self.current_index += 1;
334
335            Some(val)
336        } else {
337            None
338        }
339    }
340}