This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub kmyk/competitive-programming-library
#include "modulus/modsqrt.hpp"
#pragma once #include <cassert> #include <cstdint> #include "../modulus/modpow.hpp" /** * @brief find the sqrt $b$ of $a$ modulo $p$ * @param p must be a prime * @note i.e. $b^2 \equiv a \pmod{p}$ * @note -1 if not found */ int modsqrt(int a, int p) { auto legendre_symbol = [&](int a) { return modpow(a, (p - 1) / 2, p); // Euler's criterion }; a %= p; if (a == 0) return 0; if (p == 2) return a; assert (p >= 3); if (legendre_symbol(a) != +1) return -1; int b = 1; while (legendre_symbol(b) == 1) { b += 1; } int e = 0; int m = p - 1; while (m % 2 == 0) { m /= 2; e += 1; } int64_t x = modpow(a, (m - 1) / 2, p); int64_t y = a * x % p * x % p; x = x * a % p; int64_t z = modpow(b, m, p); while (y != 1) { int j = 0; for (int64_t t = y; t != 1; t = t * t % p) ++ j; if (e <= j) return -1; z = modpow(z, 1ll << (e - j - 1), p); x = x * z % p; z = z * z % p; y = y * z % p; e = j; } assert (x * x % p == a); return x; }
#line 2 "modulus/modsqrt.hpp" #include <cassert> #include <cstdint> #line 4 "modulus/modpow.hpp" inline int32_t modpow(uint_fast64_t x, uint64_t k, int32_t MOD) { assert (/* 0 <= x and */ x < (uint_fast64_t)MOD); uint_fast64_t y = 1; for (; k; k >>= 1) { if (k & 1) (y *= x) %= MOD; (x *= x) %= MOD; } assert (/* 0 <= y and */ y < (uint_fast64_t)MOD); return y; } #line 5 "modulus/modsqrt.hpp" /** * @brief find the sqrt $b$ of $a$ modulo $p$ * @param p must be a prime * @note i.e. $b^2 \equiv a \pmod{p}$ * @note -1 if not found */ int modsqrt(int a, int p) { auto legendre_symbol = [&](int a) { return modpow(a, (p - 1) / 2, p); // Euler's criterion }; a %= p; if (a == 0) return 0; if (p == 2) return a; assert (p >= 3); if (legendre_symbol(a) != +1) return -1; int b = 1; while (legendre_symbol(b) == 1) { b += 1; } int e = 0; int m = p - 1; while (m % 2 == 0) { m /= 2; e += 1; } int64_t x = modpow(a, (m - 1) / 2, p); int64_t y = a * x % p * x % p; x = x * a % p; int64_t z = modpow(b, m, p); while (y != 1) { int j = 0; for (int64_t t = y; t != 1; t = t * t % p) ++ j; if (e <= j) return -1; z = modpow(z, 1ll << (e - j - 1), p); x = x * z % p; z = z * z % p; y = y * z % p; e = j; } assert (x * x % p == a); return x; }