Source code for toy_crypto.rsa

from dataclasses import dataclass
import hashlib
from hmac import compare_digest
import math
import secrets
from typing import Callable
from toy_crypto import utils
from toy_crypto.nt import lcm, modinv

_DEFAULT_E = 65537


[docs] def default_e() -> int: """Returns the default public exponent, 65537""" return _DEFAULT_E
class DecryptionError(Exception): """ For secure implementations it is important to not report why a decryption failed. See :func:`Oaep.allow_unsafe_messages` to enable reasons why decryption fails. """ type HashFunc = Callable[[bytes], hashlib._Hash] """Type for hashlib style hash function.""" type MgfFunc = Callable[[bytes, int, str], bytes] """Type for RFC8017 Mask Generation Function."""
[docs] class Oaep: """ Tools and data for OAEP. Although this attempts to follow :rfc:`8017` in many respects, this is not designed to be interoperable with compliant keys and ciphertext. """ _unsafe_messages: str = "forbid" """Anything other than "allow" will keep DecryptionError messages safe."""
[docs] @classmethod def allow_unsafe_messages(cls, allow: bool = True) -> None: """Allow (or disallow) verbose DecryptionError messages.""" cls._unsafe_messages = "allow" if allow else "forbid"
[docs] @classmethod def are_unsafe_messages_allowed(cls) -> bool: """Does what it says on the tin.""" if cls._unsafe_messages == "allow": return True return False
@classmethod def _unsafe_msg(cls, message: str) -> str: if cls._unsafe_messages == "allow": return message return "Not telling, nohow!"
[docs] @dataclass(frozen=True, kw_only=True) class HashInfo: """Information about hash function :param hashlib_name: Name as known by hashlib :param function: The callable function :param digest_size: in bytes :param input_limit: in bytes Note that names and identifiers here do not conform to RFCs. These are not mean to be interoperable with anything out in the world. """ #: Name as known by hashlib hashlib_name: str function: HashFunc #: The callable function itself digest_size: int #: in bytes input_limit: int #: maximum input, in bytes
[docs] @dataclass(frozen=True, kw_only=True) class MgfInfo: """Information about Mask Generation function. :param algorithm: Name of the algorithm. :param hashAlgorithm: Key in :data:`KNOWN_HASHES`. :param function: A Callable mask generation function """ algorithm: str # eg "id-mfg1" hashAlgorithm: str # Key in KNOWN_HASHES function: MgfFunc
[docs] @staticmethod def mgf1( seed: bytes, # There do not appear to be any constraints on this length: int, # Output length hash_id: str, # key for KNOWN_HASHES ) -> bytes: """Mask generation function. Generates a unique mask of length :data:`length` as described in :rfc:`appendix B.2.1 <8017#appendix-B.2.1>` of :rfc:`8017`. :param seed: This should come from a CSPRNG :param length: Length in bytes of the mask to generate. :param hash_id: The name hash function in :data:`KNOWN_HASHES`. :raises ValueError: if :data:`length` :math:`> 2^{32}` bytes. :raises ValueError: if :data:`hash_id` is unknown. """ """ I am using Pythonic variable naming instead of what is in the RFC RFC Name | My name | Description | Type ----------------------------------------------------------------- mfgSeed | seed | seed from which mask is generated | bytes maskLen | length | Intended length of mask | int mask | mask | Output | bytes T | t | Internal array for building mask | bytearray C | counter | Counter, four octets | bytes | hash_id | ID of hash function | str Hash | hasher | CS hash function | HashFunc """ if length > 2**32: raise ValueError("mask too long") try: hash = Oaep.KNOWN_HASHES[hash_id] except KeyError: raise ValueError(f'Unsupported hash function: "{hash_id}') digest_size = hash.digest_size hasher = hash.function t = bytearray() # "For counter from 0 to \ceil (maskLen / hLen) - 1, ..." # range is not inclusive at high end. So we need to add one for c in range(math.ceil(length / digest_size) - 1 + 1): counter = Oaep.i2osp(c, 4) assert len(counter) == 4 digest = hasher(seed + counter).digest() t.extend(digest) assert len(t) >= length mask = bytes(t[:length]) return mask
KNOWN_HASHES: dict[str, HashInfo] = { "sha256": HashInfo( hashlib_name="sha256", function=hashlib.sha256, digest_size=32, input_limit=2**61 - 1, ), # We need sha1 because that is what test vectors exist for "sha1": HashInfo( hashlib_name="sha1", function=hashlib.sha1, digest_size=20, input_limit=2**64 - 1, ), } """Hashes known for OAEP. keys will be hashlib names.""" KNOWN_MGFS: dict[str, MgfInfo] = { "mgf1SHA256": MgfInfo( algorithm="id_mgf1", hashAlgorithm="sha256", function=mgf1 ), "mgf1SHA1": MgfInfo( algorithm="id_mgf1", hashAlgorithm="sha1", function=mgf1 ), } """Known Mask Generation Functions."""
[docs] @staticmethod def i2osp(n: int, length: int) -> bytes: """Integer to an octet string of length length. Implements function from :rfc:`RFC 8017 §4.1 <8017#section-4.1>`. :param n: A non-negative integer :param length: Length of returned bytes object :raises ValueError: if :data:`n` is negative. :raises ValueError: if :data:`n` cannot fit in :data:`length` bytes .. warning:: When called from a decryption operation, exceptions should be caught and handled discretely. All operations big-endian. """ if n < 0: raise ValueError("Number cannot be negative") if (n.bit_length() + 7) // 8 > length: raise ValueError("input is too large for the given length") return n.to_bytes(length, byteorder="big", signed=False)
[docs] @staticmethod def os2ip(x: bytes) -> int: """octet-stream to unsigned big-endian int. Implements function from :rfc:`RFC 8017 §4.2 <8017#section-4.2>`. :param x: The octet-stream (:py:class:`bytes`) you want to make an :py:class:`int` from. Returned is a non-negative integer. All operations are big-endian. """ return int.from_bytes(x, byteorder="big", signed=False)
[docs] class PublicKey: def __init__(self, modulus: int, public_exponent: int) -> None: """Public key from public values.""" self._N = modulus self._e = public_exponent @property def N(self) -> int: """Public modulus N.""" return self._N @property def e(self) -> int: """Public exponent e""" return self._e
[docs] def encrypt(self, message: int) -> int: """Primitive encryption with neither padding nor nonce. :raises ValueError: if message < 0 :raises ValueError: if message isn't less than the public modulus """ if message < 0: raise ValueError("Positive messages only") """ There is a reason for the explicit conversion to int in the comparison below. If message was created as a member of a SageMath finite group mod N, self._N would be converted to that before comparison and self._N ≡ 0 (mod self._N). """ if not int(message) < self._N: raise ValueError("Message too big") return pow(base=message, exp=self._e, mod=self._N)
[docs] def oaep_encrypt( self, message: bytes, label: bytes = b"", hash_id: str = "sha256", mgf_id: str = "mgf1SHA256", _seed: bytes | None = None, # For testing only ) -> bytes: """ RSA OAEP encryption. :param message: The message to encrypt. :param label: Rarely used. Just leave as default. :param hash_id: Name of the hash function. :param mgf_id: Name of the MGF function (with hash). :param _seed: Used for testing only. OAEP is not supposed to be deterministic. :raises ValueError: if hash or MGF is not recognized. :raises ValueError: if lengths of inputs exceed what modulus and hash sizes can accommodate. https://datatracker.ietf.org/doc/html/rfc8017#section-7.1.1 """ try: h = Oaep.KNOWN_HASHES[hash_id] except KeyError: raise ValueError(f'Unsupported hash: "{hash_id}') try: mgf = Oaep.KNOWN_MGFS[mgf_id] except KeyError: raise ValueError( f'Unsupported mask generation function: "{mgf_id}' ) if len(label) > h.input_limit: raise ValueError("label too long") k = (self.N.bit_length() + 7) // 8 # length of N in bytes if len(message) > k - 2 * h.digest_size - 2: raise ValueError("message too long") lhash = h.function(label).digest() ps_length = k - len(message) - 2 * h.digest_size - 2 padding_string = bytes(ps_length) data_block = lhash + padding_string + bytes([0x01]) + message # So that we can test with a set seed seed: bytes if _seed is None: seed = secrets.token_bytes(h.digest_size) else: seed = _seed mask = mgf.function(seed, k - h.digest_size - 1, mgf.hashAlgorithm) masked_db = utils.xor(data_block, mask) seed_mask = mgf.function(masked_db, h.digest_size, mgf.hashAlgorithm) masked_seed = utils.xor(seed, seed_mask) encoded_m = bytes([0x00]) + masked_seed + masked_db m = Oaep.os2ip(encoded_m) ctext = self.encrypt(m) return Oaep.i2osp(ctext, k)
def __eq__(self, other: object) -> bool: """True when each has the same modulus and public exponent. When comparing to a PrivateKey, this compares only the public parts. """ if isinstance(other, PublicKey): return self.e == other.e and self.N == other.N return NotImplemented
[docs] class PrivateKey: def __init__(self, p: int, q: int, pub_exponent: int = _DEFAULT_E) -> None: """RSA private key from primes p and q. This does not perform any sanity checks on p and q. It is your responsibility to ensure that p and q are prime :raises ValueError: if e is not coprime with lcm(p - 1, q - 1). """ self._p = p self._q = q self._e = pub_exponent self._N = self._p * self._q self._pubkey = PublicKey(self._N, self._e) self._dP = modinv(self._e, p - 1) self._dQ = modinv(self._e, (self._q - 1)) self._qInv = modinv(self._q, self._p) try: self._d = self._compute_d() except ValueError: raise ValueError("p, q, and e are incompatible with each other ") @property def pub_key(self) -> PublicKey: """The public key corresponding to self. The public key does not contain any secrets. """ return self._pubkey @property def e(self) -> int: """Public exponent.""" return self._e def __eq__(self, other: object) -> bool: """True iff keys are mathematically equivalent Private keys with internal differences can behave identically with respect to input and output. This comparison will return True when they are equivalent in this respect. When compared to a PublicKey, this compares only the public part. """ if isinstance(other, PrivateKey): return self.pub_key == other.pub_key if isinstance(other, PublicKey): return self.pub_key == other return NotImplemented def _compute_d(self) -> int: λ = lcm(self._p - 1, self._q - 1) try: return modinv(self.e, λ) except ValueError: raise ValueError("Inverse of e mod λ does not exist")
[docs] def decrypt(self, ciphertext: int) -> int: """Primitive decryption. :param ciphertext: Ciphertext as :py:class:`int` :raises ValueError: if :data:`ciphertext` is out of range for this key. """ ciphertext = int(ciphertext) # See comment in PublicKey.encrypt() if ciphertext < 1 or ciphertext >= self.pub_key.N: raise ValueError("ciphertext is out of range") # m = pow(base=ciphertext, exp=self._d, mod=self._N) # but we will use the CRT # version comes from rfc8017 §5.1.2 m_1 = pow(ciphertext, self._dP, self._p) m_2 = pow(ciphertext, self._dQ, self._q) # CRT for the win! h = ((m_1 - m_2) * self._qInv) % self._p m = m_2 + self._q * h return m
[docs] def oaep_decrypt( self, ciphertext: bytes, label: bytes = b"", hash_id: str = "sha256", mgf_id: str = "mgf1SHA256", ) -> bytes: """ RSA OAEP decryption. :param ciphertext: The message to encrypt. :param label: Rarely used. :param hash_id: Name of the hash function. :param mgf_id: Name of the MGF function (with hash). :raises ValueError: if hash or MGF is not recognized. :raises DecryptionError: on various decryption errors. If unsafe error reporting is enabled, details of decryption errors will be provided. """ try: h = Oaep.KNOWN_HASHES[hash_id] except KeyError: raise ValueError(f'Unsupported hash: "{hash_id}') try: mgf = Oaep.KNOWN_MGFS[mgf_id] except KeyError: raise ValueError( f'Unsupported mask generation function: "{mgf_id}' ) k = ( self.pub_key.N.bit_length() + 7 ) // 8 # ceil(length of N in bytes) # checks in Step 1 involve no secrets # Step 1.a if len(label) > h.input_limit: raise DecryptionError(Oaep._unsafe_msg("Label too long")) # Step 1.b if len(ciphertext) != k: raise DecryptionError(Oaep._unsafe_msg("Ciphertext is wrong size")) # Step 1.c if k < 2 * h.digest_size + 2: raise DecryptionError(Oaep._unsafe_msg("Modulus is way too short")) # Step 2.a c = Oaep.os2ip(ciphertext) # Step 2.b try: m = self.decrypt(c) except Exception as e: raise DecryptionError( Oaep._unsafe_msg(f"Primitive decryption error: {e}") ) # Step 2.c # if k is computed correctly, conversion of bytes # should never fail. em = Oaep.i2osp(m, k) # Step 3 is where we perform validations that make use of secrets # Step 3.a is done before other tests to thwart certain timing attacks lhash = h.function(label).digest() # Step 3.b. Parse encoded message y: int = em[0] masked_seed: bytes = em[1 : h.digest_size + 1] masked_datablock: bytes = em[h.digest_size + 1 :] # Step 3.c and 3.d seed_mask = mgf.function( masked_datablock, h.digest_size, mgf.hashAlgorithm ) seed = utils.xor(masked_seed, seed_mask) # Steps 3.e and 3.f db_mask = mgf.function(seed, k - h.digest_size - 1, mgf.hashAlgorithm) data_block = utils.xor(masked_datablock, db_mask) # Parsing portion of Step 3.g lhash_prime: bytes = data_block[: h.digest_size] remainder: bytes = data_block[h.digest_size :].lstrip(bytes([0])) one: int = remainder[0] message: bytes = remainder[1:] # Validation portion of Step 3.g if one != 1: raise DecryptionError(Oaep._unsafe_msg("Expected 0x01")) if not compare_digest(lhash, lhash_prime): raise DecryptionError(Oaep._unsafe_msg("Label mismatch")) if y != 0: raise Exception(Oaep._unsafe_msg("Expected 0x00 leading byte")) return message