package net.sf.distrib_rsa.protocols.straubSharing;

import java.io.IOException;
import java.math.BigInteger;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.util.Enumeration;
import java.util.Hashtable;

import javax.net.ssl.SSLSocket;

import net.sf.distrib_rsa.EnvironmentSetup;

import org.apache.log4j.Logger;


/**
 * @author lippold Published under the GPLv2 Licence (c) 2006 Georg Lippold
 * 
 */
public class StraubServer extends StraubCommon {

	private static final Logger log = Logger
			.getLogger(StraubServer.class);

	/**
	 * @param socket
	 * @throws IOException
	 * @throws NoSuchProviderException
	 * @throws NoSuchAlgorithmException
	 */
	public StraubServer(final SSLSocket socket) throws IOException,
			NoSuchAlgorithmException, NoSuchProviderException {
		super(socket);
	}

	public void run() {
		log.debug("Starting distributed sieving with client " + remoteID);
		int command = -1;
		try {
			while (running && ((command = in.read()) != -1)) {
				// log.debug("Command is " + command);
				switch (command) {

				case StraubProtocol.PREFERRED_SHARED_KEY_SIZE:
					if (super.actualState == null) {
						log.debug("Got PREFERRED_SHARED_KEY_SIZE");
						super.actualState = states[0];
						agreeOnPrimeLength();
					} else {
						log.fatal(PROTO_VIOL + "PREFERRED_SHARED_KEY_SIZE");
					}
					break;

				case StraubProtocol.REMOTE_NCS_PUB_KEY:
					if (super.actualState.equals(states[0])) {
						log.debug("Got REMOTE_NCS_PUB_KEY");
						super.actualState = states[1];
						super.readRemotePubKey();
						super.sendNCSPubKey();
					} else {
						log.fatal(PROTO_VIOL + "REMOTE_NCS_PUB_KEY");
					}
					break;

				case StraubProtocol.START_CANDIDATE_GENERATION:
					if (super.actualState.equals(states[1])
							|| super.actualState.equals(states[4])) {
						log.debug("Got START_CANDIDATE_GENERATION");
						super.actualState = states[2];
						super.populatePrimeLists();
						log.debug("Writing START_CANDIDATE_GENERATION");
						out.write(StraubProtocol.START_CANDIDATE_GENERATION);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "START_CANDIDATE_GENERATION");
					}
					break;

				case StraubProtocol.ADD_TO_MULT:
					if (super.actualState.equals(states[2])
							|| super.actualState.equals(states[4])) {
						final int prime = in.readInt();
						// log.debug("sharing for prime " + prime);
						if (in.read() == StraubProtocol.REMOTE_A) {
							computeRemoteM(prime);
						} else {
							throw new Exception(PROTO_VIOL + "REMOTE_A");
						}
					} else {
						log.fatal(PROTO_VIOL + "ADD_TO_MULT");
					}
					break;

				case StraubProtocol.GEN_TABLE_SUCCESSFUL:
					if (super.actualState.equals(states[2])) {
						log.debug("Got GEN_TABLE_SUCCESSFUL");
						super.actualState = states[3];
						if (checkCandidateTable(generatingPrimes)) {
							super.candidate = super
									.crtCandidate(generatingPrimes);
							log.debug("Writing GEN_TABLE_SUCCESSFUL");
							out.write(StraubProtocol.GEN_TABLE_SUCCESSFUL);
						} else {
							log.debug("Writing TABLE_FAILED");
							out.write(StraubProtocol.TABLE_FAILED);
							successful = false;
							endThread();
						}
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "GEN_TABLE_SUCCESSFUL");
					}
					break;

				case StraubProtocol.START_CANDIDATE_VERIFICATION:
					if (super.actualState.equals(states[3])) {
						log.debug("Got START_CANDIDATE_VERIFICATION");
						super.actualState = states[4];
						log.debug("Writing START_CANDIDATE_VERIFICATION");
						out.write(StraubProtocol.START_CANDIDATE_VERIFICATION);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "START_CANDIDATE_VERIFICATION");
					}
					break;

				case StraubProtocol.VER_TABLE_SUCCESSFUL:
					if (super.actualState.equals(states[4])) {
						log.debug("Received VER_TABLE_SUCCESSFUL");
						super.actualState = states[5];
						if (checkCandidateTable(verifyingPrimes)) {
							super.candidate = super
									.crtCandidate(verifyingPrimes);
							log.debug("Writing VER_TABLE_SUCCESSFUL");
							out.write(StraubProtocol.VER_TABLE_SUCCESSFUL);
						} else {
							log.debug("Writing TABLE_FAILED");
							out.write(StraubProtocol.TABLE_FAILED);
							successful = false;
							endThread();
						}
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "VER_TABLE_SUCCESSFUL");
					}
					break;

				case StraubProtocol.PUBLISH_MQ:
					if (super.actualState.equals(states[5])) {
						log.debug("Got PUBLISH_MQ");
						super.actualState = states[6];
						myMQ = q.multiply(candidate.getMultiplicative().mod(
								remSigma));
						readRSAModulus();
						log.debug("Writing PUBLISH_MQ");
						out.write(StraubProtocol.PUBLISH_MQ);
						out.writeObject(myMQ);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "PUBLISH_MQ");
					}
					break;

				case StraubProtocol.CONVERT_Q1_P2:
					if (super.actualState.equals(states[6])) {
						log.debug("Got CONVERT_Q1_P2");
						super.actualState = states[7];
						super.resetMultToAddPrimes();
						out.write(StraubProtocol.CONVERT_Q1_P2);
						log.debug("Writing CONVERT_Q1_P2");
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "CONVERT_Q1_P2");
					}
					break;

				case StraubProtocol.MULT_TO_ADD:
					if (super.actualState.equals(states[7])
							|| super.actualState.equals(states[10])) {
						final int prime = in.readInt();
						if (in.read() == StraubProtocol.REMOTE_M) {
							computeRemoteA(prime);
						} else {
							throw new Exception(PROTO_VIOL + "REMOTE_M");
						}
					} else {
						log.fatal(PROTO_VIOL + "MULT_TO_ADD");
					}
					break;

				case StraubProtocol.CONVERT_Q1_P2_FINISHED:
					if (super.actualState.equals(states[7])) {
						log.debug("Got CONVERT_Q1_P2_FINISHED");
						super.actualState = states[8];
						if (checkCandidateTable(multToAddPrimes)) {
							super.myA = super.crtCandidate(multToAddPrimes)
									.getAdditive().mod(remSigma);
							log.debug("Writing CONVERT_Q1_P2_FINISHED");
							out.write(StraubProtocol.CONVERT_Q1_P2_FINISHED);
						} else {
							log.debug("Writing TABLE_FAILED");
							out.write(StraubProtocol.TABLE_FAILED);
							successful = false;
							endThread();
						}
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "CONVERT_Q1_P2_FINISHED");
					}
					break;

				case StraubProtocol.A_MOD_2:
					if (super.actualState.equals(states[8])) {
						super.actualState = states[9];
						log.debug("Got A_MOD_2");
						final int myAmod2 = myA.mod(BigInteger.valueOf(2)).intValue();
						int subtractP2 = in.read() + myAmod2;
						subtractP2 %= 2;
						myA = myA.subtract(mult2addProd.multiply(BigInteger
								.valueOf(subtractP2)));
						log.debug("Writing A_MOD_2");
						out.write(StraubProtocol.A_MOD_2);
						out.write(myAmod2);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "A_MOD_2");
					}
					break;

				case StraubProtocol.CONVERT_P1_Q1_Q2:
					if (super.actualState.equals(states[9])) {
						log.debug("Got CONVERT_P1_Q1_Q2");
						super.actualState = states[10];
						super.resetMultToAddPrimes();
						log.debug("Writing CONVERT_P1_Q1_Q2");
						out.write(StraubProtocol.CONVERT_P1_Q1_Q2);
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "CONVERT_P1_Q1_Q2");
					}
					break;

				case StraubProtocol.CONVERT_P1_Q1_Q2_FINISHED:
					if (super.actualState.equals(states[10])) {
						log.debug("Got CONVERT_P1_Q1_Q2_FINISHED");
						super.actualState = states[11];
						if (checkCandidateTable(multToAddPrimes)) {
							super.myB = super.crtCandidate(multToAddPrimes)
									.getAdditive().mod(remSigma);
							log.debug("Writing CONVERT_P1_Q1_Q2_FINISHED");
							out
									.write(StraubProtocol.CONVERT_P1_Q1_Q2_FINISHED);
						} else {
							log.debug("Writing TABLE_FAILED");
							out.write(StraubProtocol.TABLE_FAILED);
							successful = false;
							endThread();
						}
						out.flush();
					} else {
						log.fatal(PROTO_VIOL + "CONVERT_P1_Q1_Q2_FINISHED");
					}
					break;

				default:
					log.fatal("Unknown command. Exiting.");
					successful = false;
					endThread();
					break;

				}
			}
		} catch (final Exception e) {
			log.fatal("Caught Exception: ", e);
			successful = false;
		}
		log.info("finished computation with client " + remoteID);
	}

	private boolean checkCandidateTable(final Hashtable table) {
		final Enumeration keys = table.keys();
		while (keys.hasMoreElements()) {
			final Object value = table.get(keys.nextElement());
			if (!(value instanceof AddMultRemainders)) {
				// reset to verified beginning
				super.actualState = states[1];
				return false;
			}
		}
		return true;
	}

	/**
	 * Computes the minimum of Servers preferred key size and clients and writes
	 * it to client. Now both parties talk about the same keySize and should use
	 * the same primes. This is verified by checkCandidateTable later in the
	 * protocol.
	 * 
	 * @throws IOException
	 */
	private void agreeOnPrimeLength() throws IOException {
		log.debug("My preferred key size: "
				+ EnvironmentSetup.getDesiredKeySize());
		final int remoteKeySize = in.readInt();
		log.debug("Remote preferred key size: " + remoteKeySize);
		preferredSharedKeySize = Math.min(EnvironmentSetup.getDesiredKeySize(),
				remoteKeySize);
		log.debug("preferred Key Size is " + preferredSharedKeySize);
		log.debug("Writing PREFERRED_SHARED_KEY_SIZE");
		out.write(StraubProtocol.PREFERRED_SHARED_KEY_SIZE);
		out.writeInt(preferredSharedKeySize);
		out.flush();
	}

	private void computeRemoteA(final int prime) throws Exception {
		// log.debug("ComputeRemoteA");
		final byte[] clientM = new byte[in.readInt()];
		in.read(clientM);
		out.write(StraubProtocol.MULT_TO_ADD);
		out.writeInt(prime);
		out.write(StraubProtocol.COMPUTED_A);
		final byte[] clientA = multToAdd(clientM, prime);
		out.writeInt(clientA.length);
		out.write(clientA);
		out.flush();
	}

	private void computeRemoteM(final int prime) throws Exception {
		// log.debug("Compute Remote M");
		final byte[] clientA = new byte[in.readInt()];
		in.read(clientA);
		out.write(StraubProtocol.ADD_TO_MULT);
		out.writeInt(prime);
		out.write(StraubProtocol.COMPUTED_M);
		final byte[] clientM = addToMult(clientA, prime);
		out.writeInt(clientM.length);
		out.write(clientM);
		out.flush();
	}

	/**
	 * Bob's method to create a additively shared secret and convert it to a
	 * multiplicatively shared one.
	 * 
	 * @param aliceA
	 *            Alices encrypted a
	 * @return m1, Alices encrypted m
	 * @throws Exception
	 */
	private byte[] addToMult(final byte[] aliceA, final int prime) throws Exception {
		BigInteger m_inv = null;
		BigInteger a = null;
		BigInteger m = null;
		a = BigInteger.valueOf(rand.nextInt(prime));
		while (true) {
			// m = new BigInteger(remSigma.bitLength() + 1, rand);
			m = BigInteger.valueOf(rand.nextInt(prime));
			// while (m.compareTo(remSigma) > 0 || m.equals(BigInteger.ZERO)) {
			// m = new BigInteger(remSigma.bitLength() + 1, rand);
			// }
			try {
				// it sometimes happens that a BigInteger is not inversible mod
				// sigma
				m_inv = m.modInverse(remSigma);
			} catch (final Exception e) {
			}
			if (m_inv != null) {
				break;
			}
		}

		final AddMultRemainders amr = new AddMultRemainders(a, m, remSigma);
		final byte[] aCrypted = pubEngine.processData(amr.getAdditive()
				.toByteArray());
		final byte[] added = pubEngine.addCryptedBlocks(aCrypted, aliceA);
		log.debug("Found new sharing for prime " + prime + ": add "
				+ amr.getAdditive() + " mult " + amr.getMultiplicative());
		if (super.actualState.equals(states[2])) {
			generatingPrimes.put(BigInteger.valueOf(prime), amr);
		} else if (super.actualState.equals(states[4])) {
			verifyingPrimes.put(BigInteger.valueOf(prime), amr);
		} else {
			throw new Exception(PROTO_VIOL + "addToMult not allowed");
		}
		return pubEngine.multiplyCryptedBlock(added, amr
				.getMultiplicativeInverse());
	}

	/**
	 * Bob's method to create a multiplcatively shared secret and convert it to
	 * a additively shared one.
	 * 
	 * @param aliceM
	 *            Alices encrypted m
	 * @return aliceA, Alices encrypted a
	 * @throws Exception
	 */
	private byte[] multToAdd(final byte[] aliceM, final int prime) throws Exception {
		BigInteger m = null;
		if (super.actualState.equals(states[7])) {
			m = candidate.getMultiplicative().mod(BigInteger.valueOf(prime));
		} else if (super.actualState.equals(states[10])) {
			m = q.mod(BigInteger.valueOf(prime));
		} else {
			throw new Exception(PROTO_VIOL
					+ "multToAdd not allowed in current context");
		}

		BigInteger a = new BigInteger(remSigma.bitLength() + 1, rand);
		while (a.compareTo(remSigma) > 0) {
			a = new BigInteger(remSigma.bitLength() + 1, rand);
		}

		log.debug("MultToAdd: My mult: " + m + ", my add: " + a);
		final AddMultRemainders amr = new AddMultRemainders(a, m, remSigma);
		final byte[] m1MultM2 = pubEngine.multiplyCryptedBlock(aliceM, amr
				.getMultiplicative());
		final byte[] aInvCrypted = pubEngine.processData(amr
				.getAdditiveInverse().toByteArray());
		final byte[] added = pubEngine.addCryptedBlocks(m1MultM2, aInvCrypted);
		super.multToAddPrimes.put(BigInteger.valueOf(prime), amr);
		return added;
	}

}
