# SPDX-FileCopyrightText: 2024-present Jeffrey Goldberg <jeffrey@goldmark.org>
#
# SPDX-License-Identifier: MIT
import math
from collections import UserList
from collections.abc import Iterator, Iterable
from typing import Any, NewType, Optional, Self, TypeGuard
import primefac
from bitarray import bitarray
from bitarray.util import count_n
from . import types
Modulus = NewType("Modulus", int)
"""type Modulus is an int greater than 1"""
[docs]
def is_modulus(n: Any) -> TypeGuard[Modulus]:
if not isinstance(n, int):
return False
if n < 2:
return False
if not isprime(n):
return False
return True
[docs]
def isprime(n: int) -> bool:
"""False if composite; True if very probably prime."""
return primefac.isprime(n)
[docs]
def isqrt(n: int) -> int:
"""returns the greatest r such that r * r =< n"""
if n < 0:
raise ValueError("n cannot be negative")
return primefac.introot(n)
[docs]
def modinv(a: int, m: int) -> int:
"""
Returns b such that :math:`ab \\equiv 1 \\pmod m`.
:raises ValueError: if a is not coprime with m
"""
# python 3.8 allows -1 as power.
return pow(a, -1, m)
[docs]
class FactorList(UserList[tuple[int, int]]):
"""
A FactorList is an list of (prime, exponent) tuples.
It represents the prime factorization of a number.
"""
def __init__(
self,
prime_factors: list[tuple[int, int]] = [],
check_primes: bool = False,
) -> None:
"""
prime_factors should be a list of (prime, exponent) tuples.
Either you ensure that the primes really are prime or use
``check_primes = True``.
"""
super().__init__(prime_factors)
# Normalization will do some sanity checking as well
self.normalize(check_primes=check_primes)
# property-like things that are computed when first needed
self._n: Optional[int] = None
self._totient: Optional[int] = None
self._radical: Optional[FactorList] = None
self._radical_value: Optional[int] = None
self._factors_are_prime: Optional[bool] = None
def __repr__(self) -> str:
s: list[str] = []
for p, e in self.data:
term = f"{p}" if e == 1 else f"{p}^{e}"
s.append(term)
return " * ".join(s)
def __eq__(self, other: object) -> bool:
# Implemented for
# - list
# - int
# - UserDict
if isinstance(other, list):
try:
other_f = FactorList(other)
except (ValueError, TypeError):
return False
return self.data == other_f.data
# Fundamental theorem of arithmetic
if isinstance(other, int):
return self.n == other
if not isinstance(other, UserList):
return NotImplemented
return self.data == other.data
def __add__(self, other: Iterable[tuple[int, int]]) -> "FactorList":
added = super().__add__(other)
added = FactorList(added.data) # init will normalize
return added
[docs]
def normalize(self, check_primes: bool = False) -> Self:
"""
Deduplicates and sorts in prime order, removing exponent == 0 cases.
:raises TypeError: if prime and exponents are not ints
:raises ValueError: if p < 2 or e < 0
This only checks that primes are prime if ``check_primes`` is True.
"""
# this calls for some clever list comprehensions.
# But I am not feeling that clever at the moment
# I will construct a dict from the data and then
# reconstruct the data from the dict
d = {p: 0 for (p, _) in self.data}
for p, e in self.data:
if not isinstance(p, int) or not isinstance(e, int):
raise TypeError("Primes and exponents must be integers")
if p < 2:
raise ValueError(f"{p} should be greater than 1")
if e == 0:
continue
if e < 0:
raise ValueError(f"exponent ({e}) should not be negative")
if check_primes:
if not isprime(p):
raise ValueError(f"{p} is composite")
d[p] += e
self.data = [(p, d[p]) for p in sorted(d.keys())]
return self
@property
def factors_are_prime(self) -> bool:
"""True iff all the alleged primes are prime."""
if self._factors_are_prime is not None:
return self._factors_are_prime
self._factors_are_prime = all([isprime(p) for p, _ in self.data])
return self._factors_are_prime
@property
def n(self) -> int:
"""The integer that this is a factorization of"""
if self._n is None:
self._n = int(math.prod([p**e for p, e in self.data]))
return self._n
@property
def phi(self) -> int:
"""
Returns Euler's Totient (phi)
:math:`\\phi(n)` is the number of numbers
less than n which are coprime with n.
This assumes that the factorization (self) is a prime factorization.
"""
if self._totient is None:
self._totient = int(
math.prod([p ** (e - 1) * (p - 1) for p, e in self.data])
)
return self._totient
[docs]
def coprimes(self) -> Iterator[int]:
"""Iterator of coprimes."""
for a in range(1, self.n):
if not any([a % p == 0 for p, _ in self.data]):
yield a
[docs]
def unit(self) -> int:
"""Unit is always 1 for positive integer factorization."""
return 1
[docs]
def is_integral(self) -> bool:
"""Always true for integer factorization."""
return True
[docs]
def value(self) -> int:
"""Same as ``n()``."""
return self.n
[docs]
def radical(self) -> "FactorList":
"""All exponents on factors set to 1"""
if self._radical is None:
self._radical = FactorList([(p, 1) for p, _ in self.data])
return self._radical
[docs]
def radical_value(self) -> int:
"""Product of factors each used only once."""
if self._radical_value is None:
self._radical_value = math.prod([p for p, _ in self.data])
return self._radical_value
[docs]
def pow(self, n: int) -> "FactorList":
"""Return (self)^n, where *n* is positive int."""
if not types.is_positive_int(n):
raise TypeError("n must be a positive integer")
return FactorList([(p, n * e) for p, e in self.data])
[docs]
def factor(n: int, ith: int = 0) -> FactorList:
"""
Returns list (prime, exponent) factors of n.
Starts trial div at ith prime.
This wraps ``primefac.primefac()``, but creates our FactorList
"""
primes = primefac.primefac(n)
return FactorList([(p, 1) for p in primes])
[docs]
def gcd(*integers: int) -> int:
"""Returns greatest common denomenator of arguments."""
return math.gcd(*integers)
[docs]
def egcd(a: int, b: int) -> tuple[int, int, int]:
"""returns (g, x, y) such that :math:`ax + by = \\gcd(a, b) = g`."""
x0, x1, y0, y1 = 0, 1, 1, 0
while a != 0:
(q, a), b = divmod(b, a), a
y0, y1 = y1, y0 - q * y1
x0, x1 = x1, x0 - q * x1
return b, x0, y0
[docs]
def is_square(n: int) -> bool:
"""True iff n is a perfect square."""
return primefac.ispower(n, 2) is not None
[docs]
def mod_sqrt(a: int, m: int) -> list[int]:
"""Modular square root.
For prime m, this generally returns a list with either two members,
:math:`[r, m - r]` such that :math:`r^2 = {(m - r)}^2 = a \\pmod m`
if such an a is a quadratic residue the empty list if there is no such r.
However, for compatibility with SageMath this can return a list with just
one element in some special cases
- returns ``[0]`` if :math:`m = 3` or :math:`a = 0`.
- returns ``[a % m]`` if :math:`m = 2`.
Otherwise it returns a list with the two quadratic residues if they exist
or and empty list otherwise.
**Warning**: The behavior is undefined if m is not prime.
"""
match m:
case 2:
return [a % m]
case 3:
return [0]
case _ if m < 2:
raise ValueError("modulus must be prime")
if a == 1:
return [1, m - 1]
a = a % m
if a == 0:
return [0]
# check that a is a quadratic residue, return []] if not
if pow(a, (m - 1) // 2, m) != 1:
return []
v = primefac.sqrtmod_prime(a, m)
return [v, (m - v) % m]
[docs]
def lcm(*integers: int) -> int:
"""Least common multiple"""
# requires python 3.9, but I'm already requiring 3.11
return math.lcm(*integers)
[docs]
class Sieve:
"""Sieve of Eratosthenes.
The good parts of this implementation are lifted from the example provided
with the `bitarray package <https://pypi.org/project/bitarray/>`_ source.
"""
"""
We keep the largest sieve created, which will be shared
among all instances
"""
_cached_array = bitarray("0011")
def _make_array(self, n: int) -> None:
len_c = len(self._cached_array)
if n <= len_c:
return
len_e = n - len_c
# Not thread safe. Need to make this atomic
self._cached_array.extend([True] * len_e)
for i in range(2, isqrt(n) + 1):
if self._cached_array[i] is False:
continue
self._cached_array[i * i :: i] = False
[docs]
@classmethod
def clear(cls) -> None:
"""Resets the cached array.
There is no reason to ever use this outside of performance testing.
"""
cls._cached_array = bitarray("0011")
def __init__(self, n: int) -> None:
"""Creates sieve covering the first n integers.
:raises TypeError: if n in not an int.
:raises ValueError: if n < 2.
"""
if not isinstance(n, int):
raise TypeError
if n < 2:
raise ValueError("n must be greater than 2")
self._make_array(n)
self._n = n
self._count: int = self._cached_array[:n].count()
self._bitstring: Optional[str] = None
@property
def n(self) -> int:
return self._n
@property
def array(self) -> bitarray:
"""The sieve as a bitarray."""
return self._cached_array[: self._n]
@property
def count(self) -> int:
"""The number of primes in the sieve."""
return self._count
[docs]
def to01(self) -> str:
"""The sieve as a string of 0s and 1s.
The output is to be read left to right. That is, it should begin with
``001101010001`` corresponding to primes [2, 3, 5, 7, 11]
"""
if self._bitstring is None:
self._bitstring = self._cached_array[: self._n].to01()
return self._bitstring
[docs]
def nth_prime(self, n: int) -> int:
"""Returns n-th prime.
:raises ValueError: if n exceeds count.
"""
if n > self._count:
raise ValueError("n cannot exceed count")
return count_n(self._cached_array, n)