/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *   EMDataGenerator.java
 *   Copyright (C) 2002 Mark Hall
 *
 */

package weka.gui.boundaryvisualizer;

import weka.core.*;
import weka.clusterers.*;
import java.io.*;
import java.util.Random;


/**
 * Class that uses EM to build a probabilistic clustering model of
 * supplied input data and then generates new random instances based
 * that model.
 *
 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
 * @version $Revision: 1.1.1.1 $
 */
public class EMDataGenerator implements DataGenerator, Serializable {

  // the instances to cluster
  private Instances m_instancesToCluster;

  // the clusterer
  private EM m_clusterer;

  // number of clusters generated by EM
  private int m_numClusters = -1;

  // parameters of normal distributions for each attribute in each cluster
  private double [][][] m_normalDistributions;

  // prior probabilities for each cluster
  private double [] m_clusterPriors;

  // random number seed
  private int m_seed = 1;

  // random number generator
  private Random m_random;

  // the cluster from which to generate the next instance from
  private int m_clusterToGenerateFrom = 0;

  // which dimensions to use for computing a weight for each generated
  // instance
  private boolean [] m_weightingDimensions;
  
  // the values for the weighting dimensions to use for computing the weight
  // for the next instance to be generated
  private double [] m_weightingValues;

  // created once only - for generating instances fast
  private Instance m_instance;
  
  private double [] m_instanceVals;

  // cumulative distribution for cluster priors
  private double [] m_cumDist;

  private static double m_normConst = Math.sqrt(2*Math.PI);

  /**
   * Builds the data generator
   *
   * @param inputInstances Instances to construct the clusterer with
   * @exception Exception if an error occurs
   */
  public void buildGenerator(Instances inputInstances) throws Exception {
    m_clusterer = new EM();
    m_random = new Random(m_seed);
    m_clusterToGenerateFrom = 0;
    m_instancesToCluster = inputInstances;
    m_clusterer.buildClusterer(m_instancesToCluster);
    m_numClusters = m_clusterer.numberOfClusters();
    m_normalDistributions = m_clusterer.getClusterModelsNumericAtts();
    m_clusterPriors = m_clusterer.getClusterPriors();
    System.err.println(m_clusterer);
    m_instanceVals = new double [m_instancesToCluster.numAttributes()];
    m_instance = new Instance(1.0, m_instanceVals);

    // Compute cumulative distribution for cluster priors
    m_cumDist = computeCumulativeDistribution(m_clusterPriors);
  }
  
  /**
   * Return a cumulative distribution from a discrete distribution
   *
   * @param dist the distribution to use
   * @return the cumulative distribution
   */
  private double [] computeCumulativeDistribution(double [] dist) {
    double [] cumDist = new double[dist.length];
    double sum = 0;
    for (int i = 0; i < dist.length; i++) {
      sum += dist[i];
      cumDist[i] = sum;
    }

    return cumDist;
  }

  /**
   * Generate a new instance. Returns the instance in an brand new
   * Instance object.
   *
   * @return an <code>Instance</code> value
   * @exception Exception if an error occurs
   */
  public Instance generateInstance() throws Exception {
    return generateInstance(false);
  }

  /**
   * Generate a new instance. Reuses an existing instance object to
   * speed up the process.
   *
   * @return an <code>Instance</code> value
   * @exception Exception if an error occurs
   */
  public Instance generateInstanceFast() throws Exception {
    return generateInstance(true);
  }

  /**
   * Randomly generates an instance from one cluster's model. Successive
   * calls to this method cycle through the clusters
   *
   * @return an <code>Instance</code> value
   * @exception Exception if an error occurs
   */
  private Instance generateInstance(boolean fast) throws Exception {
    if (m_clusterer == null) {
      throw new Exception("Generator has not been built yet!");
    }
    Instance newInst;
    if (fast) {
      newInst = m_instance;
    } else {
      m_instanceVals = new double [m_instancesToCluster.numAttributes()];
      newInst = new Instance(1.0, m_instanceVals);
    }

    // choose cluster to generate from
    double randomCluster = m_random.nextDouble();
    for (int i = 0; i < m_cumDist.length; i++) {
      if (randomCluster <= m_cumDist[i]) {
	m_clusterToGenerateFrom = i;
	break;
      }
    }
    
    if (m_weightingDimensions.length != m_instancesToCluster.numAttributes()) {
      throw new Exception("Weighting dimension array != num attributes!");
    }
    
    
    // set instance values and weight
    double weight = 1;
    for (int i = 0; i < m_instancesToCluster.numAttributes(); i++) {
      if (!m_weightingDimensions[i]) {
	if (m_instancesToCluster.attribute(i).isNumeric()) {
	  double val = m_random.nextGaussian();
	  //	System.err.println("val "+val);
	  // de-standardize with respect to this normal distribution
	  val *= m_normalDistributions[m_clusterToGenerateFrom][i][1];
	  val += m_normalDistributions[m_clusterToGenerateFrom][i][0];
	  //	newInst.setValue(i, val);
	  m_instanceVals[i] = val;
	} else {
	  // nominal attribute
	}
      } else {
	weight *= normalDens(m_weightingValues[i], 
			     m_normalDistributions[m_clusterToGenerateFrom][i][0],
			     m_normalDistributions[m_clusterToGenerateFrom][i][1]);
	m_instanceVals[i] = m_weightingValues[i];
      }
    }
    newInst.setWeight(weight);

    // advance the cluster
    //    m_clusterToGenerateFrom = (m_clusterToGenerateFrom + 1) % m_numClusters;
    return newInst;
  }

  /**
   * Set which dimensions to use when computing a weight for the next
   * instance to generate
   *
   * @param dims an array of booleans indicating which dimensions to use
   */
  public void setWeightingDimensions(boolean [] dims) {
    m_weightingDimensions = dims;
  }
  
  /**
   * Set the values for the weighting dimensions to be used when computing
   * the weight for the next instance to be generated
   *
   * @param vals an array of doubles containing the values of the
   * weighting dimensions (corresponding to the entries that are set to
   * true throw setWeightingDimensions)
   */
  public void setWeightingValues(double [] vals) {
    m_weightingValues = vals;
  }

  /**
   * Density function of normal distribution.
   * @param x input value
   * @param mean mean of distribution
   * @param stdDev standard deviation of distribution
   */
  private double normalDens (double x, double mean, double stdDev) {
    double diff = x - mean;
   
    return  (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev)));
  }

  /**
   * Return the EM model of the data
   *
   * @return an <code>EM</code> value
   */
  public EM getEMModel() {
    return m_clusterer;
  }

  /**
   * Return the number of clusters generated by EM
   *
   * @return an <code>int</code> value
   */
  public int getNumGeneratingModels() {
    return m_numClusters;
  }

  /**
   * Return the number of the cluster from which the next instance
   * will be generated from
   *
   * @return an <code>int</code> value
   */
  public int getClusterUsedToGenerateLastInstanceFrom() {
    return m_clusterToGenerateFrom;
  }

  /**
   * Main method for tesing this class
   *
   * @param args a <code>String[]</code> value
   */
  public static void main(String [] args) {
    try {
      Reader r = null;
      if (args.length != 1) {
	throw new Exception("Usage: EMDataGenerator <filename>");
      } else {
	r = new BufferedReader(new FileReader(args[0]));
	Instances insts = new Instances(r);
	EMDataGenerator dg = new EMDataGenerator();
	dg.buildGenerator(insts);
	Instances header = new Instances(insts,0);
	System.out.println(header);
	for (int i = 0; i < insts.numInstances(); i++) {
	  Instance newInst = dg.generateInstance();
	  newInst.setDataset(header);
	  System.out.println(newInst);
	}
      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }
}
