diff --git a/math/modular_inverse_fermat_little_theorem.cpp b/math/modular_inverse_fermat_little_theorem.cpp index 7550e14bf23..d870c4da978 100644 --- a/math/modular_inverse_fermat_little_theorem.cpp +++ b/math/modular_inverse_fermat_little_theorem.cpp @@ -30,8 +30,8 @@ * a^{m-2} &≡& a^{-1} \;\text{mod}\; m * \f} * - * We will find the exponent using binary exponentiation. Such that the - * algorithm works in \f$O(\log m)\f$ time. + * We will find the exponent using binary exponentiation such that the + * algorithm works in \f$O(\log n)\f$ time. * * Examples: - * * a = 3 and m = 7 @@ -43,56 +43,98 @@ * (as \f$a\times a^{-1} = 1\f$) */ -#include -#include +#include /// for assert +#include /// for std::int64_t +#include /// for IO implementations -/** Recursive function to calculate exponent in \f$O(\log n)\f$ using binary - * exponent. +/** + * @namespace math + * @brief Maths algorithms. + */ +namespace math { +/** + * @namespace modular_inverse_fermat + * @brief Calculate modular inverse using Fermat's Little Theorem. + */ +namespace modular_inverse_fermat { +/** + * @brief Calculate exponent with modulo using binary exponentiation in \f$O(\log b)\f$ time. + * @param a The base + * @param b The exponent + * @param m The modulo + * @return The result of \f$a^{b} % m\f$ */ -int64_t binExpo(int64_t a, int64_t b, int64_t m) { - a %= m; - int64_t res = 1; - while (b > 0) { - if (b % 2) { - res = res * a % m; - } - a = a * a % m; - // Dividing b by 2 is similar to right shift. - b >>= 1; +std::int64_t binExpo(std::int64_t a, std::int64_t b, std::int64_t m) { + a %= m; + std::int64_t res = 1; + while (b > 0) { + if (b % 2 != 0) { + res = res * a % m; } - return res; + a = a * a % m; + // Dividing b by 2 is similar to right shift by 1 bit + b >>= 1; + } + return res; } - -/** Prime check in \f$O(\sqrt{m})\f$ time. +/** + * @brief Check if an integer is a prime number in \f$O(\sqrt{m})\f$ time. + * @param m An intger to check for primality + * @return true if the number is prime + * @return false if the number is not prime */ -bool isPrime(int64_t m) { - if (m <= 1) { - return false; - } else { - for (int64_t i = 2; i * i <= m; i++) { - if (m % i == 0) { - return false; - } - } +bool isPrime(std::int64_t m) { + if (m <= 1) { + return false; + } + for (std::int64_t i = 2; i * i <= m; i++) { + if (m % i == 0) { + return false; } - return true; + } + return true; +} +/** + * @brief calculates the modular inverse. + * @param a Integer value for the base + * @param m Integer value for modulo + * @return The result that is the modular inverse of a modulo m + */ +std::int64_t modular_inverse(std::int64_t a, std::int64_t m) { + while (a < 0) { + a += m; + } + + // Check for invalid cases + if (!isPrime(m) || a == 0) { + return -1; // Invalid input + } + + return binExpo(a, m - 2, m); // Fermat's Little Theorem +} +} // namespace modular_inverse_fermat +} // namespace math + +/** + * @brief Self-test implementation + * @return void + */ +static void test() { + assert(math::modular_inverse_fermat::modular_inverse(0, 97) == -1); + assert(math::modular_inverse_fermat::modular_inverse(15, -2) == -1); + assert(math::modular_inverse_fermat::modular_inverse(3, 10) == -1); + assert(math::modular_inverse_fermat::modular_inverse(3, 7) == 5); + assert(math::modular_inverse_fermat::modular_inverse(1, 101) == 1); + assert(math::modular_inverse_fermat::modular_inverse(-1337, 285179) == 165519); + assert(math::modular_inverse_fermat::modular_inverse(123456789, 998244353) == 25170271); + assert(math::modular_inverse_fermat::modular_inverse(-9876543210, 1000000007) == 784794281); } /** - * Main function + * @brief Main function + * @return 0 on exit */ int main() { - int64_t a, m; - // Take input of a and m. - std::cout << "Computing ((a^(-1))%(m)) using Fermat's Little Theorem"; - std::cout << std::endl << std::endl; - std::cout << "Give input 'a' and 'm' space separated : "; - std::cin >> a >> m; - if (isPrime(m)) { - std::cout << "The modular inverse of a with mod m is (a^(m-2)) : "; - std::cout << binExpo(a, m - 2, m) << std::endl; - } else { - std::cout << "m must be a prime number."; - std::cout << std::endl; - } + test(); // run self-test implementation + return 0; }