/**
 * 
 */
package net.sf.distrib_rsa.protocols.computeD;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.math.BigInteger;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SecureRandom;
import java.util.Vector;

import net.sf.distrib_rsa.cryptosystems.PrimeUtils;

import org.apache.log4j.Logger;

/**
 * @author lippold
 * 
 */
public class ComputeKeyServer extends ComputeKey {

	private static final Logger log = Logger.getLogger(ComputeKeyServer.class);

	/**
	 * @uml.property name="oblExp"
	 */
	private BigInteger oblExp;

	/**
	 * @uml.property name="oblModulus"
	 */
	private BigInteger oblModulus;

	/**
	 * @uml.property name="oblMessages"
	 * @uml.associationEnd multiplicity="(0 -1)"
	 *                     elementType="java.math.BigInteger"
	 */
	private final Vector oblMessages = new Vector();

	/**
	 * @uml.property name="o"
	 */
	private final Object o = new Object();

	/**
	 * @uml.property name="x"
	 */
	private BigInteger x;

	/**
	 * @uml.property name="certainty"
	 */
	private final int certainty;

	/**
	 * @uml.property name="actualBit"
	 */
	int actualBit = -13;

	private boolean oblSetupDone = false;

	public ComputeKeyServer(final BigInteger p, final BigInteger q,
			final BigInteger rsaN, final ObjectInputStream in,
			final ObjectOutputStream out, final String remIP,
			final int primeCertainty, final SecureRandom rand)
			throws NoSuchAlgorithmException, NoSuchProviderException,
			IOException {
		super(p, q, rsaN, in, out, remIP, rand);
		certainty = primeCertainty;

		phi = rsaN.subtract(p).subtract(q).add(BigInteger.ONE);

		log.debug("Key generation server set up");
	}

	public void run() {
		log.info(REMOTE_IP + "starting RSA key generation");

		int command = -1;
		try {

			while (running && ((command = in.read()) != -1)) {

				switch (command) {

				case KeyProtocol.PUB_EXP:
					log.debug("Got PUB_EXP");
					if (super.actualState == null) {
						setE();
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "PUB_EXP");
					}
					out.flush();
					break;

				case KeyProtocol.EXP_FAILED:
					log.debug("Got EXP_FAILED");
					if (super.actualState.equals(states[0])) {
						super.actualState = null;
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "EXP_FAILED");
					}
					out.flush();
					break;

				case KeyProtocol.OBLIVIOUS_SETUP:
					log.debug("Got OBLIVIOUS_SETUP");
					if (super.actualState.equals(states[0])) {
						super.actualState = states[1];
						obliviousSetup();
						out.write(KeyProtocol.START_OBL_TRANSFER);
						out.flush();
					} else if (super.actualState.equals(states[5])
							|| super.actualState.equals(states[8])
							|| super.actualState.equals(states[10])) {
						obliviousSetup();
						if (super.actualState.equals(states[8])) {
							out.write(KeyProtocol.COMPUTE_ALPHA);
						}
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "OBLIVIOUS_SETUP");
					}
					out.flush();
					break;

				case KeyProtocol.OBL_START:
					log.debug("Got OBL_START");
					if ((super.actualState.equals(states[1])
							|| super.actualState.equals(states[5])
							|| super.actualState.equals(states[9]) || super.actualState
							.equals(states[10]))
							&& oblSetupDone) {
						actualBit = -1;
						oblMessages.clear();
						// Populate vector so that I can use Vector.set()
						for (int i = 0; i < ring.bitLength(); i++) {
							oblMessages.add(o);
						}
						out.write(KeyProtocol.OBL_START);
					} else {
						log.fatal(REMOTE_IP + PROTO_VIOL + "OBL_START");
					}
					out.flush();
					break;

				case KeyProtocol.OBL_TRANSFER:
					log.debug("Got OBL_TRANSFER");
					if (super.actualState.equals(states[1])) {
						obliviousTransfer(ring, phi);
					} else if (super.actualState.equals(states[5])) {
						obliviousTransfer(ring, zeta);
					} else if (super.actualState.equals(states[9])) {
						obliviousTransfer(ring, phi);
					} else if (super.actualState.equals(states[10])) {
						obliviousTransfer(ring, psi);
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "OBL_TRANSFER");
					}
					out.flush();
					break;

				case KeyProtocol.PHI_MOD_E_FINISHED:
					log.debug("Got PHI_MOD_E_FINISHED");
					if (super.actualState.equals(states[1])) {
						super.actualState = states[3];
						x = BigInteger.ZERO;
						for (int i = 0; i < oblMessages.size(); i++) {
							final Object o = oblMessages.get(i);
							if (!(o instanceof BigInteger)) {
								log.fatal("Not all necessary "
										+ "oblivious transfers succeeded");
								log.fatal("Exiting");
								successful = false;
								endThread();
								break;
							} else {
								x = x.add((BigInteger) o);
							}
						}
						x = x.mod(exp);
						// log.debug("x is " + x);
						out.write(KeyProtocol.PHI_MOD_E_FINISHED);
					} else {
						log
								.fatal(PROTO_VIOL + REMOTE_IP
										+ "PHI_MOD_E_FINISHED");
					}
					out.flush();
					break;

				case KeyProtocol.Y_ADD_PHI_R:
					log.debug("Got Y_ADD_PHI_R");
					if (super.actualState.equals(states[3])) {
						super.actualState = states[4];
						final BigInteger y_add_phi_r = (BigInteger) in
								.readObject();
						multShare_Phi_n_mod_e = x.add(y_add_phi_r);
						zeta = multShare_Phi_n_mod_e.modInverse(exp);
						// One party has to negate its zeta to have zetaA *
						// zetaB
						// = -phi(n)^(-1)
						zeta = zeta.negate().mod(exp);
						out.write(KeyProtocol.Y_ADD_PHI_R);
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "Y_ADD_PHI_R");
					}
					out.flush();
					break;

				case KeyProtocol.COMPUTE_PSI:
					log.debug("Got COMPUTE_PSI");
					if (super.actualState.equals(states[4])) {
						super.actualState = states[5];
						out.write(KeyProtocol.COMPUTE_PSI);
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "COMPUTE_PSI");
					}
					out.flush();
					break;

				case KeyProtocol.COMPUTE_PSI_FINISHED:
					log.debug("Got COMPUTE_PSI_FINISHED");
					if (super.actualState.equals(states[5])) {
						super.actualState = states[6];
						x = BigInteger.ZERO;
						for (int i = 0; i < oblMessages.size(); i++) {
							final Object o = oblMessages.get(i);
							if (!(o instanceof BigInteger)) {
								log.fatal("Not all necessary "
										+ "oblivious transfers succeeded");
								log.fatal("Exiting");
								successful = false;
								endThread();
								break;
							} else {
								x = x.add((BigInteger) o);
							}
						}
						x = x.mod(exp);
						out.write(KeyProtocol.COMPUTE_PSI_FINISHED);
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP
								+ "COMPUTE_PSI_FINISHED");
					}
					out.flush();
					break;

				case KeyProtocol.Y_SUB_PSI:
					log.debug("Got Y_SUB_PSI");
					if (super.actualState.equals(states[6])) {
						super.actualState = states[7];
						// final BigInteger y_sub_psiB = (BigInteger) in
						// .readObject();
						// psi = x.add(y_sub_psiB);
						// psi = psi.mod(exp);
						psi = x;
						out.write(KeyProtocol.Y_SUB_PSI);
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "Y_SUB_PSI");
					}
					out.flush();
					break;

				case KeyProtocol.COMPUTE_ALPHA:
					log.debug("Got COMPUTE_ALPHA");
					if (super.actualState.equals(states[7])) {
						super.actualState = states[8];
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "COMPUTE_ALPHA");
					}
					out.flush();
					break;

				case KeyProtocol.ALPHA1:
					log.debug("Got ALPHA1");
					if (super.actualState.equals(states[8])) {
						super.actualState = states[9];
						// Nothing to do
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "ALPHA1");
					}
					out.flush();
					break;

				case KeyProtocol.ALPHA2:
					log.debug("Got ALPHA2");
					if (super.actualState.equals(states[9])) {
						super.actualState = states[10];
						x = BigInteger.ZERO;
						for (int i = 0; i < oblMessages.size(); i++) {
							final Object o = oblMessages.get(i);
							if (!(o instanceof BigInteger)) {
								log.fatal("Not all necessary "
										+ "oblivious transfers succeeded");
								log.fatal("Exiting");
								successful = false;
								endThread();
								break;
							} else {
								x = x.add((BigInteger) o);
							}
						}
						x = x.mod(ring);
						alpha1 = x;
						// alpha1 = alpha1.mod(ring);
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "ALPHA2");
					}
					out.flush();
					break;

				case KeyProtocol.ALPHA1_ALPHA2_Y:
					log.debug("Got ALPHA1_ALPHA2_Y");
					if (super.actualState.equals(states[10])) {
						x = BigInteger.ZERO;
						for (int i = 0; i < oblMessages.size(); i++) {
							final Object o = oblMessages.get(i);
							if (!(o instanceof BigInteger)) {
								log.fatal("Not all necessary "
										+ "oblivious transfers succeeded");
								log.fatal("Exiting");
								successful = false;
								endThread();
								break;
							} else {
								x = x.add((BigInteger) o);
							}
						}
						x = x.mod(ring);
						alpha2 = x;
						final BigInteger phiPsi = phi.multiply(psi);
						x = (alpha1.add(alpha2).mod(ring)).add(phiPsi);
						final BigInteger yAlpha1Alpha2 = (BigInteger) in
								.readObject();
						x = x.add(yAlpha1Alpha2);
						x = x.mod(ring);
						d = x.divide(exp);
						endThread();
					} else {
						log.fatal(PROTO_VIOL + REMOTE_IP + "ALPHA1_ALPHA2_Y");
					}
					out.flush();
					break;

				default:
					log.fatal(REMOTE_IP + PROTO_VIOL
							+ "Unknown command. Exiting.");
					successful = false;
					endThread();
					break;

				}

			}
		} catch (final Exception e) {
			log.error("Caught exception : ", e);
			successful = false;
		}

		if (successful) {
			super.generateRSAKey();
		} else {
			log.debug(REMOTE_IP + "Last command was " + command);
		}

		synchronized (this) {
			notifyAll();
		}
	}

	/**
	 * @throws ClassNotFoundException
	 * @throws IOException
	 * 
	 */
	private void setE() throws IOException, ClassNotFoundException {
		super.actualState = states[0];
		final BigInteger remExp = (BigInteger) in.readObject();
		BigInteger myExp;
		do {
			myExp = PrimeUtils.getRandom(rsaN, rand);
			exp = remExp.add(myExp);
			if (exp.isProbablePrime(certainty)
					&& (exp.bitLength() > (rsaN.bitLength() - 30))) {
				break;
			}
		} while (true);
		out.write(KeyProtocol.PUB_EXP);
		out.writeObject(myExp);
		out.flush();
	}

	/**
	 * Receives the parameters needed for the following oblivious transfers from
	 * the client and informs the other side that paramters were received.
	 * 
	 * @throws Exception
	 *             If either the protocol is violated or the connection is
	 *             closed.
	 * 
	 */
	private void obliviousSetup() throws Exception {
		BigInteger n;
		BigInteger e;
		if (in.read() == KeyProtocol.OBL_RING) {
			ring = (BigInteger) in.readObject();
		} else {
			final String msg = REMOTE_IP + PROTO_VIOL + "Expecting OBL_RING";
			log.fatal(msg);
			throw new Exception(msg);
		}
		if (in.read() == KeyProtocol.OBL_MODULUS) {
			n = (BigInteger) in.readObject();
		} else {
			final String msg = REMOTE_IP + PROTO_VIOL + "Expecting OBL_MODULUS";
			log.fatal(msg);
			throw new Exception(msg);
		}
		if (in.read() == KeyProtocol.OBL_EXP) {
			e = (BigInteger) in.readObject();
		} else {
			final String msg = REMOTE_IP + PROTO_VIOL + "Expecting OBL_EXP";
			log.fatal(msg);
			throw new Exception(msg);
		}
		oblExp = e;
		oblModulus = n;
		oblSetupDone = true;
		out.write(KeyProtocol.OBLIVIOUS_SETUP);
		out.flush();
	}

	/**
	 * Does a Oblivious Transfer for one bit of <tt>value</tt>. In each
	 * transfer the bit is increased by one, so that in the end all bits have
	 * been transferred.
	 * 
	 * @throws Exception
	 *             If either the protocol is violated or the connection is
	 *             closed.
	 * 
	 */
	private void obliviousTransfer(final BigInteger ring, final BigInteger value)
			throws Exception {
		// make sure that value is in Z_N
		final BigInteger myValue = value.mod(ring);
		actualBit++;
		log.debug("Now executing bit " + actualBit + "/" + ring.bitLength());
		// construct private element
		BigInteger k = PrimeUtils.getRandom(ring, rand);

		// and make sure that it can be encrypted
		k = k.mod(oblModulus);
		// and encrypt it
		BigInteger kx = k.modPow(oblExp, oblModulus);

		// read the other parties preapared messages
		BigInteger m0;
		BigInteger m1;
		if (in.read() == KeyProtocol.OBL_X0) {
			m0 = (BigInteger) in.readObject();
		} else {
			final String msg = PROTO_VIOL + REMOTE_IP + "Expecting OBL_X0";
			log.fatal(msg);
			throw new Exception(msg);
		}
		if (in.read() == KeyProtocol.OBL_X1) {
			m1 = (BigInteger) in.readObject();
		} else {
			final String msg = PROTO_VIOL + REMOTE_IP + "Expecting OBL_X1";
			log.fatal(msg);
			throw new Exception(msg);
		}

		// and add the correct message to my encrypted value
		if (myValue.testBit(actualBit)) {
			kx = kx.add(m1);
		} else {
			kx = kx.add(m0);
		}

		// write the sum to the other party
		out.write(KeyProtocol.OBL_KX);
		out.writeObject(kx);
		out.flush();

		// and read the answers
		BigInteger m0k0;
		BigInteger m1k1;
		if (in.read() == KeyProtocol.OBL_M0K0) {
			m0k0 = (BigInteger) in.readObject();
		} else {
			final String msg = PROTO_VIOL + REMOTE_IP + "Expecting OBL_M0K0";
			log.fatal(msg);
			throw new Exception(msg);
		}
		if (in.read() == KeyProtocol.OBL_M1K1) {
			m1k1 = (BigInteger) in.readObject();
		} else {
			final String msg = PROTO_VIOL + REMOTE_IP + "Expecting OBL_M1K1";
			log.fatal(msg);
			throw new Exception(msg);
		}

		// choose the correct answer
		BigInteger m;
		if (myValue.testBit(actualBit)) {
			m = m1k1.subtract(k);
		} else {
			m = m0k0.subtract(k);
		}

		// and record it for later addition with other elements
		oblMessages.set(actualBit, m);
	}

	/**
	 * @return the x
	 * @uml.property name="x"
	 */
	protected BigInteger getX() {
		return x;
	}
}
