/**
 * 
 */
package net.sf.distrib_rsa.protocols.straubSharing;

import java.io.IOException;
import java.math.BigInteger;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.util.Enumeration;

import javax.net.ssl.SSLSocket;

import org.apache.log4j.Logger;

/**
 * @author lippold
 * 
 */
public class StraubClient extends StraubCommon {

	private final static Logger log = Logger
			.getLogger(StraubClient.class);

	public StraubClient(final SSLSocket socket) throws IOException,
			NoSuchAlgorithmException, NoSuchProviderException {
		super(socket);
	}

	public void run() {
		log.debug("starting distributed sieving with server:\n"
				+ super.remoteID);
		sendDesiredPrimeLength();
		int command = -1;
		try {
			while (running && ((command = in.read()) != -1)) {
				switch (command) {

				case StraubProtocol.PREFERRED_SHARED_KEY_SIZE:
					if (super.actualState == null) {
						log.debug("Got PREFERRED_SHARED_KEY_SIZE");
						super.actualState = states[0];
						final int serverRecommends = in.readInt();
						if (serverRecommends > preferredSharedKeySize) {
							log.fatal("Server wants bigger key size. " +
										"Not supported. Exiting.");
							successful = false;
							endThread();
						} else {
							preferredSharedKeySize = serverRecommends;
							log.debug("agreed on keysize "
									+ preferredSharedKeySize);
							sendNCSPubKey();
						}
					} else {
						log.fatal(PROTO_VIOL + "PREFERRED_SHARED_KEY_SIZE");
					}
					break;

				case StraubProtocol.REMOTE_NCS_PUB_KEY:
					if (super.actualState.equals(states[0])) {
						log.debug("Got REMOTE_NCS_PUB_KEY");
						super.actualState = states[1];
						readRemotePubKey();
						log.debug("Writing START_CANDIDATE_GENERATION");
						out.write(StraubProtocol.START_CANDIDATE_GENERATION);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "REMOTE_NCS_PUB_KEY");
					}
					break;

				case StraubProtocol.START_CANDIDATE_GENERATION:
					if (super.actualState.equals(states[1])
							|| super.actualState.equals(states[4])) {
						log.debug("Got START_CANDIDATE_GENERATION");
						super.actualState = states[2];
						super.populatePrimeLists();
						generatePrimeCandidate();
						log.debug("Writing GEN_TABLE_SUCCESSFUL");
						out.write(StraubProtocol.GEN_TABLE_SUCCESSFUL);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "START_CANDIDATE_GENERATION");
					}
					break;

				case StraubProtocol.GEN_TABLE_SUCCESSFUL:
					if (super.actualState.equals(states[2])) {
						log.debug("Got GEN_TABLE_SUCCESSFUL");
						super.actualState = states[3];
						candidate = super.crtCandidate(generatingPrimes);
						log.debug("Writing START_CANDIDATE_VERIFICATION");
						out.write(StraubProtocol.START_CANDIDATE_VERIFICATION);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "GEN_TABLE_SUCCESSFUL");
					}
					break;

				case StraubProtocol.START_CANDIDATE_VERIFICATION:
					if (super.actualState.equals(states[3])) {
						log.debug("Got START_CANDIDATE_VERIFICATION");
						super.actualState = states[4];
						verifyCandidate();
						log.debug("Writing VER_TABLE_SUCCESSFUL");
						out.write(StraubProtocol.VER_TABLE_SUCCESSFUL);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "START_CANDIDATE_VERIFICATION");
					}
					break;

				case StraubProtocol.VER_TABLE_SUCCESSFUL:
					if (super.actualState.equals(states[4])) {
						log.debug("Got VER_TABLE_SUCCESSFUL");
						super.actualState = states[5];
						candidate = super.crtCandidate(verifyingPrimes);
						log.debug("New prime found, fermat test pending:\n"
								+ candidate);
						log.debug("Writing PUBLISH_MQ");
						out.write(StraubProtocol.PUBLISH_MQ);
						myMQ = q.multiply(candidate.getMultiplicative().mod(
								mySigma));
						out.writeObject(myMQ);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "VER_TABLE_SUCCESSFUL");
					}
					break;

				case StraubProtocol.PUBLISH_MQ:
					if (super.actualState.equals(states[5])) {
						log.debug("Got PUBLISH_MQ");
						super.actualState = states[6];
						super.readRSAModulus();
						log.debug("Writing CONVERT_Q1_P2");
						out.write(StraubProtocol.CONVERT_Q1_P2);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "PUBLISH_MQ");
					}
					break;

				case StraubProtocol.CONVERT_Q1_P2:
					if (super.actualState.equals(states[6])) {
						log.debug("Got CONVERT_Q1_P2");
						super.actualState = states[7];
						super.resetMultToAddPrimes();
						convertQ1P2();
						log.debug("Writing CONVERT_Q1_P2_FINISHED");
						out.write(StraubProtocol.CONVERT_Q1_P2_FINISHED);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "CONVERT_Q1_P2");
					}
					break;

				case StraubProtocol.CONVERT_Q1_P2_FINISHED:
					if (super.actualState.equals(states[7])) {
						log.debug("Got CONVERT_Q1_P2_FINISHED");
						super.actualState = states[8];
						myA = crtCandidate(multToAddPrimes).getAdditive().mod(
								mySigma);
						log.debug("Writing A_MOD_2");
						out.write(StraubProtocol.A_MOD_2);
						out.write(myA.mod(BigInteger.valueOf(2)).intValue());
						out.flush();
					}
					break;

				case StraubProtocol.A_MOD_2:
					if (super.actualState.equals(states[8])) {
						log.debug("Got A_MOD_2");
						super.actualState = states[9];
						int subtractP2 = in.read()
								+ myA.mod(BigInteger.valueOf(2)).intValue();
						subtractP2 %= 2;
						myA = myA.subtract(mult2addProd.multiply(BigInteger
								.valueOf(subtractP2)));
						log.debug("Writing CONVERT_P1_Q1_Q2");
						out.write(StraubProtocol.CONVERT_P1_Q1_Q2);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "A_MOD_2");
					}
					break;

				case StraubProtocol.CONVERT_P1_Q1_Q2:
					if (super.actualState.equals(states[9])) {
						log.debug("Got CONVERT_P1_Q1_Q2");
						super.actualState = states[10];
						super.resetMultToAddPrimes();
						convertP1Q1Q2();
						log.debug("Writing CONVERT_P1_Q1_Q2_FINISHED");
						out.write(StraubProtocol.CONVERT_P1_Q1_Q2_FINISHED);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "CONVERT_P1_Q1_Q2");
					}
					break;

				case StraubProtocol.FERMAT_TEST:
					if (super.actualState.equals(states[9])) {
						super.actualState = states[10];
						super.fermatTest();
					} else {
						log.fatal(PROTO_VIOL + "FERMAT_TEST");
					}
					break;

				case StraubProtocol.TABLE_FAILED:
					if (super.actualState.equals(states[2])) {
						log.fatal("Server reports that his table " +
							"is incomplete although " +
							"all primes are through");
						log.fatal("Exiting");
						successful = false;
						endThread();
					} else {
						log.fatal(PROTO_VIOL + "TABLE_FAILED");
					}
					break;

				default:
					log.fatal(PROTO_VIOL + "Unknown command. Exiting.");
					successful = false;
					endThread();
					break;
				}
			}
		} catch (final Exception e) {
			successful = false;
			log.error("Caught exception : ", e);
		}
		log.info("finished distributed sieving with server:\n"
						+ remoteID);
	}

	private void sendDesiredPrimeLength() {
		try {
			log.debug("Writing PREFERRED_SHARED_KEY_SIZE");
			out.write(StraubProtocol.PREFERRED_SHARED_KEY_SIZE);
			log.debug("Sending Desired Prime Length of "
					+ preferredSharedKeySize);
			out.writeInt(preferredSharedKeySize);
			log.debug("Sent Desired Prime Length of " + preferredSharedKeySize);
			out.flush();
		} catch (final IOException e) {
			log.error(e);
		}
	}

	private void generatePrimeCandidate() throws Exception {
		log.debug("generating candidate");
		final Enumeration primes = super.generatingPrimes.keys();
		while (primes.hasMoreElements()) {
			final BigInteger key = (BigInteger) primes.nextElement();
			final int prime = key.intValue();
			int retval = 0;
			BigInteger add = null;
			BigInteger mult = null;
			while (retval == 0) {
				add = BigInteger.valueOf(rand.nextInt(prime));
				mult = addToMult(key, add);
				retval = mult.intValue();
				if (mySigma.divide(mult).equals(key)) {
					log.debug("found multiple of prime");
					retval = 0;
				}
				log.debug("prime is " + prime + " add is " + add + " mult is "
						+ mult);
			}
			log.debug("Found new multiplicative sharing for prime " + key
					+ ": add " + add + " mult " + mult);
			generatingPrimes
					.put(key, new AddMultRemainders(add, mult, mySigma));
		}
	}

	private void verifyCandidate() throws Exception {
		log.debug("verifying candidate");
		final Enumeration primes = super.verifyingPrimes.keys();
		while (primes.hasMoreElements()) {
			final BigInteger prime = (BigInteger) primes.nextElement();
			final BigInteger remainder = super.candidate.getAdditive().mod(prime);
			final BigInteger mult = addToMult(prime, remainder);
			if (mult.equals(BigInteger.ZERO)) {
				log.info("Detected that candidate is not prime");
				log.info("Resetting to " + states[1]);
				super.actualState = states[1];
				log.debug("Writing START_CANDIDATE_GENERATION");
				out.write(StraubProtocol.START_CANDIDATE_GENERATION);
				out.flush();
				return;
			} else {
				log.debug("Found new multiplicative sharing for prime " + prime
						+ ": add " + remainder + " mult " + mult);
				final AddMultRemainders amr = new AddMultRemainders(remainder, mult,
						mySigma);
				verifyingPrimes.put(prime, amr);
			}
		}
	}

	private void convertQ1P2() throws Exception {
		log.debug("Converting q1*p2 form multiplicative to additive sharing");
		final Enumeration primes = super.multToAddPrimes.keys();
		while (primes.hasMoreElements()) {
			final BigInteger prime = (BigInteger) primes.nextElement();
			final BigInteger remainder = q.mod(prime);
			final BigInteger add = multToAdd(prime, remainder);
			final AddMultRemainders amr = new AddMultRemainders(add, remainder,
					mySigma);
			log.debug("found new additive sharing for (q1*p2) mod prime "
					+ prime + ": add " + add + " mult " + remainder);
			multToAddPrimes.put(prime, amr);
		}
	}

	private void convertP1Q1Q2() throws Exception {
		log
				.debug("Converting (p1 + q1)q2 form multiplicative to additive sharing");
		final Enumeration primes = super.multToAddPrimes.keys();
		while (primes.hasMoreElements()) {
			final BigInteger prime = (BigInteger) primes.nextElement();
			final BigInteger p1Q1 = candidate.getMultiplicative().add(q);
			final BigInteger remainder = p1Q1.mod(prime);
			final BigInteger add = multToAdd(prime, remainder);
			final AddMultRemainders amr = new AddMultRemainders(add, remainder,
					mySigma);
			log.debug("found new additive sharing for (p1 + q1)*q2 mod prime "
					+ prime + ": add " + add + " mult " + remainder);
			multToAddPrimes.put(prime, amr);
		}
	}

	private BigInteger addToMult(final BigInteger prime, final BigInteger addSharing)
			throws Exception {
		BigInteger mult = null;
		out.write(StraubProtocol.ADD_TO_MULT);
		out.writeInt(prime.intValue());
		out.write(StraubProtocol.REMOTE_A);
		final byte[] encrypted = pubEngine.processData(addSharing.toByteArray());
		out.writeInt(encrypted.length);
		out.write(encrypted);
		out.flush();
		if (in.read() == StraubProtocol.ADD_TO_MULT) {
			final int remPrime = in.readInt();
			if (remPrime != prime.intValue()) {
				throw new Exception(
						"AddToMult Error: Not talking about the same prime");
			}
		} else {
			throw new Exception(PROTO_VIOL + "ADD_TO_MULT");
		}
		if (in.read() == StraubProtocol.COMPUTED_M) {
			final byte[] myM = new byte[in.readInt()];
			in.read(myM);
			mult = new BigInteger(1, privEngine.processData(myM));
		} else {
			throw new Exception(PROTO_VIOL + "COMPUTED_M");
		}
		return mult;
	}

	private BigInteger multToAdd(final BigInteger prime, final BigInteger multSharing)
			throws Exception {
		BigInteger add = null;
		out.write(StraubProtocol.MULT_TO_ADD);
		out.writeInt(prime.intValue());
		out.write(StraubProtocol.REMOTE_M);
		final byte[] encrypted = pubEngine.processData(multSharing.toByteArray());
		out.writeInt(encrypted.length);
		out.write(encrypted);
		out.flush();
		if (in.read() == StraubProtocol.MULT_TO_ADD) {
			final int remPrime = in.readInt();
			if (remPrime != prime.intValue()) {
				throw new Exception(
						"MultToAdd Error: not talking about the same prime");
			}
		} else {
			throw new Exception(PROTO_VIOL + "MULT_TO_ADD");
		}
		if (in.read() == StraubProtocol.COMPUTED_A) {
			final byte[] myA = new byte[in.readInt()];
			in.read(myA);
			add = new BigInteger(1, privEngine.processData(myA));
		} else {
			throw new Exception(PROTO_VIOL + "COMPUTED_A");
		}
		return add;
	}
}
