frontend/components/binary_operations/
logical_shift.rs

1//! Implements bit shifting gates.
2
3use std::cmp::{max, min};
4
5use itertools::Itertools;
6use shared_types::Field;
7
8use crate::layouter::builder::{CircuitBuilder, NodeRef};
9
10/// A component that performs logical bit shift operations using [crate::layouter::nodes::identity_gate::IdentityGateNode].
11///
12/// Both left and right shifts are supported by using negative/positive values in the `shift_amount`
13/// parameter.
14///
15/// This is a _logical_ shift, meaning that bits which are shifted out are discarded, and zeros are
16/// filled in on the other side.
17///
18/// Requires that the input node has already been verified to contain binary digits.
19#[derive(Clone, Debug)]
20pub struct ShiftNode<F: Field> {
21    output: NodeRef<F>,
22}
23
24impl<F: Field> ShiftNode<F> {
25    /// Create a new [ShiftNode] that performs a shift by `shift_amount` (to the right if
26    /// `shift_amount > 0` or to the left if `shift_amount < 0`) on `input` node which contains
27    /// `2^num_vars` binary digits.
28    ///
29    /// # Requires
30    /// `input` is assumed to only contain binary digits (i.e. only values from the set
31    /// `{F::ZERO, F::ONE}` for a field `F`).
32    pub fn new(
33        builder_ref: &mut CircuitBuilder<F>,
34        num_vars: usize,
35        shift_amount: i32,
36        input: &NodeRef<F>,
37    ) -> Self {
38        // Compute the bit reroutings that effectively shift the
39        // input MLE by the appropriate amount.
40        let shift_wirings = generate_shift_wirings(num_vars, shift_amount);
41        let output = builder_ref.add_identity_gate_node(input, shift_wirings, num_vars, None);
42
43        Self { output }
44    }
45
46    /// Returns a reference to the node containing the shifted value.
47    pub fn get_output(&self) -> NodeRef<F> {
48        self.output.clone()
49    }
50}
51
52fn generate_shift_wirings(num_vars: usize, shift_amount: i32) -> Vec<(u32, u32)> {
53    // Ensure `shift_amount` can represent all possible shift amounts for a given value of
54    // `num_vars`.
55    // In general, if `shift_amount` is a signed `n`-bit integer, it can represent shift amounts
56    // in the integer range `[-2^(n-1), +2^(n-1) - 1]`. For a `2^num_vars`-bit shifter,
57    // ideally we'd like to support all bit shift values in the integer range `[-2^num_vars,
58    // +2^num_vars]`, therefore we have to work under the assumption that `num_vars < n-1`.
59    // Here `shift_amount` is of type `i32` (`n == 32`), so we need `num_vars <= 30`.
60    assert!(num_vars <= 30);
61
62    // Cap the shift amount to be in the range `[-2^num_vars, 2^num_vars]`.
63    let shift_amount = max(min(shift_amount, 1 << num_vars), -(1 << num_vars));
64
65    if shift_amount >= 0 {
66        let shift_amount = shift_amount as u32;
67        (0..(1 << num_vars) - shift_amount)
68            .map(|i| (i + shift_amount, i))
69            .collect_vec()
70    } else {
71        let shift_amount: u32 = shift_amount.unsigned_abs();
72        (0..(1 << num_vars) - shift_amount)
73            .map(|i| (i, i + shift_amount))
74            .collect_vec()
75    }
76}
77
78#[cfg(test)]
79mod test {
80    use super::*;
81
82    #[test]
83    fn test_8bit_right_shift_by_1() {
84        let shift_wirings = generate_shift_wirings(3, 1);
85
86        assert_eq!(
87            shift_wirings,
88            vec![(1, 0), (2, 1), (3, 2), (4, 3), (5, 4), (6, 5), (7, 6)]
89        );
90    }
91
92    #[test]
93    fn test_256bit_right_shift_by_1() {
94        let shift_wirings = generate_shift_wirings(8, 1);
95
96        assert_eq!(
97            shift_wirings,
98            (0..255).into_iter().map(|i| (i + 1, i)).collect_vec()
99        );
100    }
101
102    #[test]
103    fn test_8bit_left_shift_by_1() {
104        let shift_wirings = generate_shift_wirings(3, -1);
105
106        assert_eq!(
107            shift_wirings,
108            vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
109        );
110    }
111
112    #[test]
113    fn test_256bit_left_shift_by_1() {
114        let shift_wirings = generate_shift_wirings(8, -1);
115
116        assert_eq!(
117            shift_wirings,
118            (0..255).into_iter().map(|i| (i, i + 1)).collect_vec()
119        );
120    }
121
122    #[test]
123    fn test_zero_shift() {
124        let shift_wirings = generate_shift_wirings(3, 0);
125        // let zero_wirings = generate_zero_wirings(3, 0);
126
127        assert_eq!(
128            shift_wirings,
129            vec![
130                (0, 0),
131                (1, 1),
132                (2, 2),
133                (3, 3),
134                (4, 4),
135                (5, 5),
136                (6, 6),
137                (7, 7)
138            ]
139        );
140    }
141
142    #[test]
143    fn test_8bit_right_shift_by_bit_length() {
144        let shift_wirings = generate_shift_wirings(3, 1 << 3);
145
146        assert_eq!(shift_wirings, vec![]);
147    }
148
149    #[test]
150    fn test_8bit_right_shift_by_more_than_bit_length() {
151        let shift_wirings = generate_shift_wirings(3, 100);
152
153        assert_eq!(shift_wirings, vec![]);
154    }
155}