package net.sf.distrib_rsa.cryptosystems;

import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.Vector;

/**
 * @author lippold Published under the GPLv2 Licence (c) 2006 Georg Lippold
 * 
 */
public class PrimeUtils {

	// better to read like this
	private static final BigInteger ZERO = BigInteger.ZERO;

	private static final BigInteger ONE = BigInteger.ONE;

	private static final BigInteger TWO = BigInteger.valueOf(2);

	private static final BigInteger THREE = BigInteger.valueOf(3);

	private static final BigInteger FOUR = BigInteger.valueOf(4);

	private static final BigInteger FIVE = BigInteger.valueOf(5);

	private static final BigInteger SEVEN = BigInteger.valueOf(7);

	private static final BigInteger EIGHT = BigInteger.valueOf(8);

	/**
	 * Computes the Legendre symbol of a and p.
	 * 
	 * @param a
	 *            the number to test
	 * @param p
	 *            the prime p
	 * @return the Legendre symbol -1, 0 or 1. See also Jacobi Symbol.
	 * @see #jacobiSymbol(BigInteger, BigInteger)
	 */
	public static int legendreSymbol(final BigInteger a, final BigInteger p) {
		BigInteger tmp = a;
		tmp = tmp.mod(p);
		final BigInteger sym = tmp.modPow((p.subtract(ONE)).divide(TWO), p);
		if (sym.equals(ZERO)) {
			return 0;
		}
		if (sym.equals(ONE)) {
			return 1;
		}
		// mod in Java always returns positive integers
		if (sym.equals(p.subtract(ONE)) || (sym.mod(p)).equals(p.subtract(ONE))) {
			return -1;
		}
		// shouldn't get here
		throw new IllegalArgumentException(
				"probably the second argument is not prime");
	}

	/**
	 * Returns the Jacobi symbol (a over p):
	 * 
	 * <li> -1 if a is not quadratic residue mod p </li>
	 * <li> 0 if a is multiple of p </li>
	 * <li> 1 if a is quadratic residue mod p </li>
	 * 
	 * @param a
	 *            the number to test
	 * @param p
	 *            the prime or uneven number the test is run against.
	 * @return the JacobiSymbol -1, 0 or 1.
	 */
	public static int jacobiSymbol(final BigInteger a, final BigInteger p) {

		if (a.equals(ZERO)) {
			return 0;
		}
		if (a.equals(ONE)) {
			return 1;
		}
		if (a.compareTo(p) > 0) {
			return jacobiSymbol(a.mod(p), p);
		}
		if (a.mod(FOUR).equals(ZERO)) {
			return jacobiSymbol(a.divide(FOUR), p);
		}
		if (a.mod(TWO).equals(ZERO)
				&& (p.mod(EIGHT).equals(ONE) || p.mod(EIGHT).equals(SEVEN))) {
			return jacobiSymbol(a.divide(TWO), p);
		}
		if (a.mod(TWO).equals(ZERO)
				&& (p.mod(EIGHT).equals(THREE) || p.mod(EIGHT).equals(FIVE))) {
			return (-1) * jacobiSymbol(a.divide(TWO), p);
		}
		if (a.mod(FOUR).equals(ONE) || p.mod(FOUR).equals(ONE)) {
			return jacobiSymbol(p.mod(a), a);
		}
		if (a.mod(FOUR).equals(THREE) || p.mod(FOUR).equals(THREE)) {
			return (-1) * jacobiSymbol(p.mod(a), a);
		}
		// should never get here
		throw new IllegalArgumentException(
				"probably the second argument is even");
	}

	/**
	 * Computes the integer x that is expressed through the given primes and the
	 * congruences with the chinese remainder theorem (CRT).
	 * 
	 * @param congruences
	 *            the congruences c_i
	 * @param primes
	 *            the primes p_i
	 * @return an integer x for that x % p_i == c_i
	 */
	public static BigInteger chineseRemainder(final Vector congruences,
			final Vector primes) {
		BigInteger retval = ZERO;
		BigInteger all = ONE;
		for (int i = 0; i < primes.size(); i++) {
			all = all.multiply((BigInteger) primes.elementAt(i));
		}
		BigInteger a;
		BigInteger b;
		BigInteger b_;
		BigInteger tmp;
		for (int i = 0; i < primes.size(); i++) {
			a = (BigInteger) primes.elementAt(i);
			b = all.divide(a);
			b_ = b.modInverse(a);
			tmp = b.multiply(b_);
			tmp = tmp.multiply((BigInteger) congruences.elementAt(i));
			retval = retval.add(tmp).mod(all);
		}
		return retval;
	}

	/**
	 * Faster modPow if modulus = p_1*p_2*... with p_1 ... p_n coprime.
	 * @param base The base that shall be exponentiated
	 * @param exp The exponent
	 * @param modulus The coprime integers whose product is the modulus
	 * @return <tt>base<sup>exp</sup> mod Prod(modulus)</tt>
	 */
	public static BigInteger chineseModPow(final BigInteger base, final BigInteger exp, final BigInteger[] modulus){
		Vector pq = new Vector();
		Vector results = new Vector();
		BigInteger comp;
		for(int i = 0; i<modulus.length; i++){
			pq.add(modulus[i]);
			comp = base.modPow(exp, modulus[i]);
			results.add(comp);
		}
		return chineseRemainder(results, pq);
	}
	
	/**
	 * Returns a random BigInteger that is smaller than upperLimit.
	 * 
	 * @param upperLimit
	 *            The upper limit for the return value
	 * @param rand
	 *            The SecureRandom to use
	 * @return a BigInteger that is random and smaller than upperLimit.
	 */
	public static BigInteger getRandom(final BigInteger upperLimit,
			final SecureRandom rand) {
		BigInteger retval = new BigInteger(upperLimit.bitLength(), rand);
		while (retval.compareTo(upperLimit) >= 0) {
			retval = new BigInteger(upperLimit.bitLength(), rand);
		}
		return retval;
	}

	/**
	 * Gets a random element from the twisted Group of N, T<sub>N</sub> = ( Z<sub>N</sub>[x] /
	 * (x<sup>2</sup> + 1) )<sup>*</sup> / Z<sub>N</sub><sup>*</sup>
	 * 
	 * @param upperLimit
	 *            the N
	 * @param rand
	 *            the SecureRandom to use
	 * @return a random element from the twisted Group of N, { BigInteger a,
	 *         BigInteger b}.
	 */
	public static BigInteger[] getRandomTwistedElement(
			final BigInteger upperLimit, final SecureRandom rand) {
		BigInteger a = getRandom(upperLimit, rand);
		final BigInteger b = getRandom(upperLimit, rand);
		if (b.equals(ZERO)) {
			a = ONE;
		}
		final BigInteger[] retval = { a, b };
		return retval;
	}

	/**
	 * Adapted from http://www.cs.utsa.edu/~wagner/laws/fav_alg.html
	 * 
     * <pre>
     *    Inputs: integers x, Y = Yk Yk-1 ... Y1 Y0 (in binary)
     *    
     *    Output: integer z
     *     
     *    Algorithm:
     *    int exp(int x, int Y[], int k) {
     *      int y = 0, z = 1; 
     *      for (int i = k; i &gt;= 0; i--) {
     *          y = 2*y; 
     *          z = z*z; 
     *          if (Y[i] == 1) {
     *              y++;
     *              z = z*x;
     *          } 
     *      }
     *      return z;
     *    }
     * </pre>
	 * 
	 * @param twistedElement
	 *            the element to be exponentiated
	 * @param exp
	 *            the exponent
	 * @param modulus
	 *            the modulus
	 * @return twistedElement<sup>exp</sup> <tt>mod</tt> modulus
	 */
	public static BigInteger[] twistedModPow(final BigInteger[] twistedElement,
			final BigInteger exp, final BigInteger modulus) {
		if (twistedElement.length != 2) {
			throw new IllegalArgumentException(
					"Expecting twistedElement to be a BigInteger[2]");
		}
		BigInteger[] retval = { ONE, ZERO };
		for (int i = exp.bitLength() - 1; i >= 0; i--) {
			retval = twistedModMult(retval, retval, modulus);
			if (exp.testBit(i)) {
				retval = twistedModMult(retval, twistedElement, modulus);
			}
		}
		return retval;
	}

	/**
	 * Computes twistedElement <tt>mod</tt> modulus
	 * 
	 * @param twistedElement
	 *            the element
	 * @param modulus
	 *            the modulus
	 * @return { twistedElement[0].mod(modulus), twistedElement[1].mod(modulus) }
	 */
	public static BigInteger[] twistedMod(final BigInteger[] twistedElement,
			final BigInteger modulus) {
		if (twistedElement.length != 2) {
			throw new IllegalArgumentException(
					"Expecting twistedElement to be a BigInteger[2]");
		}
		final BigInteger[] retval = { twistedElement[0].mod(modulus),
				twistedElement[1].mod(modulus) };
		return retval;
	}

	/**
	 * Multiplies two BigIntegers from the twisted Group T<sub>N</sub>
	 * <tt>mod</tt> modulus
	 * 
	 * @param twistedA
	 *            The first twistedElement
	 * @param twistedB
	 *            The second twistedElement
	 * @param modulus
	 *            the modulus
	 * @return twistedA*twistedB = { twistedA[0]*twistedB[0] -
	 *         twistedA[1]*twistedB[1] , twistedA[0]*twistedB[1] +
	 *         twistedA[1]*twistedB[0] }
	 */
	public static BigInteger[] twistedModMult(final BigInteger[] twistedA,
			final BigInteger[] twistedB, final BigInteger modulus) {
		if ((twistedA.length != 2) || (twistedB.length != 2)) {
			throw new IllegalArgumentException(
					"Expecting twistedElement to be a BigInteger[2]");
		}
		final BigInteger a = twistedA[0];
		final BigInteger b = twistedA[1];
		final BigInteger c = twistedB[0];
		final BigInteger d = twistedB[1];
		BigInteger newB = a.multiply(d);
		newB = newB.add(b.multiply(c));
		newB = newB.mod(modulus);
		if (newB.equals(ZERO)) {
			final BigInteger[] retval = { ONE, ZERO };
			return retval;
		} else {
			BigInteger newA = a.multiply(c);
			newA = newA.subtract(b.multiply(d));
			newA = newA.mod(modulus);
			final BigInteger[] retval = { newA, newB };
			return retval;
		}
	}

	/**
	 * Adds two twisted Elements
	 * 
	 * @param twistedA
	 *            the first element
	 * @param twistedB
	 *            the second element
	 * @return twistedA + twistedB = {twistedA[0] + twistedB[0], twistedA[1] +
	 *         twistedB[1] }
	 */
	public static BigInteger[] twistedAdd(final BigInteger[] twistedA,
			final BigInteger[] twistedB) {
		if ((twistedA.length != 2) || (twistedB.length != 2)) {
			throw new IllegalArgumentException(
					"Expecting twistedElement to be a BigInteger[2]");
		}

		final BigInteger a = twistedA[0];
		final BigInteger b = twistedA[1];
		final BigInteger c = twistedB[0];
		final BigInteger d = twistedB[1];

		final BigInteger[] retval = { a.add(c), b.add(d) };
		return retval;

	}

	/**
	 * Inverts a twistedElement
	 * 
	 * @param twistedElement
	 *            the element to invert
	 * @return the inverted of the element: {element[0], -element[1] }
	 */
	public static BigInteger[] twistedInv(final BigInteger[] twistedElement) {
		if (twistedElement.length != 2) {
			throw new IllegalArgumentException(
					"Expecting twistedElement to be a BigInteger[2]");
		}
		final BigInteger[] retval = { twistedElement[0], twistedElement[1].negate() };
		return retval;
	}

	/**
	 * Gets the first n primes with a sieve of Erathostenes.
	 * 
	 * @param count
	 *            the number of primes to find
	 * @return a integer array containing the first <tt>count</tt> primes
	 *         starting with 2
	 */
	public static int[] getFirstPrimes(final int count) {
		int pos = 0;
		long testRange = 4;
		int sqrPos = 0;
		final int[] retval = new int[count];
		retval[0] = 2;
		int current = 3;
		boolean isPrime = true;
		while (pos < count - 1) {
			if (testRange <= current) {
				sqrPos++;
				testRange = retval[sqrPos] * retval[sqrPos];
			}
			isPrime = true;
			for (int i = 0; i < sqrPos; i++) {
				if (current % retval[i] == 0) {
					isPrime = false;
					break;
				}
			}
			if (isPrime) {
				pos++;
				retval[pos] = current;
			}
			current++;
		}
		return retval;
	}

}
