package net.sf.distrib_rsa.cryptosystems.benaloh;

import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.Arrays;

import net.sf.distrib_rsa.cryptosystems.PrimeUtils;

import org.bouncycastle.crypto.AsymmetricBlockCipher;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.InvalidCipherTextException;

/**
 * @author lippold Published under the GPLv2 Licence (c) 2006 Georg Lippold
 * 
 */
public class BenalohEngine implements AsymmetricBlockCipher {

	/**
	 * @uml.property  name="debug"
	 */
	private boolean debug = false;

	/**
	 * @uml.property  name="forEncryption"
	 */
	private boolean forEncryption = true;

	/**
	 * @uml.property  name="bkp"
	 * @uml.associationEnd  
	 */
	private BenalohKeyParameters bkp;

	/**
	 * @uml.property  name="cert"
	 */
	private BigInteger cert;

	/**
	 * @uml.property  name="rand"
	 */
	private final SecureRandom rand;

	public BenalohEngine() {
		rand = new SecureRandom();
	}

	public BenalohEngine(final SecureRandom random) {
		rand = random;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see org.bouncycastle.crypto.AsymmetricBlockCipher#getInputBlockSize()
	 */
	public int getInputBlockSize() {
		if (forEncryption) {
			return bkp.getMessageBorder().toByteArray().length - 1;
		}
		return bkp.getModulus().toByteArray().length;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see org.bouncycastle.crypto.AsymmetricBlockCipher#getOutputBlockSize()
	 */
	public int getOutputBlockSize() {
		if (forEncryption) {
			return bkp.getModulus().toByteArray().length;
		}
		return bkp.getMessageBorder().toByteArray().length;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see org.bouncycastle.crypto.AsymmetricBlockCipher#init(boolean,
	 *      org.bouncycastle.crypto.CipherParameters)
	 */
	public void init(final boolean forEncryption, final CipherParameters param) {
		this.forEncryption = forEncryption;
		if (this.forEncryption
				&& (param instanceof BenalohPrivateKeyParameters)) {
			throw new IllegalArgumentException(
					"Encryption needs BenalohKeyParameters, "
							+ "not BenalohPrivateKeyParameters");
		}
		bkp = (BenalohKeyParameters) param;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see org.bouncycastle.crypto.AsymmetricBlockCipher#processBlock(byte[],
	 *      int, int)
	 */
	public byte[] processBlock(final byte[] input, final int inOff,
			final int len) throws InvalidCipherTextException {
		if (!forEncryption && (len < getInputBlockSize())) {
			// At decryption make sure that we receive padded data blocks
			throw new InvalidCipherTextException(
					"BlockLength does not match modulus for Benaloh cipher.\n");
		}

		byte[] block;

		if ((inOff != 0) || (len != input.length)) {
			block = new byte[len];
			System.arraycopy(input, inOff, block, 0, len);
		} else {
			block = input;
		}

		// transform input into BigInteger
		final BigInteger input_converted = new BigInteger(1, block);

		if (forEncryption
				&& (input_converted.compareTo(bkp.getMessageBorder()) >= 0)) {
			throw new DataLengthException("Input too large for Benaloh cipher.");
		}
		if (debug) {
			System.out.println("Input as BigInteger: " + input_converted);
		}
		byte[] output;
		if (forEncryption) {
			output = encrypt(input_converted);
		} else {
			output = decrypt(input_converted);
		}

		return output;
	}

	/**
	 * Encrypts a BigInteger representing the plaintext message with the public
	 * key.
	 * 
	 * @param plain
	 *            The plaintext message
	 * @return The encrypted plaintext message as BigInteger.toByteArray()
	 */
	private byte[] encrypt(final BigInteger plain) {
		cert = PrimeUtils.getRandom(bkp.getModulus(), rand);
		final byte[] output = bkp.getModulus().toByteArray();
		Arrays.fill(output, Byte.parseByte("0"));

		BigInteger encrypted = bkp.getY().modPow(plain, bkp.getModulus());
		encrypted = encrypted.multiply(
				cert.modPow(bkp.getMessageBorder(), bkp.getModulus())).mod(
				bkp.getModulus());

		final byte[] tmp = encrypted.toByteArray();
		System
				.arraycopy(tmp, 0, output, output.length - tmp.length,
						tmp.length);
		if (debug) {
			System.out.println("Encrypted value is:  "
					+ new BigInteger(1, output));
		}
		return output;

	}

	/**
	 * Decides which decryption function to use
	 * 
	 * @param encrypted
	 *            The encrypted message
	 * @return The decrypted message
	 */
	private byte[] decrypt(final BigInteger encrypted) {
		byte[] output;
		if (bkp.getMessageBorder().equals(BigInteger.valueOf(2))) {
			// Goldwasser-Micali scheme
			output = decryptGM(encrypted);
		} else {
			// Benaloh scheme
			output = decryptBenaloh(encrypted);
		}
		return output;

	}

	/**
	 * Decrypts a BigInteger representing an encrypted bit with the
	 * Goldwasser-Micali algorithm
	 * 
	 * @param encrypted
	 *            The encrypted message
	 * @return The decrypted message in a BigInteger.toByteArray()
	 */
	private byte[] decryptGM(final BigInteger encrypted) {
		final BenalohPrivateKeyParameters privKey = (BenalohPrivateKeyParameters) bkp;
		final int pRes = PrimeUtils.jacobiSymbol(encrypted, privKey.getPQ()[0]);
		final int qRes = PrimeUtils.jacobiSymbol(encrypted, privKey.getPQ()[1]);

		if (debug) {
			System.out.println("pRes is " + pRes);
			System.out.println("qRes is " + qRes);
		}

		if ((pRes == 1) && (qRes == 1)) {
			return BigInteger.ZERO.toByteArray();
		} else {
			return BigInteger.ONE.toByteArray();
		}

	}

	/**
	 * Decrypts a BigInteger that is a Benaloh encrypted message.
	 * 
	 * @param encrypted
	 *            The encrypted message
	 * @return A BigInteger.toByteArray() representing the decrypted message.
	 */
	private byte[] decryptBenaloh(final BigInteger encrypted) {
		final BenalohPrivateKeyParameters privKey = (BenalohPrivateKeyParameters) bkp;

		// set up lookup table if necessary
		if (privKey.getLookupList().size() == 0) {
			privKey.setupLookupList();
		}

		BigInteger[] pq = privKey.getPQ();
		BigInteger pSub1 = pq[0].subtract(BigInteger.ONE);
		BigInteger qSub1 = pq[1].subtract(BigInteger.ONE);
		BigInteger exp = pSub1.multiply(qSub1).divide(bkp.getMessageBorder());
		BigInteger res = PrimeUtils.chineseModPow(encrypted, exp, pq);

		final byte[] output = privKey.getMessageBorder().toByteArray();
		Arrays.fill(output, Byte.parseByte("0"));
		final BigInteger decrypted = (BigInteger) privKey.getLookupList().get(
				res);
		final byte[] tmp = decrypted.toByteArray();
		System
				.arraycopy(tmp, 0, output, output.length - tmp.length,
						tmp.length);
		if (debug) {
			System.out.println("Decrypted value is " + decrypted);
		}
		return output;
	}

	/**
	 * @return  the cert
	 * @uml.property  name="cert"
	 */
	public BigInteger getCert() {
		return cert;
	}

	public byte[] addCryptedBlocks(final byte[] block1, final byte[] block2)
			throws InvalidCipherTextException {
		final BigInteger m1Crypt = new BigInteger(1, block1);
		final BigInteger m2Crypt = new BigInteger(1, block2);

		// check for correct blocksize
		if ((m1Crypt.compareTo(bkp.getModulus()) >= 0)
				|| (m2Crypt.compareTo(bkp.getModulus()) >= 0)) {
			throw new InvalidCipherTextException(
					"BlockLength too large for addition");
		}

		// calculate resulting block
		BigInteger m1m2Crypt = m1Crypt.multiply(m2Crypt);
		m1m2Crypt = m1m2Crypt.mod(bkp.getModulus());
		if (debug) {
			System.out.println("c(m1) as BigInteger:......... " + m1Crypt);
			System.out.println("c(m2) as BigInteger:......... " + m2Crypt);
			System.out.println("(c(m1)*c(m2))%n = c(m1+m2)%r: " + m1m2Crypt);
		}

		final byte[] output = bkp.getModulus().toByteArray();
		Arrays.fill(output, Byte.parseByte("0"));
		System.arraycopy(m1m2Crypt.toByteArray(), 0, output, output.length
				- m1m2Crypt.toByteArray().length,
				m1m2Crypt.toByteArray().length);

		return output;
	}

	public byte[] multiplyCryptedBlock(final byte[] block1,
			final BigInteger value) throws InvalidCipherTextException {

		final BigInteger m1Crypt = new BigInteger(1, block1);
		if (m1Crypt.compareTo(bkp.getModulus()) >= 0) {
			throw new InvalidCipherTextException(
					"BlockLength too large for multiplication.\n");
		}

		// calculate resulting block
		final BigInteger m1m2Crypt = m1Crypt.modPow(value, bkp.getModulus());
		if (debug) {
			System.out.println("c(m1) as BigInteger:....... " + m1Crypt);
			System.out.println("m2 as BigInteger:.......... " + value);
			System.out.println("(c(m1)^m2)%n = c(m1*m2)%r: " + m1m2Crypt);
		}

		final byte[] output = bkp.getModulus().toByteArray();
		Arrays.fill(output, Byte.parseByte("0"));
		System.arraycopy(m1m2Crypt.toByteArray(), 0, output, output.length
				- m1m2Crypt.toByteArray().length,
				m1m2Crypt.toByteArray().length);

		return output;
	}

	/**
	 * @param debug  the debug to set
	 * @uml.property  name="debug"
	 */
	public void setDebug(final boolean debug) {
		this.debug = debug;
	}

}
