Source code for toy_crypto.bit_utils

"""
Various utilities for manipulationg bit-like things.

The utilities here are most subject to change
as many where just quick things I needed.
This is more of a trash heap than a well-thought-out module.
"""

from typing import Optional, Any, Union, Callable, Self
from collections.abc import Iterator
import operator
from .utils import xor
from .types import SupportsBool


[docs] def bits(n: int) -> Iterator[int]: """ 0s and 1s representing bits of n, starting with the least significant bit. :raises ValueError: if n is negative. """ n = int(n) if n < 0: raise ValueError("n cannot be negative") while n > 0: yield n & 1 n >>= 1
[docs] def get_bit(n: int, k: int) -> int: """returns k-th bit of n.""" if k < 0: raise ValueError("k cannot be negative") return (n & (1 << (k))) >> (k)
[docs] def set_bit(n: int, k: int, value: bool | int = True) -> int: """Returns copy of n with k-th bit set to value.""" # k must be greater than 0 if k < 0: raise ValueError("k cannot be negative") if not value: # setting to 0 return n & ~(1 << k) return (1 << k) | n
[docs] def bit_index_linear(n: int, k: int, b: bool | int = 1) -> int | None: """Returns which bit is the k-th b bit in n. This is to mimic bitarray.utils.count_n, but for working with python built-in int """ if k < 1: raise ValueError("n must be positive") bc = n.bit_count() if b else (~n).bit_count() if k > bc: return None b = 1 if b else 0 # naive code that just does a linear search count through bits of n count = 0 for i in range(n.bit_length()): n, r = divmod(n, 2) if r == b: count += 1 if count >= k: return i return None # This really shouldn't ever be reached.
[docs] def bit_index(n: int, k: int, b: bool | int = 1) -> int | None: """Returns which bit is the k-th b bit in n. This is to mimic bitarray.utils.count_n, but for working with python built-in int """ if k < 1: raise ValueError("k must be positive") bc = n.bit_count() if b else (~n).bit_count() bl = n.bit_length() if k > bc: return None b = 1 if b else 0 linear_theshold = 31 if bc <= linear_theshold: return bit_index_linear(n, k, b) mid_index = bl // 2 midpoint = int(2**mid_index) upper, lower = divmod(n, midpoint) # Debugging check against alternative way to compute upper and lower assert upper == n >> mid_index assert lower == n & ((1 << mid_index) - 1) bcl = lower.bit_count() if bcl == k: return mid_index - 1 if bcl > k: ret_val = bit_index(lower, k, b) return ret_val u_count = bit_index(upper, k - bcl, b) if u_count is not None: return bcl + u_count return None
[docs] class Bit: """ Because I made poor choices earlier of how to represent bits, I need an abstraction. """ def __init__(self, b: SupportsBool) -> None: self._value: bool = b is True self._as_int: Optional[int] = None self._as_bytes: Optional[bytes] = None def __bool__(self) -> bool: """Every Bit should be Truthy.""" return self._value
[docs] def as_int(self) -> int: """Returns 0 or 1.""" if not self._as_int: self._as_int = 1 if self._value else 0 return self._as_int
[docs] def as_bytes(self) -> bytes: """Returns a big-endian byte. Either 0x00 or 0x01.""" if not self._as_bytes: self._as_bytes = ( (1).to_bytes(1, "big") if self._value else (0).to_bytes(1, "big") ) return self._as_bytes
def __eq__(self, other: Any) -> bool: """Equality with other.""" ob = self._other_bool(other) if ob is None: return NotImplemented ob = bool(other) return self._value == ob @staticmethod def _other_bool(other: Any) -> Optional[bool]: if isinstance(other, bytes): ob = any([b != 0 for b in other]) elif not isinstance(other, SupportsBool): return None else: ob = other.__bool__() return ob def _logic( self, other: bool, expr: Callable[[bool, bool], bool] ) -> Union["Bit", int, bool, bytes]: """ Abstraction to manage type of :data:`other` for things like :func:`__and__` and :func:`__or__`. """ sb = bool(self) tvalue = expr(sb, other) if isinstance(other, Bit): return Bit(tvalue) if isinstance(other, int): return 1 if tvalue else 0 if isinstance(other, bytes): return (1).to_bytes(1, "big") if tvalue else (0).to_bytes(1, "big") return tvalue def __and__(self, other: Any) -> Union["Bit", int, bool, bytes]: """For the "&" operator.""" ob = self._other_bool(other) if ob is None: return NotImplemented return self._logic(other=ob, expr=lambda s, o: s and o) def __xor__(self, other: Any) -> Union["Bit", int, bool, bytes]: """For the "^" operator.""" ob = self._other_bool(other) if ob is None: return NotImplemented return self._logic(other=ob, expr=lambda s, o: operator.xor(s, o)) def __or__(self, other: Any) -> Union["Bit", int, bool, bytes]: """For the "|" operator.""" ob = self._other_bool(other) if ob is None: return NotImplemented return self._logic(other=ob, expr=lambda s, o: s or o) def inv(self) -> "Bit": inv_b = not self return Bit(inv_b) def __inv__(self) -> "Bit": return self.inv()
[docs] def set_bit_in_byte(byte: int, bit: int, value: SupportsBool) -> int: """Sets the bit-most significant bit to value in byte.""" byte %= 256 bit %= 8 if value: byte |= 1 << bit else: byte &= ~(1 << bit) return byte % 256
[docs] def flip_end(byte: int) -> int: """Return int with reversed bit sequence""" if byte < 0 or byte > 255: raise ValueError("byte is not representable as byte") result = 0 for p in range(8): byte, b = divmod(byte, 2) result += b * (2 ** (8 - p)) return result
[docs] def hamming_distance(a: bytes, b: bytes) -> int: """Hamming distance between byte sequences of equal length. :raises ValueError: if len(a) != len(b). """ if len(a) != len(b): raise ValueError("Lengths are unequal") # hamming distance will be the number of 1 bits in a xor b db: bytes = xor(a, b) # bit_count is only defined for ints, so return int.from_bytes(db, signed=False).bit_count()
class PyBitArray: """A pure Python bitarray-like object. This does not implement all methods of bitarray, nor does it fully follow the bitarray API. This is very much a work in progress. """ """ Internal little-endian bytearray of big endian bytes, but that shouldn't be seen by users. So the bit at index 0 will be the right-most bit of the left-most byte. If I do this right, users will never have to know or deal with that. """ def __init__(self, bit_length: int, fill_bit: SupportsBool = 0) -> None: # Instance attributes that should always exist self._data: bytearray self._length: int # length in used bits self._free_bits: int # number of unused bits in last byte if bit_length < 0: raise ValueError("bit_length cannot be negative") fill_byte: int if not fill_bit: fill_byte = 0 else: fill_byte = 255 self._length = bit_length byte_len, self._free_bits = divmod(self._length, 8) if self._free_bits > 0: byte_len += 1 self._data = bytearray([fill_byte] * byte_len) self._data[-1] >>= self._free_bits @classmethod def from_int(cls, n: int) -> Self: """New instance from int.""" instance = cls(n.bit_length()) idx = 0 while n: n, r = divmod(n, 2) if r: instance[idx] = 1 idx += 1 return instance def append(self, b: SupportsBool) -> None: """appends bit b.""" b = 1 if b else 0 self._length += 1 if self._free_bits == 0: self._data.append(b) else: self[-1] = b self._free_bits -= 1 self._free_bits %= 8 def _inner_getitem(self, index: int) -> int: while index < 0: index += self._length byte_index, bit_index = divmod(index, 8) byte = self._data[byte_index] value = byte & (1 << bit_index) return 1 if value != 0 else 0 def __getitem__(self, index: int) -> int: """Retrieve a bit using [] notation.""" if not index < self._length: raise IndexError return self._inner_getitem(index) def __setitem__(self, index: int, value: SupportsBool) -> None: """Set a bit using [] notation.""" while index < 0: index += self._length byte_index, bit_index = divmod(index, 8) byte = self._data[byte_index] new_byte = set_bit_in_byte(byte, bit_index, value) self._data[byte_index] = new_byte @property def nbytes(self) -> int: """Length in bytes""" return len(self._data) @property def padbits(self) -> int: """Number of pad bits""" return self._free_bits def __len__(self) -> int: """len() is defined.""" return self._length def bits( self, start: int = 0, stop: int | None = None, step: int = 1 ) -> Iterator[int]: """Iterator of bits, from least significant.""" if step == 0: raise ValueError("step cannot be 0") step_sign = 1 if step > 0 else -1 match (stop, step_sign): case None, 1: stop = self._length case None, _: # sign is negative stop = 0 case _, _: stop = stop # I know this, you know this. Now the type checker will know this assert stop is not None indices = range(start, stop, step) for idx in indices: yield self[idx] def __iter__(self) -> Iterator[int]: return self.bits() def __int__(self) -> int: """Integer with first element as least significant bit""" return sum(b * (2**i) for i, b in enumerate(self)) def count(self) -> int: """Returns number of 1 bits""" return int.from_bytes(self._data).bit_count() @classmethod def from_bytes(cls, byte_data: bytes, endian: str = "big") -> Self: ints: list[int] match endian.lower(): case "big": ints = [int(b) for b in byte_data] case "little": ints = [int(b) for b in reversed(byte_data)] case _: raise ValueError('endian must be "big" or "little"') instance = cls(8 * len(ints)) instance._data = bytearray(ints) return instance def to_bytes(self, endian: str = "big") -> bytes: if endian not in ["big", "little"]: raise ValueError('endian must be "big" or "little"') if endian == "big": return bytes(self._data) return bytes(self._data[::-1]) def bit_index(self, k: int, b: int = 1) -> int | None: """The index of the n-th 1 or 0 bit.""" # For the moment I will just use bit_index for int # TODO: Write b-search for bit index using native data as_int = int(self) return bit_index(as_int, k, b)