package net.sf.distrib_rsa.cryptosystems.naccacheStern;

import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;

import net.sf.distrib_rsa.cryptosystems.PrimeUtils;

import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.AsymmetricCipherKeyPairGenerator;
import org.bouncycastle.crypto.KeyGenerationParameters;

/**
 * Key generation parameters for NaccacheStern cipher. For details on this
 * cipher, please see
 * 
 * http://www.gemplus.com/smart/rd/publications/pdf/NS98pkcs.pdf
 */
public class NaccacheSternKeyPairGenerator implements
		AsymmetricCipherKeyPairGenerator {

	/**
	 * @uml.property  name="threads"
	 * @uml.associationEnd  multiplicity="(0 -1)" elementType="java.lang.Thread"
	 */
	private final Vector threads = new Vector();

	/**
	 * @uml.property  name="waitFor"
	 */
	private final Object waitFor = new Object();

	/**
	 * @uml.property  name="yParts"
	 * @uml.associationEnd  qualifier="access$0:java.math.BigInteger java.math.BigInteger"
	 */
	private final Hashtable yParts = new Hashtable();

	/**
	 * @uml.property  name="processorCount"
	 */
	private final int processorCount;

	/**
	 * @uml.property  name="gDivisible"
	 */
	private boolean gDivisible = false;

	/**
	 * @uml.property  name="a"
	 */
	private BigInteger a = null;

	/**
	 * @uml.property  name="b"
	 */
	private BigInteger b = null;

	/**
	 * @uml.property  name="u"
	 */
	private BigInteger u;

	/**
	 * @uml.property  name="v"
	 */
	private BigInteger v;

	/**
	 * @uml.property  name="p"
	 */
	private BigInteger p = null;

	/**
	 * @uml.property  name="q"
	 */
	private BigInteger q = null;

	/**
	 * @uml.property  name="p_"
	 */
	private BigInteger p_ = null;

	/**
	 * @uml.property  name="q_"
	 */
	private BigInteger q_ = null;

	// The original name g was changed to y, since in Benaloh and
	// Goldwasser-Micali it's always y
	/**
	 * @uml.property  name="y"
	 */
	private BigInteger y;

	/**
	 * @uml.property  name="n"
	 */
	private BigInteger n;

	/**
	 * @uml.property  name="sigma"
	 */
	private BigInteger sigma;

	/**
	 * @uml.property  name="phi_n"
	 */
	private BigInteger phi_n;

	/**
	 * @uml.property  name="debug"
	 */
	private boolean debug;

	/**
	 * @uml.property  name="strength"
	 */
	private int strength;

	/**
	 * @uml.property  name="rand"
	 */
	private SecureRandom rand;

	/**
	 * @uml.property  name="certainty"
	 */
	private int certainty;

	/**
	 * @uml.property  name="smallPrimes"
	 * @uml.associationEnd  multiplicity="(0 -1)" elementType="java.math.BigInteger"
	 */
	private Vector smallPrimes;

	/**
	 * Constructor for multi-processor systems. You can get the number of
	 * processors in your system via
	 * 
	 * Runtime.getRuntime().availableProcessors();
	 * 
	 * You need java >= 1.4 for this.
	 * 
	 * @param processorCnt
	 *            The number of CPUs to use.
	 */
	public NaccacheSternKeyPairGenerator(final int processorCnt) {
		processorCount = processorCnt;
	}

	/**
	 * Standard constructor, only one CPU is used.
	 * 
	 */
	public NaccacheSternKeyPairGenerator() {
		processorCount = 1;
	}

	/**
	 * Initializes the NaccacheSternKeyPairGenerator with the user supplied
	 * NaccacheSternKeyGenerationParameters.
	 * 
	 * @param param
	 *            The KeyGenerationParameters
	 * 
	 * @see NaccacheSternKeyGenerationParameters
	 * @see org.bouncycastle.crypto.AsymmetricCipherKeyPairGenerator#init(org.bouncycastle.crypto.KeyGenerationParameters)
	 */
	public void init(final KeyGenerationParameters param) {
		strength = param.getStrength();
		rand = param.getRandom();
		certainty = ((NaccacheSternKeyGenerationParameters) param)
				.getCertainty();
		debug = ((NaccacheSternKeyGenerationParameters) param).isDebug();
		if (debug) {
			System.out.println("Fetching first "
					+ ((NaccacheSternKeyGenerationParameters) param)
							.getCntSmallPrimes() + " primes.");
		}
		// Prepare the list of small primes for keyGeneration
		smallPrimes = findFirstPrimes(((NaccacheSternKeyGenerationParameters) param)
				.getCntSmallPrimes());
	}

	/**
	 * Generates a new NaccacheSternKeyPair using the number of processors given
	 * in the constructor and the parameters supplied in the init() method.
	 * 
	 * 
	 * @see org.bouncycastle.crypto.AsymmetricCipherKeyPairGenerator#generateKeyPair()
	 */
	public AsymmetricCipherKeyPair generateKeyPair() {

		// Permute the prime list individually
		smallPrimes = permuteList(smallPrimes, rand);

		// compute u and v
		u = BigInteger.ONE;
		v = BigInteger.ONE;

		for (int i = 0; i < smallPrimes.size() / 2; i++) {
			u = u.multiply((BigInteger) smallPrimes.get(i));
		}
		for (int i = smallPrimes.size() / 2; i < smallPrimes.size(); i++) {
			v = v.multiply((BigInteger) smallPrimes.get(i));
		}

		// the upper bound for messages, sigma
		sigma = u.multiply(v);

		// generate a and b (threaded)
		generateAB();

		if (debug) {
			System.out.println("generating p and q");
		}

		// generate p and q (threaded)
		generatePQ();

		// n and \phi(n)
		n = p.multiply(q);
		phi_n = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));

		if (debug) {
			System.out.println("generating y");
		}

		// compute y, NCS'98 calls it g (threaded)
		computeY();

		if (debug) {
			System.out.println();
			System.out.println("found new NaccacheStern cipher variables:");
			System.out.println("smallPrimes: " + smallPrimes);
			System.out.println("sigma:...... " + sigma + " ("
					+ sigma.bitLength() + " bits)");
			System.out.println("a:.......... " + a);
			System.out.println("b:.......... " + b);
			System.out.println("p':......... " + p_);
			System.out.println("q':......... " + q_);
			System.out.println("p:.......... " + p);
			System.out.println("q:.......... " + q);
			System.out.println("n:.......... " + n);
			System.out.println("phi(n):..... " + phi_n);
			System.out.println("y:.......... " + y);
			System.out.println();
		}

		return new AsymmetricCipherKeyPair(new NaccacheSternKeyParameters(
				false, y, n, sigma), new NaccacheSternPrivateKeyParameters(y,
				n, sigma, smallPrimes, p, q, debug, processorCount));
	}

	/**
	 * When called after init() and before generateKeyPair(), it is possible to
	 * generate a NCS-key that has this custom set of pairwise coprime integers
	 * (i.e. for every two different elements in primes: GCD(prime1, prime2) =
	 * 1).
	 * 
	 * If the product of these primes has bit length > strength/4, the strength
	 * is modified accordingly.
	 * 
	 * @param primes
	 *            The pairwise coprime integers that shall be used during
	 *            generation.
	 */
	public void setPrimes(final Vector primes) {
		sigma = BigInteger.ONE;
		for (int i = 0; i < primes.size(); i++) {
			sigma = sigma.multiply((BigInteger) smallPrimes.get(i));
		}
		// In section 3.1 of NaccacheStern98 they state that
		// sigma.bitLength() / n.bitLength() < 1/4.
		// So let's adjust n appropriatley:
		final int neededBitLength = sigma.bitLength() * 4 + 1;
		if (neededBitLength > strength) {
			// FIXME: optionally, one could also throw an
			// IllegalArgumentException,
			// but that would require more user interaction.
			if (debug) {
				System.out
						.println("Strength is not sufficient for given primes");
				System.out.println("adjusting to " + neededBitLength + ".");
			}
			strength = neededBitLength;
		}
		smallPrimes = primes;
	}

	/**
	 * Convenience method to compute 2*a*u*p_ +1
	 * 
	 * @param a
	 *            The "large prime" a
	 * @param u
	 *            The product of half of our small primes
	 * @param p_
	 *            The "tuning prime" p
	 * @return A BigInteger that is 2*a*u*p +1
	 */
	private static BigInteger computeP(final BigInteger a, final BigInteger u,
			final BigInteger p_) {
		return (((p_.multiply(BigInteger.valueOf(2))).multiply(a)).multiply(u))
				.add(BigInteger.ONE);
	}

	/**
	 * Generates a prime that has not less than the desired bitLength.
	 * 
	 * @param bitLength
	 *            The desired bit length for the new prime
	 * @param certainty
	 *            The certainty for the new prime
	 * @param rand
	 *            The source of randomness to use
	 * @return A prime that has bitLength bit.
	 */
	private static BigInteger generatePrime(final int bitLength,
			final int certainty, final SecureRandom rand) {
		BigInteger p_ = new BigInteger(bitLength, certainty, rand);
		while (p_.bitLength() != bitLength) {
			p_ = new BigInteger(bitLength, certainty, rand);
		}
		return p_;
	}

	/**
	 * Callback method for the returning threads computing y. Puts the result of
	 * the thread into the list from which y will be generated. Removes the
	 * thread from the list of running threads. Notifies thread starter that
	 * thread has finished.
	 * 
	 * @param t
	 *            The thread that finished.
	 */
	private void submitYPart(final ComputeYPart t) {
		synchronized (yParts) {
			yParts.put(t.smallPrime, t.yPart);
		}

		synchronized (threads) {
			threads.remove(t);
		}

		synchronized (waitFor) {
			waitFor.notifyAll();
		}
	}

	/**
	 * Callback method for OrderOfYThread. If a thread detects that y has order
	 * of his prime, the result is stored and all threads to be started are
	 * removed. Thus saves computing time.
	 * 
	 * @param result
	 *            The result of the OrderOfYThread
	 * @param t
	 *            The thread that finished.
	 */
	private void submitOrderTest(final boolean result, final Thread t) {
		if (result) {
			gDivisible = result;
			synchronized (threads) {
				threads.removeAllElements();
			}
		}
		synchronized (threads) {
			threads.remove(t);
		}
		synchronized (waitFor) {
			waitFor.notifyAll();
		}
	}

	/**
	 * Callback method for GeneratePQThread. The values for p, q, p' and q' are
	 * stored and the starter notified that the computation was successful.
	 * 
	 * @param p
	 *            The new p.
	 * @param q
	 *            The new q.
	 * @param p_
	 *            The new p'.
	 * @param q_
	 *            The new q'.
	 */
	private void submitPQ(final BigInteger p, final BigInteger q,
			final BigInteger p_, final BigInteger q_) {
		this.p = p;
		this.q = q;
		this.p_ = p_;
		this.q_ = q_;

		synchronized (waitFor) {
			waitFor.notifyAll();
		}
	}

	/**
	 * Generates a permuted ArrayList from the original one. The original List
	 * is not modified.
	 * 
	 * @param arr
	 *            the ArrayList to be permuted
	 * @param rand
	 *            the source of Randomness for permutation
	 * @return a new ArrayList with the permuted elements.
	 */
	private static Vector permuteList(final Vector arr, final SecureRandom rand) {
		final Vector retval = new Vector();
		final Vector tmp = new Vector();
		for (int i = 0; i < arr.size(); i++) {
			tmp.add(arr.get(i));
		}
		retval.add(tmp.remove(0));
		while (tmp.size() != 0) {
			retval.add(rand.nextInt(retval.size() + 1), tmp.remove(0));
		}
		return retval;
	}

	/**
	 * Starts threads to generate the primes a and b in the desired strength. As
	 * soon as one prime is found the other threads generating the same prime
	 * are stopped.
	 */
	private void generateAB() {

		// n = (2 a u p_ + 1 ) ( 2 b v q_ + 1)
		// -> |n| = strength
		// |2| = 1 in bits
		// -> |a| * |b| = |n| - |u| - |v| - |p_| - |q_| - |2| -|2|
		// remainingStrength = strength - sigma.bitLength() - p_.bitLength() -
		// q_.bitLength() - 1 -1
		final int remainingStrength = (strength - sigma.bitLength() - 48) / 2 + 1;

		if (processorCount == 1) {
			new GeneratePrimeA(remainingStrength).run();
			new GeneratePrimeB(remainingStrength).run();
		} else {
			final Vector threads = new Vector();
			for (int i = 0; i < processorCount / 2; i++) {
				final Thread t = new GeneratePrimeA(remainingStrength);
				threads.add(t);
			}
			for (int i = processorCount / 2; i < processorCount; i++) {
				final Thread t = new GeneratePrimeB(remainingStrength);
				threads.add(t);
			}
			synchronized (waitFor) {
				for (int i = 0; i < threads.size(); i++) {
					final Thread t = (Thread) threads.get(i);
					t.start();
				}
				while ((a == null) || (b == null)) {
					try {
						waitFor.wait();
						if (a != null) {
							// stop all threads generating a
							for (int i = 0; i < processorCount / 2; i++) {
								final GeneratePrimeA t = (GeneratePrimeA) threads
										.get(i);
								t.endThread();
							}
						}
						if (b != null) {
							for (int i = processorCount / 2; i < processorCount; i++) {
								// stop all threads generating b
								final GeneratePrimeB t = (GeneratePrimeB) threads
										.get(i);
								t.endThread();
							}
						}
					} catch (final InterruptedException e) {
					}
				}
			}
			for (int i = 0; i < threads.size(); i++) {
				final Thread t = (Thread) threads.get(i);
				try {
					t.join();
				} catch (final InterruptedException e) {
				}
			}
			threads.removeAllElements();
		}

	}

	/**
	 * Generates p and q. Starts as many threads as we have processors and
	 * collects them on return.
	 * 
	 */
	private void generatePQ() {
		// parallelize the generation of p and q
		for (int i = 0; i < processorCount; i++) {
			final Thread t = new GeneratePQThread();
			threads.add(t);
		}

		// wait for one thread to notify us of the correct primes
		synchronized (waitFor) {
			for (int i = 0; i < threads.size(); i++) {
				final Thread t = (Thread) threads.get(i);
				t.start();
			}
			while (p == null) {
				try {
					waitFor.wait();
				} catch (final InterruptedException e) {
				}
			}
		}

		// stop all threads and clear thread vector
		synchronized (threads) {
			for (int i = 0; i < threads.size(); i++) {
				final GeneratePQThread pqt = (GeneratePQThread) threads.get(i);
				pqt.endThread();
				try {
					pqt.join();
				} catch (final InterruptedException e) {
				}
			}
			threads.removeAllElements();
		}

	}

	/**
	 * Computes the public base y and checks that it fulfills our criteria.
	 * 
	 */
	private void computeY() {
		for (;;) {
			computeYParts();

			threadedYOrderTest();

			if (gDivisible) {
				continue;
			}

			// make sure that y has order > phi_n/4

			if (y.modPow(phi_n.divide(BigInteger.valueOf(4)), n).equals(
					BigInteger.ONE)) {
				if (debug) {
					System.out.println("y has order phi(n)/4\n y:" + y);
				}
				continue;
			}

			if (y.modPow(phi_n.divide(p_), n).equals(BigInteger.ONE)) {
				if (debug) {
					System.out.println("y has order phi(n)/p'\n y: " + y);
				}
				continue;
			}
			if (y.modPow(phi_n.divide(q_), n).equals(BigInteger.ONE)) {
				if (debug) {
					System.out.println("y has order phi(n)/q'\n y: " + y);
				}
				continue;
			}
			if (y.modPow(phi_n.divide(a), n).equals(BigInteger.ONE)) {
				if (debug) {
					System.out.println("y has order phi(n)/a\n y: " + y);
				}
				continue;
			}
			if (y.modPow(phi_n.divide(b), n).equals(BigInteger.ONE)) {
				if (debug) {
					System.out.println("y has order phi(n)/b\n y: " + y);
				}
				continue;
			}
			break;
		}
	}

	/**
	 * Computes y by starting as many threads as useful that generate parts of
	 * the final y. Finally computes y from its parts.
	 * 
	 */
	private void computeYParts() {
		yParts.clear();
		// Prepare threads that compute y
		for (int ind = 0; ind != smallPrimes.size(); ind++) {
			final BigInteger smallPrime = (BigInteger) smallPrimes
					.elementAt(ind);
			final Thread t = new ComputeYPart(smallPrime);
			synchronized (threads) {
				threads.add(t);
			}
		}

		final Vector runningThreads = new Vector();

		// wait for them to return
		synchronized (waitFor) {
			for (int i = 0; (i < threads.size()) && (i < processorCount); i++) {
				final Thread t = (Thread) threads.get(i);
				runningThreads.add(t);
				t.start();
			}
			while (threads.size() > 0) {
				try {
					waitFor.wait();
				} catch (final InterruptedException e) {
				}
				for (int i = 0; i < threads.size(); i++) {
					final Thread t = (Thread) threads.get(i);
					if (!runningThreads.contains(t)) {
						runningThreads.add(t);
						t.start();
						break;
					}
				}
			}
		}
		for (int i = 0; i < runningThreads.size(); i++) {
			final Thread t = (Thread) runningThreads.get(i);
			try {
				t.join();
			} catch (final InterruptedException e) {
			}
		}

		if (debug) {
			System.out.println("all threads for generating y finished");
		}

		// compute y from them
		y = BigInteger.ONE;
		final Enumeration en = yParts.keys();
		while (en.hasMoreElements()) {
			final BigInteger smallPrime = (BigInteger) en.nextElement();
			final BigInteger gPart = (BigInteger) yParts.get(smallPrime);
			y = y.multiply((gPart).modPow(sigma.divide(smallPrime), n)).mod(n);
		}

	}

	/**
	 * Tests for every p_i in our smallPrimes, that y<sup>(phi(n)/p_i)</sup> !=
	 * 1. Starts as many threads as we have processors in parallel and waits for
	 * them to return.
	 */
	private void threadedYOrderTest() {
		// make sure that y is not divisible by p_i or q_i
		gDivisible = false;
		for (int i = 0; i < smallPrimes.size(); i++) {
			// Usually (>99%) the test returns false, thus running it in
			// parallel increases speed on multi-processor platforms

			// prepare all threads
			final Thread t = new OrderOfYTest((BigInteger) smallPrimes.get(i));
			threads.add(t);
		}
		final Vector runningThreads = new Vector();

		synchronized (waitFor) {
			// start as many as needed
			for (int i = 0; (i < threads.size()) && (i < processorCount); i++) {
				final Thread t = (Thread) threads.get(i);
				runningThreads.add(t);
				t.start();
			}

			while (threads.size() > 0) {
				try {
					waitFor.wait();
				} catch (final InterruptedException e) {
				}
				for (int i = 0; i < threads.size(); i++) {
					final Thread t = (Thread) threads.get(i);
					if (!runningThreads.contains(t)) {
						runningThreads.add(t);
						t.start();
						break;
					}
				}
			}
		}

		for (int i = 0; i < runningThreads.size(); i++) {
			final Thread t = (Thread) runningThreads.get(i);
			try {
				t.join();
			} catch (final InterruptedException e) {
			}
		}
		runningThreads.removeAllElements();
	}

	/**
	 * Finds the first 'count' primes starting with 3
	 * 
	 * @param count
	 *            the number of primes to find
	 * @return a vector containing the found primes as Integer
	 */
	private static Vector findFirstPrimes(final int count) {
		final Vector primes = new Vector(count);

		final int[] smallPrimes = PrimeUtils.getFirstPrimes(count + 1);
		for (int i = 1; i != count + 1; i++) {
			primes.addElement(BigInteger.valueOf(smallPrimes[i]));
		}

		return primes;
	}

	/**
	 * Generates P and Q for encryption
	 * 
	 * @author lippold Published under the GPLv2 Licence (c) 2006 Georg Lippold
	 * 
	 */
	class GeneratePQThread extends Thread {

		boolean running = true;

		GeneratePQThread() {
			super();
		}

		public void run() {
			BigInteger p_, q_, p, q;
			while (running) {
				p_ = generatePrime(24, certainty, rand);
				q_ = generatePrime(24, certainty, rand);
				p = computeP(a, u, p_);
				q = computeP(b, v, q_);
				if (p_.equals(q_)) {
					// System.out.println("p_ == q_ : " + p_ + q_);
					continue;
				}
				if (!sigma.gcd(p_.multiply(q_)).equals(BigInteger.ONE)) {
					// System.out.println("sigma.gcd(p_.mult(q_)) != 1!\n p_: "
					// + p_
					// +"\n q_: "+ q_ );
					continue;
				}
				if (!p.isProbablePrime(certainty)) {
					// System.out.println("p is not prime: " + p);
					continue;
				}
				if (!q.isProbablePrime(certainty)) {
					// System.out.println("q is not prime: " + q);
					continue;
				}
				if (p.multiply(q).bitLength() < strength) {
					if (debug) {
						System.out.println("key size too small. Should be "
								+ strength + " but is actually "
								+ p.multiply(q).bitLength());
					}
					continue;
				}
				submitPQ(p, q, p_, q_);
				running = false;
			}

		}

		public void endThread() {
			running = false;
		}

	}

	/**
	 * Computes a BigInteger to assemble Y from.
	 * 
	 */
	class ComputeYPart extends Thread {

		private final BigInteger smallPrime;

		private BigInteger yPart;

		ComputeYPart(final BigInteger smallPrime) {
			super();
			this.smallPrime = smallPrime;
		}

		public void run() {
			if (debug) {
				System.out.println("computing yPart for " + smallPrime);
			}

			for (;;) {
				yPart = new BigInteger(strength, rand);
				if (yPart.modPow(phi_n.divide(smallPrime), n).equals(
						BigInteger.ONE)) {
					continue;
				}
				if (debug) {
					System.out.println("Prime " + smallPrime + " submitting "
							+ yPart);
				}
				submitYPart(this);
				break;
			}
		}
	}

	/**
	 * Tests if y has order of phi(n)/small prime.
	 * 
	 */
	class OrderOfYTest extends Thread {
		private final BigInteger smallPrime;

		OrderOfYTest(final BigInteger smallPrime) {
			super();
			this.smallPrime = smallPrime;
		}

		public void run() {
			if (y.modPow(phi_n.divide(smallPrime), n).equals(BigInteger.ONE)) {
				if (debug) {
					System.out.println("y has order phi(n)/" + smallPrime
							+ "\n y: " + y);
				}
				submitOrderTest(true, this);
			} else {
				if (debug) {
					System.out.println("Prime " + smallPrime + " finished.");
				}
				submitOrderTest(false, this);
			}
		}

	}

	/**
	 * Generates a prime with a specified number of bits.
	 * 
	 * @author lippold
	 */
	class GeneratePrimeA extends Thread {
		private boolean running = true;

		private final int bits;

		GeneratePrimeA(final int bits) {
			this.bits = bits;
		}

		public void run() {
			if (debug) {
				System.out.println("Generating a with " + bits
						+ " bit and certainty 2^(-" + certainty + ").");
			}

			BigInteger p_;
			do {
				p_ = new BigInteger(bits, rand);
				if (p_.bitLength() != bits) {
					continue;
				}
				if (p_.isProbablePrime(certainty)) {
					break;
				}
			} while (running);

			if (running) {
				a = p_;
				synchronized (waitFor) {
					waitFor.notifyAll();
				}
			}

		}

		public void endThread() {
			running = false;
		}

	}

	/**
	 * Generates a prime with a specified number of bits.
	 * 
	 * @author lippold
	 * 
	 */
	class GeneratePrimeB extends Thread {
		private boolean running = true;

		private final int bits;

		GeneratePrimeB(final int bits) {
			this.bits = bits;
		}

		public void run() {
			if (debug) {
				System.out.println("Generating b with " + bits
						+ " bit and certainty 2^(-" + certainty + ").");
			}

			BigInteger p_;
			do {
				p_ = new BigInteger(bits, rand);
				if (p_.bitLength() != bits) {
					continue;
				}
				if (p_.isProbablePrime(certainty)) {
					break;
				}
			} while (running);

			if (running) {
				b = p_;
				synchronized (waitFor) {
					waitFor.notifyAll();
				}
			}

		}

		public void endThread() {
			running = false;
		}

	}
}
