CodeChef submission 44635 (C++ 4.0.0-8) plaintext list. Status: AC, problem E4, contest JULY09. By ashutoshmehra (Ashutosh Mehra), 2009-07-01 22:35:09.
// quadratic_eqns.cpp -- Wed Jul 01 2009 #include <stdio.h> #include <stdlib.h> #include <math.h> #include <time.h> #include <assert.h> #include <algorithm> typedef unsigned long long ull; typedef signed long long sll; // #define ENABLE_TESTING 1 ull modpow(ull base, ull n, ull modulus) { if(n == 0) { return 1; } else if(!(n & 1)) { /* Even */ ull t = modpow(base, n >> 1, modulus); return (t * t) % modulus; } else { /* Odd */ ull t = modpow(base, n - 1, modulus); return (t * base) % modulus; } } /* Returns the legendre symbol (a|p): 1 is a is a quardratic residue of p, -1 if it is not. We assert that Gcd(a,p) == 1, and that p is an odd prime. */ sll legendre(ull a, ull p) { // assert(is_prime(p) && a%p != 0); /* We use Euler's criterion below -- See Niven's book, Corr 2.38 */ ull r = modpow(a, (p-1)>>1, p); return r == 1 ? 1 : -1; } /* returns an x < p such that x^2 == n (mod p), p is odd prime. Returns 0 if n is not a quardratic res. If p is of form 4k+3, just returns n^((p+1)/4) (mod p). Else If p is of the form 4k+1, uses the Shanks-Tonelli algorithm alg. to find one. The other residue will of course be p-x. */ ull residuesolve(ull n, ull p) { if(legendre(n, p) == -1) return 0; else if(p%4 == 3) return modpow(n, (p+1)>>2, p); /* For reference to the Shanks-Tonelli alg. see: http://en.wikipedia.org/wiki/Shanks-Tonelli_algorithm Also Niven's book, pp. 100-120 */ ull q, r, s, w, v, i, j, t, jj; ull y, b; /* p-1 = q*2^s, odd q */ for(s = 0, q = p-1; !(q&1); ++s, q >>= 1) ; /* Pick w such that Legendre(w, p) == -1 */ for(w = 2; legendre(w, p) == 1; ++w) ; v = modpow(w, q, p); r = modpow(n, (q+1)>>1, p); t = modpow(n, q, p); /* We maintain t = r^2*n^(-1) */ while(true) { // printf("p=%u n=%u w=%u q=%u s=%u r=%u t=%u\n",p,n,w,q,s,r,t); /* y = t^{2^i} */ for(i = 0, y = t; y != 1; ++i) y = (y*y)%p; if(i == 0) break; /* b = v^{2^{s-i-1}} */ for(b = v, j = 0, jj = s-i-1; j < jj; ++j) b = (b*b)%p; /* New r' = r*b, new t = t*b^2 */ r = (r*b)%p; b = (b*b)%p, t = (t*b)%p; /* We've reestablished t = r^2*n^(-1) */ } return r; } sll gcd_extended(sll a, sll b, sll * px, sll * py) { sll x = 0, y = 1, u = 1, v = 0, m = 0, n = 0, r = 0, q = 0; while(a) { q = b/a, r = b%a; m = x - u*q, n = y - v*q; b = a, a = r, x = u, y = v, u = m, v = n; } *px = x, *py = y; return b; } sll mod_inv(sll a, sll p) { // Returns b such that a*b = 1 (mod p) sll x, y; sll q = gcd_extended(a, p, &x, &y); sll r = (x+p)%p; return r; } unsigned naive_solve(sll a, sll b, sll c, sll p, sll * sols) { unsigned cnt = 0; for(sll j = 0; j < p; ++j) { if((a*j*j + b*j + c)%p == 0) sols[cnt++] = j; } return cnt; } void solve(sll a, sll b, sll c, sll p) { unsigned num_sols = 0; sll sols[3]; if(c == 0) sols[num_sols++] = 0; if(p == 2) { if((a+b+c)%2 == 0) { // 1 is a solution sols[num_sols++] = 1; } } else { // p is an odd prime ... now we're talkin' something'! // Refer: // <http://planetmath.org/encyclopedia/QuadraticCongruence.html> // q = b^2 - 4ac // y=2ax+b // y^2 === q (mod p) sll q = ((b*b)%p - (4*a*c)%p + p + p)%p; sll y; if(q == 0) { // y = 0 is the only solution // That is, x = (p-b)/2a. sll w = ((p-b)*mod_inv(2*a, p))%p; if(w != 0) sols[num_sols++] = w; } else if((y = residuesolve((ull)q, (ull)p)) != 0) { // Roots are +y and p-y (these are not congruent) sll w = ((y+p-b)*mod_inv(2*a, p))%p; if(w != 0) sols[num_sols++] = w; w = ((p+p-y-b)*mod_inv(2*a, p))%p; if(w != 0) sols[num_sols++] = w; } else { // q is non-residue // No further solutions possible! } } std::sort(sols, sols + num_sols); for(unsigned j = 0; j < num_sols; ++j) { sll x = sols[j]; } // Testing #if ENABLE_TESTING sll naive_sols[32]; sll num_naive_sols = naive_solve(a, b, c, p, naive_sols); for(unsigned j = 0; j < num_sols; ++j) { } #endif } int main(int argc, char *argv[]) { #if ENABLE_TESTING // Test harness unsigned t = 10000; for(unsigned j = 0; j < t; ++j) { sll a, b, c, p; if(a == 0) ++a; solve(a, b, c, p); } #else // The real world! unsigned t; for(unsigned j = 0; j < t; ++j) { sll a, b, c, p; solve(a, b, c, p); } #endif return 0; }
Comments


Thanks for the explanation within the code , Ashutosh. Appreciate it.