package Crypto;

/**
 * This program is free software: you can redistribute it and/or modify it 
 * under the terms of the GNU General Public License as published by the Free 
 * Software Foundation, either version 3 of the License, or (at your option) 
 * any later version. 
 * 
 * This program is distributed in the hope that it will be useful, but WITHOUT 
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for 
 * more details. 
 * 
 * You should have received a copy of the GNU General Public License along with 
 * this program.  If not, see <http://www.gnu.org/licenses/>.
 */

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.math.*;
import java.security.SecureRandom;
import java.util.*;

import org.bouncycastle.crypto.InvalidCipherTextException;

import Algorithm.BloomEncryption;
import Helper.UnitConversion;

import net.sf.distrib_rsa.cryptosystems.naccacheStern.NaccacheSternKeyParameters;
import net.sf.distrib_rsa.cryptosystems.naccacheStern.NaccacheSternKeySerializationFactory;
import net.sf.distrib_rsa.cryptosystems.naccacheStern.NaccacheSternPrivateKeyParameters;

/**
 * Paillier Cryptosystem <br><br>
 * References: <br>
 * [1] Pascal Paillier, "Public-Key Cryptosystems Based on Composite Degree Residuosity Classes," EUROCRYPT'99.
 *    URL: <a href="http://www.gemplus.com/smart/rd/publications/pdf/Pai99pai.pdf">http://www.gemplus.com/smart/rd/publications/pdf/Pai99pai.pdf</a><br>
 * 
 * [2] Paillier cryptosystem from Wikipedia. 
 *    URL: <a href="http://en.wikipedia.org/wiki/Paillier_cryptosystem">http://en.wikipedia.org/wiki/Paillier_cryptosystem</a>
 * @author Kun Liu (kunliu1@cs.umbc.edu)
 * @version 1.0
 */
public class CryptServerP {
	private static final long serialVersionUID = -4600732652879071173L;

	private static final int PUBKEY = 0;
	private static final int PRIVKEY = 1;
	
	private Random rnd = new RandomSource(BloomEncryption.secureRandom).getRandom(); 
	private long seed = RandomSource.getNextSeed(BloomEncryption.secureRandom);
    /**
     * p and q are two large primes. 
     * lambda = lcm(p-1, q-1) = (p-1)*(q-1)/gcd(p-1, q-1).
     */
    private BigInteger p,  q,  lambda;
    /**
     * n = p*q, where p and q are two large primes.
     */
    public BigInteger n;
    /**
     * nsquare = n*n
     */
    public BigInteger nsquare;
    /**
     * a random integer in Z*_{n^2} where gcd (L(g^lambda mod n^2), n) = 1.
     */
    private BigInteger g;
    /**
     * number of bits of modulus
     */
    private int bitLength;

    /**
     * Constructs an instance of the Paillier cryptosystem.
     * @param bitLengthVal number of bits of modulus
     * @param certainty The probability that the new BigInteger represents a prime number will exceed (1 - 2^(-certainty)). The execution time of this constructor is proportional to the value of this parameter.
     */
    public CryptServerP(int bitLengthVal, int certainty) {
    	this.rnd.setSeed(this.seed);
        KeyGeneration(bitLengthVal, certainty);
    }

    /**
     * Constructs an instance of the Paillier cryptosystem with 512 bits of modulus and at least 1-2^(-64) certainty of primes generation.
     */
    public CryptServerP() {
    	this.rnd.setSeed(this.seed);
    }

    /**
     * Sets up the public key and private key.
     * @param bitLengthVal number of bits of modulus.
     * @param certainty The probability that the new BigInteger represents a prime number will exceed (1 - 2^(-certainty)). The execution time of this constructor is proportional to the value of this parameter.
     */
    public void KeyGeneration(int bitLengthVal, int certainty) {
        bitLength = bitLengthVal;
        /*Constructs two randomly generated positive BigIntegers that are probably prime, with the specified bitLength and certainty.*/
        p = new BigInteger(bitLength / 2, certainty, this.rnd);
        q = new BigInteger(bitLength / 2, certainty, this.rnd);

        n = p.multiply(q);
        nsquare = n.multiply(n);

        g = new BigInteger("2");
        lambda = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE)).divide(
                p.subtract(BigInteger.ONE).gcd(q.subtract(BigInteger.ONE)));
        /* check whether g is good.*/
        if (g.modPow(lambda, nsquare).subtract(BigInteger.ONE).divide(n).gcd(n).intValue() != 1) {
            System.out.println("g is not good. Choose g again.");
            System.exit(1);
        }
    }

    public void setPubKey(byte[] pubKey) {
		final ByteArrayInputStream bais = new ByteArrayInputStream(pubKey);
		ObjectInputStream ois;
		try {
			ois = new ObjectInputStream(bais);
			final long serializedData = ois.readLong();
			if (serializedData != serialVersionUID) {
				throw new IllegalArgumentException(
						"The supplied byte[] was not serialized by "
								+ CryptServerP.class
										.getName());
			}
			final int type = ois.read();
			final BigInteger n = (BigInteger) ois.readObject();
			final BigInteger g = (BigInteger) ois.readObject();
			ois.close();
			bais.close();
			this.n = n;
			this.g = g;
			this.nsquare = n.multiply(n);
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (ClassNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
    }
    
    //contains n and g
    public byte[] getPubKeyAsByteA() {
    	byte[] retval = null;
		final ByteArrayOutputStream baos = new ByteArrayOutputStream();
		ObjectOutputStream oos;
		try {
			oos = new ObjectOutputStream(baos);
			oos.writeLong(serialVersionUID);
			oos.write(PUBKEY);

			oos.writeObject(n);
			oos.writeObject(g);
			oos.close();
			baos.flush();
			retval = baos.toByteArray();
			baos.close();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		return retval;
    }

//    //contains n and g
//    public byte[] getPrivKeyAsByteA() {
//    	byte[] privKey = null;
//
//		final ByteArrayOutputStream baos = new ByteArrayOutputStream();
//		final ObjectOutputStream oos = new ObjectOutputStream(baos);
//		oos.writeLong(serialVersionUID);
//		if (key instanceof NaccacheSternPrivateKeyParameters) {
//			oos.write(PRIVATE_KEY);
//		} else {
//			oos.write(PUBLIC_KEY);
//		}
//
//		oos.writeObject(key.getY());
//		oos.writeObject(key.getModulus());
//		oos.writeObject(key.getSigma());
//
//		if (key instanceof NaccacheSternPrivateKeyParameters) {
//			final NaccacheSternPrivateKeyParameters priv = (NaccacheSternPrivateKeyParameters) key;
//			oos.writeObject(priv.getSmallPrimes());
//			BigInteger[] pq = priv.getPQ();
//			oos.writeObject(pq[0]);
//			oos.writeObject(pq[1]);
//			oos.writeObject(priv.getLookupTable());
//		}
//
//		oos.close();
//		baos.flush();
//		final byte[] retval = baos.toByteArray();
//		baos.close();
//		return retval;
//
//    }

    /**
     * Encrypts plaintext m. ciphertext c = g^m * r^n mod n^2. This function explicitly requires random input r to help with encryption.
     * @param m plaintext as a BigInteger
     * @param r random plaintext to help with encryption
     * @return ciphertext as a BigInteger
     */
    public BigInteger Encryption(BigInteger m, BigInteger r) {
        return g.modPow(m, nsquare).multiply(r.modPow(n, nsquare)).mod(nsquare);
    }

    public int getBitLength() {
    	return this.bitLength;
    }
    
    /**
     * Encrypts plaintext m. ciphertext c = g^m * r^n mod n^2. This function automatically generates random input r (to help with encryption).
     * @param m plaintext as a BigInteger
     * @return ciphertext as a BigInteger
     */
    public BigInteger Encryption(BigInteger m) {
        BigInteger r = new BigInteger(bitLength, this.rnd);
        return g.modPow(m, nsquare).multiply(r.modPow(n, nsquare)).mod(nsquare);

    }

    /**
     * Decrypts ciphertext c. plaintext m = L(c^lambda mod n^2) * u mod n, where u = (L(g^lambda mod n^2))^(-1) mod n.
     * @param c ciphertext as a BigInteger
     * @return plaintext as a BigInteger
     */
    public BigInteger Decryption(BigInteger c) {
        BigInteger u = g.modPow(lambda, nsquare).subtract(BigInteger.ONE).divide(n).modInverse(n);
        return c.modPow(lambda, nsquare).subtract(BigInteger.ONE).divide(n).multiply(u).mod(n);
    }

    public BigInteger cryptedAdd(BigInteger a, BigInteger b) {
    	BigInteger c = a.multiply(b).mod(nsquare);
    	
    	return c;
    }
    
    public BigInteger cryptedMulWithConst(BigInteger a, BigInteger constant) {
    	BigInteger c = a.modPow(constant, nsquare);
    	
    	return c;
    }

    public BigInteger cryptedInverse(BigInteger a) {
    	BigInteger c = a.modInverse(nsquare);
    	
    	return c;
    }

	public BigInteger subValues(BigInteger a, BigInteger b) {
		BigInteger second = null;
		second = this.cryptedInverse(b);
		
		return cryptedAdd(a, second);
	}

    public BigInteger getRandomWithinModulo() {
    	BigInteger r;
    	do {
    	    r = new BigInteger(nsquare.bitLength(), this.rnd);
    	} while (r.compareTo(nsquare) >= 0);
    	
    	return r;
    }

    //    /**
//     * main function
//     * @param str intput string
//     */
//    public static void main(String[] str) {
//        /* instantiating an object of Paillier cryptosystem*/
//        Paillier paillier = new Paillier();
//        /* instantiating two plaintext msgs*/
//        BigInteger m1 = new BigInteger("20");
//        BigInteger m2 = new BigInteger("60");
//        /* encryption*/
//        BigInteger em1 = paillier.Encryption(m1);
//        BigInteger em2 = paillier.Encryption(m2);
//        /* printout encrypted text*/
//        System.out.println(em1);
//        System.out.println(em2);
//        /* printout decrypted text */
//        System.out.println(paillier.Decryption(em1).toString());
//        System.out.println(paillier.Decryption(em2).toString());
//
//        /* test homomorphic properties -> D(E(m1)*E(m2) mod n^2) = (m1 + m2) mod n */
//        BigInteger product_em1em2 = em1.multiply(em2).mod(paillier.nsquare);
//        BigInteger sum_m1m2 = m1.add(m2).mod(paillier.n);
//        System.out.println("original sum: " + sum_m1m2.toString());
//        System.out.println("decrypted sum: " + paillier.Decryption(product_em1em2).toString());
//
//        /* test homomorphic properties -> D(E(m1)^m2 mod n^2) = (m1*m2) mod n */
//        BigInteger expo_em1m2 = em1.modPow(m2, paillier.nsquare);
//        BigInteger prod_m1m2 = m1.multiply(m2).mod(paillier.n);
//        System.out.println("original product: " + prod_m1m2.toString());
//        System.out.println("decrypted product: " + paillier.Decryption(expo_em1m2).toString());
//
//    }
}