frontend/components/binary_operations/
rotate_bits.rs

1//! Implements bit shifting gates.
2
3use shared_types::Field;
4
5use crate::layouter::builder::{CircuitBuilder, NodeRef};
6/// A component that performs bit (wire rotation) using
7/// [crate::layouter::nodes::identity_gate::IdentityGateNode] (i.e., rewire).
8///
9/// TODO: Generalize to a wire shuffle (permute) gate.
10///
11/// Both left and right rotations are supported by using
12/// negative/positive values in the `rotate_amount` parameter. +ve value
13/// means rotate left
14///
15/// This is a rotate instruction, meaning that bits which are shifted
16/// beyond num_vars of the word, are appended to the other side in the
17/// same order.
18///
19/// Requires that the input node has already been verified to contain
20/// binary digits.
21#[derive(Clone, Debug)]
22pub struct RotateNode<F: Field> {
23    output: NodeRef<F>,
24}
25
26impl<F: Field> RotateNode<F> {
27    /// Create a new [crate::components::binary_operations::logical_shift::ShiftNode] that performs a rotation by `rotate_amount` (to the right if
28    /// `rotate_amount > 0` or to the left if `rotate_amount < 0`) on `input` node which contains
29    /// `2^num_vars` binary digits.
30    ///
31    /// # Requires
32    /// `input` is assumed to only contain binary digits (i.e. only values from the set
33    /// `{F::ZERO, F::ONE}` for a field `F`).
34    pub fn new(
35        builder_ref: &mut CircuitBuilder<F>,
36        num_vars: usize,
37        rotate_amount: i32,
38        input: &NodeRef<F>,
39    ) -> Self {
40        // Compute the bit reroutings that effectively shift the
41        // input MLE by the appropriate amount.
42        let rot_wirings = generate_rot_wirings(num_vars, rotate_amount);
43        let output = builder_ref.add_identity_gate_node(input, rot_wirings, num_vars, None);
44
45        Self { output }
46    }
47
48    /// Returns a reference to the node containing the shifted value.
49    pub fn get_output(&self) -> NodeRef<F> {
50        self.output.clone()
51    }
52}
53
54fn generate_rot_wirings(arity: usize, rot_amount: i32) -> Vec<(u32, u32)> {
55    // Ensure `rot_amount` can represent all possible rotations for a
56    // given value of `arity`. Here `rot_amount` is of type `i32` and a
57    // -ve value means left rotation and a +ve value means right
58    // rotation. The amount of rotation
59
60    assert!(arity <= 30);
61
62    let arr_sz = 1 << arity;
63
64    let mod_n = |x: i32| {
65        let v = x % arr_sz;
66        if v < 0 {
67            v + arr_sz
68        } else {
69            v
70        }
71    };
72
73    (0..arr_sz)
74        .map(|i| (mod_n(i + rot_amount) as _, i as _))
75        .collect()
76}
77
78#[cfg(test)]
79mod test {
80    use super::*;
81    use itertools::Itertools;
82
83    #[test]
84    fn test_8bit_right_rotate_by_one() {
85        let shift_wirings = generate_rot_wirings(3, 1);
86
87        assert_eq!(
88            shift_wirings,
89            vec![
90                (1, 0),
91                (2, 1),
92                (3, 2),
93                (4, 3),
94                (5, 4),
95                (6, 5),
96                (7, 6),
97                (0, 7)
98            ]
99        );
100    }
101
102    #[test]
103    fn test_8bit_left_rotate_by_one() {
104        let shift_wirings = generate_rot_wirings(3, -1);
105
106        assert_eq!(
107            shift_wirings,
108            vec![
109                (7, 0),
110                (0, 1),
111                (1, 2),
112                (2, 3),
113                (3, 4),
114                (4, 5),
115                (5, 6),
116                (6, 7)
117            ]
118        );
119    }
120
121    #[test]
122    fn test_256bit_right_rotate_by_rand() {
123        let shift_wirings = generate_rot_wirings(8, 1);
124        let expected = (1..256)
125            .map(|i| (i, i - 1))
126            .chain((255..256).map(|_| (0, 255)))
127            .collect_vec();
128        assert_eq!(shift_wirings, expected);
129    }
130
131    #[test]
132    fn test_256bit_left_rotate_by_rand() {
133        let rot_wirings = generate_rot_wirings(8, -1);
134        let mut expected_wires = vec![(255, 0)];
135        (1..256).for_each(|i| expected_wires.push((i - 1 as u32, i as u32)));
136        assert_eq!(rot_wirings, expected_wires);
137    }
138}