histograms;
double sumRespOnMV = 0.0;
double sumWeightOnMV = 0.0;
double[][] histogram = new double[size][3];
for (IntPair entry : indices) {
int index = entry.v1;
int w = entry.v2;
double weight = weights[index];
double value = fvalues[index];
double target = targets[index];
if (!Double.isNaN(value)) {
int idx = (int) value;
if (isClassification) {
histogram[idx][0] += target * w;
} else {
histogram[idx][0] += target * weight * w;
}
histogram[idx][1] += weight * w;
} else {
if (isClassification) {
sumRespOnMV += target * w;
} else {
sumRespOnMV += target * weight * w;
}
sumWeightOnMV += weight * w;
}
}
histograms = new ArrayList<>(histogram.length + 1);
for (int i = 0; i < histogram.length; i++) {
if (!MathUtils.isZero(histogram[i][1])) {
double[] hist = histogram[i];
histograms.add(new double[] {i, hist[0], hist[1]});
}
}
histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV });
Function1D func = LineCutter.build(attIndex, histograms, numIntervals);
ensemble.add(func);
}
}
return ensemble;
}
}
================================================
FILE: src/main/java/mltk/predictor/function/BivariateFunction.java
================================================
package mltk.predictor.function;
/**
* Interface for bivariate real functions.
*
* @author Yin Lou
*
*/
public interface BivariateFunction {
/**
* Computes the value for the function.
*
* @param x the 1st argument.
* @param y the 2nd argument.
* @return the value for the function.
*/
public double evaluate(double x, double y);
}
================================================
FILE: src/main/java/mltk/predictor/function/CHistogram.java
================================================
package mltk.predictor.function;
/**
* Class for cumulative histograms.
*
* @author Yin Lou
*
*/
public class CHistogram {
public double[] sum;
public double[] count;
public double sumOnMV;
public double countOnMV;
/**
* Constructor.
*
* @param n the size of this cumulative histogram.
*/
public CHistogram(int n) {
sum = new double[n];
count = new double[n];
sumOnMV = 0.0;
countOnMV = 0.0;
}
/**
* Returns the size of this cumulative histogram.
*
* @return the size of this cumulative histogram.
*/
public int size() {
return sum.length;
}
/**
* Returns {@code true} if missing values are present.
*
* @return {@code true} if missing values are present.
*/
public boolean hasMissingValue() {
return countOnMV > 0;
}
}
================================================
FILE: src/main/java/mltk/predictor/function/CompressionUtils.java
================================================
package mltk.predictor.function;
import mltk.core.Instance;
import mltk.predictor.BaggedEnsemble;
import mltk.predictor.BoostedEnsemble;
import mltk.predictor.Predictor;
import mltk.util.tuple.IntPair;
/**
* Class for utility functions for compressing ensembles of univariate/bivariate functions.
*
* @author Yin Lou
*
*/
public class CompressionUtils {
/**
* Compresses a bagged ensemble of 1D functions to a single 1D function.
*
* @param attIndex the attribute index of this regressor.
* @param baggedEnsemble the bagged ensemble.
* @return a single compressed 1D function.
*/
public static Function1D compress(int attIndex, BaggedEnsemble baggedEnsemble) {
Function1D function = Function1D.getConstantFunction(attIndex, 0);
for (int i = 0; i < baggedEnsemble.size(); i++) {
Predictor predictor = baggedEnsemble.get(i);
Function1D func = null;
if (predictor instanceof Function1D) {
func = (Function1D) predictor;
} else {
throw new IllegalArgumentException();
}
function.add(func);
}
function.divide(baggedEnsemble.size());
return function;
}
/**
* Compresses a boosted ensemble to a single 1D function.
*
* @param attIndex the attribute of this regressor.
* @param boostedEnsemble the boosted ensemble.
* @return a single compressed 1D function.
*/
public static Function1D compress(int attIndex, BoostedEnsemble boostedEnsemble) {
Function1D function = Function1D.getConstantFunction(attIndex, 0);
for (int i = 0; i < boostedEnsemble.size(); i++) {
Predictor predictor = boostedEnsemble.get(i);
Function1D func = null;
if (predictor instanceof Function1D) {
func = (Function1D) predictor;
} else if (predictor instanceof BaggedEnsemble) {
func = compress(attIndex, (BaggedEnsemble) predictor);
} else {
throw new IllegalArgumentException();
}
function.add(func);
}
return function;
}
/**
* Converts a 1D function to 1D lookup table.
*
* @param n the number of bins.
* @param function the 1D function.
* @return a 1D lookup table.
*/
public static Array1D convert(int n, Function1D function) {
double[] predictions = new double[n];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = function.evaluate(i);
}
return new Array1D(function.getAttributeIndex(), predictions, function.predictionOnMV);
}
/**
* Compresses a bagged ensemble of 2D functions to a single 2D function.
*
* @param attIndex1 the 1st attribute index.
* @param attIndex2 the 2nd attribute index.
* @param baggedEnsemble the bagged ensemble.
* @return a single compressed 2D functions.
*/
public static Function2D compress(int attIndex1, int attIndex2, BaggedEnsemble baggedEnsemble) {
// TODO check consistency problem when missing values are present.
Function2D function = Function2D.getConstantFunction(attIndex1, attIndex2, 0);
for (int i = 0; i < baggedEnsemble.size(); i++) {
Predictor predictor = baggedEnsemble.get(i);
Function2D func = null;
if (predictor instanceof Function2D) {
func = (Function2D) predictor;
} else {
throw new IllegalArgumentException();
}
function.add(func);
}
function.divide(baggedEnsemble.size());
return function;
}
/**
* Compresses a boosted ensemble to a single 2D function.
*
* @param attIndex1 the 1st attribute index.
* @param attIndex2 the 2nd attribute index.
* @param boostedEnsemble the boosted ensemble.
* @return a single compressed 2D function.
*/
public static Function2D compress(int attIndex1, int attIndex2, BoostedEnsemble boostedEnsemble) {
// TODO check consistency problem when missing values are present.
Function2D function = Function2D.getConstantFunction(attIndex1, attIndex2, 0);
for (int i = 0; i < boostedEnsemble.size(); i++) {
Predictor predictor = boostedEnsemble.get(i);
Function2D func = null;
if (predictor instanceof Function2D) {
func = (Function2D) predictor;
} else if (predictor instanceof BaggedEnsemble) {
func = compress(attIndex1, attIndex2, (BaggedEnsemble) predictor);
} else {
throw new IllegalArgumentException();
}
function.add(func);
}
return function;
}
/**
* Compresses and converts a boosted ensemble to a single 2D lookup table.
*
* @param attIndex1 the 1st attribute index.
* @param attIndex2 the 2nd attribute index.
* @param boostedEnsemble the boosted ensemble.
* @return a 2D lookup table.
*/
public static Array2D compress(int attIndex1, int attIndex2, int n1, int n2, BoostedEnsemble boostedEnsemble) {
double[][] predictions = new double[n1][n2];
double[] predictionsOnMV1 = new double[n2];
double[] predictionsOnMV2 = new double[n1];
double[] vector = new double[Math.max(attIndex1, attIndex2) + 1];
Instance instance = new Instance(vector);
for (int i = 0; i < n1; i++) {
vector[attIndex1] = i;
vector[attIndex2] = Double.NaN;
predictionsOnMV2[i] = boostedEnsemble.regress(instance);
double[] preds = predictions[i];
for (int j = 0; j < n2; j++) {
vector[attIndex1] = Double.NaN;
vector[attIndex2] = j;
predictionsOnMV1[j] = boostedEnsemble.regress(instance);
vector[attIndex1] = i;
preds[j] = boostedEnsemble.regress(instance);
}
}
vector[attIndex1] = Double.NaN;
vector[attIndex2] = Double.NaN;
return new Array2D(attIndex1, attIndex2, predictions,
predictionsOnMV1, predictionsOnMV2, boostedEnsemble.regress(instance));
}
/**
* Converts a 2D function to 2D lookup table.
*
* @param n1 the number of bins for 1st attribute.
* @param n2 the number of bins for 2nd attribute.
* @param function the 2D function.
* @return a 2D lookup table.
*/
public static Array2D convert(int n1, int n2, Function2D function) {
double[][] predictions = new double[n1][n2];
double[] predictionsOnMV1 = new double[n2];
double[] predictionsOnMV2 = new double[n1];
for (int i = 0; i < n1; i++) {
predictionsOnMV2[i] = function.evaluate(i, Double.NaN);
double[] preds = predictions[i];
for (int j = 0; j < n2; j++) {
predictionsOnMV1[j] = function.evaluate(Double.NaN, j);
preds[j] = function.evaluate(i, j);
}
}
IntPair attIndices = function.getAttributeIndices();
return new Array2D(attIndices.v1, attIndices.v2, predictions,
predictionsOnMV1, predictionsOnMV2, function.predictionOnMV12);
}
}
================================================
FILE: src/main/java/mltk/predictor/function/CubicSpline.java
================================================
package mltk.predictor.function;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.Arrays;
import mltk.core.Instance;
import mltk.predictor.Regressor;
import mltk.util.ArrayUtils;
/**
* Class for cubic splines. Given knots k, the cubic spline uses the following basis: 1, x, x^2, x^3, (x - k[i])_+^3.
*
* @author Yin Lou
*
*/
public class CubicSpline implements Regressor, UnivariateFunction {
protected int attIndex;
protected double intercept;
protected double[] knots;
protected double[] w;
/**
* Constructor.
*
* @param attIndex the attribute index.
* @param intercept the intercept.
* @param knots the knots.
* @param w the coefficient vector.
*/
public CubicSpline(int attIndex, double intercept, double[] knots, double[] w) {
this.attIndex = attIndex;
this.intercept = intercept;
this.knots = knots;
this.w = w;
}
/**
* Constructor.
*
* @param intercept the intercept.
* @param knots the knots.
* @param w the coefficient vector.
*/
public CubicSpline(double intercept, double[] knots, double[] w) {
this(-1, intercept, knots, w);
}
/**
* Constructor.
*
* @param knots the knots.
* @param w the coefficient vector.
*/
public CubicSpline(double[] knots, double[] w) {
this(-1, 0, knots, w);
}
/**
* Construct a cubic spline with specified knots and all coefficients set to zero.
*
* @param knots the knots.
*/
public CubicSpline(double[] knots) {
this(-1, 0, knots, new double[knots.length + 3]);
}
/**
* Constructor.
*/
public CubicSpline() {
}
@Override
public void read(BufferedReader in) throws Exception {
attIndex = Integer.parseInt(in.readLine().split(": ")[1]);
intercept = Double.parseDouble(in.readLine().split(": ")[1]);
in.readLine();
knots = ArrayUtils.parseDoubleArray(in.readLine());
in.readLine();
w = ArrayUtils.parseDoubleArray(in.readLine());
}
@Override
public void write(PrintWriter out) throws Exception {
out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName());
out.println("AttIndex: " + attIndex);
out.println("Intercept: " + intercept);
out.println("Knots: " + knots.length);
out.println(Arrays.toString(knots));
out.println("Coefficients: " + w.length);
out.println(Arrays.toString(w));
}
@Override
public double evaluate(double x) {
double pred = intercept + w[0] * x + w[1] * x * x + w[2] * x * x * x;
for (int i = 0; i < knots.length; i++) {
pred += h(x, knots[i]) * w[i + 3];
}
return pred;
}
/**
* Calculate the basis.
*
* @param x a real.
* @param k a knot.
* @return h(x, z), a basis in cubic spline.
*/
public static double h(double x, double k) {
double t = x - k;
if (t < 0) {
return 0;
}
return t * t * t;
}
@Override
public double regress(Instance instance) {
return evaluate(instance.getValue(attIndex));
}
/**
* Returns the attribute index.
*
* @return the attribute index.
*/
public int getAttributeIndex() {
return attIndex;
}
/**
* Sets the attribute index.
*
* @param attIndex the new attribute index.
*/
public void setAttributeIndex(int attIndex) {
this.attIndex = attIndex;
}
/**
* Returns the coefficient vector.
*
* @return the coefficient vector.
*/
public double[] getCoefficients() {
return w;
}
/**
* Returns the knot vector.
*
* @return the knot vector.
*/
public double[] getKnots() {
return knots;
}
/**
* Returns the intercept.
*
* @return the intercept.
*/
public double getIntercept() {
return intercept;
}
@Override
public CubicSpline copy() {
double[] newW = w.clone();
double[] newKnots = knots.clone();
return new CubicSpline(attIndex, intercept, newW, newKnots);
}
}
================================================
FILE: src/main/java/mltk/predictor/function/EnsembledLineCutter.java
================================================
package mltk.predictor.function;
import mltk.core.Attribute;
import mltk.core.Instances;
import mltk.predictor.BaggedEnsemble;
import mltk.predictor.Learner;
public abstract class EnsembledLineCutter extends Learner {
protected int attIndex;
protected int numIntervals;
protected int baggingIters;
protected boolean isClassification;
@Override
public BaggedEnsemble build(Instances instances) {
return build(instances, attIndex, numIntervals);
}
/**
* Builds an 1D function ensemble.
*
* @param instances the training set.
* @param attIndex the attribute index.
* @param numIntervals the number of intervals.
* @return an 1D function ensemble.
*/
public BaggedEnsemble build(Instances instances, int attIndex, int numIntervals) {
Attribute attribute = instances.getAttributes().get(attIndex);
return build(instances, attribute, numIntervals);
}
public abstract BaggedEnsemble build(Instances instances, Attribute attribute, int numIntervals);
/**
* Returns the index in the attribute list of the training set.
*
* @return the index in the attribute list of the training set.
*/
public int getAttributeIndex() {
return attIndex;
}
/**
* Sets the index in the attribute list of the training set.
*
* @param attIndex the attribute index.
*/
public void setAttributeIndex(int attIndex) {
this.attIndex = attIndex;
}
/**
* Returns the number of bagging iterations.
*
* @return the number of bagging iterations.
*/
public int getBaggingIters() {
return baggingIters;
}
/**
* Sets the number of bagging iterations.
*
* @param baggingIters the number of bagging iterations.
*/
public void setBaggingIters(int baggingIters) {
this.baggingIters = baggingIters;
}
/**
* Returns {@code true} if it is a classification problem.
*
* @return {@code true} if it is a classification problem.
*/
public boolean isClassification() {
return isClassification;
}
/**
* Sets {@code true} if it is a classification problem.
*
* @param isClassification {@code true} if it is a classification problem.
*/
public void setClassification(boolean isClassification) {
this.isClassification = isClassification;
}
/**
* Returns the number of intervals.
*
* @return the number of intervals.
*/
public int getNumIntervals() {
return numIntervals;
}
/**
* Sets the number of intervals.
*
* @param numIntervals the number of intervals.
*/
public void setNumIntervals(int numIntervals) {
this.numIntervals = numIntervals;
}
}
================================================
FILE: src/main/java/mltk/predictor/function/Function1D.java
================================================
package mltk.predictor.function;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.Arrays;
import mltk.core.Instance;
import mltk.predictor.Regressor;
import mltk.util.ArrayUtils;
import mltk.util.MathUtils;
import mltk.util.VectorUtils;
/**
* Class for 1D functions.
*
*
* This class represents a segmented 1D function. Segments are defined in split array. For example, [3, 5, +INF] defines
* three segments: (-INF, 3], (3, 5], (5, +INF). The last value in the split array is always +INF. The prediction array
* is the corresponding predictions for segments defined in splits.
*
*
* @author Yin Lou
*
*/
public class Function1D implements Regressor, UnivariateFunction {
/**
* Attribute index.
*/
protected int attIndex;
/**
* Last value is always Double.POSITIVE_INFINITY. e.g. [3, 5, +INF] defines three segments: (-INF, 3], (3, 5], (5,
* +INF)
*/
protected double[] splits;
/**
* Corresponding predictions for segments defined in splits.
*/
protected double[] predictions;
/**
* Prediction on missing value.
*/
protected double predictionOnMV;
/**
* Returns a constant 1D function.
*
* @param attIndex the attribute index of this function.
* @param prediction the constant.
* @return a constant 1D function.
*/
public static Function1D getConstantFunction(int attIndex, double prediction) {
Function1D func = new Function1D(attIndex, new double[] { Double.POSITIVE_INFINITY },
new double[] { prediction });
return func;
}
/**
* Resets this function to 0.
*/
public void setZero() {
splits = new double[] { Double.POSITIVE_INFINITY };
predictions = new double[] { 0 };
predictionOnMV = 0.0;
}
/**
* Returns {@code true} if the function is 0.
*
* @return {@code true} if the function is 0.
*/
public boolean isZero() {
return ArrayUtils.isConstant(predictions, 0, predictions.length, 0)
&& MathUtils.isZero(predictionOnMV);
}
/**
* Returns {@code true} if the function is constant.
*
* @return {@code true} if the function is constant.
*/
public boolean isConstant() {
return ArrayUtils.isConstant(predictions, 1, predictions.length, predictions[0])
&& MathUtils.isZero(predictionOnMV - predictions[0]);
}
/**
* Multiplies this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function1D multiply(double c) {
VectorUtils.multiply(predictions, c);
predictionOnMV *= c;
return this;
}
/**
* Divides this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function1D divide(double c) {
VectorUtils.divide(predictions, c);
predictionOnMV /= c;
return this;
}
/**
* Adds this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function1D add(double c) {
VectorUtils.add(predictions, c);
predictionOnMV += c;
return this;
}
/**
* Subtracts this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function1D subtract(double c) {
VectorUtils.subtract(predictions, c);
predictionOnMV -= c;
return this;
}
/**
* Constructor.
*/
public Function1D() {
}
/**
* Constructor.
*
* @param attIndex the attribute index.
* @param splits the splits.
* @param predictions the predictions.
*/
public Function1D(int attIndex, double[] splits, double[] predictions) {
this(attIndex, splits, predictions, 0.0);
}
/**
* Constructor.
*
* @param attIndex the attribute index.
* @param splits the splits.
* @param predictions the predictions.
* @param predictionOnMissing prediction on missing value;
*/
public Function1D(int attIndex, double[] splits, double[] predictions, double predictionOnMissing) {
this.attIndex = attIndex;
this.splits = splits;
this.predictions = predictions;
this.predictionOnMV = predictionOnMissing;
}
/**
* Adds this function whit another function.
*
* @param func the other function.
* @return this function.
*/
public Function1D add(Function1D func) {
if (attIndex != func.attIndex) {
throw new IllegalArgumentException("Cannot add functions on different terms");
}
int[] insertionPoints = new int[func.splits.length - 1];
int newElements = 0;
for (int i = 0; i < insertionPoints.length; i++) {
insertionPoints[i] = Arrays.binarySearch(splits, func.splits[i]);
if (insertionPoints[i] < 0) {
newElements++;
}
}
if (newElements > 0) {
double[] newSplits = new double[splits.length + newElements];
System.arraycopy(splits, 0, newSplits, 0, splits.length);
int k = splits.length;
for (int i = 0; i < insertionPoints.length; i++) {
if (insertionPoints[i] < 0) {
newSplits[k++] = func.splits[i];
}
}
Arrays.sort(newSplits);
double[] newPredictions = new double[newSplits.length];
for (int i = 0; i < newSplits.length; i++) {
newPredictions[i] = this.evaluate(newSplits[i]) + func.evaluate(newSplits[i]);
}
splits = newSplits;
predictions = newPredictions;
} else {
for (int i = 0; i < splits.length; i++) {
predictions[i] += func.evaluate(splits[i]);
}
}
this.predictionOnMV += func.predictionOnMV;
return this;
}
/**
* Returns the attribute index.
*
* @return the attribute index.
*/
public int getAttributeIndex() {
return attIndex;
}
/**
* Sets the attribute index.
*
* @param attIndex the new attribute index.
*/
public void setAttributeIndex(int attIndex) {
this.attIndex = attIndex;
}
/**
* Returns the internal split array.
*
* @return the internal split array.
*/
public double[] getSplits() {
return splits;
}
/**
* Sets the split array.
*
* @param splits the new split array.
*/
public void setSplits(double[] splits) {
this.splits = splits;
}
/**
* Returns the internal prediction array.
*
* @return the internal prediction array.
*/
public double[] getPredictions() {
return predictions;
}
/**
* Sets the prediction array.
*
* @param predictions the new prediction array.
*/
public void setPredictions(double[] predictions) {
this.predictions = predictions;
}
/**
* Returns the prediction on missing value.
*
* @return the prediction on missing value.
*/
public double getPredictionOnMV() {
return predictionOnMV;
}
/**
* Sets the prediction on missing value.
*
* @param predictionOnMV the prediction on missing value.
*/
public void setPredictionOnMV(double predictionOnMV) {
this.predictionOnMV = predictionOnMV;
}
@Override
public void read(BufferedReader in) throws Exception {
String line = in.readLine();
String[] data = line.split(": ");
attIndex = Integer.parseInt(data[1]);
line = in.readLine();
data = line.split(": ");
predictionOnMV = Double.parseDouble(data[1]);
in.readLine();
line = in.readLine();
splits = ArrayUtils.parseDoubleArray(line);
in.readLine();
line = in.readLine();
predictions = ArrayUtils.parseDoubleArray(line);
}
@Override
public void write(PrintWriter out) throws Exception {
out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName());
out.println("AttIndex: " + attIndex);
out.println("PredictinOnMV: " + predictionOnMV);
out.println("Splits: " + splits.length);
out.println(Arrays.toString(splits));
out.println("Predictions: " + predictions.length);
out.println(Arrays.toString(predictions));
}
@Override
public double regress(Instance instance) {
return evaluate(instance.getValue(attIndex));
}
@Override
public double evaluate(double x) {
if (Double.isNaN(x)) {
return predictionOnMV;
} else {
return predictions[getSegmentIndex(x)];
}
}
@Override
public Function1D copy() {
double[] splitsCopy = Arrays.copyOf(splits, splits.length);
double[] predictionsCopy = Arrays.copyOf(predictions, predictions.length);
return new Function1D(attIndex, splitsCopy, predictionsCopy, predictionOnMV);
}
/**
* Returns the segment index given x.
*
* @param x the search key.
* @return the segment index given x.
*/
protected int getSegmentIndex(double x) {
// Assume x is not NaN
return ArrayUtils.findInsertionPoint(splits, x);
}
}
================================================
FILE: src/main/java/mltk/predictor/function/Function2D.java
================================================
package mltk.predictor.function;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.Arrays;
import mltk.core.Instance;
import mltk.predictor.Regressor;
import mltk.util.ArrayUtils;
import mltk.util.VectorUtils;
import mltk.util.tuple.IntPair;
/**
* Class for 2D functions.
*
*
* This class represents a segmented 2D function. Segments are defined in split arrays for the two attributes. For
* example, [3, 5, +INF] defines three segments: (-INF, 3], (3, 5], (5, +INF). The last value in the split array is
* always +INF. The prediction matrix is the corresponding predictions for segments defined in splits.
*
*
* @author Yin Lou
*
*/
public class Function2D implements Regressor, BivariateFunction {
/**
* First attribute index.
*/
protected int attIndex1;
/**
* Second attribute index.
*/
protected int attIndex2;
/**
* Last value is always Double.POSITIVE_INFINITY. e.g. [3, 5, +INF] defines three segments: (-INF, 3], (3, 5], (5,
* +INF)
*/
protected double[] splits1;
protected double[] splits2;
/**
* Predictions.
*/
protected double[][] predictions;
/**
* Predictions on missing value for attribute 1.
*/
protected double[] predictionsOnMV1;
/**
* Predictions on missing value for attribute 2.
*/
protected double[] predictionsOnMV2;
/**
* Prediction when both attribute 1 and 2 are missing.
*/
protected double predictionOnMV12;
/**
* Constructor.
*/
public Function2D() {
}
/**
* Constructor.
*
* @param attIndex1 the 1st attribute index.
* @param attIndex2 the 2nd attribute index.
* @param splits1 the split array for the 1st attribute.
* @param splits2 the split array for the 2nd attribute.
* @param predictions the prediction matrix.
*/
public Function2D(int attIndex1, int attIndex2, double[] splits1, double[] splits2, double[][] predictions) {
this(attIndex1, attIndex2, splits1, splits2, predictions, new double[splits2.length], new double[splits1.length], 0.0);
}
/**
* Constructor.
*
* @param attIndex1 the 1st attribute index.
* @param attIndex2 the 2nd attribute index.
* @param splits1 the split array for the 1st attribute.
* @param splits2 the split array for the 2nd attribute.
* @param predictions the prediction matrix.
* @param predictionsOnMV1 the prediction array when the 1st attribute is missing.
* @param predictionsOnMV2 the prediction array when the 2nd attribute is missing.
* @param predictionOnMV12 the prediction when both attributes are missing.
*/
public Function2D(int attIndex1, int attIndex2,
double[] splits1, double[] splits2, double[][] predictions,
double[] predictionsOnMV1, double[] predictionsOnMV2, double predictionOnMV12) {
this.attIndex1 = attIndex1;
this.attIndex2 = attIndex2;
this.predictions = predictions;
this.splits1 = splits1;
this.splits2 = splits2;
this.predictionsOnMV1 = predictionsOnMV1;
this.predictionsOnMV2 = predictionsOnMV2;
this.predictionOnMV12 = predictionOnMV12;
}
/**
* Returns a constant 2D function.
*
* @param attIndex1 the 1st attribute index.
* @param attIndex2 the 2nd attribute index.
* @param prediction the constant.
* @return a constant 2D function.
*/
public static Function2D getConstantFunction(int attIndex1, int attIndex2, double prediction) {
Function2D func = new Function2D(attIndex1, attIndex2, new double[] { Double.POSITIVE_INFINITY },
new double[] { Double.POSITIVE_INFINITY }, new double[][] { { prediction } });
return func;
}
/**
* Returns the index of 1st attribute.
*
* @return the index of 1st attribute.
*/
public int getAttributeIndex1() {
return attIndex1;
}
/**
* Returns the index of 2nd attribute.
*
* @return the index of 2nd attribute.
*/
public int getAttributeIndex2() {
return attIndex2;
}
/**
* Returns the attribute indices pair.
*
* @return the attribute indices pair.
*/
public IntPair getAttributeIndices() {
return new IntPair(attIndex1, attIndex2);
}
/**
* Sets the attribute indices.
*
* @param attIndex1 the new index for the 1st attribute.
* @param attIndex2 the new index for the 2nd attribute.
*/
public void setAttributeIndices(int attIndex1, int attIndex2) {
this.attIndex1 = attIndex1;
this.attIndex2 = attIndex2;
}
/**
* Returns the internal prediction matrix.
*
* @return the internal prediction matrix.
*/
public double[][] getPredictions() {
return predictions;
}
/**
* Sets the prediction matrix.
*
* @param predictions the new prediction matrix.
*/
public void setPredictions(double[][] predictions) {
this.predictions = predictions;
}
/**
* Returns the prediction array when the 1st attribute is missing.
*
* @return the prediction array when the 1st attribute is missing.
*/
public double[] getPredictionsOnMV1() {
return predictionsOnMV1;
}
/**
* Sets the prediction array when the 1st attribute is missing.
*
* @param predictionsOnMV1 the new prediction array.
*/
public void setPredictionsOnMV1(double[] predictionsOnMV1) {
this.predictionsOnMV1 = predictionsOnMV1;
}
/**
* Returns the prediction array when the 2nd attribute is missing.
*
* @return the prediction array when the 2nd attribute is missing.
*/
public double[] getPredictionsOnMV2() {
return predictionsOnMV2;
}
/**
* Sets the prediction array when the 2nd attribute is missing.
*
* @param predictionsOnMV2 the new prediction array.
*/
public void setPredictionsOnMV2(double[] predictionsOnMV2) {
this.predictionsOnMV2 = predictionsOnMV2;
}
/**
* Returns the prediction when both attributes are missing.
*
* @return the prediction when both attributes are missing.
*/
public double getPredictionOnMV12() {
return predictionOnMV12;
}
/**
* Sets the prediction when both attributes are missing.
*
* @param predictionOnMV12 the new prediction.
*/
public void setPredictionOnMV12(double predictionOnMV12) {
this.predictionOnMV12 = predictionOnMV12;
}
/**
* Multiplies this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function2D multiply(double c) {
for (double[] preds : predictions) {
VectorUtils.multiply(preds, c);
}
VectorUtils.multiply(predictionsOnMV1, c);
VectorUtils.multiply(predictionsOnMV2, c);
predictionOnMV12 *= c;
return this;
}
/**
* Divides this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function2D divide(double c) {
for (double[] preds : predictions) {
VectorUtils.divide(preds, c);
}
VectorUtils.divide(predictionsOnMV1, c);
VectorUtils.divide(predictionsOnMV2, c);
predictionOnMV12 /= c;
return this;
}
/**
* Adds this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function2D add(double c) {
for (double[] preds : predictions) {
VectorUtils.add(preds, c);
}
VectorUtils.add(predictionsOnMV1, c);
VectorUtils.add(predictionsOnMV2, c);
predictionOnMV12 += c;
return this;
}
/**
* Subtracts this function with a constant.
*
* @param c the constant.
* @return this function.
*/
public Function2D subtract(double c) {
for (double[] preds : predictions) {
VectorUtils.subtract(preds, c);
}
VectorUtils.subtract(predictionsOnMV1, c);
VectorUtils.subtract(predictionsOnMV2, c);
predictionOnMV12 -= c;
return this;
}
/**
* Adds this function with another one.
*
* @param func the other function.
* @return this function.
*/
public Function2D add(Function2D func) {
if (attIndex1 != func.attIndex1 || attIndex2 != func.attIndex2) {
throw new IllegalArgumentException("Cannot add arrays on differnt terms");
}
double[] s1 = splits1;
int[] insertionPoints1 = new int[func.splits1.length - 1];
int newElements1 = 0;
for (int i = 0; i < insertionPoints1.length; i++) {
insertionPoints1[i] = Arrays.binarySearch(splits1, func.splits1[i]);
if (insertionPoints1[i] < 0) {
newElements1++;
}
}
if (newElements1 > 0) {
double[] newSplits1 = new double[splits1.length + newElements1];
System.arraycopy(splits1, 0, newSplits1, 0, splits1.length);
int k = splits1.length;
for (int i = 0; i < insertionPoints1.length; i++) {
if (insertionPoints1[i] < 0) {
newSplits1[k++] = func.splits1[i];
}
}
Arrays.sort(newSplits1);
s1 = newSplits1;
}
double[] s2 = splits2;
int[] insertionPoints2 = new int[func.splits2.length - 1];
int newElements2 = 0;
for (int j = 0; j < insertionPoints2.length; j++) {
insertionPoints2[j] = Arrays.binarySearch(splits2, func.splits2[j]);
if (insertionPoints2[j] < 0) {
newElements2++;
}
}
if (newElements2 > 0) {
double[] newSplits2 = new double[splits2.length + newElements2];
System.arraycopy(splits2, 0, newSplits2, 0, splits2.length);
int k = splits2.length;
for (int j = 0; j < insertionPoints2.length; j++) {
if (insertionPoints2[j] < 0) {
newSplits2[k++] = func.splits2[j];
}
}
Arrays.sort(newSplits2);
s2 = newSplits2;
}
if (newElements1 == 0 && newElements2 == 0) {
for (int i = 0; i < splits1.length; i++) {
predictionsOnMV2[i] += func.evaluate(splits1[i], Double.NaN);
double[] ps = predictions[i];
for (int j = 0; j < splits2.length; j++) {
predictionsOnMV1[j] += func.evaluate(Double.NaN, splits2[j]);
ps[j] += func.evaluate(splits1[i], splits2[j]);
}
}
predictionOnMV12 += func.predictionOnMV12;
} else {
double[][] newPredictions = new double[s1.length][s2.length];
predictionsOnMV1 = new double[s2.length];
predictionsOnMV2 = new double[s1.length];
for (int i = 0; i < s1.length; i++) {
predictionsOnMV2[i] = this.evaluate(s1[i], Double.NaN) + func.evaluate(s1[i], Double.NaN);
double[] ps = newPredictions[i];
for (int j = 0; j < s2.length; j++) {
predictionsOnMV1[j] = this.evaluate(Double.NaN, s2[j]) + func.evaluate(Double.NaN, s2[j]);
ps[j] = this.evaluate(s1[i], s2[j]) + func.evaluate(s1[i], s2[j]);
}
}
splits1 = s1;
splits2 = s2;
predictions = newPredictions;
predictionOnMV12 = this.predictionOnMV12 + func.predictionOnMV12;
}
return this;
}
@Override
public void read(BufferedReader in) throws Exception {
String line = in.readLine();
String[] data = line.split(": ");
attIndex1 = Integer.parseInt(data[1]);
line = in.readLine();
data = line.split(": ");
attIndex2 = Integer.parseInt(data[1]);
in.readLine();
line = in.readLine();
splits1 = ArrayUtils.parseDoubleArray(line);
in.readLine();
line = in.readLine();
splits2 = ArrayUtils.parseDoubleArray(line);
String[] dim = in.readLine().split(": ")[1].split("x");
predictions = new double[Integer.parseInt(dim[0])][];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = ArrayUtils.parseDoubleArray(in.readLine());
}
in.readLine();
predictionsOnMV1 = ArrayUtils.parseDoubleArray(in.readLine());
in.readLine();
predictionsOnMV2 = ArrayUtils.parseDoubleArray(in.readLine());
data = in.readLine().split(": ");
predictionOnMV12 = Double.parseDouble(data[1]);
}
@Override
public void write(PrintWriter out) throws Exception {
out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName());
out.println("AttIndex1: " + attIndex1);
out.println("AttIndex2: " + attIndex2);
out.println("Splits1: " + splits1.length);
out.println(Arrays.toString(splits1));
out.println("Splits2: " + splits2.length);
out.println(Arrays.toString(splits2));
out.println("Predictions: " + predictions.length + "x" + predictions[0].length);
for (int i = 0; i < predictions.length; i++) {
out.println(Arrays.toString(predictions[i]));
}
out.println("PredictionsOnMV1: " + predictionsOnMV1.length);
out.println(Arrays.toString(predictionsOnMV1));
out.println("PredictionsOnMV2: " + predictionsOnMV2.length);
out.println(Arrays.toString(predictionsOnMV2));
out.println("PredictionOnMV12: " + predictionOnMV12);
}
@Override
public double regress(Instance instance) {
return evaluate(instance.getValue(attIndex1), instance.getValue(attIndex2));
}
@Override
public double evaluate(double x, double y) {
if (!Double.isNaN(x) && !Double.isNaN(y)) {
IntPair idx = getSegmentIndex(x, y);
return predictions[idx.v1][idx.v2];
} else if (Double.isNaN(x) && !Double.isNaN(y)) {
int idx = ArrayUtils.findInsertionPoint(splits2, y);
return predictionsOnMV1[idx];
} else if (!Double.isNaN(x) && Double.isNaN(y)) {
int idx = ArrayUtils.findInsertionPoint(splits1, x);
return predictionsOnMV2[idx];
} else {
return predictionOnMV12;
}
}
@Override
public Function2D copy() {
double[] splits1Copy = Arrays.copyOf(splits1, splits1.length);
double[] splits2Copy = Arrays.copyOf(splits2, splits2.length);
double[][] predictionsCopy = new double[predictions.length][];
for (int i = 0; i < predictionsCopy.length; i++) {
predictionsCopy[i] = Arrays.copyOf(predictions[i], predictions[i].length);
}
double[] predictionsOnMV1Copy = Arrays.copyOf(predictionsOnMV1, predictionsOnMV1.length);
double[] predictionsOnMV2Copy = Arrays.copyOf(predictionsOnMV2, predictionsOnMV2.length);
return new Function2D(attIndex1, attIndex2, splits1Copy, splits2Copy, predictionsCopy,
predictionsOnMV1Copy, predictionsOnMV2Copy, predictionOnMV12);
}
/**
* Returns the segment indices pair given (x1, x2). Assume x1 and x2 are not missing values.
*
* @param x1 the 1st search key.
* @param x2 the 2nd search key.
* @return segment indices pair at (x1, x2).
*/
protected IntPair getSegmentIndex(double x1, double x2) {
int idx1 = ArrayUtils.findInsertionPoint(splits1, x1);
int idx2 = ArrayUtils.findInsertionPoint(splits2, x2);
return new IntPair(idx1, idx2);
}
}
================================================
FILE: src/main/java/mltk/predictor/function/Histogram2D.java
================================================
package mltk.predictor.function;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.util.tuple.Pair;
/**
* Class for 2D histograms.
*
* @author Yin Lou
*
*/
public class Histogram2D {
public double[][] resp;
public double[][] count;
public double[] respOnMV1;
public double[] countOnMV1;
public double[] respOnMV2;
public double[] countOnMV2;
public double respOnMV12;
public double countOnMV12;
public static class Table {
public double[][][] resp;
public double[][][] count;
public double[][] respOnMV1;
public double[][] countOnMV1;
public double[][] respOnMV2;
public double[][] countOnMV2;
public double respOnMV12;
public double countOnMV12;
public Table(int n, int m) {
resp = new double[n][m][4];
count = new double[n][m][4];
respOnMV1 = new double[m][2];
countOnMV1 = new double[m][2];
respOnMV2 = new double[n][2];
countOnMV2 = new double[n][2];
respOnMV12 = 0.0;
countOnMV12 = 0.0;
}
}
/**
* Computes 2D histogram given (f1, f2).
*
* @param instances the data set.
* @param f1 the 1st feature.
* @param f2 the 2nd feature.
* @param hist2d the histogram to compute.
*/
public static void computeHistogram2D(Instances instances, int f1, int f2, Histogram2D hist2d) {
for (Instance instance : instances) {
double resp = instance.getTarget() * instance.getWeight();
double weight = instance.getWeight();
if (!instance.isMissing(f1) && !instance.isMissing(f2)) {
int idx1 = (int) instance.getValue(f1);
int idx2 = (int) instance.getValue(f2);
hist2d.resp[idx1][idx2] += resp;
hist2d.count[idx1][idx2] += weight;
} else if (instance.isMissing(f1) && !instance.isMissing(f2)) {
int idx2 = (int) instance.getValue(f2);
hist2d.respOnMV1[idx2] += resp;
hist2d.countOnMV1[idx2] += weight;
} else if (!instance.isMissing(f1) && instance.isMissing(f2)) {
int idx1 = (int) instance.getValue(f1);
hist2d.respOnMV2[idx1] += resp;
hist2d.countOnMV2[idx1] += weight;
} else {
hist2d.respOnMV12 += resp;
hist2d.countOnMV12 += weight;
}
}
}
/**
* Computes auxiliary data structure given 2D histogram and cumulative 1D histograms.
*
* @param hist2d the 2D histogram.
* @param cHist1 the cumulative histogram for the 1st feature.
* @param cHist2 the cumulative histogram for the 2nd feature.
* @return table auxiliary data structure.
*/
public static Table computeTable(Histogram2D hist2d, CHistogram cHist1, CHistogram cHist2) {
Table table = new Table(hist2d.resp.length, hist2d.resp[0].length);
double sum = 0;
double count = 0;
for (int j = 0; j < hist2d.resp[0].length; j++) {
sum += hist2d.resp[0][j];
table.resp[0][j][0] = sum;
count += hist2d.count[0][j];
table.count[0][j][0] = count;
fillTable(table, 0, j, cHist1, cHist2);
}
for (int i = 1; i < hist2d.resp.length; i++) {
sum = count = 0;
for (int j = 0; j < hist2d.resp[i].length; j++) {
sum += hist2d.resp[i][j];
table.resp[i][j][0] = table.resp[i - 1][j][0] + sum;
count += hist2d.count[i][j];
table.count[i][j][0] = table.count[i - 1][j][0] + count;
fillTable(table, i, j, cHist1, cHist2);
}
}
double respOnMV1 = 0;
double countOnMV1 = 0;
for (int j = 0; j < hist2d.respOnMV1.length; j++) {
respOnMV1 += hist2d.respOnMV1[j];
countOnMV1 += hist2d.countOnMV1[j];
table.respOnMV1[j][0] = respOnMV1;
table.respOnMV1[j][1] = cHist1.sumOnMV - respOnMV1;
table.countOnMV1[j][0] = countOnMV1;
table.countOnMV1[j][1] = cHist1.countOnMV - countOnMV1;
}
double respOnMV2 = 0;
double countOnMV2 = 0;
for (int i = 0; i < hist2d.respOnMV2.length; i++) {
respOnMV2 += hist2d.respOnMV2[i];
countOnMV2 += hist2d.countOnMV2[i];
table.respOnMV2[i][0] = respOnMV2;
table.respOnMV2[i][1] = cHist2.sumOnMV - respOnMV2;
table.countOnMV2[i][0] = countOnMV2;
table.countOnMV2[i][1] = cHist2.countOnMV - countOnMV2;
}
table.respOnMV12 = hist2d.respOnMV12;
table.countOnMV12 = hist2d.countOnMV12;
return table;
}
protected static void fillTable(Table table, int i, int j, CHistogram cHist1, CHistogram cHist2) {
double[] count = table.count[i][j];
double[] resp = table.resp[i][j];
resp[1] = cHist1.sum[i] - resp[0];
resp[2] = cHist2.sum[j] - resp[0];
resp[3] = cHist1.sum[cHist1.size() - 1] - cHist1.sum[i] - resp[2];
count[1] = cHist1.count[i] - count[0];
count[2] = cHist2.count[j] - count[0];
count[3] = cHist1.count[cHist1.size() - 1] - cHist1.count[i] - count[2];
}
/**
* Constructor.
*
* @param n the size of the 1st dimension.
* @param m the size of the 2nd dimension.
*/
public Histogram2D(int n, int m) {
resp = new double[n][m];
count = new double[n][m];
respOnMV1 = new double[m];
countOnMV1 = new double[m];
respOnMV2 = new double[n];
countOnMV2 = new double[n];
respOnMV12 = 0.0;
countOnMV12 = 0.0;
}
/**
* Computes the cumulative histograms on the margin.
*
* @return the cumulative histograms.
*/
public Pair computeCHistogram() {
CHistogram cHist1 = new CHistogram(resp.length);
CHistogram cHist2 = new CHistogram(resp[0].length);
for (int i = 0; i < resp.length; i++) {
double[] r = resp[i];
double[] c = count[i];
for (int j = 0; j < r.length; j++) {
cHist1.sum[i] += r[j];
cHist1.count[i] += c[j];
cHist2.sum[j] += r[j];
cHist2.count[j] += c[j];
}
}
for (int i = 1; i < cHist1.size(); i++) {
cHist1.sum[i] += cHist1.sum[i - 1];
cHist1.count[i] += cHist1.count[i - 1];
}
for (int j = 1; j < cHist2.size(); j++) {
cHist2.sum[j] += cHist2.sum[j - 1];
cHist2.count[j] += cHist2.count[j - 1];
}
for (int j = 0; j < respOnMV1.length; j++) {
cHist1.sumOnMV += respOnMV1[j];
cHist1.countOnMV += countOnMV1[j];
}
cHist1.sumOnMV += respOnMV12;
cHist1.countOnMV += countOnMV12;
for (int i = 0; i < respOnMV2.length; i++) {
cHist2.sumOnMV += respOnMV2[i];
cHist2.countOnMV += countOnMV2[i];
}
cHist2.sumOnMV += respOnMV12;
cHist2.countOnMV += countOnMV12;
return new Pair(cHist1, cHist2);
}
}
================================================
FILE: src/main/java/mltk/predictor/function/LineCutter.java
================================================
package mltk.predictor.function;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.predictor.Learner;
import mltk.util.Random;
import mltk.util.Element;
import mltk.util.MathUtils;
import mltk.util.OptimUtils;
/**
* Class for cutting lines.
*
* @author Yin Lou
*
*/
public class LineCutter extends Learner {
static class Interval implements Comparable {
boolean finalized;
int start;
int end;
// INF: Can split but split point not computed
// NaN: Declared as leaf
// Other: Split point
double split;
double sum;
double weight;
double value; // mean * sum, or sum * sum /weight; negative gain
double gain;
Interval left;
Interval right;
Interval() {
split = Double.POSITIVE_INFINITY;
}
Interval(int start, int end, double sum, double weight) {
this.start = start;
this.end = end;
this.split = Double.POSITIVE_INFINITY;
this.sum = sum;
this.weight = weight;
}
@Override
public int compareTo(Interval o) {
if (this.value < o.value) {
return -1;
} else if (this.value > o.value) {
return 1;
} else {
return 0;
}
}
double getPrediction() {
return MathUtils.divide(sum, weight, 0.0);
}
boolean isFinalized() {
return finalized;
}
boolean isInteriorNode() {
return split < Double.POSITIVE_INFINITY;
}
boolean isLeaf() {
return Double.isNaN(split);
}
}
private int attIndex;
private int numIntervals;
private boolean isClassification;
/**
* Constructor.
*/
public LineCutter() {
this(false);
}
/**
* Constructor.
*
* @param isClassification {@code true} if it is a classification problem.
*/
public LineCutter(boolean isClassification) {
attIndex = -1;
this.isClassification = isClassification;
}
@Override
public Function1D build(Instances instances) {
return build(instances, attIndex, numIntervals);
}
/**
* Builds a 1D function.
*
* @param instances the training set.
* @param attribute the attribute.
* @param numIntervals the number of intervals.
* @return a 1D function.
*/
public Function1D build(Instances instances, Attribute attribute, int numIntervals) {
int attIndex = attribute.getIndex();
double sumRespOnMV = 0.0;
double sumWeightOnMV = 0.0;
List histograms;
if (attribute.getType() == Attribute.Type.NUMERIC) {
// weight: attribute value
// [feature value, sum, weight]
List> pairs = new ArrayList<>(instances.size());
for (Instance instance : instances) {
double weight = instance.getWeight();
double value = instance.getValue(attIndex);
double target = instance.getTarget();
if (!Double.isNaN(value)) {
if (isClassification) {
pairs.add(new Element<>(new double[] { target, weight }, value));
} else {
pairs.add(new Element<>(new double[] { target * weight, weight }, value));
}
} else {
if (isClassification) {
sumRespOnMV += target;
} else {
sumRespOnMV += target * weight;
}
sumWeightOnMV += weight;
}
}
Collections.sort(pairs);
histograms = new ArrayList<>(pairs.size() + 1);
getHistograms(pairs, histograms);
histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV });
} else {
int size = 0;
if (attribute.getType() == Attribute.Type.BINNED) {
size = ((BinnedAttribute) attribute).getNumBins();
} else {
size = ((NominalAttribute) attribute).getCardinality();
}
double[][] histogram = new double[size][2];
for (Instance instance : instances) {
double weight = instance.getWeight();
double value = instance.getValue(attIndex);
double target = instance.getTarget();
if (!Double.isNaN(value)) {
int idx = (int) value;
if (isClassification) {
histogram[idx][0] += target;
} else {
histogram[idx][0] += target * weight;
}
histogram[idx][1] += weight;
} else {
if (isClassification) {
sumRespOnMV += target;
} else {
sumRespOnMV += target * weight;
}
sumWeightOnMV += weight;
}
}
histograms = new ArrayList<>(histogram.length + 1);
for (int i = 0; i < histogram.length; i++) {
if (!MathUtils.isZero(histogram[i][1])) {
double[] hist = histogram[i];
histograms.add(new double[] {i, hist[0], hist[1]});
}
}
histograms.add(new double[] { Double.NaN, sumRespOnMV, sumWeightOnMV });
}
return build(attIndex, histograms, numIntervals);
}
/**
* Builds a 1D function.
*
* @param instances the training set.
* @param attIndex the index in the attribute list of the training set.
* @param numIntervals the number of intervals.
* @return a 1D function.
*/
public Function1D build(Instances instances, int attIndex, int numIntervals) {
Attribute attribute = instances.getAttributes().get(attIndex);
return build(instances, attribute, numIntervals);
}
/**
* Returns the index in the attribute list of the training set.
*
* @return the index in the attribute list of the training set.
*/
public int getAttributeIndex() {
return attIndex;
}
/**
* Sets the index in the attribute list of the training set.
*
* @param attIndex the attribute index.
*/
public void setAttributeIndex(int attIndex) {
this.attIndex = attIndex;
}
/**
* Returns {@code true} if it is a classification problem.
*
* @return {@code true} if it is a classification problem.
*/
public boolean isClassification() {
return isClassification;
}
/**
* Sets {@code true} if it is a classification problem.
*
* @param isClassification {@code true} if it is a classification problem.
*/
public void setClassification(boolean isClassification) {
this.isClassification = isClassification;
}
/**
* Returns the number of intervals.
*
* @return the number of intervals.
*/
public int getNumIntervals() {
return numIntervals;
}
/**
* Sets the number of intervals.
*
* @param numIntervals the number of intervals.
*/
public void setNumIntervals(int numIntervals) {
this.numIntervals = numIntervals;
}
protected static void getHistograms(List> pairs, List histograms) {
if (pairs.size() == 0) {
return;
}
// Element(new double[] {sum, weight}, value)
double[] hist = pairs.get(0).element;
double lastValue = pairs.get(0).weight;
double sum = hist[0];
double weight = hist[1];
for (int i = 1; i < pairs.size(); i++) {
Element element = pairs.get(i);
hist = element.element;
double value = element.weight;
double s = hist[0];
double w = hist[1];
if (value != lastValue) {
histograms.add(new double[] { lastValue, sum, weight });
lastValue = value;
sum = s;
weight = w;
} else {
sum += s;
weight += w;
}
}
histograms.add(new double[] { lastValue, sum, weight });
}
protected static Function1D build(int attIndex, List histograms, int numIntervals) {
Function1D func = new Function1D();
func.attIndex = attIndex;
// [feature value, sum, weight]
double[] histOnMV = histograms.get(histograms.size() - 1);
func.predictionOnMV = MathUtils.divide(histOnMV[1], histOnMV[2], 0.0);
// 1. Check basic leaf conditions
if (histograms.size() <= 2) {
func.splits = new double[] { Double.POSITIVE_INFINITY };
double prediction = 0.0;
if (histograms.size() == 2) {
double[] hist = histograms.get(0);
prediction = MathUtils.divide(hist[1], hist[2], 0.0);
}
func.predictions = new double[] { prediction };
return func;
}
// 2. Cut the line
// 2.1 First cut
double[] stats = sumUp(histograms, 0, histograms.size() - 1);
Interval root = new Interval(0, histograms.size() - 1, stats[0], stats[1]);
split(histograms, root);
if (numIntervals == 2) {
func.splits = new double[] { root.split, Double.POSITIVE_INFINITY };
func.predictions = new double[] { root.left.getPrediction(), root.right.getPrediction() };
} else if (numIntervals > 2) {
PriorityQueue q = new PriorityQueue<>();
if (!root.isLeaf()) {
q.add(root);
}
int numSplits = 0;
while (!q.isEmpty()) {
Interval parent = q.remove();
parent.finalized = true;
split(histograms, parent.left);
split(histograms, parent.right);
if (!parent.left.isLeaf()) {
q.add(parent.left);
}
if (!parent.right.isLeaf()) {
q.add(parent.right);
}
numSplits++;
if (numSplits >= numIntervals - 1) {
break;
}
}
List splits = new ArrayList<>(numIntervals - 1);
List predictions = new ArrayList<>(numIntervals);
inorder(root, splits, predictions);
func.splits = new double[predictions.size()];
func.predictions = new double[predictions.size()];
for (int i = 0; i < func.predictions.length; i++) {
func.predictions[i] = predictions.get(i);
}
for (int i = 0; i < func.splits.length - 1; i++) {
func.splits[i] = splits.get(i);
}
func.splits[func.splits.length - 1] = Double.POSITIVE_INFINITY;
}
return func;
}
protected static double[] sumUp(List histograms, int start, int end) {
double sum = 0;
double weight = 0;
for (int i = start; i < end; i++) {
double[] hist = histograms.get(i);
sum += hist[1];
weight += hist[2];
}
return new double[] { sum, weight };
}
protected static void split(List histograms, Interval parent) {
split(histograms, parent, 5);
}
protected static void split(List histograms, Interval parent, double limit) {
// Test if we need to split
if (parent.weight <= limit || parent.end - parent.start <= 1) {
parent.split = Double.NaN; // Declared as leaf
} else {
parent.left = new Interval();
parent.right = new Interval();
int start = parent.left.start = parent.start;
int end = parent.right.end = parent.end;
final double sum = parent.sum;
final double totalWeight = parent.weight;
double sum1 = histograms.get(start)[1];
double sum2 = sum - sum1;
double weight1 = histograms.get(start)[2];
double weight2 = totalWeight - weight1;
double bestEval = -(OptimUtils.getGain(sum1, weight1) + OptimUtils.getGain(sum2, weight2));
List splits = new ArrayList<>();
splits.add(new double[] { (histograms.get(start)[0] + histograms.get(start + 1)[0]) / 2, start, sum1,
weight1, sum2, weight2 });
for (int i = start + 1; i < end - 1; i++) {
double[] hist = histograms.get(i);
final double s = hist[1];
final double w = hist[2];
sum1 += s;
sum2 -= s;
weight1 += w;
weight2 -= w;
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval <= bestEval) {
double split = (histograms.get(i)[0] + histograms.get(i + 1)[0]) / 2;
if (eval < bestEval) {
bestEval = eval;
splits.clear();
}
splits.add(new double[] { split, i, sum1, weight1, sum2, weight2 });
}
}
Random rand = Random.getInstance();
double[] split = splits.get(rand.nextInt(splits.size()));
parent.split = split[0];
parent.left.end = (int) split[1] + 1;
parent.right.start = (int) split[1] + 1;
parent.left.sum = split[2];
parent.left.weight = split[3];
parent.right.sum = split[4];
parent.right.weight = split[5];
parent.gain = OptimUtils.getGain(parent.left.sum, parent.left.weight)
+ OptimUtils.getGain(parent.right.sum, parent.right.weight);
parent.value = -parent.gain + (sum / totalWeight * sum);
}
}
protected static void inorder(Interval parent, List splits, List predictions) {
if (parent.isFinalized()) {
inorder(parent.left, splits, predictions);
splits.add(parent.split);
inorder(parent.right, splits, predictions);
} else {
predictions.add(parent.getPrediction());
}
}
}
================================================
FILE: src/main/java/mltk/predictor/function/LinearFunction.java
================================================
package mltk.predictor.function;
import java.io.BufferedReader;
import java.io.PrintWriter;
import mltk.core.Instance;
import mltk.predictor.Regressor;
/**
* Class for linear functions.
*
* @author Yin Lou
*
*/
public class LinearFunction implements Regressor, UnivariateFunction {
/**
* The attribute index.
*/
protected int attIndex;
/**
* The slope.
*/
protected double beta;
/**
* Constructor.
*/
public LinearFunction() {
}
/**
* Constructs a linear function with a provided slope value.
*
* @param beta the slope.
*/
public LinearFunction(double beta) {
this(-1, beta);
}
/**
* Constructs a linear function with a provided slope value and attribute index.
*
* @param attIndex the attribute index.
* @param beta the slope.
*/
public LinearFunction(int attIndex, double beta) {
this.attIndex = attIndex;
this.beta = beta;
}
@Override
public void read(BufferedReader in) throws Exception {
attIndex = Integer.parseInt(in.readLine().split(": ")[1]);
beta = Double.parseDouble(in.readLine().split(": ")[1]);
}
@Override
public void write(PrintWriter out) throws Exception {
out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName());
out.println("AttIndex: " + attIndex);
out.println("Beta: " + beta);
}
@Override
public double evaluate(double x) {
return beta * x;
}
@Override
public double regress(Instance instance) {
return evaluate(instance.getValue(attIndex));
}
public double getSlope() {
return beta;
}
public void setSlope(double beta) {
this.beta = beta;
}
public int getAttributeIndex() {
return attIndex;
}
@Override
public LinearFunction copy() {
return new LinearFunction(attIndex, beta);
}
}
================================================
FILE: src/main/java/mltk/predictor/function/SquareCutter.java
================================================
package mltk.predictor.function;
import java.util.List;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.predictor.Learner;
import mltk.util.MathUtils;
import mltk.util.OptimUtils;
import mltk.util.tuple.Pair;
/**
* Class for cutting squares.
*
* @author Yin Lou
*
*/
public class SquareCutter extends Learner {
private int attIndex1;
private int attIndex2;
private boolean lineSearch;
/**
* Constructor.
*/
public SquareCutter() {
}
/**
* Constructor.
*
* @param lineSearch {@code true} if line search is performed in the end.
*/
public SquareCutter(boolean lineSearch) {
this.lineSearch = lineSearch;
}
/**
* Sets the attribute indices.
*
* @param attIndex1 the 1st index of attribute.
* @param attIndex2 the 2nd index of attribute.
*/
public void setAttIndices(int attIndex1, int attIndex2) {
this.attIndex1 = attIndex1;
this.attIndex2 = attIndex2;
}
public Function2D build(Instances instances) {
// Note: Currently only support cutting on binned/nominal features
List attributes = instances.getAttributes();
int size1 = 0;
Attribute f1 = attributes.get(attIndex1);
if (f1.getType() == Attribute.Type.BINNED) {
size1 = ((BinnedAttribute) f1).getNumBins();
} else if (f1.getType() == Attribute.Type.NOMINAL) {
size1 = ((NominalAttribute) f1).getCardinality();
}
int size2 = 0;
Attribute f2 = attributes.get(attIndex2);
if (f2.getType() == Attribute.Type.BINNED) {
size2 = ((BinnedAttribute) f2).getNumBins();
} else if (f2.getType() == Attribute.Type.NOMINAL) {
size2 = ((NominalAttribute) f2).getCardinality();
}
Histogram2D hist2d = new Histogram2D(size1, size2);
Histogram2D.computeHistogram2D(instances, f1.getIndex(), f2.getIndex(), hist2d);
Pair cHist = hist2d.computeCHistogram();
if ((size1 == 1 && !cHist.v1.hasMissingValue()) || (size2 == 1 && !cHist.v2.hasMissingValue())) {
// Not an interaction
// Recommend: Use LineCutter to shape the non-trivial attribute
return new Function2D(f1.getIndex(), f2.getIndex(), new double[] { Double.POSITIVE_INFINITY },
new double[] { Double.POSITIVE_INFINITY }, new double[1][1]);
}
Histogram2D.Table table = Histogram2D.computeTable(hist2d, cHist.v1, cHist.v2);
double bestRSS = Double.POSITIVE_INFINITY;
double[] predInt1 = new double[9];
int bestV1 = -1;
int[] bestV2s = new int[3];
int[] v2s = new int[3];
v2s[2] = -1;
for (int v1 = 0; v1 < size1 - 1; v1++) {
findCuts(table, v1, v2s, cHist.v1.hasMissingValue());
getPredictor(table, v1, v2s, predInt1);
double rss = getRSS(table, v1, v2s, predInt1);
if (rss < bestRSS) {
bestRSS = rss;
bestV1 = v1;
bestV2s[0] = v2s[0];
bestV2s[1] = v2s[1];
bestV2s[2] = v2s[2];
}
}
boolean cutOnAttr2 = false;
double[] predInt2 = new double[9];
int[] bestV1s = new int[3];
int bestV2 = -1;
int[] v1s = new int[3];
v1s[2] = -1;
for (int v2 = 0; v2 < size2 - 1; v2++) {
findCuts(table, v1s, v2, cHist.v2.hasMissingValue());
getPredictor(table, v1s, v2, predInt2);
double rss = getRSS(table, v1s, v2, predInt2);
if (rss < bestRSS) {
bestRSS = rss;
bestV2 = v2;
bestV1s[0] = v1s[0];
bestV1s[1] = v1s[1];
bestV1s[2] = v1s[2];
cutOnAttr2 = true;
}
}
if (cutOnAttr2) {
// Root cut on attribute 2 is better
getPredictor(table, bestV1s, bestV2, predInt2);
if (lineSearch) {
lineSearch(instances, f2.getIndex(), f1.getIndex(), bestV2, bestV1s[0], bestV1s[1], bestV1s[2], predInt2);
}
return getFunction2D(f1.getIndex(), f2.getIndex(), bestV1s, bestV2, predInt2);
} else {
// Root cut on attribute 1 is better
getPredictor(table, bestV1, bestV2s, predInt1);
if (lineSearch) {
lineSearch(instances, f1.getIndex(), f2.getIndex(), bestV1, bestV2s[0], bestV2s[1], bestV2s[2], predInt1);
}
return getFunction2D(f1.getIndex(), f2.getIndex(), bestV1, bestV2s, predInt1);
}
}
protected static void findCuts(Histogram2D.Table table, int v1, int[] v2, boolean hasMissingValue) {
// Find upper cut
double bestEval = Double.POSITIVE_INFINITY;
for (int i = 0; i < table.resp[v1].length - 1; i++) {
double[] resp = table.resp[v1][i];
double[] count = table.count[v1][i];
double sum1 = resp[0];
double sum2 = resp[1];
double weight1 = count[0];
double weight2 = count[1];
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval < bestEval) {
bestEval = eval;
v2[0] = i;
}
}
// Find lower cut
bestEval = Double.POSITIVE_INFINITY;
for (int i = 0; i < table.resp[v1].length - 1; i++) {
double[] resp = table.resp[v1][i];
double[] count = table.count[v1][i];
double sum1 = resp[2];
double sum2 = resp[3];
double weight1 = count[2];
double weight2 = count[3];
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval < bestEval) {
bestEval = eval;
v2[1] = i;
}
}
if (hasMissingValue) {
// Find cut on missing value
bestEval = Double.POSITIVE_INFINITY;
for (int i = 0; i < table.respOnMV1.length; i++) {
double[] respOnMV1 = table.respOnMV1[i];
double[] countOnMV1 = table.countOnMV1[i];
double sum1 = respOnMV1[0];
double sum2 = respOnMV1[1];
double weight1 = countOnMV1[0];
double weight2 = countOnMV1[1];
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval < bestEval) {
bestEval = eval;
v2[2] = i;
}
}
}
}
protected static void findCuts(Histogram2D.Table table, int[] v1, int v2, boolean hasMissingValue) {
// Find left cut
double bestEval = Double.POSITIVE_INFINITY;
for (int i = 0; i < table.resp.length - 1; i++) {
double[] resp = table.resp[i][v2];
double[] count = table.count[i][v2];
double sum1 = resp[0];
double sum2 = resp[2];
double weight1 = count[0];
double weight2 = count[2];
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval < bestEval) {
bestEval = eval;
v1[0] = i;
}
}
// Find right cut
bestEval = Double.POSITIVE_INFINITY;
for (int i = 0; i < table.resp.length - 1; i++) {
double[] resp = table.resp[i][v2];
double[] count = table.count[i][v2];
double sum1 = resp[1];
double sum2 = resp[3];
double weight1 = count[1];
double weight2 = count[3];
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval < bestEval) {
bestEval = eval;
v1[1] = i;
}
}
if (hasMissingValue) {
// Find cut on missing value
bestEval = Double.POSITIVE_INFINITY;
for (int i = 0; i < table.respOnMV2.length; i++) {
double[] respOnMV2 = table.respOnMV2[i];
double[] countOnMV2 = table.countOnMV2[i];
double sum1 = respOnMV2[0];
double sum2 = respOnMV2[1];
double weight1 = countOnMV2[0];
double weight2 = countOnMV2[1];
double eval1 = OptimUtils.getGain(sum1, weight1);
double eval2 = OptimUtils.getGain(sum2, weight2);
double eval = -(eval1 + eval2);
if (eval < bestEval) {
bestEval = eval;
v1[2] = i;
}
}
}
}
protected static void getPredictor(Histogram2D.Table table, int v1, int[] v2, double[] pred) {
int v21 = v2[0];
int v22 = v2[1];
int vMV = v2[2];
double[] resp1 = table.resp[v1][v21];
double[] count1 = table.count[v1][v21];
double[] resp2 = table.resp[v1][v22];
double[] count2 = table.count[v1][v22];
pred[0] = MathUtils.divide(resp1[0], count1[0], 0);
pred[1] = MathUtils.divide(resp1[1], count1[1], 0);
pred[2] = MathUtils.divide(resp2[2], count2[2], 0);
pred[3] = MathUtils.divide(resp2[3], count2[3], 0);
if (vMV >= 0) {
double[] respOnMV1 = table.respOnMV1[vMV];
double[] countOnMV1 = table.countOnMV1[vMV];
pred[4] = MathUtils.divide(respOnMV1[0], countOnMV1[0], 0);
pred[5] = MathUtils.divide(respOnMV1[1], countOnMV1[1], 0);
}
double[] respOnMV2 = table.respOnMV2[v1];
double[] countOnMV2 = table.countOnMV2[v1];
pred[6] = MathUtils.divide(respOnMV2[0], countOnMV2[0], 0);
pred[7] = MathUtils.divide(respOnMV2[1], countOnMV2[1], 0);
pred[8] = MathUtils.divide(table.respOnMV12, table.countOnMV12, 0);
}
protected static void getPredictor(Histogram2D.Table table, int[] v1, int v2, double[] pred) {
int v11 = v1[0];
int v12 = v1[1];
int vMV = v1[2];
double[] resp1 = table.resp[v11][v2];
double[] count1 = table.count[v11][v2];
double[] resp2 = table.resp[v12][v2];
double[] count2 = table.count[v12][v2];
pred[0] = MathUtils.divide(resp1[0], count1[0], 0);
pred[1] = MathUtils.divide(resp1[2], count1[2], 0);
pred[2] = MathUtils.divide(resp2[1], count2[1], 0);
pred[3] = MathUtils.divide(resp2[3], count2[3], 0);
if (vMV >= 0) {
double[] respOnMV2 = table.respOnMV2[vMV];
double[] countOnMV2 = table.countOnMV2[vMV];
pred[4] = MathUtils.divide(respOnMV2[0], countOnMV2[0], 0);
pred[5] = MathUtils.divide(respOnMV2[1], countOnMV2[1], 0);
}
double[] respOnMV1 = table.respOnMV1[v2];
double[] countOnMV1 = table.countOnMV1[v2];
pred[6] = MathUtils.divide(respOnMV1[0], countOnMV1[0], 0);
pred[7] = MathUtils.divide(respOnMV1[1], countOnMV1[1], 0);
pred[8] = MathUtils.divide(table.respOnMV12, table.countOnMV12, 0);
}
protected static double getRSS(Histogram2D.Table table, int v1, int v2[], double[] pred) {
int v21 = v2[0];
int v22 = v2[1];
int vMV = v2[2];
double[] resp1 = table.resp[v1][v21];
double[] resp2 = table.resp[v1][v22];
double[] count1 = table.count[v1][v21];
double[] count2 = table.count[v1][v22];
double rss = 0;
rss += pred[0] * pred[0] * count1[0];
rss += pred[1] * pred[1] * count1[1];
rss += pred[2] * pred[2] * count2[2];
rss += pred[3] * pred[3] * count2[3];
if (vMV >= 0) {
double[] countOnMV1 = table.countOnMV1[vMV];
rss += pred[4] * pred[4] * countOnMV1[0];
rss += pred[5] * pred[5] * countOnMV1[1];
}
double[] countOnMV2 = table.countOnMV2[v1];
rss += pred[6] * pred[6] * countOnMV2[0];
rss += pred[7] * pred[7] * countOnMV2[1];
rss += pred[8] * pred[8] * table.countOnMV12;
double t = 0;
t += pred[0] * resp1[0];
t += pred[1] * resp1[1];
t += pred[2] * resp2[2];
t += pred[3] * resp2[3];
if (vMV >= 0) {
double[] respOnMV1 = table.respOnMV1[vMV];
t += pred[4] * respOnMV1[0];
t += pred[5] * respOnMV1[1];
}
double[] respOnMV2 = table.respOnMV2[v1];
t += pred[6] * respOnMV2[0];
t += pred[7] * respOnMV2[1];
t += pred[8] * table.respOnMV12;
rss -= 2 * t;
return rss;
}
protected static double getRSS(Histogram2D.Table table, int[] v1, int v2, double[] pred) {
int v11 = v1[0];
int v12 = v1[1];
int vMV = v1[2];
double[] resp1 = table.resp[v11][v2];
double[] resp2 = table.resp[v12][v2];
double[] count1 = table.count[v11][v2];
double[] count2 = table.count[v12][v2];
double rss = 0;
rss += pred[0] * pred[0] * count1[0];
rss += pred[1] * pred[1] * count1[2];
rss += pred[2] * pred[2] * count2[1];
rss += pred[3] * pred[3] * count2[3];
if (vMV >= 0) {
double[] countOnMV2 = table.countOnMV2[vMV];
rss += pred[4] * pred[4] * countOnMV2[0];
rss += pred[5] * pred[5] * countOnMV2[1];
}
double[] countOnMV1 = table.countOnMV1[v2];
rss += pred[6] * pred[6] * countOnMV1[0];
rss += pred[7] * pred[7] * countOnMV1[1];
rss += pred[8] * pred[8] * table.countOnMV12;
double t = 0;
t += pred[0] * resp1[0];
t += pred[1] * resp1[2];
t += pred[2] * resp2[1];
t += pred[3] * resp2[3];
if (vMV >= 0) {
double[] respOnMV2 = table.respOnMV2[vMV];
t += pred[4] * respOnMV2[0];
t += pred[5] * respOnMV2[1];
}
double[] respOnMV1 = table.respOnMV1[v2];
t += pred[6] * respOnMV1[0];
t += pred[7] * respOnMV1[1];
t += pred[8] * table.respOnMV12;
rss -= 2 * t;
return rss;
}
protected static Function2D getFunction2D(int attIndex1, int attIndex2, int v1, int[] v2, double[] predInt) {
double[] splits1 = new double[] { v1, Double.POSITIVE_INFINITY };
double[] splits2 = null;
double[][] predictions = null;
double[] predictionsOnMV1 = null;
double[] predictionsOnMV2 = new double[] {predInt[6], predInt[7]};
if (v2[0] < v2[1]) {
if (v2[2] < 0 || v2[2] == v2[0] || v2[2] == v2[1]) {
splits2 = new double[] { v2[0], v2[1], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[1], predInt[1] },
{ predInt[2], predInt[2], predInt[3] }
};
if (v2[2] < 0) {
predictionsOnMV1 = new double[] { 0.0, 0.0, 0.0 };
} else if (v2[2] == v2[0]) {
predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5] };
} else {
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5] };
}
} else {
if (v2[2] < v2[0]) {
splits2 = new double[] { v2[2], v2[0], v2[1], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[0], predInt[1], predInt[1] },
{ predInt[2], predInt[2], predInt[2], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] };
} else if (v2[2] < v2[1]) {
splits2 = new double[] { v2[0], v2[2], v2[1], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[1], predInt[1], predInt[1] },
{ predInt[2], predInt[2], predInt[2], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] };
} else {
splits2 = new double[] { v2[0], v2[1], v2[2], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[1], predInt[1], predInt[1] },
{ predInt[2], predInt[2], predInt[3], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] };
}
}
} else if (v2[0] > v2[1]) {
if (v2[2] < 0 || v2[2] == v2[1] || v2[2] == v2[1]) {
splits2 = new double[] { v2[1], v2[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[0], predInt[1] },
{ predInt[2], predInt[3], predInt[3] }
};
if (v2[2] < 0) {
predictionsOnMV1 = new double[] { 0.0, 0.0, 0.0 };
} else if (v2[2] == v2[1]) {
predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5] };
} else {
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5] };
}
} else {
if (v2[2] < v2[1]) {
splits2 = new double[] { v2[2], v2[1], v2[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[0], predInt[0], predInt[1] },
{ predInt[2], predInt[2], predInt[3], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] };
} else if (v2[2] < v2[0]) {
splits2 = new double[] { v2[1], v2[2], v2[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[0], predInt[0], predInt[1] },
{ predInt[2], predInt[3], predInt[3], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] };
} else {
splits2 = new double[] { v2[1], v2[0], v2[2], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[0], predInt[1], predInt[1] },
{ predInt[2], predInt[3], predInt[3], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] };
}
}
} else {
// v2[0] == v2[1]
if (v2[2] < 0 || v2[2] == v2[0]) {
splits2 = new double[] { v2[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[1] },
{ predInt[2], predInt[3] }
};
if (v2[2] < 0) {
predictionsOnMV1 = new double[] { 0.0, 0.0 };
} else {
predictionsOnMV1 = new double[] { predInt[4], predInt[5] };
}
} else {
if (v2[2] < v2[0]) {
splits2 = new double[] { v2[2], v2[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[0], predInt[1] },
{ predInt[2], predInt[2], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[5], predInt[5] };
} else {
splits2 = new double[] { v2[0], v2[2], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[1], predInt[1] },
{ predInt[2], predInt[3], predInt[3] }
};
predictionsOnMV1 = new double[] { predInt[4], predInt[4], predInt[5] };
}
}
}
return new Function2D(attIndex1, attIndex2, splits1, splits2, predictions,
predictionsOnMV1, predictionsOnMV2, predInt[8]);
}
protected static Function2D getFunction2D(int attIndex1, int attIndex2, int[] v1, int v2, double[] predInt) {
double[] splits1 = null;
double[] splits2 = new double[] { v2, Double.POSITIVE_INFINITY };
double[] predictionsOnMV1 = new double[] {predInt[6], predInt[7]};
double[] predictionsOnMV2 = null;
double[][] predictions = null;
if (v1[0] < v1[1]) {
if (v1[2] < 0 || v1[2] == v1[0] || v1[2] == v1[1]) {
splits1 = new double[] { v1[0], v1[1], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[1], predInt[2] },
{ predInt[1], predInt[3] }
};
if (v1[2] < 0) {
predictionsOnMV2 = new double[] { 0.0, 0.0, 0.0 };
} else if (v1[2] == v1[0]) {
predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5] };
} else {
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5] };
}
} else {
if (v1[2] < v1[0]) {
splits1 = new double[] { v1[2], v1[0], v1[1], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[0], predInt[2] },
{ predInt[1], predInt[2] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] };
} else if (v1[2] < v1[1]) {
splits1 = new double[] { v1[0], v1[2], v1[1], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[1], predInt[2] },
{ predInt[1], predInt[2] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] };
} else {
splits1 = new double[] { v1[0], v1[1], v1[2], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[1], predInt[2] },
{ predInt[1], predInt[3] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] };
}
}
} else if (v1[0] > v1[1]) {
if (v1[2] < 0 || v1[2] == v1[0] || v1[2] == v1[1]) {
splits1 = new double[] { v1[1], v1[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[0], predInt[3] },
{ predInt[1], predInt[3] }
};
if (v1[2] < 0) {
predictionsOnMV2 = new double[] { 0.0, 0.0, 0.0 };
} else if (v1[2] == v1[0]) {
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5] };
} else {
predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5] };
}
} else {
if (v1[2] < v1[1]) {
splits1 = new double[] { v1[2], v1[1], v1[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[0], predInt[2] },
{ predInt[0], predInt[3] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5], predInt[5] };
} else if (v1[2] < v1[0]) {
splits1 = new double[] { v1[1], v1[2], v1[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[0], predInt[3] },
{ predInt[0], predInt[3] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5], predInt[5] };
} else {
splits1 = new double[] { v1[1], v1[0], v1[2], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[0], predInt[2] },
{ predInt[0], predInt[3] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[4], predInt[5] };
}
}
} else {
// v1[0] == v1[1]
if (v1[2] < 0 || v1[2] == v1[0]) {
splits1 = new double[] { v1[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[1], predInt[3] }
};
if (v1[2] < 0) {
predictionsOnMV2 = new double[] { 0.0, 0.0 };
} else {
predictionsOnMV2 = new double[] { predInt[4], predInt[5] };
}
} else {
if (v1[2] < v1[0]) {
splits1 = new double[] { v1[2], v1[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[0], predInt[2] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[5], predInt[5] };
} else {
splits1 = new double[] { v1[2], v1[0], Double.POSITIVE_INFINITY };
predictions = new double[][] {
{ predInt[0], predInt[2] },
{ predInt[1], predInt[3] },
{ predInt[1], predInt[3] }
};
predictionsOnMV2 = new double[] { predInt[4], predInt[4], predInt[5] };
}
}
}
return new Function2D(attIndex1, attIndex2, splits1, splits2, predictions,
predictionsOnMV1, predictionsOnMV2, predInt[8]);
}
protected static void lineSearch(Instances instances, int attIndex1, int attIndex2, int c1, int c21, int c22, int cMV,
double[] predictions) {
double[] numerator = new double[9];
double[] denominator = new double[9];
for (Instance instance : instances) {
final double target = instance.getTarget();
final double weight = instance.getWeight();
final double t = Math.abs(target);
final double num = target * weight;
final double den = t * (1 - t) * weight;
if (!instance.isMissing(attIndex1) && !instance.isMissing(attIndex2)) {
int v1 = (int) instance.getValue(attIndex1);
int v2 = (int) instance.getValue(attIndex2);
if (v1 <= c1) {
if (v2 <= c21) {
numerator[0] += num;
denominator[0] += den;
} else {
numerator[1] += num;
denominator[1] += den;
}
} else {
if (v2 <= c22) {
numerator[2] += num;
denominator[2] += den;
} else {
numerator[3] += num;
denominator[3] += den;
}
}
} else if (instance.isMissing(attIndex1) && !instance.isMissing(attIndex2)) {
int v2 = (int) instance.getValue(attIndex2);
if (cMV >= 0) {
if (v2 <= cMV) {
numerator[4] += num;
denominator[4] += den;
} else {
numerator[5] += num;
denominator[5] += den;
}
} else {
throw new RuntimeException("Something went wrong");
}
} else if (!instance.isMissing(attIndex1) && instance.isMissing(attIndex2)) {
int v1 = (int) instance.getValue(attIndex1);
if (v1 <= c1) {
numerator[6] += num;
denominator[6] += den;
} else {
numerator[7] += num;
denominator[7] += den;
}
} else {
numerator[8] += num;
denominator[8] += den;
}
}
for (int i = 0; i < numerator.length; i++) {
predictions[i] = MathUtils.divide(numerator[i], denominator[i], 0);
}
}
}
================================================
FILE: src/main/java/mltk/predictor/function/SubagSequence.java
================================================
package mltk.predictor.function;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import mltk.util.Element;
import mltk.util.Permutation;
import mltk.util.Queue;
import mltk.util.UFSets;
import mltk.util.tuple.IntPair;
class SubagSequence {
static class SampleDelta {
int[] toAdd;
int[] toDel;
SampleDelta(Set toAdd, Set toDel) {
this.toAdd = new int[toAdd.size()];
int k = 0;
for (Integer idx : toAdd) {
this.toAdd[k++] = idx;
}
k = 0;
this.toDel = new int[toDel.size()];
for (Integer idx : toDel) {
this.toDel[k++] = idx;
}
}
int getDistance() {
return toAdd.length + toDel.length;
}
}
static class Sample {
Set set;
int[] indices;
Sample(Set set) {
this.set = set;
this.indices = new int[set.size()];
int k = 0;
for (Integer idx : set) {
this.indices[k++] = idx;
}
}
static int computeDistance(Sample s1, Sample s2) {
SampleDelta delta = computeDelta(s1, s2);
return delta.getDistance();
}
static SampleDelta computeDelta(Sample s1, Sample s2) {
Set toAdd = new HashSet<>(s2.set);
toAdd.removeAll(s1.set);
Set toDel = new HashSet<>(s1.set);
toDel.removeAll(s2.set);
return new SampleDelta(toAdd, toDel);
}
int getWeight() {
return set.size();
}
}
Sample[] samples;
int[] start;
int[] end;
SampleDelta[] deltas;
int[] count;
SubagSequence(int n, int m, int baggingIters) {
samples = new Sample[baggingIters];
count = new int[baggingIters];
for (int i = 0; i < samples.length; i++) {
samples[i] = createSubsample(n, m);
}
computeSequence(samples);
}
Sample createSubsample(int n, int m) {
Permutation perm = new Permutation(n);
perm.permute();
int[] a = perm.getPermutation();
Set set = new HashSet<>(m);
for (int i = 0; i < m; i++) {
set.add(a[i]);
}
return new Sample(set);
}
private void computeSequence(Sample[] samples) {
PriorityQueue> q = new PriorityQueue<>(samples.length * (samples.length - 1) / 2);
for (int i = 0; i < samples.length - 1; i++) {
for (int j = i + 1; j < samples.length; j++) {
int distance = Sample.computeDistance(samples[i], samples[j]);
q.add(new Element(new IntPair(i, j), distance));
}
}
Map> map = new HashMap<>();
UFSets ufsets = new UFSets(samples.length);
while (!q.isEmpty()) {
Element e = q.poll();
int x = e.element.v1;
int y = e.element.v2;
int root1 = ufsets.find(x);
int root2 = ufsets.find(y);
if (root1 != root2) {
ufsets.union(root1, root2);
if (!map.containsKey(x)) {
map.put(x, new HashSet<>());
}
map.get(x).add(y);
if (!map.containsKey(y)) {
map.put(y, new HashSet<>());
}
map.get(y).add(x);
}
}
int s = 0;
for (int i = 1; i < samples.length; i++) {
int weight = samples[i].getWeight();
if (weight < samples[s].getWeight()) {
s = i;
}
}
List fromList = new ArrayList<>(samples.length);
List toList = new ArrayList<>(samples.length);
List deltas = new ArrayList<>(samples.length);
// BFS
Set covered = new HashSet<>();
covered.add(s);
Queue queue = new Queue<>();
queue.enqueue(s);
while (covered.size() < samples.length) {
Integer node = queue.dequeue();
Set children = map.get(node);
for (Integer child : children) {
if (covered.contains(child)) {
continue;
}
covered.add(child);
fromList.add(node);
toList.add(child);
SampleDelta delta = Sample.computeDelta(samples[node], samples[child]);
deltas.add(delta);
queue.enqueue(child);
}
}
this.start = new int[fromList.size()];
this.end = new int[toList.size()];
this.deltas = new SampleDelta[deltas.size()];
for (int i = 0; i < this.start.length; i++) {
this.start[i] = fromList.get(i);
this.end[i] = toList.get(i);
this.deltas[i] = deltas.get(i);
}
for (int i = 0; i < this.start.length; i++) {
count[this.start[i]]++;
}
}
}
================================================
FILE: src/main/java/mltk/predictor/function/SubaggedLineCutter.java
================================================
package mltk.predictor.function;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.predictor.BaggedEnsemble;
import mltk.predictor.function.SubagSequence.SampleDelta;
import mltk.util.MathUtils;
/**
* Class for cutting lines with subagging.
*
* @author Yin Lou
*
*/
public class SubaggedLineCutter extends EnsembledLineCutter {
private int subsampleSize;
private SubagSequence ss;
/**
* Constructor.
*/
public SubaggedLineCutter() {
this(false);
}
/**
* Constructor.
*
* @param isClassification {@code true} if it is a classification problem.
*/
public SubaggedLineCutter(boolean isClassification) {
attIndex = -1;
this.isClassification = isClassification;
}
/**
* Creates internal subsamples.
*
* @param n the size of the data set.
* @param subsampleRatio the subsample ratio.
* @param baggingIters the number of bagging iterations.
*/
public void createSubags(int n, double subsampleRatio, int baggingIters) {
this.subsampleSize = (int) (n * subsampleRatio);
ss = new SubagSequence(n, subsampleSize, baggingIters);
}
@Override
public BaggedEnsemble build(Instances instances) {
return build(instances, attIndex, numIntervals);
}
class Histogram {
double[][] histogram;
Histogram(double[][] histogram) {
this.histogram = histogram;
}
Histogram copy() {
double[][] newHistogram = new double[histogram.length][histogram[0].length];
for (int i = 0; i < histogram.length; i++) {
double[] hist = histogram[i];
double[] newHist = newHistogram[i];
for (int j = 0; j < hist.length; j++) {
newHist[j] = hist[j];
}
}
return new Histogram(newHistogram);
}
}
/**
* Builds an 1D function ensemble.
*
* @param instances the training set.
* @param attIndex the attribute index.
* @param numIntervals the number of intervals.
* @return an 1D function ensemble.
*/
public BaggedEnsemble build(Instances instances, int attIndex, int numIntervals) {
Attribute attribute = instances.getAttributes().get(attIndex);
return build(instances, attribute, numIntervals);
}
/**
* Builds an 1D function ensemble.
*
* @param instances the training set.
* @param attribute the attribute.
* @param numIntervals the number of intervals.
* @return an 1D function ensemble.
*/
public BaggedEnsemble build(Instances instances, Attribute attribute, int numIntervals) {
if (ss == null) {
ss = new SubagSequence(instances.size(), subsampleSize, baggingIters);
}
SubagSequence.Sample[] samples = ss.samples;
int[] start = ss.start;
int[] end = ss.end;
SampleDelta[] deltas = ss.deltas;
int[] count = Arrays.copyOf(ss.count, ss.count.length);
Function1D[] funcs = new Function1D[ss.samples.length];
int attIndex = attribute.getIndex();
double[] targets = new double[instances.size()];
double[] fvalues = new double[instances.size()];
double[] weights = new double[instances.size()];
for (int i = 0; i < instances.size(); i++) {
Instance instance = instances.get(i);
targets[i] = instance.getTarget();
fvalues[i] = instance.getValue(attribute);
weights[i] = instance.getWeight();
}
if (attribute.getType() == Attribute.Type.NUMERIC) {
throw new RuntimeException("Not implemented yet!");
} else {
int size = 0;
if (attribute.getType() == Attribute.Type.BINNED) {
size = ((BinnedAttribute) attribute).getNumBins();
} else {
size = ((NominalAttribute) attribute).getCardinality();
}
Histogram[] histograms = new Histogram[ss.samples.length];
{// Initialization
double[][] histogram = new double[size + 1][2];
for (int index : samples[start[0]].indices) {
double value = fvalues[index];
double weight = weights[index];
double target = targets[index];
int idx = histogram.length - 1;
if (!Double.isNaN(value)) {
idx = (int) value;
}
if (isClassification) {
histogram[idx][0] += target;
} else {
histogram[idx][0] += target * weight;
}
histogram[idx][1] += weight;
}
histograms[0] = new Histogram(histogram);
funcs[0] = buildFromHistogram(attIndex, histograms[0]);
}
for (int k = 0; k < start.length; k++) {
int from = start[k];
Histogram histStart = histograms[from];
count[from]--;
// Generate histogram for its child
int to = end[k];
Histogram histEnd = null;
if (count[from] == 0) {
histEnd = histStart;
} else {
histEnd = histStart.copy();
}
SampleDelta delta = deltas[k];
for (int index : delta.toAdd) {
double value = fvalues[index];
double weight = weights[index];
double target = targets[index];
int idx = histEnd.histogram.length - 1;
if (!Double.isNaN(value)) {
idx = (int) value;
}
if (isClassification) {
histEnd.histogram[idx][0] += target;
} else {
histEnd.histogram[idx][0] += target * weight;
}
histEnd.histogram[idx][1] += weight;
}
for (int index : delta.toDel) {
double value = fvalues[index];
double weight = weights[index];
double target = targets[index];
int idx = histEnd.histogram.length - 1;
if (!Double.isNaN(value)) {
idx = (int) value;
}
if (isClassification) {
histEnd.histogram[idx][0] -= target;
} else {
histEnd.histogram[idx][0] -= target * weight;
}
histEnd.histogram[idx][1] -= weight;
}
histograms[to] = histEnd;
funcs[to] = buildFromHistogram(attIndex, histEnd);
}
}
BaggedEnsemble ensemble = new BaggedEnsemble(ss.samples.length);
for (Function1D func : funcs) {
ensemble.add(func);
}
return ensemble;
}
protected Function1D buildFromHistogram(int attIndex, Histogram histogram) {
final int size = histogram.histogram.length;
List histograms = new ArrayList<>(size);
for (int i = 0; i < size - 1; i++) {
if (!MathUtils.isZero(histogram.histogram[i][1])) {
double[] hist = histogram.histogram[i];
histograms.add(new double[] {i, hist[0], hist[1]});
}
}
histograms.add(new double[] { Double.NaN, histogram.histogram[size - 1][0], histogram.histogram[size - 1][1] });
Function1D func = LineCutter.build(attIndex, histograms, numIntervals);
return func;
}
}
================================================
FILE: src/main/java/mltk/predictor/function/UnivariateFunction.java
================================================
package mltk.predictor.function;
/**
* Interface for univariate functions.
*
* @author Yin Lou
*
*/
public interface UnivariateFunction {
/**
* Computes the value for the function.
*
* @param x the argument.
* @return the value for the function.
*/
public double evaluate(double x);
}
================================================
FILE: src/main/java/mltk/predictor/function/package-info.java
================================================
/**
* Provides classes for simple functions, such as univariate/bivariate functions.
*/
package mltk.predictor.function;
================================================
FILE: src/main/java/mltk/predictor/gam/DenseDesignMatrix.java
================================================
package mltk.predictor.gam;
import java.util.HashSet;
import java.util.Set;
import mltk.predictor.function.CubicSpline;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
class DenseDesignMatrix {
double[][][] x;
double[][] knots;
double[][] std;
DenseDesignMatrix(double[][][] x, double[][] knots, double[][] std) {
this.x = x;
this.knots = knots;
this.std = std;
}
static DenseDesignMatrix createCubicSplineDesignMatrix(double[][] dataset, double[] stdList, int numKnots) {
final int n = dataset[0].length;
final int p = dataset.length;
double[][][] x = new double[p][][];
double[][] knots = new double[p][];
double[][] std = new double[p][];
double factor = Math.sqrt(n);
for (int j = 0; j < dataset.length; j++) {
Set uniqueValues = new HashSet<>();
double[] x1 = dataset[j];
for (int i = 0; i < n; i++) {
uniqueValues.add(x1[i]);
}
int nKnots = uniqueValues.size() <= numKnots ? 0 : numKnots;
knots[j] = new double[nKnots];
if (nKnots != 0) {
x[j] = new double[nKnots + 3][];
std[j] = new double[nKnots + 3];
} else {
x[j] = new double[1][];
std[j] = new double[1];
}
double[][] t = x[j];
t[0] = x1;
std[j][0] = stdList[j] / factor;
if (nKnots != 0) {
double[] x2 = new double[n];
for (int i = 0; i < n; i++) {
x2[i] = x1[i] * x1[i];
}
t[1] = x2;
double[] x3 = new double[n];
for (int i = 0; i < n; i++) {
x3[i] = x2[i] * x1[i];
}
t[2] = x3;
std[j][1] = StatUtils.sd(x2) / factor;
std[j][2] = StatUtils.sd(x3) / factor;
double max = StatUtils.max(x1);
double min = StatUtils.min(x1);
double stepSize = (max - min) / nKnots;
for (int k = 0; k < nKnots; k++) {
knots[j][k] = min + stepSize * k;
double[] basis = new double[n];
for (int i = 0; i < n; i++) {
basis[i] = CubicSpline.h(x1[i], knots[j][k]);
}
std[j][k + 3] = StatUtils.sd(basis) / factor;
t[k + 3] = basis;
}
}
}
// Normalize the inputs
for (int j = 0; j < p; j++) {
double[][] block = x[j];
double[] s = std[j];
for (int i = 0; i < block.length; i++) {
VectorUtils.divide(block[i], s[i]);
}
}
return new DenseDesignMatrix(x, knots, std);
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/GA2MLearner.java
================================================
package mltk.predictor.gam;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.cmdline.options.HoldoutValidatedLearnerWithTaskOptions;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.Sampling;
import mltk.core.Attribute.Type;
import mltk.core.io.InstancesReader;
import mltk.predictor.BaggedEnsemble;
import mltk.predictor.BaggedEnsembleLearner;
import mltk.predictor.BoostedEnsemble;
import mltk.predictor.HoldoutValidatedLearner;
import mltk.predictor.Regressor;
import mltk.predictor.evaluation.ConvergenceTester;
import mltk.predictor.evaluation.Metric;
import mltk.predictor.evaluation.MetricFactory;
import mltk.predictor.evaluation.SimpleMetric;
import mltk.predictor.function.Array2D;
import mltk.predictor.function.CompressionUtils;
import mltk.predictor.function.Function2D;
import mltk.predictor.function.SquareCutter;
import mltk.predictor.io.PredictorReader;
import mltk.predictor.io.PredictorWriter;
import mltk.util.OptimUtils;
import mltk.util.Random;
import mltk.util.tuple.IntPair;
/**
* Class for learning GA^2M models via gradient boosting.
*
*
* Reference:
* Y. Lou, R. Caruana, J. Gehrke, and G. Hooker. Accurate intelligible models with pairwise interactions. In
* Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD),
* Chicago, IL, USA, 2013.
*
*
* @author Yin Lou
*
*/
public class GA2MLearner extends HoldoutValidatedLearner {
static class Options extends HoldoutValidatedLearnerWithTaskOptions {
@Argument(name = "-i", description = "input model path", required = true)
String inputModelPath = null;
@Argument(name = "-I", description = "list of pairwise interactions path", required = true)
String interactionsPath = null;
@Argument(name = "-m", description = "maximum number of iterations", required = true)
int maxNumIters = -1;
@Argument(name = "-b", description = "bagging iterations (default: 100)")
int baggingIters = 100;
@Argument(name = "-s", description = "seed of the random number generator (default: 0)")
long seed = 0L;
@Argument(name = "-l", description = "learning rate (default: 0.01)")
double learningRate = 0.01;
}
/**
* Trains a GA2M.
*
*
* Usage: mltk.predictor.gam.GA2MLearner
* -t train set path
* -i input model path
* -I list of pairwise interactions path
* -m maximum number of iterations
* [-g] task between classification (c) and regression (r) (default: r)
* [-v] valid set path
* [-e] evaluation metric (default: default metric of task)
* [-S] convergence criteria (default: -1)
* [-r] attribute file path
* [-o] output model path
* [-V] verbose (default: true)
* [-b] bagging iterations (default: 100)
* [-s] seed of the random number generator (default: 0)
* [-l] learning rate (default: 0.01)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(GA2MLearner.class, opts);
Task task = null;
Metric metric = null;
try {
parser.parse(args);
task = Task.get(opts.task);
if (opts.metric == null) {
metric = task.getDefaultMetric();
} else {
metric = MetricFactory.getMetric(opts.metric);
}
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Random.getInstance().setSeed(opts.seed);
ConvergenceTester ct = ConvergenceTester.parse(opts.cc);
Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath);
List terms = new ArrayList<>();
BufferedReader br = new BufferedReader(new FileReader(opts.interactionsPath));
for (;;) {
String line = br.readLine();
if (line == null) {
break;
}
String[] data = line.split("\\s+");
IntPair term = new IntPair(Integer.parseInt(data[0]), Integer.parseInt(data[1]));
terms.add(term);
}
br.close();
GAM gam = PredictorReader.read(opts.inputModelPath, GAM.class);
GA2MLearner learner = new GA2MLearner();
learner.setBaggingIters(opts.baggingIters);
learner.setGAM(gam);
learner.setMaxNumIters(opts.maxNumIters);
learner.setTask(task);
learner.setMetric(metric);
learner.setConvergenceTester(ct);
learner.setPairs(terms);
learner.setLearningRate(opts.learningRate);
learner.setVerbose(opts.verbose);
if (opts.validPath != null) {
Instances validSet = InstancesReader.read(opts.attPath, opts.validPath);
learner.setValidSet(validSet);
}
long start = System.currentTimeMillis();
learner.build(trainSet);
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000.0);
if (opts.outputModelPath != null) {
PredictorWriter.write(gam, opts.outputModelPath);
}
}
private int baggingIters;
private int maxNumIters;
private Task task;
private double learningRate;
private GAM gam;
private List pairs;
/**
* Constructor.
*/
public GA2MLearner() {
verbose = false;
baggingIters = 100;
maxNumIters = -1;
learningRate = 0.01;
task = Task.REGRESSION;
metric = task.getDefaultMetric();
}
/**
* Returns the number of bagging iterations.
*
* @return the number of bagging iterations.
*/
public int getBaggingIters() {
return baggingIters;
}
/**
* Sets the number of bagging iterations.
*
* @param baggingIters the number of bagging iterations.
*/
public void setBaggingIters(int baggingIters) {
this.baggingIters = baggingIters;
}
/**
* Returns the maximum number of iterations.
*
* @return the maximum number of iterations.
*/
public int getMaxNumIters() {
return maxNumIters;
}
/**
* Sets the maximum number of iterations.
*
* @param maxNumIters the maximum number of iterations.
*/
public void setMaxNumIters(int maxNumIters) {
this.maxNumIters = maxNumIters;
}
/**
* Returns the learning rate.
*
* @return the learning rate.
*/
public double getLearningRate() {
return learningRate;
}
/**
* Sets the learning rate.
*
* @param learningRate the learning rate.
*/
public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
}
/**
* Returns the task of this learner.
*
* @return the task of this learner.
*/
public Task getTask() {
return task;
}
/**
* Sets the task of this learner.
*
* @param task the new task.
*/
public void setTask(Task task) {
this.task = task;
}
/**
* Returns the GAM.
*
* @return the GAM.
*/
public GAM getGAM() {
return gam;
}
/**
* Sets the GAM.
*
* @param gam the GAM.
*/
public void setGAM(GAM gam) {
this.gam = gam;
}
/**
* Returns the list of feature interaction pairs.
*
* @return the list of feature interaction pairs.
*/
public List getPairs() {
return pairs;
}
/**
* Sets the list of feature interaction pairs.
*
* @param pairs the list of feature interaction pairs.
*/
public void setPairs(List pairs) {
this.pairs = pairs;
}
/**
* Builds a classifier.
*
* @param gam the GAM.
* @param terms the list of feature interaction pairs.
* @param trainSet the training set.
* @param validSet the validation set.
* @param maxNumIters the maximum number of iterations.
*/
public void buildClassifier(GAM gam, List terms, Instances trainSet, Instances validSet, int maxNumIters) {
List regressors = new ArrayList<>(terms.size());
int[] indices = new int[terms.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = indexOf(gam.terms, terms.get(i));
regressors.add(new BoostedEnsemble());
}
// Backup targets
double[] target = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
target[i] = trainSet.get(i).getTarget();
}
// Create bags
Instances[] bags = Sampling.createBags(trainSet, baggingIters);
SquareCutter cutter = new SquareCutter(true);
BaggedEnsembleLearner learner = new BaggedEnsembleLearner(bags.length, cutter);
// Initialize predictions and residuals
double[] pTrain = new double[trainSet.size()];
double[] rTrain = new double[trainSet.size()];
double[] pValid = new double[validSet.size()];
for (int i = 0; i < pTrain.length; i++) {
Instance instance = trainSet.get(i);
pTrain[i] = gam.regress(instance);
}
OptimUtils.computePseudoResidual(pTrain, target, rTrain);
for (int i = 0; i < pValid.length; i++) {
Instance instance = validSet.get(i);
pValid[i] = gam.regress(instance);
}
// Gradient boosting
// Resets the convergence tester
ct.setMetric(metric);
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < terms.size(); j++) {
// Derivitive to attribute k
// Minimizes the loss function: log(1 + exp(-yF))
for (int i = 0; i < trainSet.size(); i++) {
trainSet.get(i).setTarget(rTrain[i]);
}
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Train model
IntPair term = terms.get(j);
cutter.setAttIndices(term.v1, term.v2);
BaggedEnsemble baggedEnsemble = learner.build(bags);
if (learningRate != 1) {
for (int i = 0; i < baggedEnsemble.size(); i++) {
Function2D func = (Function2D) baggedEnsemble.get(i);
func.multiply(learningRate);
}
}
boostedEnsemble.add(baggedEnsemble);
// Update predictions
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
double pred = baggedEnsemble.regress(instance);
pTrain[i] += pred;
rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], target[i]);
}
for (int i = 0; i < validSet.size(); i++) {
Instance instance = validSet.get(i);
double pred = baggedEnsemble.regress(instance);
pValid[i] += pred;
}
double measure = metric.eval(pValid, validSet);
ct.add(measure);
if (verbose) {
System.out.println("Iteration " + iter + " term " + j + ": " + measure);
}
}
if (ct.isConverged()) {
break;
}
}
// Search the best model on validation set
int idx = ct.getBestIndex();
// Remove trees
int n = idx / terms.size();
int m = idx % terms.size();
for (int k = 0; k < terms.size(); k++) {
BoostedEnsemble boostedEnsemble = regressors.get(k);
for (int i = boostedEnsemble.size(); i > n + 1; i--) {
boostedEnsemble.removeLast();
}
if (k > m) {
boostedEnsemble.removeLast();
}
}
// Restore targets
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
}
// Compress model
List attributes = trainSet.getAttributes();
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
IntPair term = terms.get(i);
Attribute f1 = attributes.get(term.v1);
Attribute f2 = attributes.get(term.v2);
int n1 = -1;
if (f1.getType() == Type.BINNED) {
n1 = ((BinnedAttribute) f1).getNumBins();
} else if (f1.getType() == Type.NOMINAL) {
n1 = ((NominalAttribute) f1).getCardinality();
}
int n2 = -1;
if (f2.getType() == Type.BINNED) {
n2 = ((BinnedAttribute) f2).getNumBins();
} else if (f2.getType() == Type.NOMINAL) {
n2 = ((NominalAttribute) f2).getCardinality();
}
Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble);
if (indices[i] < 0) {
gam.add(new int[] { term.v1, term.v2 }, newRegressor);
} else {
Regressor regressor = gam.regressors.get(indices[i]);
if (regressor instanceof Array2D) {
Array2D ary = (Array2D) regressor;
ary.add(newRegressor);
} else {
throw new RuntimeException("Failed to add new regressor");
}
}
}
}
/**
* Builds a classifier.
*
* @param gam the GAM.
* @param terms the list of feature interaction pairs.
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
*/
public void buildClassifier(GAM gam, List terms, Instances trainSet, int maxNumIters) {
SimpleMetric simpleMetric = (SimpleMetric) metric;
List regressors = new ArrayList<>();
int[] indices = new int[terms.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = indexOf(gam.terms, terms.get(i));
regressors.add(new BoostedEnsemble());
}
// Backup targets
double[] target = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
target[i] = trainSet.get(i).getTarget();
}
// Create bags
Instances[] bags = Sampling.createBags(trainSet, baggingIters);
SquareCutter cutter = new SquareCutter(true);
BaggedEnsembleLearner learner = new BaggedEnsembleLearner(bags.length, cutter);
// Initialize predictions and residuals
double[] pTrain = new double[trainSet.size()];
double[] rTrain = new double[trainSet.size()];
for (int i = 0; i < pTrain.length; i++) {
Instance instance = trainSet.get(i);
pTrain[i] = gam.regress(instance);
}
OptimUtils.computePseudoResidual(pTrain, target, rTrain);
// Gradient boosting
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < terms.size(); j++) {
// Derivitive to attribute k
// Minimizes the loss function: log(1 + exp(-yF))
for (int i = 0; i < trainSet.size(); i++) {
trainSet.get(i).setTarget(rTrain[i]);
}
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Train model
IntPair term = terms.get(j);
cutter.setAttIndices(term.v1, term.v2);
BaggedEnsemble baggedEnsemble = learner.build(bags);
if (learningRate != 1) {
for (int i = 0; i < baggedEnsemble.size(); i++) {
Function2D func = (Function2D) baggedEnsemble.get(i);
func.multiply(learningRate);
}
}
boostedEnsemble.add(baggedEnsemble);
// Update predictions
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
double pred = baggedEnsemble.regress(instance);
pTrain[i] += pred;
rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], target[i]);
}
double measure = simpleMetric.eval(pTrain, target);
if (verbose) {
System.out.println("Iteration " + iter + " term " + j + ": " + measure);
}
}
}
// Restore targets
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
}
// Compress model
List attributes = trainSet.getAttributes();
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
IntPair term = terms.get(i);
Attribute f1 = attributes.get(term.v1);
Attribute f2 = attributes.get(term.v2);
int n1 = -1;
if (f1.getType() == Type.BINNED) {
n1 = ((BinnedAttribute) f1).getNumBins();
} else if (f1.getType() == Type.NOMINAL) {
n1 = ((NominalAttribute) f1).getCardinality();
}
int n2 = -1;
if (f2.getType() == Type.BINNED) {
n2 = ((BinnedAttribute) f2).getNumBins();
} else if (f2.getType() == Type.NOMINAL) {
n2 = ((NominalAttribute) f2).getCardinality();
}
Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble);
if (indices[i] < 0) {
gam.add(new int[] { term.v1, term.v2 }, newRegressor);
} else {
Regressor regressor = gam.regressors.get(indices[i]);
if (regressor instanceof Array2D) {
Array2D ary = (Array2D) regressor;
ary.add(newRegressor);
} else {
throw new RuntimeException("Failed to add new regressor");
}
}
}
}
/**
* Builds a regressor.
*
* @param gam the GAM.
* @param terms the list of feature interaction pairs.
* @param trainSet the training set.
* @param validSet the validation set.
* @param maxNumIters the maximum number of iterations.
*/
public void buildRegressor(GAM gam, List terms, Instances trainSet, Instances validSet, int maxNumIters) {
List regressors = new ArrayList<>();
int[] indices = new int[terms.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = indexOf(gam.terms, terms.get(i));
regressors.add(new BoostedEnsemble());
}
// Backup targets
double[] target = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
target[i] = trainSet.get(i).getTarget();
}
// Create bags
Instances[] bags = Sampling.createBags(trainSet, baggingIters);
SquareCutter cutter = new SquareCutter();
BaggedEnsembleLearner learner = new BaggedEnsembleLearner(baggingIters, cutter);
// Initialize predictions and residuals
double[] rTrain = new double[trainSet.size()];
double[] pValid = new double[validSet.size()];
double[] rValid = new double[validSet.size()];
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
rTrain[i] = instance.getTarget() - gam.regress(instance);
}
for (int i = 0; i < validSet.size(); i++) {
Instance instance = validSet.get(i);
pValid[i] = gam.regress(instance);
rValid[i] = instance.getTarget() - pValid[i];
}
// Gradient boosting
// Resets the convergence tester
ct.setMetric(metric);
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < terms.size(); j++) {
// Derivative to attribute k
// Equivalent to residual
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Prepare training set
for (int i = 0; i < rTrain.length; i++) {
trainSet.get(i).setTarget(rTrain[i]);
}
// Train model
IntPair term = terms.get(j);
cutter.setAttIndices(term.v1, term.v2);
BaggedEnsemble baggedEnsemble = learner.build(bags);
if (learningRate != 1) {
for (int i = 0; i < baggedEnsemble.size(); i++) {
Function2D func = (Function2D) baggedEnsemble.get(i);
func.multiply(learningRate);
}
}
boostedEnsemble.add(baggedEnsemble);
// Update residuals
for (int i = 0; i < rTrain.length; i++) {
Instance instance = trainSet.get(i);
double pred = baggedEnsemble.regress(instance);
rTrain[i] -= pred;
}
for (int i = 0; i < rValid.length; i++) {
Instance instance = validSet.get(i);
double pred = baggedEnsemble.regress(instance);
pValid[i] += pred;
rValid[i] -= pred;
}
double measure = metric.eval(pValid, validSet);
ct.add(measure);
if (verbose) {
System.out.println("Iteration " + iter + " term " + j + ":" + measure);
}
}
if (ct.isConverged()) {
break;
}
}
// Search the best model on validation set
int idx = ct.getBestIndex();
// Remove trees
int n = idx / terms.size();
int m = idx % terms.size();
for (int k = 0; k < terms.size(); k++) {
BoostedEnsemble boostedEnsemble = regressors.get(k);
for (int i = boostedEnsemble.size(); i > n + 1; i--) {
boostedEnsemble.removeLast();
}
if (k > m) {
boostedEnsemble.removeLast();
}
}
// Restore targets
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
}
// Compress model
List attributes = trainSet.getAttributes();
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
IntPair term = terms.get(i);
Attribute f1 = attributes.get(term.v1);
Attribute f2 = attributes.get(term.v2);
int n1 = -1;
if (f1.getType() == Type.BINNED) {
n1 = ((BinnedAttribute) f1).getNumBins();
} else if (f1.getType() == Type.NOMINAL) {
n1 = ((NominalAttribute) f1).getCardinality();
}
int n2 = -1;
if (f2.getType() == Type.BINNED) {
n2 = ((BinnedAttribute) f2).getNumBins();
} else if (f2.getType() == Type.NOMINAL) {
n2 = ((NominalAttribute) f2).getCardinality();
}
Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble);
if (indices[i] < 0) {
gam.add(new int[] { term.v1, term.v2 }, newRegressor);
} else {
Regressor regressor = gam.regressors.get(indices[i]);
if (regressor instanceof Array2D) {
Array2D ary = (Array2D) regressor;
ary.add(newRegressor);
} else {
throw new RuntimeException("Failed to add new regressor");
}
}
}
}
/**
* Builds a regressor.
*
* @param gam the GAM.
* @param terms the list of feature interaction pairs.
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
*/
public void buildRegressor(GAM gam, List terms, Instances trainSet, int maxNumIters) {
SimpleMetric simpleMetric = (SimpleMetric) metric;
List regressors = new ArrayList<>();
int[] indices = new int[terms.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = indexOf(gam.terms, terms.get(i));
regressors.add(new BoostedEnsemble());
}
// Backup targets
double[] target = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
target[i] = trainSet.get(i).getTarget();
}
// Create bags
Instances[] bags = Sampling.createBags(trainSet, baggingIters);
SquareCutter cutter = new SquareCutter();
BaggedEnsembleLearner learner = new BaggedEnsembleLearner(baggingIters, cutter);
// Initialize predictions and residuals
double[] pTrain = new double[trainSet.size()];
double[] rTrain = new double[trainSet.size()];
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
pTrain[i] = gam.regress(instance);
rTrain[i] = instance.getTarget() - pTrain[i];
}
// Gradient boosting
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < terms.size(); j++) {
// Derivative to attribute k
// Equivalent to residual
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Prepare training set
for (int i = 0; i < rTrain.length; i++) {
trainSet.get(i).setTarget(rTrain[i]);
}
// Train model
IntPair term = terms.get(j);
cutter.setAttIndices(term.v1, term.v2);
BaggedEnsemble baggedEnsemble = learner.build(bags);
if (learningRate != 1) {
for (int i = 0; i < baggedEnsemble.size(); i++) {
Function2D func = (Function2D) baggedEnsemble.get(i);
func.multiply(learningRate);
}
}
boostedEnsemble.add(baggedEnsemble);
// Update residuals
for (int i = 0; i < rTrain.length; i++) {
Instance instance = trainSet.get(i);
double pred = baggedEnsemble.regress(instance);
rTrain[i] -= pred;
}
double measure = simpleMetric.eval(pTrain, target);
if (verbose) {
System.out.println("Iteration " + iter + " term " + j + ":" + measure);
}
}
}
// Restore targets
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
}
// Compress model
List attributes = trainSet.getAttributes();
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
IntPair term = terms.get(i);
Attribute f1 = attributes.get(term.v1);
Attribute f2 = attributes.get(term.v2);
int n1 = -1;
if (f1.getType() == Type.BINNED) {
n1 = ((BinnedAttribute) f1).getNumBins();
} else if (f1.getType() == Type.NOMINAL) {
n1 = ((NominalAttribute) f1).getCardinality();
}
int n2 = -1;
if (f2.getType() == Type.BINNED) {
n2 = ((BinnedAttribute) f2).getNumBins();
} else if (f2.getType() == Type.NOMINAL) {
n2 = ((NominalAttribute) f2).getCardinality();
}
Array2D newRegressor = CompressionUtils.compress(term.v1, term.v2, n1, n2, boostedEnsemble);
if (indices[i] < 0) {
gam.add(new int[] { term.v1, term.v2 }, newRegressor);
} else {
Regressor regressor = gam.regressors.get(indices[i]);
if (regressor instanceof Array2D) {
Array2D ary = (Array2D) regressor;
ary.add(newRegressor);
} else {
throw new RuntimeException("Failed to add new regressor");
}
}
}
}
@Override
public GAM build(Instances instances) {
if (pairs == null) {
int p = instances.dimension();
pairs = new ArrayList();
for (int i = 0; i < p; i++) {
for (int j = i + 1; j < p; j++) {
pairs.add(new IntPair(i, j));
}
}
}
if (maxNumIters < 0) {
maxNumIters = 20;
}
switch (task) {
case REGRESSION:
if (validSet != null) {
buildRegressor(gam, pairs, instances, validSet, maxNumIters);
} else {
buildRegressor(gam, pairs, instances, maxNumIters);
}
break;
case CLASSIFICATION:
if (validSet != null) {
buildClassifier(gam, pairs, instances, validSet, maxNumIters);
} else {
buildClassifier(gam, pairs, instances, maxNumIters);
}
break;
default:
break;
}
return gam;
}
private int indexOf(List terms, IntPair pair) {
for (int i = 0; i < terms.size(); i++) {
int[] term = terms.get(i);
if (term.length == 2 && term[0] == pair.v1 && term[1] == pair.v2) {
return i;
}
}
return -1;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/GAM.java
================================================
package mltk.predictor.gam;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import mltk.core.Instance;
import mltk.predictor.ProbabilisticClassifier;
import mltk.predictor.Regressor;
import mltk.util.ArrayUtils;
import mltk.util.MathUtils;
/**
* Class for generalized additive models (GAMs).
*
* @author Yin Lou
*
*/
public class GAM implements ProbabilisticClassifier, Regressor {
class RegressorList implements Iterable {
List regressors;
RegressorList() {
regressors = new ArrayList<>();
}
@Override
public Iterator iterator() {
return regressors.iterator();
}
}
class TermList implements Iterable {
List terms;
TermList() {
terms = new ArrayList<>();
}
@Override
public Iterator iterator() {
return terms.iterator();
}
}
protected double intercept;
protected List regressors;
protected List terms;
/**
* Constructor.
*/
public GAM() {
regressors = new ArrayList<>();
terms = new ArrayList<>();
intercept = 0;
}
/**
* Returns the intercept.
*
* @return the intercept.
*/
public double getIntercept() {
return intercept;
}
/**
* Sets the intercept.
*
* @param intercept the new intercept.
*/
public void setIntercept(double intercept) {
this.intercept = intercept;
}
@Override
public void read(BufferedReader in) throws Exception {
intercept = Double.parseDouble(in.readLine().split(": ")[1]);
int size = Integer.parseInt(in.readLine().split(": ")[1]);
regressors = new ArrayList<>(size);
terms = new ArrayList<>(size);
in.readLine();
for (int i = 0; i < size; i++) {
int[] term = ArrayUtils.parseIntArray(in.readLine().split(": ")[1]);
terms.add(term);
in.readLine();
String line = in.readLine();
String regressorName = line.substring(1, line.length() - 1).split(": ")[1];
Class> clazz = Class.forName(regressorName);
Regressor regressor = (Regressor) clazz.getDeclaredConstructor().newInstance();
regressor.read(in);
regressors.add(regressor);
in.readLine();
}
}
@Override
public void write(PrintWriter out) throws Exception {
out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName());
out.println("Intercept: " + intercept);
out.println("Components: " + regressors.size());
out.println();
for (int i = 0; i < regressors.size(); i++) {
out.println("Component: " + Arrays.toString(terms.get(i)));
out.println();
regressors.get(i).write(out);
out.println();
}
}
/**
* Adds a new term into this GAM. The term is an array of attribute indices that are used in the regressor.
*
* @param term the new term to add.
* @param regressor the new regressor to add.
*/
public void add(int[] term, Regressor regressor) {
terms.add(term);
regressors.add(regressor);
}
@Override
public double regress(Instance instance) {
double pred = intercept;
for (Regressor regressor : regressors) {
pred += regressor.regress(instance);
}
return pred;
}
@Override
public int classify(Instance instance) {
double pred = regress(instance);
return pred >= 0 ? 1 : 0;
}
@Override
public double[] predictProbabilities(Instance instance) {
double pred = regress(instance);
double prob = MathUtils.sigmoid(pred);
return new double[] { 1 - prob, prob };
}
/**
* Returns the term list.
*
* @return the term list.
*/
public List getTerms() {
return terms;
}
/**
* Returns the regressor list.
*
* @return the regressor list.
*/
public List getRegressors() {
return regressors;
}
@Override
public GAM copy() {
GAM copy = new GAM();
copy.intercept = intercept;
for (Regressor regressor : regressors) {
copy.regressors.add((Regressor) regressor.copy());
}
for (int[] term : terms) {
copy.terms.add(term.clone());
}
return copy;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/GAMLearner.java
================================================
package mltk.predictor.gam;
import java.util.ArrayList;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.cmdline.options.HoldoutValidatedLearnerWithTaskOptions;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.Attribute.Type;
import mltk.core.io.InstancesReader;
import mltk.predictor.BaggedEnsemble;
import mltk.predictor.BoostedEnsemble;
import mltk.predictor.HoldoutValidatedLearner;
import mltk.predictor.Regressor;
import mltk.predictor.evaluation.ConvergenceTester;
import mltk.predictor.evaluation.Metric;
import mltk.predictor.evaluation.MetricFactory;
import mltk.predictor.evaluation.SimpleMetric;
import mltk.predictor.function.BaggedLineCutter;
import mltk.predictor.function.CompressionUtils;
import mltk.predictor.function.EnsembledLineCutter;
import mltk.predictor.function.Function1D;
import mltk.predictor.function.SubaggedLineCutter;
import mltk.predictor.io.PredictorWriter;
import mltk.util.OptimUtils;
import mltk.util.Random;
/**
* Class for learning GAMs via gradient tree boosting.
*
*
* Reference:
* Y. Lou, R. Caruana and J. Gehrke. Intelligible models for classification and regression. In Proceedings of the
* 18th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), Beijing, China, 2012.
*
* Y. Lou, Y. Wang, S. Liang and Y. Dong. Efficiently Training Intelligible Models for Global Explanations.
* In Proceedings of the 29th ACM International Conference on Information and Knowledge Management (CIKM),
* Virtual Event, Ireland, 2020.
*
*
* @author Yin Lou
*
*/
public class GAMLearner extends HoldoutValidatedLearner {
static class Options extends HoldoutValidatedLearnerWithTaskOptions {
@Argument(name = "-b", description = "base learner (default: tr:3:100:0.65)")
String baseLearner = "tr:3:100:0.65";
@Argument(name = "-m", description = "maximum number of iterations", required = true)
int maxNumIters = -1;
@Argument(name = "-s", description = "seed of the random number generator (default: 0)")
long seed = 0L;
@Argument(name = "-l", description = "learning rate (default: 0.01)")
double learningRate = 0.01;
}
/**
* Trains a GAM.
*
*
* Usage: mltk.predictor.gam.GAMLearner
* -t train set path
* -m maximum number of iterations
* [-g] task between classification (c) and regression (r) (default: r)
* [-v] valid set path
* [-e] evaluation metric (default: default metric of task)
* [-S] convergence criteria (default: -1)
* [-r] attribute file path
* [-o] output model path
* [-V] verbose (default: true)
* [-b] base learner (default: tr:3:100)
* [-s] seed of the random number generator (default: 0)
* [-l] learning rate (default: 0.01)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(GAMLearner.class, opts);
Task task = null;
Metric metric = null;
try {
parser.parse(args);
task = Task.get(opts.task);
if (opts.metric == null) {
metric = task.getDefaultMetric();
} else {
metric = MetricFactory.getMetric(opts.metric);
}
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Random.getInstance().setSeed(opts.seed);
ConvergenceTester ct = ConvergenceTester.parse(opts.cc);
Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath);
GAMLearner learner = new GAMLearner();
learner.setBaseLearner(opts.baseLearner);
learner.setMaxNumIters(opts.maxNumIters);
learner.setLearningRate(opts.learningRate);
learner.setTask(task);
learner.setMetric(metric);
learner.setConvergenceTester(ct);
learner.setVerbose(opts.verbose);
if (opts.validPath != null) {
Instances validSet = InstancesReader.read(opts.attPath, opts.validPath);
learner.setValidSet(validSet);
}
long start = System.currentTimeMillis();
GAM gam = learner.build(trainSet);
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000.0);
if (opts.outputModelPath != null) {
PredictorWriter.write(gam, opts.outputModelPath);
}
}
private int baggingIters;
private int maxNumIters;
private int maxNumLeaves;
private Task task;
private double alpha;
private double learningRate;
/**
* Constructor.
*/
public GAMLearner() {
verbose = false;
baggingIters = 100;
maxNumIters = -1;
maxNumLeaves = 3;
alpha = 0.65;
learningRate = 0.01;
task = Task.REGRESSION;
metric = task.getDefaultMetric();
}
/**
* Returns the number of bagging iterations.
*
* @return the number of bagging iterations.
*/
public int getBaggingIters() {
return baggingIters;
}
/**
* Sets the number of bagging iterations.
*
* @param baggingIters the number of bagging iterations.
*/
public void setBaggingIters(int baggingIters) {
this.baggingIters = baggingIters;
}
/**
* Returns the maximum number of iterations.
*
* @return the maximum number of iterations.
*/
public int getMaxNumIters() {
return maxNumIters;
}
/**
* Sets the maximum number of iterations.
*
* @param maxNumIters the maximum number of iterations.
*/
public void setMaxNumIters(int maxNumIters) {
this.maxNumIters = maxNumIters;
}
/**
* Returns the maximum number of leaves.
*
* @return the maximum number of leaves.
*/
public int getMaxNumLeaves() {
return maxNumLeaves;
}
/**
* Sets the maximum number of leaves.
*
* @param maxNumLeaves the maximum number of leaves.
*/
public void setMaxNumLeaves(int maxNumLeaves) {
this.maxNumLeaves = maxNumLeaves;
}
/**
* Returns the learning rate.
*
* @return the learning rate.
*/
public double getLearningRate() {
return learningRate;
}
/**
* Sets the learning rate.
*
* @param learningRate the learning rate.
*/
public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
}
/**
* Returns the subsampling ratio.
*
* @return the subsampling ratio.
*/
public double getSubsamplingRatio() {
return alpha;
}
/**
* Sets the subsampling ratio.
*
* @param alpha the subsampling ratio.
*/
public void setSubsamplingRatio(double alpha) {
this.alpha = alpha;
}
/**
* Returns the task of this learner.
*
* @return the task of this learner.
*/
public Task getTask() {
return task;
}
/**
* Sets the task of this learner.
*
* @param task the task of this learner.
*/
public void setTask(Task task) {
this.task = task;
}
/**
* Sets the base learner.
*
* @param option the option string.
*/
public void setBaseLearner(String option) {
String[] opts = option.split(":");
switch (opts[0]) {
case "tr":
int maxNumLeaves = Integer.parseInt(opts[1]);
int baggingIters = Integer.parseInt(opts[2]);
setMaxNumLeaves(maxNumLeaves);
setBaggingIters(baggingIters);
if (opts.length > 3) {
double alpha = Double.parseDouble(opts[3]);
setSubsamplingRatio(alpha);
} else {
setSubsamplingRatio(-1);
}
break;
case "cs":
break;
default:
break;
}
}
/**
* Builds a classifier.
*
* @param trainSet the training set.
* @param validSet the validation set.
* @param maxNumIters the maximum number of iterations.
* @param maxNumLeaves the maximum number of leaves.
* @return a classifier.
*/
public GAM buildClassifier(Instances trainSet, Instances validSet, int maxNumIters, int maxNumLeaves) {
GAM gam = new GAM();
// Backup targets and weights
double[] target = new double[trainSet.size()];
double[] weight = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
Instance instance = trainSet.get(i);
target[i] = instance.getTarget();
weight[i] = instance.getWeight();
}
List attributes = trainSet.getAttributes();
List regressors = new ArrayList<>(attributes.size());
for (int i = 0; i < attributes.size(); i++) {
regressors.add(new BoostedEnsemble());
}
// Create bags
EnsembledLineCutter elc = null;
if (0 <= alpha & alpha <= 1) {
SubaggedLineCutter slc = new SubaggedLineCutter(true);
slc.createSubags(trainSet.size(), alpha, baggingIters);
elc = slc;
} else {
BaggedLineCutter blc = new BaggedLineCutter(true);
blc.createBags(trainSet.size(), baggingIters);
elc = blc;
}
elc.setNumIntervals(maxNumLeaves);
// Initialize predictions and residuals
double[] predTrain = new double[trainSet.size()];
double[] probTrain = new double[trainSet.size()];
double[] rTrain = new double[trainSet.size()];
OptimUtils.computeProbabilities(predTrain, probTrain);
OptimUtils.computePseudoResidual(predTrain, target, rTrain);
double[] pValid = new double[validSet.size()];
// Gradient boosting
// Resets the convergence tester
ct.setMetric(metric);
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < attributes.size(); j++) {
// Derivitive to attribute k
// Minimizes the loss function: log(1 + exp(-yF))
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
double prob = probTrain[i];
double w = prob * (1 - prob);
instance.setTarget(rTrain[i] * weight[i]);
instance.setWeight(w * weight[i]);
}
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Train model
elc.setAttributeIndex(j);
BaggedEnsemble baggedEnsemble = elc.build(trainSet);
Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble);
if (learningRate != 1) {
func.multiply(learningRate);
}
boostedEnsemble.add(func);
baggedEnsemble = null;
// Update predictions
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
double pred = func.regress(instance);
predTrain[i] += pred;
}
OptimUtils.computeProbabilities(predTrain, probTrain);
OptimUtils.computePseudoResidual(predTrain, target, rTrain);
for (int i = 0; i < validSet.size(); i++) {
Instance instance = validSet.get(i);
double pred = func.regress(instance);
pValid[i] += pred;
}
double measure = metric.eval(pValid, validSet);
ct.add(measure);
if (verbose) {
System.out.println("Iteration " + iter + " Feature " + j + ": " + measure);
}
}
if (ct.isConverged()) {
break;
}
}
// Search the best model on validation set
int idx = ct.getBestIndex();
// Remove trees
int n = idx / attributes.size();
int m = idx % attributes.size();
for (int k = 0; k < regressors.size(); k++) {
BoostedEnsemble boostedEnsemble = regressors.get(k);
for (int i = boostedEnsemble.size(); i > n + 1; i--) {
boostedEnsemble.removeLast();
}
if (k > m) {
boostedEnsemble.removeLast();
}
}
// Restore targets and weights
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
trainSet.get(i).setWeight(weight[i]);
}
// Compress model
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
Attribute attribute = attributes.get(i);
int attIndex = attribute.getIndex();
Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble);
Regressor regressor = function;
if (attribute.getType() == Type.BINNED) {
int l = ((BinnedAttribute) attribute).getNumBins();
regressor = CompressionUtils.convert(l, function);
} else if (attribute.getType() == Type.NOMINAL) {
int l = ((NominalAttribute) attribute).getCardinality();
regressor = CompressionUtils.convert(l, function);
}
gam.add(new int[] { attIndex }, regressor);
}
return gam;
}
/**
* Builds a classifier.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param maxNumLeaves the maximum number of leaves.
* @return a classifier.
*/
public GAM buildClassifier(Instances trainSet, int maxNumIters, int maxNumLeaves) {
GAM gam = new GAM();
SimpleMetric simpleMetric = (SimpleMetric) metric;
// Backup targets and weights
double[] target = new double[trainSet.size()];
double[] weight = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
Instance instance = trainSet.get(i);
target[i] = instance.getTarget();
weight[i] = instance.getWeight();
}
List attributes = trainSet.getAttributes();
List regressors = new ArrayList<>(attributes.size());
for (int i = 0; i < attributes.size(); i++) {
regressors.add(new BoostedEnsemble());
}
// Create bags
EnsembledLineCutter elc = null;
if (0 <= alpha & alpha <= 1) {
SubaggedLineCutter slc = new SubaggedLineCutter(true);
slc.createSubags(trainSet.size(), alpha, baggingIters);
elc = slc;
} else {
BaggedLineCutter blc = new BaggedLineCutter(true);
blc.createBags(trainSet.size(), baggingIters);
elc = blc;
}
elc.setNumIntervals(maxNumLeaves);
// Initialize predictions and residuals
double[] predTrain = new double[trainSet.size()];
double[] probTrain = new double[trainSet.size()];
double[] rTrain = new double[trainSet.size()];
OptimUtils.computeProbabilities(predTrain, probTrain);
OptimUtils.computePseudoResidual(predTrain, target, rTrain);
// Gradient boosting
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < attributes.size(); j++) {
// Derivitive to attribute k
// Minimizes the loss function: log(1 + exp(-yF))
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
double prob = probTrain[i];
double w = prob * (1 - prob);
instance.setTarget(rTrain[i] * weight[i]);
instance.setWeight(w * weight[i]);
}
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Train model
elc.setAttributeIndex(j);
BaggedEnsemble baggedEnsemble = elc.build(trainSet);
Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble);
if (learningRate != 1) {
func.multiply(learningRate);
}
boostedEnsemble.add(func);
baggedEnsemble = null;
// Update predictions
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
double pred = func.regress(instance);
predTrain[i] += pred;
}
OptimUtils.computeProbabilities(predTrain, probTrain);
OptimUtils.computePseudoResidual(predTrain, target, rTrain);
double measure = simpleMetric.eval(predTrain, target);
if (verbose) {
System.out.println("Iteration " + iter + " Feature " + j + ": " + measure);
}
}
}
// Restore targets and weights
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
trainSet.get(i).setWeight(weight[i]);
}
// Compress model
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
Attribute attribute = attributes.get(i);
int attIndex = attribute.getIndex();
Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble);
Regressor regressor = function;
if (attribute.getType() == Type.BINNED) {
int l = ((BinnedAttribute) attribute).getNumBins();
regressor = CompressionUtils.convert(l, function);
} else if (attribute.getType() == Type.NOMINAL) {
int l = ((NominalAttribute) attribute).getCardinality();
regressor = CompressionUtils.convert(l, function);
}
gam.add(new int[] { attIndex }, regressor);
}
return gam;
}
/**
* Builds a regressor.
*
* @param trainSet the training set.
* @param validSet the validation set.
* @param maxNumIters the maximum number of iterations.
* @param maxNumLeaves the maximum number of leaves.
* @return a regressor.
*/
public GAM buildRegressor(Instances trainSet, Instances validSet, int maxNumIters, int maxNumLeaves) {
GAM gam = new GAM();
List attributes = trainSet.getAttributes();
List regressors = new ArrayList<>(attributes.size());
for (int i = 0; i < attributes.size(); i++) {
regressors.add(new BoostedEnsemble());
}
// Backup targets
double[] target = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
target[i] = trainSet.get(i).getTarget();
}
// Create bags
EnsembledLineCutter elc = null;
if (0 <= alpha & alpha <= 1) {
SubaggedLineCutter slc = new SubaggedLineCutter(false);
slc.createSubags(trainSet.size(), alpha, baggingIters);
elc = slc;
} else {
BaggedLineCutter blc = new BaggedLineCutter(false);
blc.createBags(trainSet.size(), baggingIters);
elc = blc;
}
elc.setNumIntervals(maxNumLeaves);
// Initialize predictions and residuals
double[] rTrain = new double[trainSet.size()];
double[] pValid = new double[validSet.size()];
double[] rValid = new double[validSet.size()];
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
rTrain[i] = instance.getTarget();
}
for (int i = 0; i < validSet.size(); i++) {
Instance instance = validSet.get(i);
rValid[i] = instance.getTarget();
}
// Gradient boosting
// Resets the convergence tester
ct.setMetric(metric);
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < attributes.size(); j++) {
// Derivative to attribute k
// Equivalent to residual
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Prepare training set
for (int i = 0; i < rTrain.length; i++) {
trainSet.get(i).setTarget(rTrain[i]);
}
// Train model
elc.setAttributeIndex(j);
BaggedEnsemble baggedEnsemble = elc.build(trainSet);
Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble);
if (learningRate != 1) {
func.multiply(learningRate);
}
boostedEnsemble.add(func);
baggedEnsemble = null;
// Update residuals
for (int i = 0; i < rTrain.length; i++) {
Instance instance = trainSet.get(i);
double pred = func.regress(instance);
rTrain[i] -= pred;
}
for (int i = 0; i < rValid.length; i++) {
Instance instance = validSet.get(i);
double pred = func.regress(instance);
pValid[i] += pred;
rValid[i] -= pred;
}
double measure = metric.eval(pValid, validSet);
ct.add(measure);
if (verbose) {
System.out.println("Iteration " + iter + " Feature " + j + ": " + measure);
}
}
if (ct.isConverged()) {
break;
}
}
// Search the best model on validation set
int idx = ct.getBestIndex();
// Prune tree ensembles
int n = idx / attributes.size();
int m = idx % attributes.size();
for (int k = 0; k < regressors.size(); k++) {
BoostedEnsemble boostedEnsemble = regressors.get(k);
for (int i = boostedEnsemble.size(); i > n + 1; i--) {
boostedEnsemble.removeLast();
}
if (k > m) {
boostedEnsemble.removeLast();
}
}
// Restore targets
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
}
// Compress model
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
Attribute attribute = attributes.get(i);
int attIndex = attribute.getIndex();
Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble);
Regressor regressor = function;
if (attribute.getType() == Type.BINNED) {
int l = ((BinnedAttribute) attribute).getNumBins();
regressor = CompressionUtils.convert(l, function);
} else if (attribute.getType() == Type.NOMINAL) {
int l = ((NominalAttribute) attribute).getCardinality();
regressor = CompressionUtils.convert(l, function);
}
gam.add(new int[] { attIndex }, regressor);
}
return gam;
}
/**
* Builds a regressor.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param maxNumLeaves the maximum number of leaves.
* @return a regressor.
*/
public GAM buildRegressor(Instances trainSet, int maxNumIters, int maxNumLeaves) {
GAM gam = new GAM();
SimpleMetric simpleMetric = (SimpleMetric) metric;
List attributes = trainSet.getAttributes();
List regressors = new ArrayList<>(attributes.size());
for (int i = 0; i < attributes.size(); i++) {
regressors.add(new BoostedEnsemble());
}
// Backup targets
double[] target = new double[trainSet.size()];
for (int i = 0; i < target.length; i++) {
target[i] = trainSet.get(i).getTarget();
}
// Create bags
EnsembledLineCutter elc = null;
if (0 <= alpha & alpha <= 1) {
SubaggedLineCutter slc = new SubaggedLineCutter(true);
slc.createSubags(trainSet.size(), alpha, baggingIters);
elc = slc;
} else {
BaggedLineCutter blc = new BaggedLineCutter(true);
blc.createBags(trainSet.size(), baggingIters);
elc = blc;
}
elc.setNumIntervals(maxNumLeaves);
// Initialize predictions and residuals
double[] pTrain = new double[trainSet.size()];
double[] rTrain = new double[trainSet.size()];
for (int i = 0; i < trainSet.size(); i++) {
Instance instance = trainSet.get(i);
rTrain[i] = instance.getTarget();
}
// Gradient boosting
for (int iter = 0; iter < maxNumIters; iter++) {
for (int j = 0; j < attributes.size(); j++) {
// Derivative to attribute k
// Equivalent to residual
BoostedEnsemble boostedEnsemble = regressors.get(j);
// Prepare training set
for (int i = 0; i < rTrain.length; i++) {
trainSet.get(i).setTarget(rTrain[i]);
}
// Train model
elc.setAttributeIndex(j);
BaggedEnsemble baggedEnsemble = elc.build(trainSet);
Function1D func = CompressionUtils.compress(attributes.get(j).getIndex(), baggedEnsemble);
if (learningRate != 1) {
func.multiply(learningRate);
}
boostedEnsemble.add(func);
baggedEnsemble = null;
// Update residuals
for (int i = 0; i < rTrain.length; i++) {
Instance instance = trainSet.get(i);
double pred = func.regress(instance);
pTrain[i] += pred;
rTrain[i] -= pred;
}
double measure = simpleMetric.eval(pTrain, target);
if (verbose) {
System.out.println("Iteration " + iter + " Feature " + j + ": " + measure);
}
}
}
// Restore targets
for (int i = 0; i < target.length; i++) {
trainSet.get(i).setTarget(target[i]);
}
// Compress model
for (int i = 0; i < regressors.size(); i++) {
BoostedEnsemble boostedEnsemble = regressors.get(i);
Attribute attribute = attributes.get(i);
int attIndex = attribute.getIndex();
Function1D function = CompressionUtils.compress(attIndex, boostedEnsemble);
Regressor regressor = function;
if (attribute.getType() == Type.BINNED) {
int l = ((BinnedAttribute) attribute).getNumBins();
regressor = CompressionUtils.convert(l, function);
} else if (attribute.getType() == Type.NOMINAL) {
int l = ((NominalAttribute) attribute).getCardinality();
regressor = CompressionUtils.convert(l, function);
}
gam.add(new int[] { attIndex }, regressor);
}
return gam;
}
@Override
public GAM build(Instances instances) {
GAM gam = null;
if (maxNumIters < 0) {
maxNumIters = 20;
}
if (metric == null) {
metric = task.getDefaultMetric();
}
switch (task) {
case REGRESSION:
if (validSet != null) {
gam = buildRegressor(instances, validSet, maxNumIters, maxNumLeaves);
} else {
gam = buildRegressor(instances, maxNumIters, maxNumLeaves);
}
break;
case CLASSIFICATION:
if (validSet != null) {
gam = buildClassifier(instances, validSet, maxNumIters, maxNumLeaves);
} else {
gam = buildClassifier(instances, maxNumIters, maxNumLeaves);
}
break;
default:
break;
}
return gam;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/GAMUtils.java
================================================
package mltk.predictor.gam;
import java.util.List;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.NominalAttribute;
import mltk.core.NumericalAttribute;
import mltk.predictor.function.Array1D;
import mltk.predictor.function.LinearFunction;
import mltk.predictor.glm.GLM;
class GAMUtils {
static GAM getGAM(GLM glm, List attList) {
double[] w = glm.coefficients(0);
GAM gam = new GAM();
int k = 0;
for (Attribute attribute : attList) {
int attIndex = attribute.getIndex();
int[] term = new int[] {attIndex};
if (attribute instanceof NumericalAttribute) {
LinearFunction func = new LinearFunction(attIndex, -w[k++]);
gam.add(term, func);
} else if (attribute instanceof BinnedAttribute) {
BinnedAttribute binnedAttribute = (BinnedAttribute) attribute;
int size = binnedAttribute.getNumBins();
double[] predictions = new double[size];
for (int j = 0; j < predictions.length && k < w.length; j++) {
predictions[j] = -w[k++];
}
Array1D ary = new Array1D(attIndex, predictions);
gam.add(term, ary);
} else if (attribute instanceof NominalAttribute) {
NominalAttribute nominalAttribute = (NominalAttribute) attribute;
int size = nominalAttribute.getCardinality();
double[] predictions = new double[size];
for (int j = 0; j < predictions.length && k < w.length; j++) {
predictions[j] = -w[k++];
}
Array1D ary = new Array1D(attIndex, predictions);
gam.add(term, ary);
}
}
gam.setIntercept(-glm.intercept(0));
return gam;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/SPLAMLearner.java
================================================
package mltk.predictor.gam;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.cmdline.options.LearnerWithTaskOptions;
import mltk.core.Instances;
import mltk.core.io.InstancesReader;
import mltk.predictor.Learner;
import mltk.predictor.Regressor;
import mltk.predictor.function.CubicSpline;
import mltk.predictor.function.LinearFunction;
import mltk.predictor.glm.GLM;
import mltk.predictor.glm.RidgeLearner;
import mltk.predictor.io.PredictorWriter;
import mltk.util.ArrayUtils;
import mltk.util.MathUtils;
import mltk.util.OptimUtils;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
/**
* Class for learning SPLAM models. Currently only cubic spline basis is supported.
*
* @author Yin Lou
*
*/
public class SPLAMLearner extends Learner {
static class Options extends LearnerWithTaskOptions {
@Argument(name = "-d", description = "number of knots (default: 10)")
int numKnots = 10;
@Argument(name = "-m", description = "maximum number of iterations (default: 0)")
int maxNumIters = 0;
@Argument(name = "-l", description = "lambda (default: 0)")
double lambda = 0;
@Argument(name = "-a", description = "alpha (default: 1, i.e., SPAM model)")
double alpha = 1;
@Argument(name = "-L", description = "whether to compute lambda max for a given a")
boolean lambdaMax = false;
}
/**
* Trains a SPLAM.
*
*
* Usage: mltk.predictor.gam.SPLAMLearner
* -t train set path
* [-g] task between classification (c) and regression (r) (default: r)
* [-r] attribute file path
* [-o] output model path
* [-V] verbose (default: true)
* [-d] number of knots (default: 10)
* [-m] maximum number of iterations (default: 0)
* [-l] lambda (default: 0)
* [-a] alpha (default: 1, i.e., SPAM model)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(SPLAMLearner.class, opts);
Task task = null;
try {
parser.parse(args);
task = Task.get(opts.task);
if (opts.numKnots < 0) {
throw new IllegalArgumentException("Number of knots must be positive.");
}
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath);
SPLAMLearner learner = new SPLAMLearner();
learner.setNumKnots(opts.numKnots);
learner.setMaxNumIters(opts.maxNumIters);
learner.setLambda(opts.lambda);
learner.setAlpha(opts.alpha);
learner.setTask(task);
learner.setVerbose(opts.verbose);
if (opts.lambdaMax) {
System.out.println(learner.findMaxLambda(trainSet, task, opts.numKnots, opts.alpha));
System.exit(0);
}
long start = System.currentTimeMillis();
GAM gam = learner.build(trainSet);
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000.0);
if (opts.outputModelPath != null) {
PredictorWriter.write(gam, opts.outputModelPath);
}
}
static class ModelStructure {
static final byte ELIMINATED = 0;
static final byte LINEAR = 1;
static final byte NONLINEAR = 2;
byte[] structure;
ModelStructure(byte[] structure) {
this.structure = structure;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
ModelStructure other = (ModelStructure) obj;
if (!Arrays.equals(structure, other.structure))
return false;
return true;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(structure);
return result;
}
}
private boolean fitIntercept;
private boolean refit;
private int numKnots;
private int maxNumIters;
private Task task;
private double lambda;
private double alpha;
private double epsilon;
/**
* Constructor.
*/
public SPLAMLearner() {
verbose = false;
fitIntercept = true;
refit = false;
numKnots = 10;
maxNumIters = -1;
lambda = 0.0;
alpha = 1;
epsilon = MathUtils.EPSILON;
task = Task.REGRESSION;
}
@Override
public GAM build(Instances instances) {
GAM gam = null;
if (maxNumIters < 0) {
maxNumIters = 20;
}
if (numKnots < 0) {
numKnots = 10;
}
switch (task) {
case REGRESSION:
gam = buildRegressor(instances, maxNumIters, numKnots, lambda, alpha);
break;
case CLASSIFICATION:
gam = buildClassifier(instances, maxNumIters, numKnots, lambda, alpha);
break;
default:
break;
}
return gam;
}
/**
* Returns a binary classifier.
*
* @param attrs the attribute list.
* @param x the inputs.
* @param y the targets.
* @param knots the knots.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a binary classifier.
*/
public GAM buildBinaryClassifier(int[] attrs, double[][][] x, double[] y, double[][] knots, int maxNumIters, double lambda,
double alpha) {
double[][] w = new double[attrs.length][];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[x[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] tl1 = new double[attrs.length];
double[] tl2 = new double[attrs.length];
getRegularizationParameters(lambda, alpha, tl1, tl2, y.length);
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
double[] stepSize = new double[attrs.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = x[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t) / 4;
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
double[] gamma1 = new double[m];
double[] gamma2 = new double[m - 1];
boolean[] activeSet = new boolean[attrs.length];
double intercept = 0;
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
boolean activeSetChanged = doOnePass(x, y, tl1, tl2, true, activeSet, w, stepSize,
g, gradient, gamma1, gamma2, pTrain, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePass(x, y, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, pTrain, rTrain);
double currLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
if (refit) {
byte[] structure = extractStructure(w);
GAM gam = refitClassifier(attrs, structure, x, y, knots, w, maxNumIters);
return gam;
} else {
return getGAM(attrs, knots, w, intercept);
}
}
/**
* Returns a binary classifier.
*
* @param attrs the attribute list.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param knots the knots.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a binary classifier.
*/
public GAM buildBinaryClassifier(int[] attrs, int[][] indices, double[][][] values, double[] y, double[][] knots,
int maxNumIters, double lambda, double alpha) {
double[][] w = new double[attrs.length][];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[values[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] tl1 = new double[attrs.length];
double[] tl2 = new double[attrs.length];
getRegularizationParameters(lambda, alpha, tl1, tl2, y.length);
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
double[] stepSize = new double[attrs.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = values[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t) / 4;
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
double[] gamma1 = new double[m];
double[] gamma2 = new double[m - 1];
boolean[] activeSet = new boolean[attrs.length];
double intercept = 0;
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, y, rTrain);
}
boolean activeSetChanged = doOnePass(indices, values, y, tl1, tl2, true, activeSet, w, stepSize,
g, gradient, gamma1, gamma2, pTrain, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePass(indices, values, y, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, pTrain, rTrain);
double currLoss = OptimUtils.computeLogisticLoss(pTrain, y) + getPenalty(w, tl1, tl2);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
if (refit) {
byte[] structure = extractStructure(w);
GAM gam = refitClassifier(attrs, structure, indices, values, y, knots, w, maxNumIters * 10);
return gam;
} else {
return getGAM(attrs, knots, w, intercept);
}
}
/**
* Builds a classifier.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param numKnots the number of knots.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a classifier.
*/
public GAM buildClassifier(Instances trainSet, boolean isSparse, int numKnots, int maxNumIters, double lambda,
double alpha) {
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, false);
SparseDesignMatrix sm = SparseDesignMatrix.createCubicSplineDesignMatrix(trainSet.size(), sd.indices,
sd.values, sd.stdList, numKnots);
double[] y = sd.y;
int[][] indices = sm.indices;
double[][][] values = sm.values;
double[][] knots = sm.knots;
int[] attrs = sd.attrs;
// Mapping from attribute index to index in design matrix
Map map = new HashMap<>();
for (int j = 0; j < sd.attrs.length; j++) {
map.put(attrs[j], j);
}
GAM gam = buildBinaryClassifier(attrs, indices, values, y, knots, maxNumIters, lambda, alpha);
// Rescale weights in gam
List regressors = gam.getRegressors();
List terms = gam.getTerms();
double intercept = gam.getIntercept();
for (int i = 0; i < regressors.size(); i++) {
Regressor regressor = regressors.get(i);
int attIndex = terms.get(i)[0];
int idx = map.get(attIndex);
double[] std = sm.std[idx];
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(func.getSlope() / std[0]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] w = spline.getCoefficients();
for (int j = 0; j < w.length; j++) {
w[j] /= std[j];
}
double[] k = spline.getKnots();
for (int j = 0; j < k.length; j++) {
intercept -= w[j + 3] * CubicSpline.h(0, k[j]);
}
}
}
if (fitIntercept) {
gam.setIntercept(intercept);
}
return gam;
} else {
DenseDataset dd = getDenseDataset(trainSet, false);
DenseDesignMatrix dm = DenseDesignMatrix.createCubicSplineDesignMatrix(dd.x, dd.stdList, numKnots);
double[] y = dd.y;
double[][][] x = dm.x;
double[][] knots = dm.knots;
int[] attrs = dd.attrs;
// Mapping from attribute index to index in design matrix
Map map = new HashMap<>();
for (int j = 0; j < dd.attrs.length; j++) {
map.put(dd.attrs[j], j);
}
GAM gam = buildBinaryClassifier(attrs, x, y, knots, maxNumIters, lambda, alpha);
// Rescale weights in gam
List regressors = gam.getRegressors();
List terms = gam.getTerms();
for (int i = 0; i < regressors.size(); i++) {
Regressor regressor = regressors.get(i);
int attIndex = terms.get(i)[0];
int idx = map.get(attIndex);
double[] std = dm.std[idx];
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(func.getSlope() / std[0]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] w = spline.getCoefficients();
for (int j = 0; j < w.length; j++) {
w[j] /= std[j];
}
}
}
return gam;
}
}
/**
* Builds a classifier.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param numKnots the number of knots.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a classifier.
*/
public GAM buildClassifier(Instances trainSet, int maxNumIters, int numKnots, double lambda, double alpha) {
return buildClassifier(trainSet, isSparse(trainSet), maxNumIters, numKnots, lambda, alpha);
}
/**
* Builds a regressor.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param maxNumIters the maximum number of iterations.
* @param numKnots the number of knots.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a regressor.
*/
public GAM buildRegressor(Instances trainSet, boolean isSparse, int maxNumIters, int numKnots, double lambda,
double alpha) {
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, false);
SparseDesignMatrix sm = SparseDesignMatrix.createCubicSplineDesignMatrix(trainSet.size(), sd.indices,
sd.values, sd.stdList, numKnots);
double[] y = sd.y;
int[][] indices = sm.indices;
double[][][] values = sm.values;
double[][] knots = sm.knots;
int[] attrs = sd.attrs;
// Mapping from attribute index to index in design matrix
Map map = new HashMap<>();
for (int j = 0; j < sd.attrs.length; j++) {
map.put(attrs[j], j);
}
GAM gam = buildRegressor(attrs, indices, values, y, knots, maxNumIters, lambda, alpha);
// Rescale weights in gam
List regressors = gam.getRegressors();
List terms = gam.getTerms();
double intercept = gam.getIntercept();
for (int i = 0; i < regressors.size(); i++) {
Regressor regressor = regressors.get(i);
int attIndex = terms.get(i)[0];
int idx = map.get(attIndex);
double[] std = sm.std[idx];
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(func.getSlope() / std[0]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] w = spline.getCoefficients();
for (int j = 0; j < w.length; j++) {
w[j] /= std[j];
}
double[] k = spline.getKnots();
for (int j = 0; j < k.length; j++) {
intercept -= w[j + 3] * CubicSpline.h(0, k[j]);
}
}
}
if (fitIntercept) {
gam.setIntercept(intercept);
}
return gam;
} else {
DenseDataset dd = getDenseDataset(trainSet, false);
DenseDesignMatrix dm = DenseDesignMatrix.createCubicSplineDesignMatrix(dd.x, dd.stdList, numKnots);
double[] y = dd.y;
double[][][] x = dm.x;
double[][] knots = dm.knots;
int[] attrs = dd.attrs;
// Mapping from attribute index to index in design matrix
Map map = new HashMap<>();
for (int j = 0; j < dd.attrs.length; j++) {
map.put(dd.attrs[j], j);
}
GAM gam = buildRegressor(attrs, x, y, knots, maxNumIters, lambda, alpha);
// Rescale weights in gam
List regressors = gam.getRegressors();
List terms = gam.getTerms();
for (int i = 0; i < regressors.size(); i++) {
Regressor regressor = regressors.get(i);
int attIndex = terms.get(i)[0];
int idx = map.get(attIndex);
double[] std = dm.std[idx];
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(func.getSlope() / std[0]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] w = spline.getCoefficients();
for (int j = 0; j < w.length; j++) {
w[j] /= std[j];
}
}
}
return gam;
}
}
/**
* Builds a regressor.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param numKnots the number of knots.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a regressor.
*/
public GAM buildRegressor(Instances trainSet, int maxNumIters, int numKnots, double lambda,
double alpha) {
return buildRegressor(trainSet, isSparse(trainSet), maxNumIters, numKnots, lambda, alpha);
}
/**
* Returns a regressor.
*
* @param attrs the attribute list.
* @param x the inputs.
* @param y the targets.
* @param knots the knots.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a regressor.
*/
public GAM buildRegressor(int[] attrs, double[][][] x, double[] y, double[][] knots, int maxNumIters,
double lambda, double alpha) {
// Backup targets
double[] rTrain = new double[y.length];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] = y[i];
}
double[][] w = new double[attrs.length][];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[x[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] tl1 = new double[attrs.length];
double[] tl2 = new double[attrs.length];
getRegularizationParameters(lambda, alpha, tl1, tl2, y.length);
double[] stepSize = new double[attrs.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = x[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t);
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
double[] gamma1 = new double[m];
double[] gamma2 = new double[m - 1];
boolean[] activeSet = new boolean[attrs.length];
double intercept = 0;
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
boolean activeSetChanged = doOnePass(x, tl1, tl2, true, activeSet, w, stepSize, g, gradient, gamma1, gamma2,
rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
doOnePass(x, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, rTrain);
double currLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
if (refit) {
byte[] struct = extractStructure(w);
return refitRegressor(attrs, struct, x, y, knots, w, maxNumIters);
} else {
return getGAM(attrs, knots, w, intercept);
}
}
/**
* Returns a regressor.
*
* @param attrs the attribute list.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param knots the knots.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param alpha the alpha.
* @return a regressor.
*/
public GAM buildRegressor(int[] attrs, int[][] indices, double[][][] values, double[] y, double[][] knots,
int maxNumIters, double lambda, double alpha) {
// Backup targets
double[] rTrain = new double[y.length];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] = y[i];
}
double[][] w = new double[attrs.length][];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[values[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] tl1 = new double[attrs.length];
double[] tl2 = new double[attrs.length];
getRegularizationParameters(lambda, alpha, tl1, tl2, y.length);
double[] stepSize = new double[attrs.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
int[] index = indices[j];
double[][] block = values[j];
for (double[] t : block) {
double l = 0;
for (int i = 0; i < index.length; i++) {
l += y[index[j]] * t[i];
}
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
double[] gamma1 = new double[m];
double[] gamma2 = new double[m - 1];
boolean[] activeSet = new boolean[attrs.length];
double intercept = 0;
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
boolean activeSetChanged = doOnePass(indices, values, tl1, tl2, true, activeSet, w, stepSize,
g, gradient, gamma1, gamma2, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
doOnePass(indices, values, tl1, tl2, false, activeSet, w, stepSize, g, gradient, gamma1, gamma2, rTrain);
double currLoss = OptimUtils.computeQuadraticLoss(rTrain) + getPenalty(w, tl1, tl2);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
if (refit) {
byte[] structure = extractStructure(w);
GAM gam = refitRegressor(attrs, structure, indices, values, y, knots, w, maxNumIters * 10);
return gam;
} else {
return getGAM(attrs, knots, w, intercept);
}
}
/**
* Returns {@code true} if we fit intercept.
*
* @return {@code true} if we fit intercept.
*/
public boolean fitIntercept() {
return fitIntercept;
}
/**
* Sets whether we fit intercept.
*
* @param fitIntercept whether we fit intercept.
*/
public void fitIntercept(boolean fitIntercept) {
this.fitIntercept = fitIntercept;
}
/**
* Returns the alpha.
*
* @return the alpha;
*/
public double getAlpha() {
return alpha;
}
/**
* Returns the convergence threshold epsilon.
*
* @return the convergence threshold epsilon.
*/
public double getEpsilon() {
return epsilon;
}
/**
* Returns the lambda.
*
* @return the lambda.
*/
public double getLambda() {
return lambda;
}
/**
* Returns the maximum number of iterations.
*
* @return the maximum number of iterations.
*/
public int getMaxNumIters() {
return maxNumIters;
}
/**
* Returns the number of knots.
*
* @return the number of knots.
*/
public int getNumKnots() {
return numKnots;
}
/**
* Returns the task of this learner.
*
* @return the task of this learner.
*/
public Task getTask() {
return task;
}
/**
* Returns {@code true} if we output something during the training.
*
* @return {@code true} if we output something during the training.
*/
public boolean isVerbose() {
return verbose;
}
/**
* Returns {@code true} if we refit the model.
*
* @return {@code true} if we refit the model.
*/
public boolean refit() {
return refit;
}
/**
* Sets whether we refit the model.
*
* @param refit {@code true} if we refit the model.
*/
public void refit(boolean refit) {
this.refit = refit;
}
/**
* Sets the alpha.
*
* @param alpha the alpha.
*/
public void setAlpha(double alpha) {
this.alpha = alpha;
}
/**
* Sets the convergence threshold epsilon.
*
* @param epsilon the convergence threshold epsilon.
*/
public void setEpsilon(double epsilon) {
this.epsilon = epsilon;
}
/**
* Sets the lambda.
*
* @param lambda the lambda.
*/
public void setLambda(double lambda) {
this.lambda = lambda;
}
/**
* Sets the maximum number of iterations.
*
* @param maxNumIters the maximum number of iterations.
*/
public void setMaxNumIters(int maxNumIters) {
this.maxNumIters = maxNumIters;
}
/**
* Sets the number of knots.
*
* @param numKnots the new number of knots.
*/
public void setNumKnots(int numKnots) {
this.numKnots = numKnots;
}
/**
* Sets the task of this learner.
*
* @param task the task of this learner.
*/
public void setTask(Task task) {
this.task = task;
}
/**
* Sets whether we output something during the training.
*
* @param verbose the switch if we output things during training.
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
public double findMaxLambda(Instances trainSet, Task task, int numKnots, double alpha) {
DenseDataset dd = getDenseDataset(trainSet, false);
DenseDesignMatrix dm = DenseDesignMatrix.createCubicSplineDesignMatrix(dd.x, dd.stdList, numKnots);
double[] y = dd.y;
double[][][] x = dm.x;
int[] attrs = dd.attrs;
double[][] w = new double[attrs.length][];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[x[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] tl1 = new double[attrs.length];
double[] tl2 = new double[attrs.length];
double[] g = new double[m];
double[] gradient = new double[m];
double[] gamma1 = new double[m];
double[] gamma2 = new double[m - 1];
if (task == Task.REGRESSION) {
return findMaxLambda(x, y, alpha, tl1, tl2, w, g, gradient, gamma1, gamma2);
} else {
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
return findMaxLambda(x, y, pTrain, rTrain, alpha, tl1, tl2, w, g, gradient, gamma1, gamma2);
}
}
protected double findMaxLambda(double[][][] x, double[] rTrain, double alpha, double[] tl1, double[] tl2,
double[][] w, double[] g, double[] gradient, double[] gamma1, double[] gamma2) {
double mean = 0;
if (fitIntercept) {
mean = OptimUtils.fitIntercept(rTrain);
}
double lHigh = 0;
for (double[][] block : x) {
computeGradient(block, rTrain, gradient);
double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length);
if (t > lHigh) {
lHigh = t;
}
}
lHigh /= alpha;
double lLow = 0;
while (lHigh - lLow > MathUtils.EPSILON) {
double lambda = (lHigh + lLow) / 2;
for (int j = 0; j < x.length; j++) {
tl1[j] = lambda * alpha * Math.sqrt(w[j].length);
tl2[j] = lambda * (1 - alpha) * Math.sqrt(w[j].length - 1);
}
boolean isZeroPoint = testZeroPoint(x, rTrain, tl1, tl2, w, g, gradient, gamma1, gamma2);
if (isZeroPoint) {
lHigh = lambda;
} else {
lLow = lambda;
}
}
if (fitIntercept) {
VectorUtils.add(rTrain, mean);
}
return lHigh;
}
protected double findMaxLambda(double[][][] x, double[] y, double[] pTrain, double[] rTrain, double alpha, double[] tl1,
double[] tl2, double[][] coefficients, double[] g, double[] gradient, double[] gamma1, double[] gamma2) {
if (fitIntercept) {
OptimUtils.fitIntercept(pTrain, rTrain, y);
}
double lHigh = 0;
for (double[][] block : x) {
computeGradient(block, rTrain, gradient);
double t = Math.sqrt(StatUtils.sumSq(gradient, 0, block.length)) / Math.sqrt(block.length);
if (t > lHigh) {
lHigh = t;
}
}
lHigh /= alpha;
double lLow = 0;
while (lHigh - lLow > MathUtils.EPSILON) {
double lambda = (lHigh + lLow) / 2;
for (int j = 0; j < x.length; j++) {
tl1[j] = lambda * alpha * Math.sqrt(coefficients[j].length);
tl2[j] = lambda * (1 - alpha) * Math.sqrt(coefficients[j].length - 1);
}
boolean isZeroPoint = testZeroPoint(x, y, pTrain, rTrain, tl1, tl2, coefficients, g, gradient, gamma1, gamma2);
if (isZeroPoint) {
lHigh = lambda;
} else {
lLow = lambda;
}
}
if (fitIntercept) {
Arrays.fill(pTrain, 0);
}
return lHigh;
}
protected boolean testZeroPoint(double[][][] x, double[] y, double[] tl1, double[] tl2, double[][] w,
double[] g, double[] gradient, double[] gamma1, double[] gamma2) {
for (int k = 0; k < x.length; k++) {
double[][] block = x[k];
final double lambda1 = tl1[k];
final double lambda2 = tl2[k];
// Proximal gradient method
computeGradient(block, y, gradient);
double[] beta = w[k];
for (int i = 0; i < beta.length; i++) {
g[i] = gradient[i];
}
// Dual method
if (beta.length > 1) {
for (int i = 1; i < beta.length; i++) {
gamma2[i - 1] = g[i];
}
double norm2 = VectorUtils.l2norm(gamma2);
double t2 = lambda2;
if (norm2 > t2) {
VectorUtils.multiply(gamma2, t2 / norm2);
}
}
gamma1[0] = g[0];
for (int i = 1; i < beta.length; i++) {
gamma1[i] = g[i] - gamma2[i - 1];
}
double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length));
double t1 = lambda1;
if (norm1 > t1) {
VectorUtils.multiply(gamma1, t1 / norm1);
}
g[0] -= gamma1[0];
for (int i = 1; i < beta.length; i++) {
g[i] -= (gamma1[i] + gamma2[i - 1]);
}
if (!ArrayUtils.isConstant(g, 0, beta.length, 0)) {
return false;
}
}
return true;
}
protected boolean testZeroPoint(double[][][] x, double[] y, double[] pTrain, double[] rTrain, double[] tl1,
double[] tl2, double[][] coefficients, double[] g, double[] gradient, double[] gamma1, double[] gamma2) {
for (int k = 0; k < x.length; k++) {
double[][] block = x[k];
final double lambda1 = tl1[k];
final double lambda2 = tl2[k];
// Proximal gradient method
computeGradient(block, rTrain, gradient);
double[] beta = coefficients[k];
for (int i = 0; i < beta.length; i++) {
g[i] = gradient[i];
}
// Dual method
if (beta.length > 1) {
for (int i = 1; i < beta.length; i++) {
gamma2[i - 1] = g[i];
}
double norm2 = VectorUtils.l2norm(gamma2);
double t2 = lambda2;
if (norm2 > t2) {
VectorUtils.multiply(gamma2, t2 / norm2);
}
}
gamma1[0] = g[0];
for (int i = 1; i < beta.length; i++) {
gamma1[i] = g[i] - gamma2[i - 1];
}
double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length));
double t1 = lambda1;
if (norm1 > t1) {
VectorUtils.multiply(gamma1, t1 / norm1);
}
g[0] -= gamma1[0];
for (int i = 1; i < beta.length; i++) {
g[i] -= (gamma1[i] + gamma2[i - 1]);
}
if (!ArrayUtils.isConstant(g, 0, beta.length, 0)) {
return false;
}
}
return true;
}
protected void computeGradient(double[][] block, double[] rTrain, double[] gradient) {
for (int i = 0; i < block.length; i++) {
gradient[i] = VectorUtils.dotProduct(block[i], rTrain);
}
}
protected void computeGradient(int[] index, double[][] block, double[] rTrain, double[] gradient) {
for (int j = 0; j < block.length; j++) {
double[] t = block[j];
gradient[j] = 0;
for (int i = 0; i < t.length; i++) {
gradient[j] += rTrain[index[i]] * t[i];
}
}
}
protected boolean doOnePass(double[][][] x, double[] tl1, double[] tl2, boolean isFullPass, boolean[] activeSet,
double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1, double[] gamma2,
double[] rTrain) {
boolean activeSetChanged = false;
for (int k = 0; k < x.length; k++) {
if (!isFullPass && !activeSet[k]) {
continue;
}
double[][] block = x[k];
double[] beta = w[k];
double tk = stepSize[k];
double lambda1 = tl1[k];
double lambda2 = tl2[k];
// Proximal gradient method
computeGradient(block, rTrain, gradient);
for (int j = 0; j < beta.length; j++) {
g[j] = beta[j] + tk * gradient[j];
}
// Dual method
if (beta.length > 1) {
for (int i = 1; i < beta.length; i++) {
gamma2[i - 1] = g[i];
}
double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1));
double t2 = lambda2 * tk;
if (norm2 > t2) {
VectorUtils.multiply(gamma2, t2 / norm2);
}
}
gamma1[0] = g[0];
for (int i = 1; i < beta.length; i++) {
gamma1[i] = g[i] - gamma2[i - 1];
}
double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length));
double t1 = lambda1 * tk;
if (norm1 > t1) {
VectorUtils.multiply(gamma1, t1 / norm1);
}
g[0] -= gamma1[0];
for (int i = 1; i < beta.length; i++) {
g[i] -= (gamma1[i] + gamma2[i - 1]);
}
// Update residuals
for (int j = 0; j < beta.length; j++) {
double[] t = block[j];
double delta = beta[j] - g[j];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] += delta * t[i];
}
}
// Update weights
for (int j = 0; j < beta.length; j++) {
beta[j] = g[j];
}
if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) {
activeSetChanged = true;
activeSet[k] = true;
}
}
return activeSetChanged;
}
protected boolean doOnePass(double[][][] x, double[] y, double[] tl1, double[] tl2, boolean isFullPass,
boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1,
double[] gamma2, double[] pTrain, double[] rTrain) {
boolean activeSetChanged = false;
for (int k = 0; k < x.length; k++) {
if (!isFullPass && !activeSet[k]) {
continue;
}
double[][] block = x[k];
double[] beta = w[k];
double tk = stepSize[k];
double lambda1 = tl1[k];
double lambda2 = tl2[k];
// Proximal gradient method
computeGradient(block, rTrain, gradient);
for (int j = 0; j < beta.length; j++) {
g[j] = beta[j] + tk * gradient[j];
}
// Dual method
if (beta.length > 1) {
for (int i = 1; i < beta.length; i++) {
gamma2[i - 1] = g[i];
}
double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1));
double t2 = lambda2 * tk;
if (norm2 > t2) {
VectorUtils.multiply(gamma2, t2 / norm2);
}
}
gamma1[0] = g[0];
for (int i = 1; i < beta.length; i++) {
gamma1[i] = g[i] - gamma2[i - 1];
}
double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length));
double t1 = lambda1 * tk;
if (norm1 > t1) {
VectorUtils.multiply(gamma1, t1 / norm1);
}
g[0] -= gamma1[0];
for (int i = 1; i < beta.length; i++) {
g[i] -= (gamma1[i] + gamma2[i - 1]);
}
// Update predictions
for (int j = 0; j < beta.length; j++) {
double[] t = block[j];
double delta = g[j] - beta[j];
for (int i = 0; i < y.length; i++) {
pTrain[i] += delta * t[i];
}
}
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
// Update weights
for (int j = 0; j < beta.length; j++) {
beta[j] = g[j];
}
if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) {
activeSetChanged = true;
activeSet[k] = true;
}
}
return activeSetChanged;
}
protected boolean doOnePass(int[][] indices, double[][][] values, double[] tl1, double[] tl2, boolean isFullPass,
boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient, double[] gamma1,
double[] gamma2, double[] rTrain) {
boolean activeSetChanged = false;
for (int k = 0; k < values.length; k++) {
if (!isFullPass && !activeSet[k]) {
continue;
}
double[][] block = values[k];
int[] index = indices[k];
double[] beta = w[k];
double tk = stepSize[k];
double lambda1 = tl1[k];
double lambda2 = tl2[k];
// Proximal gradient method
computeGradient(index, block, rTrain, gradient);
for (int j = 0; j < beta.length; j++) {
g[j] = beta[j] + tk * gradient[j];
}
// Dual method
if (beta.length > 1) {
for (int i = 1; i < beta.length; i++) {
gamma2[i - 1] = g[i];
}
double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1));
double t2 = lambda2 * tk;
if (norm2 > t2) {
VectorUtils.multiply(gamma2, t2 / norm2);
}
}
gamma1[0] = g[0];
for (int i = 1; i < beta.length; i++) {
gamma1[i] = g[i] - gamma2[i - 1];
}
double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length));
double t1 = lambda1 * tk;
if (norm1 > t1) {
VectorUtils.multiply(gamma1, t1 / norm1);
}
g[0] -= gamma1[0];
for (int i = 1; i < beta.length; i++) {
g[i] -= (gamma1[i] + gamma2[i - 1]);
}
// Update predictions
for (int j = 0; j < beta.length; j++) {
double[] t = block[j];
double delta = beta[j] - g[j];
for (int i = 0; i < t.length; i++) {
rTrain[index[i]] += delta * t[i];
}
}
// Update weights
for (int j = 0; j < beta.length; j++) {
beta[j] = g[j];
}
if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) {
activeSetChanged = true;
activeSet[k] = true;
}
}
return activeSetChanged;
}
protected boolean doOnePass(int[][] indices, double[][][] values, double[] y, double[] tl1, double[] tl2,
boolean isFullPass, boolean[] activeSet, double[][] w, double[] stepSize, double[] g, double[] gradient,
double[] gamma1, double[] gamma2, double[] pTrain, double[] rTrain) {
boolean activeSetChanged = false;
for (int k = 0; k < values.length; k++) {
if (!isFullPass && !activeSet[k]) {
continue;
}
int[] index = indices[k];
double[][] block = values[k];
double[] beta = w[k];
double tk = stepSize[k];
double lambda1 = tl1[k];
double lambda2 = tl2[k];
// Proximal gradient method
computeGradient(index, block, rTrain, gradient);
for (int j = 0; j < beta.length; j++) {
g[j] = beta[j] + tk * gradient[j];
}
// Dual method
if (beta.length > 1) {
for (int j = 1; j < beta.length; j++) {
gamma2[j - 1] = g[j];
}
double norm2 = Math.sqrt(StatUtils.sumSq(gamma2, 0, beta.length - 1));
double t2 = lambda2 * tk;
if (norm2 > t2) {
VectorUtils.multiply(gamma2, t2 / norm2);
}
}
gamma1[0] = g[0];
for (int j = 1; j < beta.length; j++) {
gamma1[j] = g[j] - gamma2[j - 1];
}
double norm1 = Math.sqrt(StatUtils.sumSq(gamma1, 0, beta.length));
double t1 = lambda1 * tk;
if (norm1 > t1) {
VectorUtils.multiply(gamma1, t1 / norm1);
}
g[0] -= gamma1[0];
for (int i = 1; i < beta.length; i++) {
g[i] -= (gamma1[i] + gamma2[i - 1]);
}
// Update predictions
for (int j = 0; j < beta.length; j++) {
double[] value = block[j];
double delta = g[j] - beta[j];
for (int i = 0; i < value.length; i++) {
pTrain[index[i]] += delta * value[i];
}
}
for (int idx : index) {
rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]);
}
// Update weights
for (int j = 0; j < beta.length; j++) {
beta[j] = g[j];
}
if (isFullPass && !activeSet[k] && !ArrayUtils.isConstant(beta, 0, beta.length, 0)) {
activeSetChanged = true;
activeSet[k] = true;
}
}
return activeSetChanged;
}
protected byte[] extractStructure(double[][] w) {
byte[] structure = new byte[w.length];
for (int i = 0; i < structure.length; i++) {
double[] beta = w[i];
boolean isLinear = beta.length == 1 || ArrayUtils.isConstant(beta, 1, beta.length, 0);
if (isLinear) {
if (beta[0] != 0) {
structure[i] = ModelStructure.LINEAR;
} else {
structure[i] = ModelStructure.ELIMINATED;
}
} else {
structure[i] = ModelStructure.NONLINEAR;
}
}
return structure;
}
protected GAM getGAM(int[] attrs, double[][] knots, double[][] w, double intercept) {
GAM gam = new GAM();
for (int j = 0; j < attrs.length; j++) {
int attIndex = attrs[j];
double[] beta = w[j];
boolean isLinear = beta.length == 1 || ArrayUtils.isConstant(beta, 1, beta.length, 0);
if (isLinear) {
if (beta[0] != 0) {
// To rule out a feature, it has to be "linear" and 0 slope.
gam.add(new int[] { attIndex }, new LinearFunction(attIndex, beta[0]));
}
} else {
double[] coef = Arrays.copyOf(beta, beta.length);
CubicSpline spline = new CubicSpline(attIndex, 0, knots[j], coef);
gam.add(new int[] { attIndex }, spline);
}
}
gam.setIntercept(intercept);
return gam;
}
protected double getPenalty(double[] w, double lambda1, double lambda2) {
double penalty = 0;
double sumSq = StatUtils.sumSq(w);
double norm1 = Math.sqrt(sumSq);
penalty += lambda1 * norm1;
double norm2 = sumSq - w[0] * w[0];
norm2 = Math.sqrt(norm2);
penalty += lambda2 * norm2;
return penalty;
}
protected double getPenalty(double[][] coef, double[] lambda1, double[] lambda2) {
double penalty = 0;
for (int i = 0; i < coef.length; i++) {
penalty += getPenalty(coef[i], lambda1[i], lambda2[i]);
}
return penalty;
}
protected void getRegularizationParameters(double lambda, double alpha, double[] tl1, double[] tl2, int n) {
for (int j = 0; j < tl1.length; j++) {
tl1[j] = lambda * alpha * n;
tl2[j] = lambda * (1 - alpha) * n;
}
}
protected GAM refitClassifier(int[] attrs, byte[] struct, double[][][] x, double[] y, double[][] knots,
double[][] w, int maxNumIters) {
List xList = new ArrayList<>();
for (int i = 0; i < struct.length; i++) {
if (struct[i] == ModelStructure.NONLINEAR) {
double[][] t = x[i];
for (int j = 0; j < t.length; j++) {
xList.add(t[j]);
}
} else if (struct[i] == ModelStructure.LINEAR) {
xList.add(x[i][0]);
}
}
double[][] xNew = new double[xList.size()][];
for (int i = 0; i < xNew.length; i++) {
xNew[i] = xList.get(i);
}
int[] attrsNew = new int[xNew.length];
for (int i = 0; i < attrsNew.length; i++) {
attrsNew[i] = i;
}
RidgeLearner ridgeLearner = new RidgeLearner();
ridgeLearner.setVerbose(verbose);
ridgeLearner.setEpsilon(epsilon);
ridgeLearner.fitIntercept(fitIntercept);
// A ridge regression with very small regularization parameter
// This often improves stability a lot
GLM glm = ridgeLearner.buildBinaryClassifier(attrsNew, xNew, y, maxNumIters, 1e-8);
GAM gam = getGAM(attrs, knots, w, glm.intercept(0));
List regressors = gam.regressors;
double[] coef = glm.coefficients(0);
int k = 0;
for (Regressor regressor : regressors) {
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(coef[k++]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] beta = spline.getCoefficients();
for (int i = 0; i < beta.length; i++) {
beta[i] = coef[k++];
}
}
}
return gam;
}
protected GAM refitClassifier(int[] attrs, byte[] struct, int[][] indices, double[][][] values, double[] y,
double[][] knots, double[][] w, int maxNumIters) {
List iList = new ArrayList<>();
List vList = new ArrayList<>();
for (int i = 0; i < struct.length; i++) {
int[] index = indices[i];
if (struct[i] == ModelStructure.NONLINEAR) {
double[][] t = values[i];
for (int j = 0; j < t.length; j++) {
iList.add(index);
vList.add(t[j]);
}
} else if (struct[i] == ModelStructure.LINEAR) {
iList.add(index);
vList.add(values[i][0]);
}
}
int[][] iNew = new int[iList.size()][];
for (int i = 0; i < iNew.length; i++) {
iNew[i] = iList.get(i);
}
double[][] vNew = new double[vList.size()][];
for (int i = 0; i < vNew.length; i++) {
vNew[i] = vList.get(i);
}
int[] attrsNew = new int[iNew.length];
for (int i = 0; i < attrsNew.length; i++) {
attrsNew[i] = i;
}
RidgeLearner ridgeLearner = new RidgeLearner();
ridgeLearner.setVerbose(verbose);
ridgeLearner.setEpsilon(epsilon);
ridgeLearner.fitIntercept(fitIntercept);
// A ridge regression with very small regularization parameter
// This often improves stability a lot
GLM glm = ridgeLearner.buildBinaryClassifier(attrsNew, iNew, vNew, y, maxNumIters, 1e-8);
GAM gam = getGAM(attrs, knots, w, glm.intercept(0));
List regressors = gam.regressors;
double[] coef = glm.coefficients(0);
int k = 0;
for (Regressor regressor : regressors) {
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(coef[k++]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] beta = spline.getCoefficients();
for (int j = 0; j < beta.length; j++) {
beta[j] = coef[k++];
}
}
}
return gam;
}
protected GAM refitRegressor(int[] attrs, byte[] struct, double[][][] x, double[] y, double[][] knots,
double[][] w, int maxNumIters) {
List xList = new ArrayList<>();
for (int i = 0; i < struct.length; i++) {
if (struct[i] == ModelStructure.NONLINEAR) {
double[][] t = x[i];
for (int j = 0; j < t.length; j++) {
xList.add(t[j]);
}
} else if (struct[i] == ModelStructure.LINEAR) {
xList.add(x[i][0]);
}
}
if (xList.size() == 0) {
if (fitIntercept) {
double intercept = StatUtils.mean(y);
GAM gam = new GAM();
gam.setIntercept(intercept);
return gam;
} else {
return new GAM();
}
}
double[][] xNew = new double[xList.size()][];
for (int i = 0; i < xNew.length; i++) {
xNew[i] = xList.get(i);
}
int[] attrsNew = new int[xNew.length];
for (int i = 0; i < attrsNew.length; i++) {
attrsNew[i] = i;
}
RidgeLearner ridgeLearner = new RidgeLearner();
ridgeLearner.setVerbose(verbose);
ridgeLearner.setEpsilon(epsilon);
ridgeLearner.fitIntercept(fitIntercept);
// A ridge regression with very small regularization parameter
// This often improves stability a lot
GLM glm = ridgeLearner.buildGaussianRegressor(attrsNew, xNew, y, maxNumIters, 1e-8);
GAM gam = getGAM(attrs, knots, w, glm.intercept(0));
List regressors = gam.regressors;
double[] coef = glm.coefficients(0);
int k = 0;
for (Regressor regressor : regressors) {
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(coef[k++]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] beta = spline.getCoefficients();
for (int i = 0; i < beta.length; i++) {
beta[i] = coef[k++];
}
}
}
return gam;
}
protected GAM refitRegressor(int[] attrs, byte[] struct, int[][] indices, double[][][] values, double[] y,
double[][] knots, double[][] w, int maxNumIters) {
List iList = new ArrayList<>();
List vList = new ArrayList<>();
for (int i = 0; i < struct.length; i++) {
int[] index = indices[i];
if (struct[i] == ModelStructure.NONLINEAR) {
double[][] t = values[i];
for (int j = 0; j < t.length; j++) {
iList.add(index);
vList.add(t[j]);
}
} else if (struct[i] == ModelStructure.LINEAR) {
iList.add(index);
vList.add(values[i][0]);
}
}
int[][] iNew = new int[iList.size()][];
for (int i = 0; i < iNew.length; i++) {
iNew[i] = iList.get(i);
}
double[][] vNew = new double[vList.size()][];
for (int i = 0; i < vNew.length; i++) {
vNew[i] = vList.get(i);
}
int[] attrsNew = new int[iNew.length];
for (int i = 0; i < attrsNew.length; i++) {
attrsNew[i] = i;
}
RidgeLearner ridgeLearner = new RidgeLearner();
ridgeLearner.setVerbose(verbose);
ridgeLearner.setEpsilon(epsilon);
ridgeLearner.fitIntercept(fitIntercept);
// A ridge regression with very small regularization parameter
// This often improves stability a lot
GLM glm = ridgeLearner.buildGaussianRegressor(attrsNew, iNew, vNew, y, maxNumIters, 1e-8);
GAM gam = getGAM(attrs, knots, w, glm.intercept(0));
List regressors = gam.regressors;
double[] coef = glm.coefficients(0);
int k = 0;
for (Regressor regressor : regressors) {
if (regressor instanceof LinearFunction) {
LinearFunction func = (LinearFunction) regressor;
func.setSlope(coef[k++]);
} else if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
double[] beta = spline.getCoefficients();
for (int j = 0; j < beta.length; j++) {
beta[j] = coef[k++];
}
}
}
return gam;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/ScorecardModelLearner.java
================================================
package mltk.predictor.gam;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.cmdline.options.LearnerWithTaskOptions;
import mltk.core.Instances;
import mltk.core.io.InstancesReader;
import mltk.core.processor.OneHotEncoder;
import mltk.predictor.Learner;
import mltk.predictor.glm.GLM;
import mltk.predictor.glm.RidgeLearner;
import mltk.predictor.io.PredictorWriter;
/**
* Class for learning scorecard models. Scorecard models are a special kind of
* generalized additive models where scores for each state of the feature are
* learned.
*
* @author Yin Lou
*
*/
public class ScorecardModelLearner extends Learner {
static class Options extends LearnerWithTaskOptions {
@Argument(name = "-m", description = "maximum number of iterations", required = true)
int maxNumIters = -1;
@Argument(name = "-l", description = "lambda (default: 0)")
double lambda = 0;
}
/**
* Trains a scorecard model.
*
*
* Usage: mltk.predictor.gam.ScorecardModelLearner
* -t train set path
* -m maximum number of iterations
* [-g] task between classification (c) and regression (r) (default: r)
* [-r] attribute file path
* [-o] output model path
* [-V] verbose (default: true)
* [-l] lambda (default: 0)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(ScorecardModelLearner.class, opts);
Task task = null;
try {
parser.parse(args);
task = Task.get(opts.task);
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath);
ScorecardModelLearner learner = new ScorecardModelLearner();
learner.setMaxNumIters(opts.maxNumIters);
learner.setLambda(opts.lambda);
learner.setTask(task);
learner.setVerbose(opts.verbose);
long start = System.currentTimeMillis();
GAM gam = learner.build(trainSet);
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000.0);
if (opts.outputModelPath != null) {
PredictorWriter.write(gam, opts.outputModelPath);
}
}
private int maxNumIters;
private double lambda;
private Task task;
private OneHotEncoder encoder;
/**
* Constructor.
*/
public ScorecardModelLearner() {
verbose = false;
maxNumIters = -1;
encoder = new OneHotEncoder();
lambda = 0;
task = Task.REGRESSION;
}
/**
* Returns the lambda.
*
* @return the lambda.
*/
public double getLambda() {
return lambda;
}
/**
* Sets the lambda.
*
* @param lambda the lambda.
*/
public void setLambda(double lambda) {
this.lambda = lambda;
}
/**
* Returns the maximum number of iterations.
*
* @return the maximum number of iterations.
*/
public int getMaxNumIters() {
return maxNumIters;
}
/**
* Sets the maximum number of iterations.
*
* @param maxNumIters the maximum number of iterations.
*/
public void setMaxNumIters(int maxNumIters) {
this.maxNumIters = maxNumIters;
}
/**
* Returns the task of this learner.
*
* @return the task of this learner.
*/
public Task getTask() {
return task;
}
/**
* Sets the task of this learner.
*
* @param task the task of this learner.
*/
public void setTask(Task task) {
this.task = task;
}
/**
* Builds a classifier.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param lambda the L2 regularization parameter.
* @return a classifier.
*/
public GAM buildClassifier(Instances trainSet, int maxNumIters, double lambda) {
Instances trainSetNew = encoder.process(trainSet);
RidgeLearner learner = new RidgeLearner();
learner.setTask(Task.CLASSIFICATION);
learner.setLambda(lambda);
learner.setVerbose(verbose);
learner.setMaxNumIters(maxNumIters);
GLM glm = learner.build(trainSetNew);
return GAMUtils.getGAM(glm, trainSet.getAttributes());
}
/**
* Builds a regressor.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param lambda the L2 regularization parameter.
* @return a regressor.
*/
public GAM buildRegressor(Instances trainSet, int maxNumIters, double lambda) {
Instances trainSetNew = encoder.process(trainSet);
RidgeLearner learner = new RidgeLearner();
learner.setTask(Task.REGRESSION);
learner.setLambda(lambda);
learner.setVerbose(verbose);
learner.setMaxNumIters(maxNumIters);
GLM glm = learner.build(trainSetNew);
return GAMUtils.getGAM(glm, trainSet.getAttributes());
}
@Override
public GAM build(Instances instances) {
GAM gam = null;
if (maxNumIters < 0) {
maxNumIters = instances.dimension() * 20;
}
switch (task) {
case REGRESSION:
gam = buildRegressor(instances, maxNumIters, lambda);
break;
case CLASSIFICATION:
gam = buildClassifier(instances, maxNumIters, lambda);
break;
default:
break;
}
return gam;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/SparseDesignMatrix.java
================================================
package mltk.predictor.gam;
import java.util.HashSet;
import java.util.Set;
import mltk.predictor.function.CubicSpline;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
class SparseDesignMatrix {
int[][] indices;
double[][][] values;
double[][] knots;
double[][] std;
SparseDesignMatrix(int[][] indices, double[][][] values, double[][] knots, double[][] std) {
this.indices = indices;
this.values = values;
this.knots = knots;
this.std = std;
}
static SparseDesignMatrix createCubicSplineDesignMatrix(int n, int[][] indices, double[][] values,
double[] stdList, int numKnots) {
final int p = indices.length;
double[][][] x = new double[p][][];
double[][] knots = new double[p][];
double[][] std = new double[p][];
double factor = Math.sqrt(n);
for (int j = 0; j < values.length; j++) {
Set uniqueValues = new HashSet<>();
double[] x1 = values[j];
for (int i = 0; i < values[j].length; i++) {
uniqueValues.add(x1[i]);
}
int nKnots = uniqueValues.size() + 1 <= numKnots ? 0 : numKnots;
knots[j] = new double[nKnots];
if (nKnots != 0) {
x[j] = new double[nKnots + 3][];
std[j] = new double[nKnots + 3];
} else {
x[j] = new double[1][];
std[j] = new double[1];
}
double[][] tX = x[j];
tX[0] = x1;
std[j][0] = stdList[j] / factor;
if (nKnots != 0) {
double[] x2 = new double[x1.length];
for (int i = 0; i < x2.length; i++) {
x2[i] = x1[i] * x1[i];
}
tX[1] = x2;
double[] x3 = new double[x1.length];
for (int i = 0; i < x3.length; i++) {
x3[i] = x2[i] * x1[i];
}
tX[2] = x3;
std[j][1] = StatUtils.sd(x2, n) / factor;
std[j][2] = StatUtils.sd(x3, n) / factor;
double max = Math.max(StatUtils.max(x1), 0);
double min = Math.min(StatUtils.min(x1), 0);
double stepSize = (max - min) / nKnots;
for (int k = 0; k < nKnots; k++) {
knots[j][k] = min + stepSize * k;
double[] basis = new double[x1.length];
double zero = CubicSpline.h(0, knots[j][k]);
for (int i = 0; i < basis.length; i++) {
basis[i] = CubicSpline.h(x1[i], knots[j][k]) - zero;
}
std[j][k + 3] = StatUtils.sd(basis, n) / factor;
tX[k + 3] = basis;
}
}
}
// Normalize the inputs
for (int j = 0; j < p; j++) {
double[][] block = x[j];
double[] s = std[j];
for (int i = 0; i < block.length; i++) {
VectorUtils.divide(block[i], s[i]);
}
}
return new SparseDesignMatrix(indices, x, knots, std);
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/interaction/FAST.java
================================================
package mltk.predictor.gam.interaction;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.Attribute.Type;
import mltk.core.io.InstancesReader;
import mltk.core.processor.Discretizer;
import mltk.predictor.function.CHistogram;
import mltk.predictor.function.Histogram2D;
import mltk.util.Element;
import mltk.util.MathUtils;
import mltk.util.tuple.IntPair;
/**
* Class for fast interaction detection.
*
*
* Reference:
* Y. Lou, R. Caruana, J. Gehrke, and G. Hooker. Accurate intelligible models with pairwise interactions. In
* Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD),
* Chicago, IL, USA, 2013.
*
*
* @author Yin Lou
*
*/
public class FAST {
static class FASTThread extends Thread {
List> pairs;
Instances instances;
FASTThread(Instances instances) {
this.instances = instances;
this.pairs = new ArrayList<>();
}
public void add(Element pair) {
pairs.add(pair);
}
public void run() {
FAST.computeWeights(instances, pairs);
}
}
static class Options {
@Argument(name = "-r", description = "attribute file path")
String attPath = null;
@Argument(name = "-d", description = "dataset path", required = true)
String datasetPath = null;
@Argument(name = "-R", description = "residual path", required = true)
String residualPath = null;
@Argument(name = "-o", description = "output path", required = true)
String outputPath = null;
@Argument(name = "-b", description = "number of bins (default: 256)")
int maxNumBins = 256;
@Argument(name = "-p", description = "number of threads (default: 1)")
int numThreads = 1;
}
/**
* Ranks pairwise interactions using FAST.
*
*
* Usage: mltk.predictor.gam.interaction.FAST
* -d dataset path
* -R residual path
* -o output path
* [-r] attribute file path
* [-b] number of bins (default: 256)
* [-p] number of threads (default: 1)
*
*
* @param args the command line arguments
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(FAST.class, opts);
try {
parser.parse(args);
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Instances instances = InstancesReader.read(opts.attPath, opts.datasetPath);
System.out.println("Reading residuals...");
BufferedReader br = new BufferedReader(new FileReader(opts.residualPath), 65535);
for (int i = 0; i < instances.size(); i++) {
String line = br.readLine();
double residual = Double.parseDouble(line);
Instance instance = instances.get(i);
instance.setTarget(residual);
}
br.close();
List attributes = instances.getAttributes();
System.out.println("Discretizing attribute...");
for (int i = 0; i < attributes.size(); i++) {
if (attributes.get(i).getType() == Type.NUMERIC) {
Discretizer.discretize(instances, i, opts.maxNumBins);
}
}
System.out.println("Generating all pairs of attributes...");
List> pairs = new ArrayList<>();
for (int i = 0; i < attributes.size(); i++) {
for (int j = i + 1; j < attributes.size(); j++) {
pairs.add(new Element(new IntPair(i, j), 0.0));
}
}
System.out.println("Creating threads...");
FASTThread[] threads = new FASTThread[opts.numThreads];
long start = System.currentTimeMillis();
for (int i = 0; i < threads.length; i++) {
threads[i] = new FASTThread(instances);
}
for (int i = 0; i < pairs.size(); i++) {
threads[i % threads.length].add(pairs.get(i));
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
System.out.println("Running FAST...");
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
long end = System.currentTimeMillis();
System.out.println("Sorting pairs...");
Collections.sort(pairs);
System.out.println("Time: " + (end - start) / 1000.0);
PrintWriter out = new PrintWriter(opts.outputPath);
for (int i = 0; i < pairs.size(); i++) {
Element pair = pairs.get(i);
out.println(pair.element.v1 + "\t" + pair.element.v2 + "\t" + pair.weight);
}
out.flush();
out.close();
}
/**
* Computes the weights of pairwise interactions.
*
* @param instances the training set.
* @param pairs the list of pairs to compute.
*/
public static void computeWeights(Instances instances, List> pairs) {
List attributes = instances.getAttributes();
boolean[] used = new boolean[attributes.size()];
for (Element pair : pairs) {
int f1 = pair.element.v1;
int f2 = pair.element.v2;
used[f1] = used[f2] = true;
}
CHistogram[] cHist = new CHistogram[attributes.size()];
for (int i = 0; i < cHist.length; i++) {
if (used[i]) {
switch (attributes.get(i).getType()) {
case BINNED:
BinnedAttribute binnedAtt = (BinnedAttribute) attributes.get(i);
cHist[i] = new CHistogram(binnedAtt.getNumBins());
break;
case NOMINAL:
NominalAttribute nominalAtt = (NominalAttribute) attributes.get(i);
cHist[i] = new CHistogram(nominalAtt.getCardinality());
default:
break;
}
}
}
double ySq = computeCHistograms(instances, used, cHist);
for (Element pair : pairs) {
final int f1 = pair.element.v1;
final int f2 = pair.element.v2;
final int size1 = cHist[f1].size();
final int size2 = cHist[f2].size();
Histogram2D hist2d = new Histogram2D(size1, size2);
Histogram2D.computeHistogram2D(instances, f1, f2, hist2d);
computeWeight(pair, cHist, hist2d, ySq);
}
}
protected static double computeCHistograms(Instances instances, boolean[] used, CHistogram[] cHist) {
double ySq = 0;
// compute histogram
for (Instance instance : instances) {
double resp = instance.getTarget();
for (int j = 0; j < instances.getAttributes().size(); j++) {
if (used[j]) {
if (!instance.isMissing(j)) {
int idx = (int) instance.getValue(j);
cHist[j].sum[idx] += resp * instance.getWeight();
cHist[j].count[idx] += instance.getWeight();
} else {
cHist[j].sumOnMV += resp * instance.getWeight();
cHist[j].countOnMV += instance.getWeight();
}
}
}
ySq += resp * resp * instance.getWeight();
}
// compute cumulative histogram
for (int j = 0; j < cHist.length; j++) {
if (used[j]) {
for (int idx = 1; idx < cHist[j].size(); idx++) {
cHist[j].sum[idx] += cHist[j].sum[idx - 1];
cHist[j].count[idx] += cHist[j].count[idx - 1];
}
}
}
return ySq;
}
protected static void computeWeight(Element pair, CHistogram[] cHist, Histogram2D hist2d, double ySq) {
final int f1 = pair.element.v1;
final int f2 = pair.element.v2;
final int size1 = cHist[f1].size();
final int size2 = cHist[f2].size();
Histogram2D.Table table = Histogram2D.computeTable(hist2d, cHist[f1], cHist[f2]);
double bestRSS = Double.POSITIVE_INFINITY;
double[] predInt = new double[4];
double[] predOnMV1 = new double[2];
double[] predOnMV2 = new double[2];
double predOnMV12 = MathUtils.divide(hist2d.respOnMV12, hist2d.countOnMV12, 0);
for (int v1 = 0; v1 < size1 - 1; v1++) {
for (int v2 = 0; v2 < size2 - 1; v2++) {
getPredictor(table, v1, v2, predInt, predOnMV1, predOnMV2);
double rss = getRSS(table, v1, v2, ySq, predInt, predOnMV1, predOnMV2, predOnMV12);
if (rss < bestRSS) {
bestRSS = rss;
}
}
}
pair.weight = bestRSS;
}
protected static void getPredictor(Histogram2D.Table table, int v1, int v2,
double[] pred, double[] predOnMV1, double[] predOnMV2) {
double[] count = table.count[v1][v2];
double[] resp = table.resp[v1][v2];
for (int i = 0; i < pred.length; i++) {
pred[i] = MathUtils.divide(resp[i], count[i], 0);
}
for (int i = 0; i < predOnMV1.length; i++) {
predOnMV1[i] = MathUtils.divide(table.respOnMV1[v2][i], table.countOnMV1[v2][i], 0);
}
for (int i = 0; i < predOnMV2.length; i++) {
predOnMV2[i] = MathUtils. divide(table.respOnMV2[v1][i], table.countOnMV2[v1][i], 0);
}
}
protected static double getRSS(Histogram2D.Table table, int v1, int v2, double ySq,
double[] pred, double[] predOnMV1, double[] predOnMV2, double predOnMV12) {
double[] count = table.count[v1][v2];
double[] resp = table.resp[v1][v2];
double[] respOnMV1 = table.respOnMV1[v2];
double[] countOnMV1 = table.countOnMV1[v2];
double[] respOnMV2 = table.respOnMV2[v1];
double[] countOnMV2 = table.countOnMV2[v1];
double rss = ySq;
// Compute main area
double t = 0;
for (int i = 0; i < pred.length; i++) {
t += pred[i] * pred[i] * count[i];
}
rss += t;
t = 0;
for (int i = 0; i < pred.length; i++) {
t += pred[i] * resp[i];
}
rss -= 2 * t;
// Compute on mv1
t = 0;
for (int i = 0; i < predOnMV1.length; i++) {
t += predOnMV1[i] * predOnMV1[i] * countOnMV1[i];
}
rss += t;
t = 0;
for (int i = 0; i < predOnMV1.length; i++) {
t += predOnMV1[i] * respOnMV1[i];
}
rss -= 2 * t;
// Compute on mv2
t = 0;
for (int i = 0; i < predOnMV2.length; i++) {
t += predOnMV2[i] * predOnMV2[i] * countOnMV2[i];
}
rss += t;
t = 0;
for (int i = 0; i < predOnMV2.length; i++) {
t += predOnMV2[i] * respOnMV2[i];
}
rss -= 2 * t;
return rss;
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/interaction/package-info.java
================================================
/**
* Provides algorithms for feature interaction detection.
*/
package mltk.predictor.gam.interaction;
================================================
FILE: src/main/java/mltk/predictor/gam/package-info.java
================================================
/**
* Provides algorithms for fitting generalized additive models (GAMs).
*/
package mltk.predictor.gam;
================================================
FILE: src/main/java/mltk/predictor/gam/tool/Diagnostics.java
================================================
package mltk.predictor.gam.tool;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.io.InstancesReader;
import mltk.predictor.Regressor;
import mltk.predictor.gam.GAM;
import mltk.predictor.io.PredictorReader;
import mltk.util.StatUtils;
import mltk.util.Element;
/**
* Class for GAM diagnostics.
*
* @author Yin Lou
*
*/
public class Diagnostics {
/**
* Enumeration of methods for calculating term importance.
*
* @author Yin Lou
*
*/
public enum Mode {
/**
* L1.
*/
L1("L1"),
/**
* PDF terminal.
*/
L2("L2");
String mode;
Mode(String mode) {
this.mode = mode;
}
public String toString() {
return mode;
}
/**
* Parses a mode from a string.
*
* @param mode the mode.
* @return a parsed terminal.
*/
public static Mode getEnum(String mode) {
for (Mode re : Mode.values()) {
if (re.mode.compareTo(mode) == 0) {
return re;
}
}
throw new IllegalArgumentException("Invalid mode: " + mode);
}
}
/**
* Computes the weights for each term in a GAM.
*
* @param gam the GAM model.
* @param instances the training set.
* @return the list of weights for each term in a GAM.
*/
public static List> diagnose(GAM gam, Instances instances) {
return diagnose(gam, instances, Mode.L2);
}
/**
* Computes the weights for each term in a GAM.
*
* @param gam the GAM model.
* @param instances the training set.
* @return the list of weights for each term in a GAM.
*/
public static List> diagnose(GAM gam, Instances instances, Mode mode) {
List> list = new ArrayList<>();
Map> map = new HashMap<>();
List terms = gam.getTerms();
List regressors = gam.getRegressors();
for (int i = 0; i < terms.size(); i++) {
int[] term = terms.get(i);
if (!map.containsKey(term)) {
map.put(term, new ArrayList());
}
Regressor regressor = regressors.get(i);
map.get(term).add(regressor);
}
double[] predictions = new double[instances.size()];
for (int[] term : map.keySet()) {
List regressorList = map.get(term);
for (int i = 0; i < instances.size(); i++) {
predictions[i] = 0;
Instance instance = instances.get(i);
for (Regressor regressor : regressorList) {
predictions[i] += regressor.regress(instance);
}
}
double weight = 0;
if (mode == Mode.L2) {
weight = StatUtils.variance(predictions);
} else {
double mean = StatUtils.mean(predictions);
weight = StatUtils.mad(predictions, mean);
}
list.add(new Element(term, weight));
}
return list;
}
static class Options {
@Argument(name = "-r", description = "attribute file path")
String attPath = null;
@Argument(name = "-d", description = "dataset path", required = true)
String datasetPath = null;
@Argument(name = "-m", description = "mode (L1 or L2, default: L2)")
String mode = null;
@Argument(name = "-i", description = "input model path", required = true)
String inputModelPath = null;
@Argument(name = "-o", description = "output path", required = true)
String outputPath = null;
}
/**
* Generates term importance for GAMs.
*
*
* Usage: mltk.predictor.gam.tool.Diagnostics
* -d dataset path
* -i input model path
* -o output path
* [-r] attribute file path
* [-m] mode (L1 or L2, default: L2)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(Diagnostics.class, opts);
try {
parser.parse(args);
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Instances dataset = InstancesReader.read(opts.attPath, opts.datasetPath);
GAM gam = PredictorReader.read(opts.inputModelPath, GAM.class);
List> list = Diagnostics.diagnose(gam, dataset, Mode.getEnum(opts.mode));
Collections.sort(list);
Collections.reverse(list);
PrintWriter out = new PrintWriter(opts.outputPath);
for (Element element : list) {
int[] term = element.element;
double weight = element.weight;
out.println(Arrays.toString(term) + ": " + weight);
}
out.flush();
out.close();
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/tool/Visualizer.java
================================================
package mltk.predictor.gam.tool;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Bins;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.io.InstancesReader;
import mltk.predictor.Regressor;
import mltk.predictor.function.Array1D;
import mltk.predictor.function.Array2D;
import mltk.predictor.function.CubicSpline;
import mltk.predictor.function.Function1D;
import mltk.predictor.gam.GAM;
import mltk.predictor.io.PredictorReader;
import mltk.util.MathUtils;
/**
* Class for visualizing 1D and 2D components in a GAM.
*
* @author Yin Lou
*
*/
public class Visualizer {
/**
* Enumeration of output terminals.
*
* @author Yin Lou
*
*/
public enum Terminal {
/**
* PNG terminal.
*/
PNG("png"),
/**
* PDF terminal.
*/
PDF("pdf");
String term;
Terminal(String term) {
this.term = term;
}
public String toString() {
return term;
}
/**
* Parses an enumeration from a string.
*
* @param term the string.
* @return a parsed terminal.
*/
public static Terminal getEnum(String term) {
for (Terminal re : Terminal.values()) {
if (re.term.compareTo(term) == 0) {
return re;
}
}
throw new IllegalArgumentException("Invalid Terminal value: " + term);
}
}
/**
* Generates a set of Gnuplot scripts for visualizing low dimensional components in a GAM.
*
* @param gam the GAM model.
* @param instances the training set.
* @param dirPath the directory path to write to.
* @param outputTerminal output plot format (png or pdf).
* @throws IOException
*/
public static void generateGnuplotScripts(GAM gam, Instances instances, String dirPath, Terminal outputTerminal)
throws IOException {
List attributes = instances.getAttributes();
int p = -1;
Map attMap = new HashMap<>(attributes.size());
for (Attribute attribute : attributes) {
int attIndex = attribute.getIndex();
attMap.put(attIndex, attribute);
if (attIndex > p) {
p = attIndex;
}
}
p++;
List terms = gam.getTerms();
List regressors = gam.getRegressors();
File dir = new File(dirPath);
if (!dir.exists()) {
dir.mkdirs();
}
double[] value = new double[p];
Instance point = new Instance(value);
String terminal = outputTerminal.toString();
for (int i = 0; i < terms.size(); i++) {
int[] term = terms.get(i);
Regressor regressor = regressors.get(i);
if (term.length == 1) {
Attribute f = attMap.get(term[0]);
switch (f.getType()) {
case BINNED:
int numBins = ((BinnedAttribute) f).getNumBins();
if (numBins == 1) {
continue;
}
break;
case NOMINAL:
int numStates = ((NominalAttribute) f).getStates().length;
if (numStates == 1) {
continue;
}
break;
default:
break;
}
Double predictionOnMV = null;
if (regressor instanceof Function1D) {
double predOnMV = ((Function1D) regressor).getPredictionOnMV();
if (!MathUtils.isZero(predOnMV)) {
predictionOnMV = predOnMV;
}
} else if (regressor instanceof Array1D) {
double predOnMV = ((Array1D) regressor).getPredictionOnMV();
if (!MathUtils.isZero(predOnMV)) {
predictionOnMV = predOnMV;
}
}
PrintWriter out = new PrintWriter(dir.getAbsolutePath() + File.separator + f.getName() + ".plt");
out.printf("set term %s\n", terminal);
out.printf("set output \"%s.%s\"\n", f.getName(), terminal);
out.println("set datafile separator \"\t\"");
out.println("set grid");
if (predictionOnMV != null) {
out.println("set multiplot layout 1,2 rowsfirst");
}
// Plot main function
switch (f.getType()) {
case BINNED:
int numBins = ((BinnedAttribute) f).getNumBins();
Bins bins = ((BinnedAttribute) f).getBins();
double[] boundaries = bins.getBoundaries();
double start = boundaries[0] - 1;
if (boundaries.length >= 2) {
start = boundaries[0] - (boundaries[1] - boundaries[0]);
}
out.printf("set xrange[%f:%f]\n", start, boundaries[boundaries.length - 1]);
List predList = new ArrayList<>();
for (int j = 0; j < numBins; j++) {
point.setValue(term[0], j);
predList.add(regressor.regress(point));
}
point.setValue(term[0], 0);
{// Writing plot data to file
String fileName = f.getName() + ".dat";
out.println("plot \"" + fileName + "\" u 1:2 w l t \"\"");
PrintWriter writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName);
writer.printf("%f\t%f\n", start, predList.get(0));
for (int j = 0; j < numBins; j++) {
point.setValue(term[0], j);
writer.printf("%f\t%f\n", boundaries[j], predList.get(j));
if (j < numBins - 1) {
writer.printf("%f\t%f\n", boundaries[j], predList.get(j + 1));
}
}
writer.flush();
writer.close();
}
break;
case NOMINAL:
out.println("set style data histogram");
out.println("set style histogram cluster gap 1");
out.println("set style fill solid border -1");
out.println("set boxwidth 0.9");
out.println("set xtic rotate by -90");
String[] states = ((NominalAttribute) f).getStates();
{// Writing plot data to file
String fileName = f.getName() + ".dat";
out.println("plot \"" + fileName + "\" u 2:xtic(1) t \"\"");
PrintWriter writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName);
for (int j = 0; j < states.length; j++) {
point.setValue(term[0], j);
writer.printf("%s\t%f\n", states[j], regressor.regress(point));
}
writer.flush();
writer.close();
}
break;
default:
Set values = new HashSet<>();
for (Instance instance : instances) {
values.add(instance.getValue(term[0]));
}
List list = new ArrayList<>(values);
Collections.sort(list);
out.printf("set xrange[%f:%f]\n", list.get(0), list.get(list.size() - 1));
if (regressor instanceof CubicSpline) {
CubicSpline spline = (CubicSpline) regressor;
out.println("z(x) = x < 0 ? 0 : x ** 3");
out.println("h(x, k) = z(x - k)");
double[] knots = spline.getKnots();
double[] w = spline.getCoefficients();
StringBuilder sb = new StringBuilder();
sb.append("plot ").append(spline.getIntercept());
sb.append(" + ").append(w[0]).append(" * x");
sb.append(" + ").append(w[1]).append(" * (x ** 2)");
sb.append(" + ").append(w[2]).append(" * (x ** 3)");
for (int j = 0; j < knots.length; j++) {
sb.append(" + ").append(w[j + 3]).append(" * ");
sb.append("h(x, ").append(knots[j]).append(")");
}
sb.append(" t \"\"");
out.println(sb.toString());
} else {
out.println("plot \"-\" u 1:2 w lp t \"\"");
for (double v : list) {
point.setValue(term[0], v);
out.printf("%f\t%f\n", v, regressor.regress(point));
}
out.println("e");
}
break;
}
// Plot prediction on missing value
if (predictionOnMV != null) {
out.println("set style fill solid border -1");
out.println("set xtic rotate by 0");
out.println("plot \"-\" using 2:xtic(1) with histogram t \"\"");
out.println("missing value\t" + predictionOnMV);
out.println("e");
}
out.flush();
out.close();
} else if (term.length == 2) {
Attribute f1 = attMap.get(term[0]);
Attribute f2 = attMap.get(term[1]);
String fileName = f1.getName() + "_" + f2.getName();
PrintWriter out = new PrintWriter(dir.getAbsolutePath()
+ File.separator + fileName + ".plt");
out.printf("set term %s\n", terminal);
out.printf("set output \"%s_%s.%s\"\n", f1.getName(),
f2.getName(), terminal);
out.println("set datafile separator \"\t\"");
int numRow = 1;
int numCol = 1;
double[] predictionsOnMV1 = null;
double[] predictionsOnMV2 = null;
double predictionsOnMV12 = 0.0;
if (regressor instanceof Array2D) {
Array2D ary2d = (Array2D) regressor;
for (double v : ary2d.getPredictionsOnMV1()) {
if (!MathUtils.isZero(v)) {
numRow = 2;
predictionsOnMV1 = ary2d.getPredictionsOnMV1();
break;
}
}
for (double v : ary2d.getPredictionsOnMV2()) {
if (!MathUtils.isZero(v)) {
numCol = 2;
predictionsOnMV2 = ary2d.getPredictionsOnMV2();
break;
}
}
if (!MathUtils.isZero(ary2d.getPredictionOnMV12())) {
numRow = 2;
numCol = 2;
predictionsOnMV12 = ary2d.getPredictionOnMV12();
}
}
if (numRow > 1 || numCol > 1) {
out.printf("set multiplot layout %d,%d rowsfirst\n", numRow, numCol);
}
int size1 = 0;
if (f1.getType() == Attribute.Type.BINNED) {
size1 = ((BinnedAttribute) f1).getNumBins();
} else if (f1.getType() == Attribute.Type.NOMINAL) {
size1 = ((NominalAttribute) f1).getCardinality();
}
int size2 = 0;
if (f2.getType() == Attribute.Type.BINNED) {
size2 = ((BinnedAttribute) f2).getNumBins();
} else if (f2.getType() == Attribute.Type.NOMINAL) {
size2 = ((NominalAttribute) f2).getCardinality();
}
if (f1.getType() == Attribute.Type.NOMINAL) {
out.print("set ytics(");
String[] states = ((NominalAttribute) f1).getStates();
for (int j = 0; j < states.length - 1; j++) {
out.printf("\"%s\" %d, ", states[j], j);
}
out.printf("\"%s\" %d)\n", states[states.length - 1], states.length - 1);
}
if (f2.getType() == Attribute.Type.NOMINAL) {
out.print("set xtics(");
String[] states = ((NominalAttribute) f2).getStates();
for (int j = 0; j < states.length - 1; j++) {
out.printf("\"%s\" %d, ", states[j], j);
}
out.printf("\"%s\" %d) rotate\n", states[states.length - 1], states.length - 1);
}
out.println("unset border");
out.println("set view map");
out.println("set style data pm3d");
out.println("set style function pm3d");
out.println("set pm3d corners2color c4");
double[] rangeY = null;
double[] valueY = null;
double[] rangeX = null;
double[] valueX = null;
if (f1 instanceof BinnedAttribute) {
rangeY = new double[size1 + 1];
valueY = new double[size1 + 1];
Bins bins = ((BinnedAttribute) f1).getBins();
double[] boundaries = bins.getBoundaries();
double start = boundaries[0] - 1;
if (boundaries.length >= 2) {
start = boundaries[0] - (boundaries[1] - boundaries[0]);
}
rangeY[0] = start;
valueY[0] = 0;
for (int k = 0; k < boundaries.length; k++) {
rangeY[k + 1] = boundaries[k];
valueY[k + 1] = k;
}
} else if (f1 instanceof NominalAttribute) {
rangeY = new double[size1];
valueY = new double[size1];
for (int k = 0; k < size1; k++) {
rangeY[k] = k;
valueY[k] = k;
}
}
out.printf("set yrange[%f:%f]\n", rangeY[0] - 1, rangeY[rangeY.length - 1] + 1);
if (f2 instanceof BinnedAttribute) {
rangeX = new double[size2 + 1];
valueX = new double[size2 + 1];
Bins bins = ((BinnedAttribute) f2).getBins();
double[] boundaries = bins.getBoundaries();
double start = boundaries[0] - 1;
if (boundaries.length >= 2) {
start = boundaries[0] - (boundaries[1] - boundaries[0]);
}
rangeX[0] = start;
valueX[0] = 0;
for (int k = 0; k < boundaries.length; k++) {
rangeX[k + 1] = boundaries[k];
valueX[k + 1] = k;
}
} else if (f2 instanceof NominalAttribute) {
rangeX = new double[size2];
valueX = new double[size2];
for (int k = 0; k < size2; k++) {
rangeX[k] = k;
valueX[k] = k;
}
}
out.printf("set xrange[%f:%f]\n", rangeX[0] - 1, rangeX[rangeX.length - 1] + 1);
out.println("splot \"" + fileName + ".dat\" with image t \"\"");
PrintWriter writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + ".dat");
for (int r = 0; r < rangeY.length; r++) {
point.setValue(term[0], valueY[r]);
for (int c = 0; c < rangeX.length; c++) {
point.setValue(term[1], valueX[c]);
writer.println(rangeX[c] + "\t" + rangeY[r] + "\t" + gam.regress(point));
}
writer.println();
}
writer.flush();
writer.close();
if (numRow > 1 || numCol > 1) {
// multiplot is on
if (predictionsOnMV2 != null) {
out.println("reset");
writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + "_mv2.dat");
if (f1.getType() == Attribute.Type.NOMINAL) {
String[] states = ((NominalAttribute) f1).getStates();
out.println("set style data histogram");
out.println("set style histogram cluster gap 1");
out.println("set style fill solid border -1");
out.println("set boxwidth 0.9");
out.println("set xtic rotate by -90");
out.println("plot \"" + fileName + "_mv2.dat\" u 2:xtic(1) t \"\"");
for (int k = 0; k < states.length; k++) {
writer.println(states[k] + "\t" + predictionsOnMV2[k]);
}
} else if (f1.getType() == Attribute.Type.BINNED) {
out.println("plot \"" + fileName + "_mv2.dat\" u 1:2 w l t \"\"");
writer.println(rangeY[0] + "\t" + predictionsOnMV2[0]);
for (int k = 0; k < predictionsOnMV2.length; k++) {
writer.println(rangeY[k + 1] + "\t" + predictionsOnMV2[k]);
}
}
writer.flush();
writer.close();
} else if (numCol > 1) {
out.println("set multiplot next");
}
if (predictionsOnMV1 != null) {
out.println("reset");
writer = new PrintWriter(dir.getAbsolutePath() + File.separator + fileName + "_mv1.dat");
if (f2.getType() == Attribute.Type.NOMINAL) {
String[] states = ((NominalAttribute) f2).getStates();
out.println("set style data histogram");
out.println("set style histogram cluster gap 1");
out.println("set style fill solid border -1");
out.println("set boxwidth 0.9");
out.println("set xtic rotate by -90");
out.println("plot \"" + fileName + "_mv1.dat\" u 2:xtic(1) t \"\"");
for (int k = 0; k < states.length; k++) {
writer.println(states[k] + "\t" + predictionsOnMV1[k]);
}
} else if (f2.getType() == Attribute.Type.BINNED) {
out.println("plot \"" + fileName + "_mv1.dat\" u 1:2 w l t \"\"");
writer.println(rangeX[0] + "\t" + predictionsOnMV1[0]);
for (int k = 0; k < predictionsOnMV1.length; k++) {
writer.println(rangeX[k + 1] + "\t" + predictionsOnMV1[k]);
}
}
writer.flush();
writer.close();
} else if (numRow > 1) {
out.println("set multiplot next");
}
if (!MathUtils.isZero(predictionsOnMV12)) {
out.println("reset");
out.println("set datafile separator \"\t\"");
out.println("set style fill solid border -1");
out.println("set xtic rotate by 0");
out.println("plot \"-\" using 2:xtic(1) with histogram t \"\"");
out.println("missing value\t" + predictionsOnMV12);
out.println("e");
} else if (numRow > 1 && numCol > 1) {
out.println("set multiplot next");
}
}
out.flush();
out.close();
}
}
}
static class Options {
@Argument(name = "-r", description = "attribute file path", required = true)
String attPath = null;
@Argument(name = "-d", description = "dataset path", required = true)
String datasetPath = null;
@Argument(name = "-i", description = "input model path", required = true)
String inputModelPath = null;
@Argument(name = "-o", description = "output directory path", required = true)
String dirPath = null;
@Argument(name = "-t", description = "output terminal (default: png)")
String terminal = "png";
}
/**
* Generates scripts for visualizing GAMs.
*
*
* Usage: mltk.predictor.gam.tool.Visualizer
* -r attribute file path
* -d dataset path
* -i input model path
* -o output directory path
* [-t] output terminal (default: png)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(Visualizer.class, opts);
try {
parser.parse(args);
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Instances dataset = InstancesReader.read(opts.attPath, opts.datasetPath);
GAM gam = PredictorReader.read(opts.inputModelPath, GAM.class);
Visualizer.generateGnuplotScripts(gam, dataset, opts.dirPath, Terminal.getEnum(opts.terminal));
}
}
================================================
FILE: src/main/java/mltk/predictor/gam/tool/package-info.java
================================================
/**
* Provides tools for diagnosing and visualizing GAMs.
*/
package mltk.predictor.gam.tool;
================================================
FILE: src/main/java/mltk/predictor/glm/ElasticNetLearner.java
================================================
package mltk.predictor.glm;
import java.util.Arrays;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.cmdline.options.LearnerWithTaskOptions;
import mltk.core.Attribute;
import mltk.core.DenseVector;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.SparseVector;
import mltk.core.io.InstancesReader;
import mltk.predictor.Family;
import mltk.predictor.LinkFunction;
import mltk.predictor.io.PredictorWriter;
import mltk.util.MathUtils;
import mltk.util.OptimUtils;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
/**
* Class for learning elastic-net penalized linear model via coordinate descent.
*
* @author Yin Lou
*
*/
public class ElasticNetLearner extends GLMLearner {
static class Options extends LearnerWithTaskOptions {
@Argument(name = "-m", description = "maximum number of iterations (default: 0)")
int maxIter = 0;
@Argument(name = "-l", description = "lambda (default: 0)")
double lambda = 0;
@Argument(name = "-a", description = "L1 ratio (default: 0)")
double l1Ratio = 0;
}
/**
* Trains elastic-net-regularized GLMs.
*
*
* Usage: mltk.predictor.glm.ElasticNetLearner
* -t train set path
* [-g] task between classification (c) and regression (r) (default: r)
* [-r] attribute file path
* [-o] output model path
* [-V] verbose (default: true)
* [-m] maximum number of iterations (default: 0)
* [-l] lambda (default: 0)
* [-a] L1 ratio (default: 0)
*
*
* @param args the command line arguments.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
Options opts = new Options();
CmdLineParser parser = new CmdLineParser(ElasticNetLearner.class, opts);
Task task = null;
try {
parser.parse(args);
task = Task.get(opts.task);
} catch (IllegalArgumentException e) {
parser.printUsage();
System.exit(1);
}
Instances trainSet = InstancesReader.read(opts.attPath, opts.trainPath);
ElasticNetLearner learner = new ElasticNetLearner();
learner.setVerbose(opts.verbose);
learner.setTask(task);
learner.setLambda(opts.lambda);
learner.setL1Ratio(opts.l1Ratio);
learner.setMaxNumIters(opts.maxIter);
long start = System.currentTimeMillis();
GLM glm = learner.build(trainSet);
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000.0);
if (opts.outputModelPath != null) {
PredictorWriter.write(glm, opts.outputModelPath);
}
}
protected double lambda;
protected double l1Ratio;
protected Task task;
/**
* Constructor.
*/
public ElasticNetLearner() {
lambda = 0; // no regularization
l1Ratio = 0; // 0: ridge, 1: lasso, (0, 1): elastic net
task = Task.REGRESSION;
}
@Override
public GLM build(Instances instances) {
GLM glm = null;
if (maxNumIters < 0) {
maxNumIters = 20;
}
switch (task) {
case REGRESSION:
glm = buildGaussianRegressor(instances, maxNumIters, lambda, l1Ratio);
break;
case CLASSIFICATION:
glm = buildClassifier(instances, maxNumIters, lambda, l1Ratio);
break;
default:
break;
}
return glm;
}
@Override
public GLM build(Instances trainSet, Family family) {
GLM glm = null;
if (maxNumIters < 0) {
maxNumIters = 20;
}
switch (family) {
case GAUSSIAN:
glm = buildGaussianRegressor(trainSet, maxNumIters, lambda, l1Ratio);
break;
case BINOMIAL:
glm = buildClassifier(trainSet, maxNumIters, lambda, l1Ratio);
break;
default:
throw new IllegalArgumentException("Unsupported family: " + family);
}
return glm;
}
/**
* Builds an elastic-net penalized binary classifier. Each row in the input matrix x represents a feature (instead
* of a data point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not
* assume the data is normalized or centered.
*
* @param attrs the attribute list.
* @param x the inputs.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized binary classifier.
*/
public GLM buildBinaryClassifier(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
// Calculate theta's
double[] theta = new double[x.length];
for (int i = 0; i < x.length; i++) {
theta[i] = StatUtils.sumSq(x[i]) / 4;
}
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(x, theta, y, tl1, tl2, w, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + " " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT);
}
/**
* Builds an elastic-net penalized binary classifier on sparse inputs. Each row of the input represents a feature
* (instead of a data point), i.e., in column-oriented format. This procedure does not assume the data is normalized
* or centered.
*
* @param attrs the attribute list.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized classifier.
*/
public GLM buildBinaryClassifier(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters,
double lambda, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
// Calculate theta's
double[] theta = new double[values.length];
for (int i = 0; i < values.length; i++) {
theta[i] = StatUtils.sumSq(values[i]) / 4;
}
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(indices, values, theta, y, tl1, tl2, w, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + " " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT);
}
/**
* Builds elastic-net penalized binary classifiers for a sequence of regularization parameter lambdas. Each row in
* the input matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the
* row-oriented data matrix. This procedure does not assume the data is normalized or centered.
*
* @param attrs the attribute list.
* @param x the inputs.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the L1 ratio.
* @return elastic-net penalized classifiers.
*/
public GLM[] buildBinaryClassifiers(int[] attrs, double[][] x, double[] y, int maxNumIters, int numLambdas,
double minLambdaRatio, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
// Calculate theta's
double[] theta = new double[x.length];
for (int i = 0; i < x.length; i++) {
theta[i] = StatUtils.sumSq(x[i]) / 4;
}
double maxLambda = findMaxLambdaBinomial(x, y, pTrain, rTrain, l1Ratio);
// Dampening factor for lambda
double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas);
GLM[] glms = new GLM[numLambdas];
double lambda = maxLambda;
for (int g = 0; g < numLambdas; g++) {
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(x, theta, y, tl1, tl2, w, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + " " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT);
lambda *= alpha;
}
return glms;
}
/**
* Builds elastic-net penalized binary classifiers on sparse inputs for a sequence of regularization parameter
* lambdas. Each row of the input represents a feature (instead of a data point), i.e., in column-oriented format.
* This procedure does not assume the data is normalized or centered.
*
* @param attrs the attribute list.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized classifier.
*/
public GLM[] buildBinaryClassifiers(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters,
int numLambdas, double minLambdaRatio, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
// Calculate theta's
double[] theta = new double[values.length];
for (int i = 0; i < values.length; i++) {
theta[i] = StatUtils.sumSq(values[i]) / 4;
}
double maxLambda = findMaxLambdaBinomial(indices, values, y, pTrain, rTrain, l1Ratio);
// Dampening factor for lambda
double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas);
GLM[] glms = new GLM[numLambdas];
double lambda = maxLambda;
for (int g = 0; g < glms.length; g++) {
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(indices, values, theta, y, tl1, tl2, w, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(pTrain, y, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + " " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.LOGIT);
lambda *= alpha;
}
return glms;
}
/**
* Builds an elastic-net penalized classifier.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the l1 ratio.
* @return an elastic-net penalized classifer.
*/
public GLM buildClassifier(Instances trainSet, boolean isSparse, int maxNumIters, double lambda, double l1Ratio) {
Attribute classAttribute = trainSet.getTargetAttribute();
if (classAttribute.getType() != Attribute.Type.NOMINAL) {
throw new IllegalArgumentException("Class attribute must be nominal.");
}
NominalAttribute clazz = (NominalAttribute) classAttribute;
int numClasses = clazz.getCardinality();
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, true);
int[] attrs = sd.attrs;
int[][] indices = sd.indices;
double[][] values = sd.values;
double[] y = new double[sd.y.length];
double[] cList = sd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == 0 ? 1 : 0;
}
GLM glm = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda, l1Ratio);
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
w[attIndex] *= cList[j];
}
return glm;
} else {
int p = attrs.length == 0 ? 0 : attrs[attrs.length - 1] + 1;
GLM glm = new GLM(numClasses, p);
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == k ? 1 : 0;
}
GLM binaryClassifier = buildBinaryClassifier(attrs, indices, values, y, maxNumIters, lambda,
l1Ratio);
double[] w = binaryClassifier.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
glm.w[k][attIndex] = w[attIndex] * cList[j];
}
glm.intercept[k] = binaryClassifier.intercept[0];
}
return glm;
}
} else {
DenseDataset dd = getDenseDataset(trainSet, true);
int[] attrs = dd.attrs;
double[][] x = dd.x;
double[] y = new double[dd.y.length];
double[] cList = dd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) dd.y[i];
y[i] = label == 0 ? 1 : 0;
}
GLM glm = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda, l1Ratio);
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
w[attIndex] *= cList[j];
}
return glm;
} else {
int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
GLM glm = new GLM(numClasses, p);
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) dd.y[i];
y[i] = label == k ? 1 : 0;
}
GLM binaryClassifier = buildBinaryClassifier(attrs, x, y, maxNumIters, lambda, l1Ratio);
double[] w = binaryClassifier.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
glm.w[k][attIndex] = w[attIndex] * cList[j];
}
glm.intercept[k] = binaryClassifier.intercept[0];
}
return glm;
}
}
}
/**
* Builds an elastic-net penalized classifier.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the l1 ratio.
* @return an elastic-net penalized classifer.
*/
public GLM buildClassifier(Instances trainSet, int maxNumIters, double lambda, double l1Ratio) {
return buildClassifier(trainSet, isSparse(trainSet), maxNumIters, lambda, l1Ratio);
}
/**
* Builds elastic-net penalized classifiers.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the l1 ratio.
* @return elastic-net penalized classifers.
*/
public GLM[] buildClassifiers(Instances trainSet, boolean isSparse, int maxNumIters, int numLambdas,
double minLambdaRatio, double l1Ratio) {
Attribute classAttribute = trainSet.getTargetAttribute();
if (classAttribute.getType() != Attribute.Type.NOMINAL) {
throw new IllegalArgumentException("Class attribute must be nominal.");
}
NominalAttribute clazz = (NominalAttribute) classAttribute;
int numClasses = clazz.getCardinality();
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, true);
int[] attrs = sd.attrs;
int[][] indices = sd.indices;
double[][] values = sd.values;
double[] y = new double[sd.y.length];
double[] cList = sd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == 0 ? 1 : 0;
}
GLM[] glms = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters, numLambdas, minLambdaRatio,
l1Ratio);
for (GLM glm : glms) {
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
w[attIndex] *= cList[j];
}
}
return glms;
} else {
int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
GLM[] glms = new GLM[numLambdas];
for (int i = 0; i < glms.length; i++) {
glms[i] = new GLM(numClasses, p);
}
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == k ? 1 : 0;
}
GLM[] binaryClassifiers = buildBinaryClassifiers(attrs, indices, values, y, maxNumIters,
numLambdas, minLambdaRatio, l1Ratio);
for (int l = 0; l < glms.length; l++) {
GLM binaryClassifier = binaryClassifiers[l];
GLM glm = glms[l];
double[] w = binaryClassifier.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
glm.w[k][attIndex] = w[attIndex] * cList[j];
}
glm.intercept[k] = binaryClassifier.intercept[0];
}
}
return glms;
}
} else {
DenseDataset dd = getDenseDataset(trainSet, true);
int[] attrs = dd.attrs;
double[][] x = dd.x;
double[] y = new double[dd.y.length];
double[] cList = dd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) dd.y[i];
y[i] = label == 0 ? 1 : 0;
}
GLM[] glms = buildBinaryClassifiers(attrs, x, y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio);
for (GLM glm : glms) {
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
w[attIndex] *= cList[j];
}
}
return glms;
} else {
int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
GLM[] glms = new GLM[numLambdas];
for (int i = 0; i < glms.length; i++) {
glms[i] = new GLM(numClasses, p);
}
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) dd.y[i];
y[i] = label == k ? 1 : 0;
}
GLM[] binaryClassifiers = buildBinaryClassifiers(attrs, x, y, maxNumIters, numLambdas,
minLambdaRatio, l1Ratio);
for (int l = 0; l < glms.length; l++) {
GLM binaryClassifier = binaryClassifiers[l];
GLM glm = glms[l];
double[] w = binaryClassifier.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
glm.w[k][attIndex] = w[attIndex] * cList[j];
}
glm.intercept[k] = binaryClassifier.intercept[0];
}
}
return glms;
}
}
}
/**
* Builds elastic-net penalized classifiers.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the l1 ratio.
* @return elastic-net penalized classifers.
*/
public GLM[] buildClassifiers(Instances trainSet, int maxNumIters, int numLambdas, double minLambdaRatio,
double l1Ratio) {
return buildClassifiers(trainSet, isSparse(trainSet), maxNumIters, numLambdas, minLambdaRatio, l1Ratio);
}
/**
* Builds an elastic-net penalized regressor.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized regressor.
*/
public GLM buildGaussianRegressor(Instances trainSet, boolean isSparse, int maxNumIters, double lambda, double l1Ratio) {
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, true);
double[] cList = sd.cList;
GLM glm = buildGaussianRegressor(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, lambda, l1Ratio);
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = sd.attrs[j];
w[attIndex] *= cList[j];
}
return glm;
} else {
DenseDataset dd = getDenseDataset(trainSet, true);
double[] cList = dd.cList;
GLM glm = buildGaussianRegressor(dd.attrs, dd.x, dd.y, maxNumIters, lambda, l1Ratio);
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = dd.attrs[j];
w[attIndex] *= cList[j];
}
return glm;
}
}
/**
* Builds an elastic-net penalized regressor.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized regressor.
*/
public GLM buildGaussianRegressor(Instances trainSet, int maxNumIters, double lambda, double l1Ratio) {
return buildGaussianRegressor(trainSet, isSparse(trainSet), maxNumIters, lambda, l1Ratio);
}
/**
* Builds an elastic-net penalized regressor. Each row in the input matrix x represents a feature (instead of a data
* point). Thus the input matrix is the transpose of the row-oriented data matrix. This procedure does not assume
* the data is normalized or centered.
*
* @param attrs the attribute list.
* @param x the inputs.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized regressor.
*/
public GLM buildGaussianRegressor(int[] attrs, double[][] x, double[] y, int maxNumIters, double lambda, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
// Initialize residuals
double[] rTrain = new double[y.length];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] = y[i];
}
// Calculate sum of squares
double[] sq = new double[x.length];
for (int i = 0; i < x.length; i++) {
sq[i] = StatUtils.sumSq(x[i]);
}
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
doOnePassGaussian(x, sq, tl1, tl2, w, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + " " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY);
}
/**
* Builds an elastic-net penalized regressor on sparse inputs. Each row of the input represents a feature (instead
* of a data point), i.e., in column-oriented format. This procedure does not assume the data is normalized or
* centered.
*
* @param attrs the attribute list.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @param l1Ratio the L1 ratio.
* @return an elastic-net penalized regressor.
*/
public GLM buildGaussianRegressor(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters,
double lambda, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
// Initialize residuals
double[] rTrain = new double[y.length];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] = y[i];
}
// Calculate sum of squares
double[] sq = new double[attrs.length];
for (int i = 0; i < values.length; i++) {
sq[i] = StatUtils.sumSq(values[i]);
}
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
doOnePassGaussian(indices, values, sq, tl1, tl2, w, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + " " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
return GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY);
}
/**
* Builds elastic-net penalized regressors for a sequence of regularization parameter lambdas.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the L1 ratio.
* @return elastic-net penalized regressors.
*/
public GLM[] buildGaussianRegressors(Instances trainSet, boolean isSparse, int maxNumIters, int numLambdas,
double minLambdaRatio, double l1Ratio) {
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, true);
double[] cList = sd.cList;
GLM[] glms = buildGaussianRegressors(sd.attrs, sd.indices, sd.values, sd.y, maxNumIters, numLambdas,
minLambdaRatio, l1Ratio);
for (GLM glm : glms) {
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = sd.attrs[j];
w[attIndex] *= cList[j];
}
}
return glms;
} else {
DenseDataset dd = getDenseDataset(trainSet, true);
double[] cList = dd.cList;
GLM[] glms = buildGaussianRegressors(dd.attrs, dd.x, dd.y, maxNumIters, numLambdas, minLambdaRatio, l1Ratio);
for (GLM glm : glms) {
double[] w = glm.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = dd.attrs[j];
w[attIndex] *= cList[j];
}
}
return glms;
}
}
/**
* Builds elastic-net penalized regressors for a sequence of regularization parameter lambdas.
*
* @param trainSet the training set.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the L1 ratio.
* @return elastic-net penalized regressors.
*/
public GLM[] buildGaussianRegressors(Instances trainSet, int maxNumIters, int numLambdas, double minLambdaRatio,
double l1Ratio) {
return buildGaussianRegressors(trainSet, isSparse(trainSet), maxNumIters, numLambdas, minLambdaRatio, l1Ratio);
}
/**
* Builds elastic-net penalized regressors for a sequence of regularization parameter lambdas. Each row in the input
* matrix x represents a feature (instead of a data point). Thus the input matrix is the transpose of the
* row-oriented data matrix. This procedure does not assume the data is normalized or centered.
*
* @param attrs the attribute list.
* @param x the inputs.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the L1 ratio.
* @return elastic-net penalized regressors.
*/
public GLM[] buildGaussianRegressors(int[] attrs, double[][] x, double[] y, int maxNumIters, int numLambdas,
double minLambdaRatio, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
GLM[] glms = new GLM[numLambdas];
// Backup targets
double[] rTrain = new double[y.length];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] = y[i];
}
// Calculate sum of squares
double[] sq = new double[x.length];
for (int i = 0; i < x.length; i++) {
sq[i] = StatUtils.sumSq(x[i]);
}
// Determine max lambda
double maxLambda = findMaxLambdaGaussian(x, y, l1Ratio);
// Dampening factor for lambda
double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas);
// Compute the regularization path
double lambda = maxLambda;
for (int g = 0; g < glms.length; g++) {
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
doOnePassGaussian(x, sq, tl1, tl2, w, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
lambda *= alpha;
glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY);
}
return glms;
}
/**
* Builds elastic-net penalized regressors on sparse inputs for a sequence of regularization parameter lambdas. Each row of the input
* represents a feature (instead of a data point), i.e., in column-oriented format. This procedure does not assume
* the data is normalized or centered.
*
* @param attrs the attribute list.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @param l1Ratio the L1 ratio.
* @return elastic-net penalized regressors.
*/
public GLM[] buildGaussianRegressors(int[] attrs, int[][] indices, double[][] values, double[] y, int maxNumIters,
int numLambdas, double minLambdaRatio, double l1Ratio) {
double[] w = new double[attrs.length];
double intercept = 0;
GLM[] glms = new GLM[numLambdas];
// Backup targets
double[] rTrain = new double[y.length];
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] = y[i];
}
// Calculate sum of squares
double[] sq = new double[values.length];
for (int i = 0; i < values.length; i++) {
sq[i] = StatUtils.sumSq(values[i]);
}
// Determine max lambda
double maxLambda = findMaxLambdaGaussian(indices, values, y, l1Ratio);
// Dampening factor for lambda
double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas);
// Compute the regularization path
double lambda = maxLambda;
for (int g = 0; g < glms.length; g++) {
final double lambda1 = lambda * l1Ratio;
final double lambda2 = lambda * (1 - l1Ratio);
final double tl1 = lambda1 * y.length;
final double tl2 = lambda2 * y.length;
// Coordinate descent
for (int iter = 0; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(rTrain);
}
doOnePassGaussian(indices, values, sq, tl1, tl2, w, rTrain);
double currLoss = GLMOptimUtils.computeElasticNetLoss(rTrain, w, lambda1, lambda2);
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
}
lambda *= alpha;
glms[g] = GLMOptimUtils.getGLM(attrs, w, intercept, LinkFunction.IDENTITY);
}
return glms;
}
protected void doOnePassGaussian(double[][] x, double[] sq, final double tl1, final double tl2, double[] w, double[] rTrain) {
for (int j = 0; j < x.length; j++) {
double[] v = x[j];
// Calculate weight updates using naive updates
double wNew = w[j] * sq[j] + VectorUtils.dotProduct(v, rTrain);
if (Math.abs(wNew) <= tl1) {
wNew = 0;
} else if (wNew > 0) {
wNew -= tl1;
} else {
wNew += tl1;
}
wNew /= (sq[j] + tl2);
double delta = wNew - w[j];
w[j] = wNew;
// Update residuals
for (int i = 0; i < rTrain.length; i++) {
rTrain[i] -= delta * v[i];
}
}
}
protected void doOnePassGaussian(int[][] indices, double[][] values, double[] sq, final double tl1, final double tl2,
double[] w, double[] rTrain) {
for (int j = 0; j < indices.length; j++) {
// Calculate weight updates using naive updates
double wNew = w[j] * sq[j];
int[] index = indices[j];
double[] value = values[j];
for (int i = 0; i < index.length; i++) {
wNew += rTrain[index[i]] * value[i];
}
if (Math.abs(wNew) <= tl1) {
wNew = 0;
} else if (wNew > 0) {
wNew -= tl1;
} else {
wNew += tl1;
}
wNew /= (sq[j] + tl2);
double delta = wNew - w[j];
w[j] = wNew;
// Update residuals
for (int i = 0; i < index.length; i++) {
rTrain[index[i]] -= delta * value[i];
}
}
}
protected void doOnePassBinomial(double[][] x, double[] theta, double[] y, final double tl1, final double tl2, double[] w,
double[] pTrain, double[] rTrain) {
for (int j = 0; j < x.length; j++) {
if (Math.abs(theta[j]) <= MathUtils.EPSILON) {
continue;
}
double[] v = x[j];
double eta = VectorUtils.dotProduct(rTrain, v);
double newW = w[j] * theta[j] + eta;
if (newW > tl1) {
newW -= tl1;
} else if (newW < -tl1) {
newW += tl1;
} else {
newW = 0;
}
newW /= (theta[j] + tl2);
double delta = newW - w[j];
w[j] = newW;
// Update predictions
for (int i = 0; i < pTrain.length; i++) {
pTrain[i] += delta * v[i];
rTrain[i] = OptimUtils.getPseudoResidual(pTrain[i], y[i]);
}
}
}
protected void doOnePassBinomial(int[][] indices, double[][] values, double[] theta, double[] y, final double tl1,
final double tl2, double[] w, double[] pTrain, double[] rTrain) {
for (int j = 0; j < indices.length; j++) {
if (Math.abs(theta[j]) <= MathUtils.EPSILON) {
continue;
}
double eta = 0;
int[] index = indices[j];
double[] value = values[j];
for (int i = 0; i < index.length; i++) {
int idx = index[i];
eta += rTrain[idx] * value[i];
}
double newW = w[j] * theta[j] + eta;
if (newW > tl1) {
newW -= tl1;
} else if (newW < -tl1) {
newW += tl1;
} else {
newW = 0;
}
newW /= (theta[j] + tl2);
double delta = newW - w[j];
w[j] = newW;
// Update predictions
for (int i = 0; i < index.length; i++) {
int idx = index[i];
pTrain[idx] += delta * value[i];
rTrain[idx] = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]);
}
}
}
protected double findMaxLambdaGaussian(double[][] x, double[] y, double l1Ratio) {
double mean = 0;
if (fitIntercept) {
mean = OptimUtils.fitIntercept(y);
}
// Determine max lambda
double maxLambda = 0;
for (double[] col : x) {
double dot = Math.abs(VectorUtils.dotProduct(col, y));
if (dot > maxLambda) {
maxLambda = dot;
}
}
maxLambda /= y.length;
maxLambda /= l1Ratio;
if (fitIntercept) {
VectorUtils.add(y, mean);
}
return maxLambda;
}
protected double findMaxLambdaGaussian(int[][] indices, double[][] values, double[] y, double l1Ratio) {
double mean = 0;
if (fitIntercept) {
mean = OptimUtils.fitIntercept(y);
}
DenseVector v = new DenseVector(y);
// Determine max lambda
double maxLambda = 0;
for (int i = 0; i < indices.length; i++) {
int[] index = indices[i];
double[] value = values[i];
double dot = Math.abs(VectorUtils.dotProduct(new SparseVector(index, value), v));
if (dot > maxLambda) {
maxLambda = dot;
}
}
maxLambda /= y.length;
maxLambda /= l1Ratio;
if (fitIntercept) {
VectorUtils.add(y, mean);
}
return maxLambda;
}
protected double findMaxLambdaBinomial(double[][] x, double[] y, double[] pTrain, double[] rTrain, double l1Ratio) {
if (fitIntercept) {
OptimUtils.fitIntercept(pTrain, rTrain, y);
}
double maxLambda = 0;
for (double[] col : x) {
double eta = 0;
for (int i = 0; i < col.length; i++) {
eta += rTrain[i] * col[i];
}
double t = Math.abs(eta);
if (t > maxLambda) {
maxLambda = t;
}
}
maxLambda /= y.length;
maxLambda /= l1Ratio;
if (fitIntercept) {
Arrays.fill(pTrain, 0);
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
}
return maxLambda;
}
protected double findMaxLambdaBinomial(int[][] indices, double[][] values, double[] y, double[] pTrain, double[] rTrain, double l1Ratio) {
if (fitIntercept) {
OptimUtils.fitIntercept(pTrain, rTrain, y);
}
double maxLambda = 0;
for (int k = 0; k < values.length; k++) {
double eta = 0;
int[] index = indices[k];
double[] value = values[k];
for (int i = 0; i < index.length; i++) {
int idx = index[i];
double r = OptimUtils.getPseudoResidual(pTrain[idx], y[idx]);
r *= value[i];
eta += r;
}
double t = Math.abs(eta);
if (t > maxLambda) {
maxLambda = t;
}
}
maxLambda /= y.length;
maxLambda /= l1Ratio;
if (fitIntercept) {
Arrays.fill(pTrain, 0);
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
}
return maxLambda;
}
/**
* Returns the l1 ratio.
*
* @return the l1 ratio.
*/
public double getL1Ratio() {
return l1Ratio;
}
/**
* Returns the lambda.
*
* @return the lambda.
*/
public double getLambda() {
return lambda;
}
/**
* Returns the task of this learner.
*
* @return the task of this learner.
*/
public Task getTask() {
return task;
}
/**
* Sets the l1 ratio.
*
* @param l1Ratio the l1 ratio.
*/
public void setL1Ratio(double l1Ratio) {
this.l1Ratio = l1Ratio;
}
/**
* Sets the lambda.
*
* @param lambda the lambda.
*/
public void setLambda(double lambda) {
this.lambda = lambda;
}
/**
* Sets the task of this learner.
*
* @param task the task of this learner.
*/
public void setTask(Task task) {
this.task = task;
}
}
================================================
FILE: src/main/java/mltk/predictor/glm/GLM.java
================================================
package mltk.predictor.glm;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.Arrays;
import mltk.core.Instance;
import mltk.core.SparseVector;
import mltk.predictor.LinkFunction;
import mltk.predictor.ProbabilisticClassifier;
import mltk.predictor.Regressor;
import mltk.util.ArrayUtils;
import mltk.util.MathUtils;
import mltk.util.StatUtils;
/**
* Class for generalized linear models (GLMs).
*
* @author Yin Lou
*
*/
public class GLM implements ProbabilisticClassifier, Regressor {
/**
* The coefficient vectors.
*/
protected double[][] w;
/**
* The intercept vector.
*/
protected double[] intercept;
/**
* The link function.
*/
protected LinkFunction link;
/**
* Constructor.
*/
public GLM() {
}
/**
* Constructs a GLM with the specified dimension.
*
* @param dimension the dimension.
*/
public GLM(int dimension) {
this(1, dimension);
}
/**
* Constructs a GLM with the specified dimension.
*
* @param numClasses the number of classes.
* @param dimension the dimension.
*/
public GLM(int numClasses, int dimension) {
w = new double[numClasses][dimension];
intercept = new double[numClasses];
}
/**
* Constructs a GLM with the intercept vector and the coefficient vectors.
*
* @param intercept the intercept vector.
* @param w the coefficient vectors.
*/
public GLM(double[] intercept, double[][] w) {
this(intercept, w, LinkFunction.IDENTITY);
}
/**
* Constructs a GLM with the intercept vector, the coefficient vectors and its link function.
*
* @param intercept the intercept vector.
* @param w the coefficient vectors.
* @param link the link function.
*/
public GLM(double[] intercept, double[][] w, LinkFunction link) {
if (intercept.length != w.length) {
throw new IllegalArgumentException("Dimensions of intercept and w must match.");
}
this.intercept = intercept;
this.w = w;
this.link = link;
}
/**
* Returns the coefficient vectors.
*
* @return the coefficient vectors.
*/
public double[][] coefficients() {
return w;
}
/**
* Returns the coefficient vectors for class k.
*
* @param k the index of the class.
* @return the coefficient vectors for class k.
*/
public double[] coefficients(int k) {
return w[k];
}
/**
* Returns the intercept vector.
*
* @return the intercept vector.
*/
public double[] intercept() {
return intercept;
}
/**
* Returns the intercept for class k.
*
* @param k the index of the class.
* @return the intercept for class k.
*/
public double intercept(int k) {
return intercept[k];
}
@Override
public void read(BufferedReader in) throws Exception {
link = LinkFunction.get(in.readLine().split(": ")[1]);
in.readLine();
intercept = ArrayUtils.parseDoubleArray(in.readLine());
int p = Integer.parseInt(in.readLine().split(": ")[1]);
w = new double[intercept.length][p];
for (int j = 0; j < p; j++) {
String[] data = in.readLine().split("\\s+");
for (int i = 0; i < w.length; i++) {
w[i][j] = Double.parseDouble(data[i]);
}
}
}
@Override
public void write(PrintWriter out) throws Exception {
out.printf("[Predictor: %s]\n", this.getClass().getCanonicalName());
out.println("Link: " + link);
out.println("Intercept: " + intercept.length);
out.println(Arrays.toString(intercept));
final int p = w[0].length;
out.println("Coefficients: " + p);
for (int j = 0; j < p; j++) {
out.print(w[0][j]);
for (int i = 1; i < w.length; i++) {
out.print(" " + w[i][j]);
}
out.println();
}
}
@Override
public double regress(Instance instance) {
return regress(intercept[0], w[0], instance);
}
@Override
public int classify(Instance instance) {
double[] prob = predictProbabilities(instance);
return StatUtils.indexOfMax(prob);
}
/**
* Returns the prediction of this GLM on the scale of the response variable.
*
* @param instance the instance to predict.
* @return the prediction of this GLM on the scale of the response variable.
*/
public double predict(Instance instance) {
return link.applyInverse(regress(instance));
}
@Override
public double[] predictProbabilities(Instance instance) {
if (w.length == 1) {
double[] prob = new double[2];
double pred = regress(intercept[0], w[0], instance);
prob[0] = MathUtils.sigmoid(pred);
prob[1] = 1 - prob[0];
return prob;
} else {
double[] prob = new double[w.length];
double[] pred = new double[w.length];
double sum = 0;
for (int i = 0; i < w.length; i++) {
pred[i] = regress(intercept[i], w[i], instance);
prob[i] = MathUtils.sigmoid(pred[i]);
sum += prob[i];
}
for (int i = 0; i < prob.length; i++) {
prob[i] /= sum;
}
return prob;
}
}
@Override
public GLM copy() {
double[][] copyW = new double[w.length][];
for (int i = 0; i < copyW.length; i++) {
copyW[i] = Arrays.copyOf(w[i], w[i].length);
}
return new GLM(intercept, copyW, link);
}
protected double regress(double intercept, double[] coef, Instance instance) {
if (!instance.isSparse()) {
double pred = intercept;
for (int i = 0; i < coef.length; i++) {
pred += coef[i] * instance.getValue(i);
}
return pred;
} else {
double pred = intercept;
SparseVector vector = (SparseVector) instance.getVector();
int[] indices = vector.getIndices();
double[] values = vector.getValues();
for (int i = 0; i < indices.length; i++) {
int index = indices[i];
if (index < coef.length) {
pred += coef[index] * values[i];
}
}
return pred;
}
}
}
================================================
FILE: src/main/java/mltk/predictor/glm/GLMLearner.java
================================================
package mltk.predictor.glm;
import mltk.core.Instances;
import mltk.predictor.Family;
import mltk.predictor.Learner;
import mltk.util.MathUtils;
/**
* Abstract class for learning generalized linear models (GLMs).
*
* @author Yin Lou
*
*/
public abstract class GLMLearner extends Learner {
protected boolean fitIntercept;
protected int maxNumIters;
protected double epsilon;
protected Family family;
/**
* Constructor.
*/
public GLMLearner() {
verbose = false;
fitIntercept = true;
maxNumIters = -1;
epsilon = MathUtils.EPSILON;
family = Family.GAUSSIAN;
}
/**
* Returns {@code true} if we fit intercept.
*
* @return {@code true} if we fit intercept.
*/
public boolean fitIntercept() {
return fitIntercept;
}
/**
* Sets whether we fit intercept.
*
* @param fitIntercept whether we fit intercept.
*/
public void fitIntercept(boolean fitIntercept) {
this.fitIntercept = fitIntercept;
}
/**
* Returns the convergence threshold epsilon.
*
* @return the convergence threshold epsilon.
*/
public double getEpsilon() {
return epsilon;
}
/**
* Sets the convergence threshold epsilon.
*
* @param epsilon the convergence threshold epsilon.
*/
public void setEpsilon(double epsilon) {
this.epsilon = epsilon;
}
/**
* Returns the maximum number of iterations.
*
* @return the maximum number of iterations.
*/
public int getMaxNumIters() {
return maxNumIters;
}
/**
* Sets the maximum number of iterations.
*
* @param maxNumIters the maximum number of iterations.
*/
public void setMaxNumIters(int maxNumIters) {
this.maxNumIters = maxNumIters;
}
/**
* Returns the response distribution family.
*
* @return the response distribution family.
*/
public Family getFamily() {
return family;
}
/**
* Sets the response distribution family.
*
* @param family the response distribution family.
*/
public void setFamily(Family family) {
this.family = family;
}
/**
* Builds a generalized linear model given response distribution family.
* The default link function for the family will be used.
*
* @param trainSet the training set.
* @param family the response distribution family.
* @return a generalized linear model.
*/
public abstract GLM build(Instances trainSet, Family family);
}
================================================
FILE: src/main/java/mltk/predictor/glm/GLMOptimUtils.java
================================================
package mltk.predictor.glm;
import mltk.predictor.LinkFunction;
import mltk.util.OptimUtils;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
class GLMOptimUtils {
static GLM getGLM(int[] attrs, double[] w, double intercept, LinkFunction link) {
final int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
GLM glm = new GLM(p);
for (int i = 0; i < attrs.length; i++) {
glm.w[0][attrs[i]] = w[i];
}
glm.intercept[0] = intercept;
glm.link = link;
return glm;
}
static double computeRidgeLoss(double[] residual, double[] w, double lambda) {
double loss = OptimUtils.computeQuadraticLoss(residual);
loss += lambda / 2 * StatUtils.sumSq(w);
return loss;
}
static double computeRidgeLoss(double[] pred, double[] y, double[] w, double lambda) {
double loss = OptimUtils.computeLogisticLoss(pred, y);
loss += lambda / 2 * StatUtils.sumSq(w);
return loss;
}
static double computeLassoLoss(double[] residual, double[] w, double lambda) {
double loss = OptimUtils.computeQuadraticLoss(residual);
loss += lambda * VectorUtils.l1norm(w);
return loss;
}
static double computeLassoLoss(double[] pred, double[] y, double[] w, double lambda) {
double loss = OptimUtils.computeLogisticLoss(pred, y);
loss += lambda * VectorUtils.l1norm(w);
return loss;
}
static double computeElasticNetLoss(double[] residual, double[] w, double lambda1, double lambda2) {
double loss = OptimUtils.computeQuadraticLoss(residual);
loss += lambda1 * VectorUtils.l1norm(w) + lambda2 / 2 * StatUtils.sumSq(w);
return loss;
}
static double computeElasticNetLoss(double[] pred, double[] y, double[] w, double lambda1, double lambda2) {
double loss = OptimUtils.computeLogisticLoss(pred, y);
loss += lambda1 * VectorUtils.l1norm(w) + lambda2 / 2 * StatUtils.sumSq(w);
return loss;
}
static double computeGroupLassoLoss(double[] residual, double[][] w, double[] tl1) {
double loss = OptimUtils.computeQuadraticLoss(residual);
for (int k = 0; k < w.length; k++) {
loss += tl1[k] * StatUtils.sumSq(w[k]);
}
return loss;
}
static double computeGroupLassoLoss(double[] pred, double[] y, double[][] w, double[] tl1) {
double loss = OptimUtils.computeLogisticLoss(pred, y);
for (int k = 0; k < w.length; k++) {
loss += tl1[k] * StatUtils.sumSq(w[k]);
}
return loss;
}
}
================================================
FILE: src/main/java/mltk/predictor/glm/GroupLassoLearner.java
================================================
package mltk.predictor.glm;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import mltk.core.Attribute;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.predictor.Family;
import mltk.predictor.LinkFunction;
import mltk.predictor.glm.GLM;
import mltk.util.ArrayUtils;
import mltk.util.OptimUtils;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
/**
* Class for learning group-lasso models via block coordinate gradient descent.
*
*
* Reference:
* M Yuan and Y Lin. Model selection and estimation in regression with grouped variables. In
* Journal of the Royal Statistical Society: Series B (Statistical Methodology),
* 68(1):49-67, 2006.
*
*
* @author Yin Lou
*
*/
public class GroupLassoLearner extends GLMLearner {
class DenseDesignMatrix {
int[][] groups;
double[][][] x;
DenseDesignMatrix(int[][] groups, double[][][] x) {
this.groups = groups;
this.x = x;
}
}
static class ModelStructure {
boolean[] structure;
ModelStructure(boolean[] structure) {
this.structure = structure;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
ModelStructure other = (ModelStructure) obj;
if (!Arrays.equals(structure, other.structure))
return false;
return true;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(structure);
return result;
}
}
class SparseDesignMatrix {
int[][] group;
int[][][] indices;
double[][][] values;
SparseDesignMatrix(int[][] groups, int[][][] indices, double[][][] values) {
this.indices = indices;
this.values = values;
}
}
protected boolean refit;
protected int numLambdas;
protected double lambda;
protected Task task;
protected List groups;
/**
* Constructor.
*/
public GroupLassoLearner() {
refit = false;
lambda = 0.0;
task = Task.REGRESSION;
groups = null;
}
@Override
public GLM build(Instances instances) {
if (groups == null) {
throw new IllegalArgumentException("Groups are not set.");
}
GLM glm = null;
if (maxNumIters < 0) {
maxNumIters = 20;
}
switch (task) {
case REGRESSION:
glm = buildGaussianRegressor(instances, groups, maxNumIters, lambda);
break;
case CLASSIFICATION:
glm = buildClassifier(instances, groups, maxNumIters, lambda);
break;
default:
break;
}
return glm;
}
@Override
public GLM build(Instances trainSet, Family family) {
if (groups == null) {
throw new IllegalArgumentException("Groups are not set.");
}
GLM glm = null;
if (maxNumIters < 0) {
maxNumIters = 20;
}
switch (family) {
case GAUSSIAN:
glm = buildGaussianRegressor(trainSet, groups, maxNumIters, lambda);
break;
case BINOMIAL:
glm = buildClassifier(trainSet, groups, maxNumIters, lambda);
break;
default:
throw new IllegalArgumentException("Unsupported family: " + family);
}
return glm;
}
/**
* Builds a group-lasso penalized binary classifier. The input matrix is grouped by groups. This procedure does not
* assume the data is normalized or centered.
*
* @param attrs the groups of variables.
* @param x the inputs.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @return a group-lasso penalized classifier.
*/
public GLM buildBinaryClassifier(int[][] attrs, double[][][] x, double[] y, int maxNumIters, double lambda) {
int p = 0;
if (attrs.length > 0) {
for (int[] attr : attrs) {
p = Math.max(p, StatUtils.max(attr));
}
p += 1;
}
double[][] w = new double[attrs.length][];
double[] tl1 = new double[attrs.length];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[attrs[j].length];
tl1[j] = lambda * Math.sqrt(w[j].length);
if (w[j].length > m) {
m = w[j].length;
}
}
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
double[] stepSize = new double[attrs.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = x[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t) / 4;
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
boolean[] activeSet = new boolean[attrs.length];
double intercept = 0;
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
boolean activeSetChanged = doOnePassBinomial(x, y, tl1, true, activeSet, w, stepSize,
g, gradient, pTrain, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(x, y, tl1, false, activeSet, w, stepSize, g, gradient, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
if (refit) {
boolean[] selected = new boolean[attrs.length];
for (int i = 0; i < selected.length; i++) {
selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0);
}
return refitGaussianRegressor(p, attrs, selected, x, y, w, maxNumIters);
} else {
return getGLM(p, attrs, w, intercept, LinkFunction.LOGIT);
}
}
/**
* Builds a group-lasso penalized binary classifier on sparse inputs. The input matrix is grouped by groups. This procedure does not
* assume the data is normalized or centered.
*
* @param attrs the groups of variables.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @return a group-lasso penalized classifier.
*/
public GLM buildBinaryClassifier(int[][] attrs, int[][][] indices, double[][][] values, double[] y, int maxNumIters, double lambda) {
int p = 0;
if (attrs.length > 0) {
for (int[] attr : attrs) {
p = Math.max(p, StatUtils.max(attr));
}
p += 1;
}
int[][] indexUnion = new int[attrs.length][];
for (int g = 0; g < attrs.length; g++) {
int[][] index = indices[g];
Set set = new HashSet<>();
for (int[] idx : index) {
for (int i : idx) {
set.add(i);
}
}
int[] idxUnion = new int[set.size()];
int k = 0;
for (int idx : set) {
idxUnion[k++] = idx;
}
indexUnion[g] = idxUnion;
}
double[][] w = new double[attrs.length][];
double[] tl1 = new double[attrs.length];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[attrs[j].length];
tl1[j] = lambda * Math.sqrt(w[j].length);
if (w[j].length > m) {
m = w[j].length;
}
}
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
double[] stepSize = new double[attrs.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = values[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t) / 4;
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
boolean[] activeSet = new boolean[values.length];
double intercept = 0;
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
boolean activeSetChanged = doOnePassBinomial(indices, indexUnion, values, y, tl1, true,
activeSet, w, stepSize, g, gradient, pTrain, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(indices, indexUnion, values, y, tl1, true, activeSet, w, stepSize,
g, gradient, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
if (refit) {
boolean[] selected = new boolean[attrs.length];
for (int i = 0; i < selected.length; i++) {
selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0);
}
return refitClassifier(p, attrs, selected, indices, values, y, w, maxNumIters);
} else {
return getGLM(p, attrs, w, intercept, LinkFunction.LOGIT);
}
}
/**
* Builds group-lasso penalized binary classifiers for a sequence of regularization parameter lambdas. The input matrix is grouped by groups. This procedure does not
* assume the data is normalized or centered.
*
* @param attrs the groups of variables.
* @param x the inputs.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @return group-lasso penalized classifiers.
*/
public List buildBinaryClassifiers(int[][] attrs, double[][][] x, double[] y, int maxNumIters, int numLambdas, double minLambdaRatio) {
int p = 0;
if (attrs.length > 0) {
for (int[] attr : attrs) {
p = Math.max(p, StatUtils.max(attr));
}
p += 1;
}
double[][] w = new double[attrs.length][];
int m = 0;
for (int j = 0; j < attrs.length; j++) {
w[j] = new double[attrs[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
double[] g = new double[m];
double[] gradient = new double[m];
double[] tl1 = new double[x.length];
double[] stepSize = new double[x.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = x[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t) / 4;
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
boolean[] activeSet = new boolean[x.length];
double maxLambda = findMaxLambdaBinomial(x, y, pTrain, rTrain, gradient);
// Dampening factor for lambda
double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas);
// Compute the regularization path
List glms = new ArrayList<>(numLambdas);
double lambda = maxLambda;
double intercept = 0;
Set structures = new HashSet<>();
for (int l = 0; l < numLambdas; l++) {
// Initialize regularization parameters
for (int j = 0; j < tl1.length; j++) {
tl1[j] = lambda * Math.sqrt(w[j].length);
}
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
boolean activeSetChanged = doOnePassBinomial(x, y, tl1, true, activeSet, w,
stepSize, g, gradient, pTrain, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(x, y, tl1, false, activeSet, w, stepSize, g, gradient, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
lambda *= alpha;
if (refit) {
boolean[] selected = new boolean[attrs.length];
for (int i = 0; i < selected.length; i++) {
selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0);
}
ModelStructure structure = new ModelStructure(selected);
if (!structures.contains(structure)) {
GLM glm = refitClassifier(p, attrs, selected, x, y, w, maxNumIters);
glms.add(glm);
structures.add(structure);
}
} else {
GLM glm = getGLM(p, attrs, w, intercept, LinkFunction.LOGIT);
glms.add(glm);
}
}
return glms;
}
/**
* Builds group-lasso penalized binary classifiers on sparse inputs for a sequence of regularization parameter lambdas. The input matrix is grouped by groups. This procedure does not
* assume the data is normalized or centered.
*
* @param attrs the groups of variables.
* @param indices the indices.
* @param values the values.
* @param y the targets.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @return group-lasso penalized classifiers.
*/
public List buildBinaryClassifiers(int[][] attrs, int[][][] indices, double[][][] values, double[] y,
int maxNumIters, int numLambdas, double minLambdaRatio) {
int p = 0;
if (attrs.length > 0) {
for (int[] attr : attrs) {
p = Math.max(p, StatUtils.max(attr));
}
p += 1;
}
int[][] indexUnion = new int[indices.length][];
for (int g = 0; g < indices.length; g++) {
int[][] index = indices[g];
Set set = new HashSet<>();
for (int[] idx : index) {
for (int i : idx) {
set.add(i);
}
}
int[] idxUnion = new int[set.size()];
int k = 0;
for (int idx : set) {
idxUnion[k++] = idx;
}
indexUnion[g] = idxUnion;
}
double[][] w = new double[attrs.length][];
double[] tl1 = new double[attrs.length];
int m = 0;
for (int j = 0; j < values.length; j++) {
w[j] = new double[values[j].length];
if (w[j].length > m) {
m = w[j].length;
}
}
double[] pTrain = new double[y.length];
double[] rTrain = new double[y.length];
OptimUtils.computePseudoResidual(pTrain, y, rTrain);
double[] stepSize = new double[values.length];
for (int j = 0; j < stepSize.length; j++) {
double max = 0;
double[][] block = values[j];
for (double[] t : block) {
double l = StatUtils.sumSq(t) / 4;
if (l > max) {
max = l;
}
}
stepSize[j] = 1.0 / max;
}
double[] g = new double[m];
double[] gradient = new double[m];
boolean[] activeSet = new boolean[values.length];
double maxLambda = findMaxLambdaBinomial(indices, values, y, pTrain, rTrain, gradient);
// Dampening factor for lambda
double alpha = Math.pow(minLambdaRatio, 1.0 / numLambdas);
List glms = new ArrayList<>(numLambdas);
double lambda = maxLambda;
double intercept = 0;
Set structures = new HashSet<>();
for (int l = 0; l < numLambdas; l++) {
// Initialize regularization parameters
for (int j = 0; j < tl1.length; j++) {
tl1[j] = lambda * Math.sqrt(w[j].length);
}
// Block coordinate gradient descent
int iter = 0;
while (iter < maxNumIters) {
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
boolean activeSetChanged = doOnePassBinomial(indices, indexUnion, values, y, tl1, true,
activeSet, w, stepSize, g, gradient, pTrain, rTrain);
iter++;
if (!activeSetChanged || iter > maxNumIters) {
break;
}
for (; iter < maxNumIters; iter++) {
double prevLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (fitIntercept) {
intercept += OptimUtils.fitIntercept(pTrain, rTrain, y);
}
doOnePassBinomial(indices, indexUnion, values, y, tl1, false, activeSet, w, stepSize,
g, gradient, pTrain, rTrain);
double currLoss = GLMOptimUtils.computeGroupLassoLoss(pTrain, y, w, tl1);
if (OptimUtils.isConverged(prevLoss, currLoss, epsilon)) {
break;
}
if (verbose) {
System.out.println("Iteration " + iter + ": " + currLoss);
}
}
}
lambda *= alpha;
if (refit) {
boolean[] selected = new boolean[attrs.length];
for (int i = 0; i < selected.length; i++) {
selected[i] = !ArrayUtils.isConstant(w[i], 0, w[i].length, 0);
}
ModelStructure structure = new ModelStructure(selected);
if (!structures.contains(structure)) {
GLM glm = refitClassifier(p, attrs, selected, indices, values, y, w, maxNumIters);
glms.add(glm);
}
} else {
GLM glm = getGLM(p, attrs, w, intercept, LinkFunction.LOGIT);
glms.add(glm);
}
}
return glms;
}
/**
* Builds a group-lasso penalized classifier.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param groups the groups of variables.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @return a group-lasso penalized classifier.
*/
public GLM buildClassifier(Instances trainSet, boolean isSparse, List groups, int maxNumIters, double lambda) {
Attribute classAttribute = trainSet.getTargetAttribute();
if (classAttribute.getType() != Attribute.Type.NOMINAL) {
throw new IllegalArgumentException("Class attribute must be nominal.");
}
NominalAttribute clazz = (NominalAttribute) classAttribute;
int numClasses = clazz.getCardinality();
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, true);
SparseDesignMatrix sm = createDesignMatrix(sd, groups);
int[] attrs = sd.attrs;
int[][] group = sm.group;
int[][][] indices = sm.indices;
double[][][] values = sm.values;
double[] y = new double[sd.y.length];
double[] cList = sd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == 0 ? 1 : 0;
}
GLM glm = buildBinaryClassifier(group, indices, values, y, maxNumIters, lambda);
double[] w = glm.coefficients(0);
for (int j = 0; j < cList.length; j++) {
int attIndex = sd.attrs[j];
w[attIndex] *= cList[j];
}
return glm;
} else {
int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
GLM glm = new GLM(numClasses, p);
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == k ? 1 : 0;
}
GLM binaryClassifier = buildBinaryClassifier(group, indices, values, y, maxNumIters, lambda);
double[] w = binaryClassifier.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
glm.w[k][attIndex] = w[attIndex] * cList[j];
}
glm.intercept[k] = binaryClassifier.intercept[0];
}
return glm;
}
} else {
DenseDataset dd = getDenseDataset(trainSet, true);
DenseDesignMatrix dm = createDesignMatrix(dd, groups);
int[] attrs = dd.attrs;
int[][] group = dm.groups;
double[][][] x = dm.x;
double[] y = new double[dd.y.length];
double[] cList = dd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) dd.y[i];
y[i] = label == 0 ? 1 : 0;
}
GLM glm = buildBinaryClassifier(group, x, y, maxNumIters, lambda);
double[] w = glm.coefficients(0);
for (int j = 0; j < cList.length; j++) {
int attIndex = dd.attrs[j];
w[attIndex] *= cList[j];
}
return glm;
} else {
int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
GLM glm = new GLM(numClasses, p);
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) dd.y[i];
y[i] = label == k ? 1 : 0;
}
GLM binaryClassifier = buildBinaryClassifier(group, x, y, maxNumIters, lambda);
double[] w = binaryClassifier.w[0];
for (int j = 0; j < cList.length; j++) {
int attIndex = attrs[j];
glm.w[k][attIndex] = w[attIndex] * cList[j];
}
glm.intercept[k] = binaryClassifier.intercept[0];
}
return glm;
}
}
}
/**
* Builds a group-lasso penalized classifier.
*
* @param trainSet the training set.
* @param groups the groups of variables.
* @param maxNumIters the maximum number of iterations.
* @param lambda the lambda.
* @return a group-lasso penalized classifier.
*/
public GLM buildClassifier(Instances trainSet, List groups, int maxNumIters, double lambda) {
return buildClassifier(trainSet, isSparse(trainSet), groups, maxNumIters, lambda);
}
/**
* Builds group-lasso penalized classifiers.
*
* @param trainSet the training set.
* @param isSparse {@code true} if the training set is treated as sparse.
* @param groups the groups of variables.
* @param maxNumIters the maximum number of iterations.
* @param numLambdas the number of lambdas.
* @param minLambdaRatio the minimum lambda is minLambdaRatio * max lambda.
* @return group-lasso penalized classifiers.
*/
public List buildClassifiers(Instances trainSet, boolean isSparse, List groups, int maxNumIters,
int numLambdas, double minLambdaRatio) {
Attribute classAttribute = trainSet.getTargetAttribute();
if (classAttribute.getType() != Attribute.Type.NOMINAL) {
throw new IllegalArgumentException("Class attribute must be nominal.");
}
NominalAttribute clazz = (NominalAttribute) classAttribute;
int numClasses = clazz.getCardinality();
if (isSparse) {
SparseDataset sd = getSparseDataset(trainSet, true);
SparseDesignMatrix sm = createDesignMatrix(sd, groups);
int[] attrs = sd.attrs;
int[][] group = sm.group;
int[][][] indices = sm.indices;
double[][][] values = sm.values;
double[] y = new double[sd.y.length];
double[] cList = sd.cList;
if (numClasses == 2) {
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == 0 ? 1 : 0;
}
List glms = buildBinaryClassifiers(sm.group, sm.indices, sm.values, y, maxNumIters, numLambdas,
minLambdaRatio);
for (GLM glm : glms) {
double[] w = glm.coefficients(0);
for (int j = 0; j < cList.length; j++) {
int attIndex = sd.attrs[j];
w[attIndex] *= cList[j];
}
}
return glms;
} else {
boolean refit = this.refit;
this.refit = false; // Not supported in multiclass
// classification
int p = attrs.length == 0 ? 0 : (StatUtils.max(attrs) + 1);
List glms = new ArrayList<>();
for (int i = 0; i < numLambdas; i++) {
GLM glm = new GLM(numClasses, p);
glms.add(glm);
}
for (int k = 0; k < numClasses; k++) {
// One-vs-the-rest
for (int i = 0; i < y.length; i++) {
int label = (int) sd.y[i];
y[i] = label == k ? 1 : 0;
}
List