import java.util.*;
import java.io.*;

public class EMGaussians {

    static final int NUM_CLASSES = 2;
    static final int NUM_DATA = 10000;
    static final int NUM_ITERATIONS = 100;

    public static void main(String[] args) {
        String datafile = "gmmdata.txt";
        String classfile = "gmmclasses.txt";
        double[] data = readData(datafile);
        int[] classes = readClasses(classfile);
        

        for (int t = 0; t < NUM_ITERATIONS; ++t) {
            double[] mu = new double[NUM_CLASSES];
            double[] mixing = new double[NUM_CLASSES];
            double[][] membership = new double[NUM_CLASSES][NUM_DATA];
            runEM(data, mu, mixing, membership);
            
            double acc = calculateAccuracy(classes, membership);
            double minMu = Math.min(mu[0], mu[1]);
            double maxMu = Math.max(mu[0], mu[1]);
            System.out.println(minMu + " " + maxMu + " " + mixing[0] + " " + mixing[1] + " " + acc);
        }


    }

    static double[] readData(String datafile) {
        double[] data = new double[NUM_DATA];
        try {
            BufferedReader br = new BufferedReader(new FileReader(datafile));
            for (int j = 0; j < NUM_DATA; ++j) {
                data[j] = new Double(br.readLine());
            }
        } catch(IOException e) {
            e.printStackTrace();
        }
        return data;

    }
    static int[] readClasses(String classfile) {
        int[] classes = new int[NUM_DATA];
        try {
            BufferedReader br = new BufferedReader(new FileReader(classfile));
            for (int j = 0; j < NUM_DATA; ++j) {
                classes[j] = new Integer(br.readLine());
            }
        } catch(IOException e) {
            e.printStackTrace();
        }
        return classes;
    }

    static double calculateAccuracy(int[] classes, double[][] membership) {
        double errors = 0;
        for (int j = 0; j < membership[0].length; ++j) {
            int myclass = 0;
            if (membership[1][j] > membership[0][j])
                myclass = 1;
            if (classes[j] != myclass)
                ++errors;
        }
        double acc = errors / NUM_DATA;
        return Math.max(acc, 1-acc);
    }


    static void runEM(double[] data, double[] outMixing, double[] outMu,
                      double[][] outMembership) {

        // Set up 
        double[] oldMu = {0,0};
        double[] oldMixing = {0,0};
        double mu0 = Math.random() * 20 - 10;
        double mu1 = Math.random() * 20 - 10;
        double[] newMu = {mu0, mu1};
        double x = Math.random();
        double[] newMixing = {x, 1-x};
        double[][] oldMembership = new double[NUM_CLASSES][NUM_DATA];
        double[][] newMembership = new double[NUM_CLASSES][NUM_DATA];
        for (int i = 0; i < oldMembership.length; ++i) {
            for (int j = 0; j < oldMembership[0].length; ++j) {
                oldMembership[i][j] = 0;
                newMembership[i][j] = 1;
            }
        }

        int iterations = 0;
        double epsilon = Math.pow(10,-5);
        while (! converged(oldMu, newMu, oldMixing, newMixing, epsilon)) {
            ++iterations;
            oldMixing = newMixing;
            oldMu = newMu;
            oldMembership = newMembership;
            
            newMembership = new double[NUM_CLASSES][NUM_DATA];
            EStep(data, oldMixing, oldMu, newMembership);
            newMu = new double[2];
            newMixing = new double[2];
            MStep(data, newMembership, newMu, newMixing);
        }
        for (int i = 0; i < NUM_CLASSES; ++i) {
            outMu[i] = newMu[i];
            outMixing[i] = newMixing[i];
            for (int j = 0; j < NUM_DATA; ++j) {
                outMembership[i][j] = newMembership[i][j];
            }
        }
    }

    static void EStep(double[] data, double[] mixing, double[] mu,
                      double[][] membership) {
        for (int j = 0; j < data.length; ++j) {
            double firstval = mixing[0] * normpdf(data[j], mu[0], 1);
            double secondval = mixing[1] * normpdf(data[j], mu[1], 1);
            double sum = firstval + secondval;
            membership[0][j] = firstval/sum;
            membership[1][j] = secondval / sum;
        }
    }

    static void MStep(double[] data, double[][] membership, double[] mu,
                      double[] mixing) {
        double sum0 = 0;
        double sum1 = 0;
        for (int j = 0; j < data.length; ++j) {
            sum0+= membership[0][j];
            sum1+= membership[1][j];
        }
        mixing[0] = sum0 / (sum0 + sum1);
        mixing[1] = sum1 / (sum0 + sum1);

        double num0 = 0;
        double denom0 = 0;
        double num1 = 0;
        double denom1 = 0;
        for (int j = 0; j < data.length; ++j) {
            num0 += membership[0][j] * data[j];
            denom0 += membership[0][j];
            num1 += membership[1][j] * data[j];
            denom1 += membership[1][j];
        }
        mu[0] = num0/denom0;
        mu[1] = num1/denom1;
    }

    static double normpdf(double x, double mu, double sigma) {
        double val = (1 / (sigma * Math.sqrt( 2 * Math.PI))) *
            Math.exp( - (x - mu) * (x - mu) / (2 * sigma * sigma));
        return val;
    }

    static boolean converged(double[] oldMu, double[] newMu, 
                             double[] oldMixing, double[] newMixing,
                             double epsilon) {
        for (int i = 0; i < oldMu.length; ++i) {
            if (Math.abs(oldMu[i] - newMu[i]) > epsilon)
                return false;
            if (Math.abs(oldMixing[i] - newMixing[i]) > epsilon)
                return false;
        }
        return true;
    }


}