/**
 * 
 */
package net.sf.distrib_rsa.protocols.computeD;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.math.BigInteger;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SecureRandom;

import net.sf.distrib_rsa.cryptosystems.PrimeUtils;

import org.apache.log4j.Logger;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.generators.RSAKeyPairGenerator;
import org.bouncycastle.crypto.params.RSAKeyGenerationParameters;
import org.bouncycastle.crypto.params.RSAKeyParameters;
import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters;

/**
 * @author lippold
 * 
 */
public class ComputeKeyClient extends ComputeKey {

	private static final Logger log = Logger.getLogger(ComputeKeyClient.class);

	/**
	 * @uml.property name="certainty"
	 */
	private final int certainty;

	/**
	 * @uml.property name="key"
	 * @uml.associationEnd
	 */
	private RSAPrivateCrtKeyParameters key;

	/**
	 * @uml.property name="y"
	 */
	BigInteger y;

	public ComputeKeyClient(final BigInteger p, final BigInteger q,
			final BigInteger rsaN, final ObjectInputStream in,
			final ObjectOutputStream out, final String remIP,
			final int primeCertainty, final SecureRandom rand)
			throws NoSuchAlgorithmException, NoSuchProviderException,
			IOException {
		super(p, q, rsaN, in, out, remIP, rand);
		certainty = primeCertainty;

		phi = p.add(q).negate();

		log.debug("Key generation client set up");
	}

	public void run() {
		log.info(REMOTE_IP + "starting RSA key generation");

		int command = -1;
		try {
			// Do not make pubExp too small to avoid SideChannel
			// Attacks
			exp = PrimeUtils.getRandom(rsaN, rand);
			out.write(KeyProtocol.PUB_EXP);
			out.writeObject(exp);
			out.flush();

			while (running && ((command = in.read()) != -1)) {

				switch (command) {

				case KeyProtocol.PUB_EXP:
					log.debug("Got PUB_EXP");
					if (super.actualState == null) {
						super.actualState = states[0];
						if (setE()) {
							super.actualState = states[1];
							setupOblTransfers(exp);
						}
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "PUB_EXP");
					}
					out.flush();
					break;

				case KeyProtocol.START_OBL_TRANSFER:
					log.debug(REMOTE_IP + "Got START_OBL_TRANSFER");
					if (super.actualState.equals(states[1])) {
						super.actualState = states[2];
						// Start Oblivious transfers to share (phi(rsaN) mod e)
						// multiplicatively

						// First get r \in e uniformly at random
						BigInteger r_e = null;
						BigInteger r_e_inv = null;
						do {
							try {
								r_e = PrimeUtils.getRandom(exp, rand);
								r_e_inv = r_e.modInverse(exp);
							} catch (final Exception e) {
								// not everything is invertible
							}
						} while (r_e_inv == null);
						y = startObliviousTransfer(exp, r_e);
						multShare_Phi_n_mod_e = r_e_inv;
						zeta = r_e;
						out.write(KeyProtocol.PHI_MOD_E_FINISHED);
					} else {
						log
								.fatal(REMOTE_IP + PROTO_VIOL
										+ "START_OBL_TRANSFER");
					}
					out.flush();
					break;

				case KeyProtocol.PHI_MOD_E_FINISHED:
					log.debug(REMOTE_IP + "Got PHI_MOD_E_FINISHED");
					if (super.actualState.equals(states[2])) {
						super.actualState = states[3];
						out.write(KeyProtocol.Y_ADD_PHI_R);
						out.writeObject(y.add(zeta.multiply(phi)).mod(exp));
					} else {
						log
								.fatal(REMOTE_IP + PROTO_VIOL
										+ "PHI_MOD_E_FINISHED");
					}
					out.flush();
					break;

				case KeyProtocol.Y_ADD_PHI_R:
					log.debug(REMOTE_IP + "Got Y_ADD_PHI_R");
					if (super.actualState.equals(states[3])) {
						super.actualState = states[4];
						// The other party inverts its zeta, no need to do it
						// here
						out.write(KeyProtocol.COMPUTE_PSI);
					} else {
						log
								.fatal(REMOTE_IP + PROTO_VIOL
										+ "PHI_MOD_E_FINISHED");
					}
					out.flush();
					break;

				case KeyProtocol.COMPUTE_PSI:
					log.debug(REMOTE_IP + "Got COMPUTE_PSI");
					if (super.actualState.equals(states[4])) {
						super.actualState = states[5];
						y = startObliviousTransfer(exp, zeta);
						// psi = PrimeUtils.getRandom(exp, rand);
						psi = y;
						out.write(KeyProtocol.COMPUTE_PSI_FINISHED);
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "COMPUTE_PSI");
					}
					out.flush();
					break;

				case KeyProtocol.COMPUTE_PSI_FINISHED:
					log.debug(REMOTE_IP + "Got COMPUTE_PSI_FINISHED");
					if (super.actualState.equals(states[5])) {
						super.actualState = states[6];
						out.write(KeyProtocol.Y_SUB_PSI);
						// out.writeObject(y.subtract(psi));
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL
								+ "COMPUTE_PSI_FINISHED");
					}
					out.flush();
					break;

				case KeyProtocol.Y_SUB_PSI:
					log.debug(REMOTE_IP + "Got Y_SUB_PSI");
					if (super.actualState.equals(states[6])) {
						super.actualState = states[7];
						ring = exp.multiply(BigInteger.valueOf(4)).multiply(
								rsaN);
						// ring > 4*e*2^(sigma)
						// ring is even (mult 4)
						ring = ring.add(BigInteger.ONE);
						final BigInteger two = BigInteger.valueOf(2);
						while (!ring.isProbablePrime(certainty)) {
							ring = ring.add(two);
						}

						out.write(KeyProtocol.COMPUTE_ALPHA);

						setupOblTransfers(ring);

					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "Y_SUB_PSI");
					}
					out.flush();
					break;

				case KeyProtocol.COMPUTE_ALPHA:
					log.debug(REMOTE_IP + "Got COMPUTE_ALPHA");
					if (super.actualState.equals(states[7])) {
						super.actualState = states[8];

						out.write(KeyProtocol.ALPHA1);
						out.flush();
						alpha1 = startObliviousTransfer(ring, psi);
						out.write(KeyProtocol.ALPHA2);
						out.flush();
						alpha2 = startObliviousTransfer(ring, phi);
						y = PrimeUtils.getRandom(ring.divide(BigInteger
								.valueOf(2)), rand);
						out.write(KeyProtocol.ALPHA1_ALPHA2_Y);
						final BigInteger phiPsi = phi.multiply(psi);
						out.writeObject((alpha1.add(alpha2).mod(ring)).add(y)
								.add(phiPsi));
						out.flush();
						d = y.negate().add(BigInteger.ONE);
						d = d.divide(exp);
						endThread();
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "COMPUTE_ALPHA");
					}
					out.flush();
					break;

				default:
					log.fatal(REMOTE_IP + PROTO_VIOL
							+ "Unknown command. Exiting.");
					successful = false;
					endThread();
					break;

				}

			}
		} catch (final Exception e) {
			log.error("Caught exception : ", e);
			successful = false;
		}

		if (successful) {
			super.generateRSAKey();
		} else {
			log.debug(REMOTE_IP + "Last command was " + command);
		}

		synchronized (this) {
			notifyAll();
		}
	}

	/**
	 * Prepares the RSA key pair needed to perform a oblivious transfer and
	 * sends it to the other party
	 * 
	 * @throws Exception
	 *             If oblivious transfers could not be prepared correctly.
	 * 
	 */
	private void setupOblTransfers(final BigInteger ring) throws Exception {
		out.write(KeyProtocol.OBLIVIOUS_SETUP);
		log.debug("Setting up oblivious transfers " + "for desired exponent");
		log.info("Generating RSA key pair in ring " + ring
				+ " for oblivious transfers");

		// Do not give pubExp too few 1-bits to avoid SideChannel Attacks
		BigInteger pubExp;
		do {
			pubExp = new BigInteger(ring.bitLength(), rand);
			if (pubExp.isProbablePrime(certainty)) {
				break;
			}
		} while (true);

		// Construct the RSA key that has size of the ring
		final RSAKeyGenerationParameters param = new RSAKeyGenerationParameters(
				pubExp, rand, ring.bitLength(), certainty);
		final RSAKeyPairGenerator generator = new RSAKeyPairGenerator();
		generator.init(param);
		final AsymmetricCipherKeyPair pair = generator.generateKeyPair();
		final RSAKeyParameters oblPubKey = (RSAKeyParameters) pair.getPublic();
		key = (RSAPrivateCrtKeyParameters) pair.getPrivate();

		// Write the data needed for the oblivious transfer to the server
		out.write(KeyProtocol.OBL_RING);
		out.writeObject(ring);
		out.write(KeyProtocol.OBL_MODULUS);
		out.writeObject(oblPubKey.getModulus());
		out.write(KeyProtocol.OBL_EXP);
		out.writeObject(oblPubKey.getExponent());
		out.flush();

		// verify that all went well
		if (in.read() == KeyProtocol.OBLIVIOUS_SETUP) {
			return;
		} else {
			String msg = REMOTE_IP + PROTO_VIOL + "expecting OBLIVIOUS_SETUP";
			log.fatal(msg);
			throw new Exception(msg);
		}
	}

	/**
	 * Does i 2-1 oblivious transfers to enable the other side to learn the
	 * needed bits of my value and returns my share of the computation.
	 * 
	 * @param myValue
	 *            the element to transfer
	 * @throws Exception
	 *             if something goes wrong
	 */
	private BigInteger startObliviousTransfer(final BigInteger ring,
			final BigInteger value) throws Exception {
		log.debug("starting Oblivious Transfer");

		// make sure that value is an element of the ring
		final BigInteger myValue = value.mod(ring);

		// initialize y, which will hold the sum of our m0
		BigInteger y = BigInteger.ZERO;

		// start the transfer
		out.write(KeyProtocol.OBL_START);
		out.flush();

		// verify other side is ready
		if (in.read() != KeyProtocol.OBL_START) {
			final String msg = REMOTE_IP + PROTO_VIOL + "Expecting OBL_START";
			log.fatal(msg);
			throw new Exception(msg);
		}

		// do ring.bitLength() oblivious transfers to convert an element
		for (int i = 0; i < ring.bitLength(); i++) {

			// inform server that an obl. transfer is coming
			out.write(KeyProtocol.OBL_TRANSFER);

			// prepare two random messages x0 and x1
			final BigInteger x0 = PrimeUtils.getRandom(ring, rand);
			final BigInteger x1 = PrimeUtils.getRandom(ring, rand);

			// and write them to the server
			out.write(KeyProtocol.OBL_X0);
			out.writeObject(x0);
			out.write(KeyProtocol.OBL_X1);
			out.writeObject(x1);
			out.flush();

			// server chooses one and adds his encrypted k. We read it
			if (in.read() == KeyProtocol.OBL_KX) {
				final BigInteger kx = (BigInteger) in.readObject();

				// and compute the two possible results
				final BigInteger k0 = rsaDecrypt(kx.subtract(x0));
				final BigInteger k1 = rsaDecrypt(kx.subtract(x1));

				// prepare our messages m_0 (random) and m_1 = 2^i * m_0
				final BigInteger m0 = PrimeUtils.getRandom(ring, rand);
				BigInteger m1 = BigInteger.valueOf(2).pow(i).multiply(myValue);
				m1 = m1.add(m0).mod(ring);
				y = y.add(m0).mod(ring);

				// add them to the two possible decryptions and send each to the
				// server
				out.write(KeyProtocol.OBL_M0K0);
				out.writeObject(k0.add(m0));
				out.write(KeyProtocol.OBL_M1K1);
				out.writeObject(k1.add(m1));
				out.flush();
				
				// server chooses correct one, one transfer is finished
			} else {
				final String msg = REMOTE_IP + PROTO_VIOL + "Expecting OBL_KX";
				log.fatal(msg);
				throw new Exception(msg);
			}
		}
		// finally, negate our sum
		y = y.negate();
		// and make sure that it is a ring element (the sum may be larger than the ring)
		y = y.mod(ring);
		log.debug("Oblivious Transfer finished");
		return y;
	}

	private boolean setE() throws IOException, ClassNotFoundException {
		// Set RSA exponent e
		final BigInteger remExp = (BigInteger) in.readObject();
		exp = exp.add(remExp).mod(rsaN);
		if ((exp.bitLength() <= (rsaN.bitLength() - 30))
				|| (exp.isProbablePrime(certainty) == false)) {
			log.debug(REMOTE_IP + "Public Exponent e too small or not prime");
			log.debug("Restarting Exponent");
			super.actualState = null;
			out.write(KeyProtocol.EXP_FAILED);
			exp = PrimeUtils.getRandom(rsaN, rand);
			out.write(KeyProtocol.PUB_EXP);
			out.writeObject(exp);
			out.flush();
			return false;
		}
		return true;
	}

	/**
	 * @return the y
	 * @uml.property name="y"
	 */
	protected BigInteger getY() {
		return y;
	}

	public BigInteger rsaDecrypt(final BigInteger input) {
		//
		// we have the extra factors, use the Chinese Remainder Theorem - the
		// author
		// wishes to express his thanks to Dirk Bonekaemper at rtsffm.com for
		// advice regarding the expression of this.
		//
		final RSAPrivateCrtKeyParameters crtKey = key;

		final BigInteger p = crtKey.getP();
		final BigInteger q = crtKey.getQ();
		BigInteger phi = p.subtract(BigInteger.ONE);
		phi = phi.multiply(q.subtract(BigInteger.ONE));
		final BigInteger dP = crtKey.getDP();
		final BigInteger dQ = crtKey.getDQ();
		final BigInteger qInv = crtKey.getQInv();

		BigInteger mP, mQ, h, m;

		// mP = ((input mod p) ^ dP)) mod p
		mP = (input.remainder(p)).modPow(dP, p);

		// mQ = ((input mod q) ^ dQ)) mod q
		mQ = (input.remainder(q)).modPow(dQ, q);

		// h = qInv * (mP - mQ) mod p
		h = mP.subtract(mQ);
		h = h.multiply(qInv);
		h = h.mod(p); // mod (in Java) returns the positive residual

		// m = h * q + mQ
		m = h.multiply(q);
		m = m.add(mQ);
		// m = m.mod(crtKey.getModulus());
		return m;

	}
}
